test_random.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import numpy as np
  2. import pytest
  3. import scipy.sparse as sp
  4. from numpy.testing import assert_array_almost_equal
  5. from scipy.special import comb
  6. from sklearn.utils._random import _our_rand_r_py
  7. from sklearn.utils.random import _random_choice_csc, sample_without_replacement
  8. ###############################################################################
  9. # test custom sampling without replacement algorithm
  10. ###############################################################################
  11. def test_invalid_sample_without_replacement_algorithm():
  12. with pytest.raises(ValueError):
  13. sample_without_replacement(5, 4, "unknown")
  14. def test_sample_without_replacement_algorithms():
  15. methods = ("auto", "tracking_selection", "reservoir_sampling", "pool")
  16. for m in methods:
  17. def sample_without_replacement_method(
  18. n_population, n_samples, random_state=None
  19. ):
  20. return sample_without_replacement(
  21. n_population, n_samples, method=m, random_state=random_state
  22. )
  23. check_edge_case_of_sample_int(sample_without_replacement_method)
  24. check_sample_int(sample_without_replacement_method)
  25. check_sample_int_distribution(sample_without_replacement_method)
  26. def check_edge_case_of_sample_int(sample_without_replacement):
  27. # n_population < n_sample
  28. with pytest.raises(ValueError):
  29. sample_without_replacement(0, 1)
  30. with pytest.raises(ValueError):
  31. sample_without_replacement(1, 2)
  32. # n_population == n_samples
  33. assert sample_without_replacement(0, 0).shape == (0,)
  34. assert sample_without_replacement(1, 1).shape == (1,)
  35. # n_population >= n_samples
  36. assert sample_without_replacement(5, 0).shape == (0,)
  37. assert sample_without_replacement(5, 1).shape == (1,)
  38. # n_population < 0 or n_samples < 0
  39. with pytest.raises(ValueError):
  40. sample_without_replacement(-1, 5)
  41. with pytest.raises(ValueError):
  42. sample_without_replacement(5, -1)
  43. def check_sample_int(sample_without_replacement):
  44. # This test is heavily inspired from test_random.py of python-core.
  45. #
  46. # For the entire allowable range of 0 <= k <= N, validate that
  47. # the sample is of the correct length and contains only unique items
  48. n_population = 100
  49. for n_samples in range(n_population + 1):
  50. s = sample_without_replacement(n_population, n_samples)
  51. assert len(s) == n_samples
  52. unique = np.unique(s)
  53. assert np.size(unique) == n_samples
  54. assert np.all(unique < n_population)
  55. # test edge case n_population == n_samples == 0
  56. assert np.size(sample_without_replacement(0, 0)) == 0
  57. def check_sample_int_distribution(sample_without_replacement):
  58. # This test is heavily inspired from test_random.py of python-core.
  59. #
  60. # For the entire allowable range of 0 <= k <= N, validate that
  61. # sample generates all possible permutations
  62. n_population = 10
  63. # a large number of trials prevents false negatives without slowing normal
  64. # case
  65. n_trials = 10000
  66. for n_samples in range(n_population):
  67. # Counting the number of combinations is not as good as counting the
  68. # the number of permutations. However, it works with sampling algorithm
  69. # that does not provide a random permutation of the subset of integer.
  70. n_expected = comb(n_population, n_samples, exact=True)
  71. output = {}
  72. for i in range(n_trials):
  73. output[frozenset(sample_without_replacement(n_population, n_samples))] = (
  74. None
  75. )
  76. if len(output) == n_expected:
  77. break
  78. else:
  79. raise AssertionError(
  80. "number of combinations != number of expected (%s != %s)"
  81. % (len(output), n_expected)
  82. )
  83. def test_random_choice_csc(n_samples=10000, random_state=24):
  84. # Explicit class probabilities
  85. classes = [np.array([0, 1]), np.array([0, 1, 2])]
  86. class_probabilities = [np.array([0.5, 0.5]), np.array([0.6, 0.1, 0.3])]
  87. got = _random_choice_csc(n_samples, classes, class_probabilities, random_state)
  88. assert sp.issparse(got)
  89. for k in range(len(classes)):
  90. p = np.bincount(got.getcol(k).toarray().ravel()) / float(n_samples)
  91. assert_array_almost_equal(class_probabilities[k], p, decimal=1)
  92. # Implicit class probabilities
  93. classes = [[0, 1], [1, 2]] # test for array-like support
  94. class_probabilities = [np.array([0.5, 0.5]), np.array([0, 1 / 2, 1 / 2])]
  95. got = _random_choice_csc(
  96. n_samples=n_samples, classes=classes, random_state=random_state
  97. )
  98. assert sp.issparse(got)
  99. for k in range(len(classes)):
  100. p = np.bincount(got.getcol(k).toarray().ravel()) / float(n_samples)
  101. assert_array_almost_equal(class_probabilities[k], p, decimal=1)
  102. # Edge case probabilities 1.0 and 0.0
  103. classes = [np.array([0, 1]), np.array([0, 1, 2])]
  104. class_probabilities = [np.array([0.0, 1.0]), np.array([0.0, 1.0, 0.0])]
  105. got = _random_choice_csc(n_samples, classes, class_probabilities, random_state)
  106. assert sp.issparse(got)
  107. for k in range(len(classes)):
  108. p = (
  109. np.bincount(
  110. got.getcol(k).toarray().ravel(), minlength=len(class_probabilities[k])
  111. )
  112. / n_samples
  113. )
  114. assert_array_almost_equal(class_probabilities[k], p, decimal=1)
  115. # One class target data
  116. classes = [[1], [0]] # test for array-like support
  117. class_probabilities = [np.array([0.0, 1.0]), np.array([1.0])]
  118. got = _random_choice_csc(
  119. n_samples=n_samples, classes=classes, random_state=random_state
  120. )
  121. assert sp.issparse(got)
  122. for k in range(len(classes)):
  123. p = np.bincount(got.getcol(k).toarray().ravel()) / n_samples
  124. assert_array_almost_equal(class_probabilities[k], p, decimal=1)
  125. def test_random_choice_csc_errors():
  126. # the length of an array in classes and class_probabilities is mismatched
  127. classes = [np.array([0, 1]), np.array([0, 1, 2, 3])]
  128. class_probabilities = [np.array([0.5, 0.5]), np.array([0.6, 0.1, 0.3])]
  129. with pytest.raises(ValueError):
  130. _random_choice_csc(4, classes, class_probabilities, 1)
  131. # the class dtype is not supported
  132. classes = [np.array(["a", "1"]), np.array(["z", "1", "2"])]
  133. class_probabilities = [np.array([0.5, 0.5]), np.array([0.6, 0.1, 0.3])]
  134. with pytest.raises(ValueError):
  135. _random_choice_csc(4, classes, class_probabilities, 1)
  136. # the class dtype is not supported
  137. classes = [np.array([4.2, 0.1]), np.array([0.1, 0.2, 9.4])]
  138. class_probabilities = [np.array([0.5, 0.5]), np.array([0.6, 0.1, 0.3])]
  139. with pytest.raises(ValueError):
  140. _random_choice_csc(4, classes, class_probabilities, 1)
  141. # Given probabilities don't sum to 1
  142. classes = [np.array([0, 1]), np.array([0, 1, 2])]
  143. class_probabilities = [np.array([0.5, 0.6]), np.array([0.6, 0.1, 0.3])]
  144. with pytest.raises(ValueError):
  145. _random_choice_csc(4, classes, class_probabilities, 1)
  146. def test_our_rand_r():
  147. assert 131541053 == _our_rand_r_py(1273642419)
  148. assert 270369 == _our_rand_r_py(0)