test_stats.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import numpy as np
  2. from numpy.testing import assert_allclose
  3. from pytest import approx
  4. from sklearn.utils.stats import _weighted_percentile
  5. def test_weighted_percentile():
  6. y = np.empty(102, dtype=np.float64)
  7. y[:50] = 0
  8. y[-51:] = 2
  9. y[-1] = 100000
  10. y[50] = 1
  11. sw = np.ones(102, dtype=np.float64)
  12. sw[-1] = 0.0
  13. score = _weighted_percentile(y, sw, 50)
  14. assert approx(score) == 1
  15. def test_weighted_percentile_equal():
  16. y = np.empty(102, dtype=np.float64)
  17. y.fill(0.0)
  18. sw = np.ones(102, dtype=np.float64)
  19. sw[-1] = 0.0
  20. score = _weighted_percentile(y, sw, 50)
  21. assert score == 0
  22. def test_weighted_percentile_zero_weight():
  23. y = np.empty(102, dtype=np.float64)
  24. y.fill(1.0)
  25. sw = np.ones(102, dtype=np.float64)
  26. sw.fill(0.0)
  27. score = _weighted_percentile(y, sw, 50)
  28. assert approx(score) == 1.0
  29. def test_weighted_percentile_zero_weight_zero_percentile():
  30. y = np.array([0, 1, 2, 3, 4, 5])
  31. sw = np.array([0, 0, 1, 1, 1, 0])
  32. score = _weighted_percentile(y, sw, 0)
  33. assert approx(score) == 2
  34. score = _weighted_percentile(y, sw, 50)
  35. assert approx(score) == 3
  36. score = _weighted_percentile(y, sw, 100)
  37. assert approx(score) == 4
  38. def test_weighted_median_equal_weights():
  39. # Checks weighted percentile=0.5 is same as median when weights equal
  40. rng = np.random.RandomState(0)
  41. # Odd size as _weighted_percentile takes lower weighted percentile
  42. x = rng.randint(10, size=11)
  43. weights = np.ones(x.shape)
  44. median = np.median(x)
  45. w_median = _weighted_percentile(x, weights)
  46. assert median == approx(w_median)
  47. def test_weighted_median_integer_weights():
  48. # Checks weighted percentile=0.5 is same as median when manually weight
  49. # data
  50. rng = np.random.RandomState(0)
  51. x = rng.randint(20, size=10)
  52. weights = rng.choice(5, size=10)
  53. x_manual = np.repeat(x, weights)
  54. median = np.median(x_manual)
  55. w_median = _weighted_percentile(x, weights)
  56. assert median == approx(w_median)
  57. def test_weighted_percentile_2d():
  58. # Check for when array 2D and sample_weight 1D
  59. rng = np.random.RandomState(0)
  60. x1 = rng.randint(10, size=10)
  61. w1 = rng.choice(5, size=10)
  62. x2 = rng.randint(20, size=10)
  63. x_2d = np.vstack((x1, x2)).T
  64. w_median = _weighted_percentile(x_2d, w1)
  65. p_axis_0 = [_weighted_percentile(x_2d[:, i], w1) for i in range(x_2d.shape[1])]
  66. assert_allclose(w_median, p_axis_0)
  67. # Check when array and sample_weight boht 2D
  68. w2 = rng.choice(5, size=10)
  69. w_2d = np.vstack((w1, w2)).T
  70. w_median = _weighted_percentile(x_2d, w_2d)
  71. p_axis_0 = [
  72. _weighted_percentile(x_2d[:, i], w_2d[:, i]) for i in range(x_2d.shape[1])
  73. ]
  74. assert_allclose(w_median, p_axis_0)