test_utils.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767
  1. import string
  2. import timeit
  3. import warnings
  4. from copy import copy
  5. from itertools import chain
  6. import numpy as np
  7. import pytest
  8. import scipy.sparse as sp
  9. from sklearn import config_context
  10. from sklearn.utils import (
  11. _approximate_mode,
  12. _determine_key_type,
  13. _get_column_indices,
  14. _message_with_time,
  15. _print_elapsed_time,
  16. _safe_assign,
  17. _safe_indexing,
  18. _to_object_array,
  19. check_random_state,
  20. column_or_1d,
  21. deprecated,
  22. gen_even_slices,
  23. get_chunk_n_rows,
  24. is_scalar_nan,
  25. resample,
  26. safe_mask,
  27. shuffle,
  28. )
  29. from sklearn.utils._mocking import MockDataFrame
  30. from sklearn.utils._testing import (
  31. _convert_container,
  32. assert_allclose_dense_sparse,
  33. assert_array_equal,
  34. assert_no_warnings,
  35. )
  36. # toy array
  37. X_toy = np.arange(9).reshape((3, 3))
  38. def test_make_rng():
  39. # Check the check_random_state utility function behavior
  40. assert check_random_state(None) is np.random.mtrand._rand
  41. assert check_random_state(np.random) is np.random.mtrand._rand
  42. rng_42 = np.random.RandomState(42)
  43. assert check_random_state(42).randint(100) == rng_42.randint(100)
  44. rng_42 = np.random.RandomState(42)
  45. assert check_random_state(rng_42) is rng_42
  46. rng_42 = np.random.RandomState(42)
  47. assert check_random_state(43).randint(100) != rng_42.randint(100)
  48. with pytest.raises(ValueError):
  49. check_random_state("some invalid seed")
  50. def test_deprecated():
  51. # Test whether the deprecated decorator issues appropriate warnings
  52. # Copied almost verbatim from https://docs.python.org/library/warnings.html
  53. # First a function...
  54. with warnings.catch_warnings(record=True) as w:
  55. warnings.simplefilter("always")
  56. @deprecated()
  57. def ham():
  58. return "spam"
  59. spam = ham()
  60. assert spam == "spam" # function must remain usable
  61. assert len(w) == 1
  62. assert issubclass(w[0].category, FutureWarning)
  63. assert "deprecated" in str(w[0].message).lower()
  64. # ... then a class.
  65. with warnings.catch_warnings(record=True) as w:
  66. warnings.simplefilter("always")
  67. @deprecated("don't use this")
  68. class Ham:
  69. SPAM = 1
  70. ham = Ham()
  71. assert hasattr(ham, "SPAM")
  72. assert len(w) == 1
  73. assert issubclass(w[0].category, FutureWarning)
  74. assert "deprecated" in str(w[0].message).lower()
  75. def test_resample():
  76. # Border case not worth mentioning in doctests
  77. assert resample() is None
  78. # Check that invalid arguments yield ValueError
  79. with pytest.raises(ValueError):
  80. resample([0], [0, 1])
  81. with pytest.raises(ValueError):
  82. resample([0, 1], [0, 1], replace=False, n_samples=3)
  83. # Issue:6581, n_samples can be more when replace is True (default).
  84. assert len(resample([1, 2], n_samples=5)) == 5
  85. def test_resample_stratified():
  86. # Make sure resample can stratify
  87. rng = np.random.RandomState(0)
  88. n_samples = 100
  89. p = 0.9
  90. X = rng.normal(size=(n_samples, 1))
  91. y = rng.binomial(1, p, size=n_samples)
  92. _, y_not_stratified = resample(X, y, n_samples=10, random_state=0, stratify=None)
  93. assert np.all(y_not_stratified == 1)
  94. _, y_stratified = resample(X, y, n_samples=10, random_state=0, stratify=y)
  95. assert not np.all(y_stratified == 1)
  96. assert np.sum(y_stratified) == 9 # all 1s, one 0
  97. def test_resample_stratified_replace():
  98. # Make sure stratified resampling supports the replace parameter
  99. rng = np.random.RandomState(0)
  100. n_samples = 100
  101. X = rng.normal(size=(n_samples, 1))
  102. y = rng.randint(0, 2, size=n_samples)
  103. X_replace, _ = resample(
  104. X, y, replace=True, n_samples=50, random_state=rng, stratify=y
  105. )
  106. X_no_replace, _ = resample(
  107. X, y, replace=False, n_samples=50, random_state=rng, stratify=y
  108. )
  109. assert np.unique(X_replace).shape[0] < 50
  110. assert np.unique(X_no_replace).shape[0] == 50
  111. # make sure n_samples can be greater than X.shape[0] if we sample with
  112. # replacement
  113. X_replace, _ = resample(
  114. X, y, replace=True, n_samples=1000, random_state=rng, stratify=y
  115. )
  116. assert X_replace.shape[0] == 1000
  117. assert np.unique(X_replace).shape[0] == 100
  118. def test_resample_stratify_2dy():
  119. # Make sure y can be 2d when stratifying
  120. rng = np.random.RandomState(0)
  121. n_samples = 100
  122. X = rng.normal(size=(n_samples, 1))
  123. y = rng.randint(0, 2, size=(n_samples, 2))
  124. X, y = resample(X, y, n_samples=50, random_state=rng, stratify=y)
  125. assert y.ndim == 2
  126. def test_resample_stratify_sparse_error():
  127. # resample must be ndarray
  128. rng = np.random.RandomState(0)
  129. n_samples = 100
  130. X = rng.normal(size=(n_samples, 2))
  131. y = rng.randint(0, 2, size=n_samples)
  132. stratify = sp.csr_matrix(y)
  133. with pytest.raises(TypeError, match="A sparse matrix was passed"):
  134. X, y = resample(X, y, n_samples=50, random_state=rng, stratify=stratify)
  135. def test_safe_mask():
  136. random_state = check_random_state(0)
  137. X = random_state.rand(5, 4)
  138. X_csr = sp.csr_matrix(X)
  139. mask = [False, False, True, True, True]
  140. mask = safe_mask(X, mask)
  141. assert X[mask].shape[0] == 3
  142. mask = safe_mask(X_csr, mask)
  143. assert X_csr[mask].shape[0] == 3
  144. def test_column_or_1d():
  145. EXAMPLES = [
  146. ("binary", ["spam", "egg", "spam"]),
  147. ("binary", [0, 1, 0, 1]),
  148. ("continuous", np.arange(10) / 20.0),
  149. ("multiclass", [1, 2, 3]),
  150. ("multiclass", [0, 1, 2, 2, 0]),
  151. ("multiclass", [[1], [2], [3]]),
  152. ("multilabel-indicator", [[0, 1, 0], [0, 0, 1]]),
  153. ("multiclass-multioutput", [[1, 2, 3]]),
  154. ("multiclass-multioutput", [[1, 1], [2, 2], [3, 1]]),
  155. ("multiclass-multioutput", [[5, 1], [4, 2], [3, 1]]),
  156. ("multiclass-multioutput", [[1, 2, 3]]),
  157. ("continuous-multioutput", np.arange(30).reshape((-1, 3))),
  158. ]
  159. for y_type, y in EXAMPLES:
  160. if y_type in ["binary", "multiclass", "continuous"]:
  161. assert_array_equal(column_or_1d(y), np.ravel(y))
  162. else:
  163. with pytest.raises(ValueError):
  164. column_or_1d(y)
  165. @pytest.mark.parametrize(
  166. "key, dtype",
  167. [
  168. (0, "int"),
  169. ("0", "str"),
  170. (True, "bool"),
  171. (np.bool_(True), "bool"),
  172. ([0, 1, 2], "int"),
  173. (["0", "1", "2"], "str"),
  174. ((0, 1, 2), "int"),
  175. (("0", "1", "2"), "str"),
  176. (slice(None, None), None),
  177. (slice(0, 2), "int"),
  178. (np.array([0, 1, 2], dtype=np.int32), "int"),
  179. (np.array([0, 1, 2], dtype=np.int64), "int"),
  180. (np.array([0, 1, 2], dtype=np.uint8), "int"),
  181. ([True, False], "bool"),
  182. ((True, False), "bool"),
  183. (np.array([True, False]), "bool"),
  184. ("col_0", "str"),
  185. (["col_0", "col_1", "col_2"], "str"),
  186. (("col_0", "col_1", "col_2"), "str"),
  187. (slice("begin", "end"), "str"),
  188. (np.array(["col_0", "col_1", "col_2"]), "str"),
  189. (np.array(["col_0", "col_1", "col_2"], dtype=object), "str"),
  190. ],
  191. )
  192. def test_determine_key_type(key, dtype):
  193. assert _determine_key_type(key) == dtype
  194. def test_determine_key_type_error():
  195. with pytest.raises(ValueError, match="No valid specification of the"):
  196. _determine_key_type(1.0)
  197. def test_determine_key_type_slice_error():
  198. with pytest.raises(TypeError, match="Only array-like or scalar are"):
  199. _determine_key_type(slice(0, 2, 1), accept_slice=False)
  200. @pytest.mark.parametrize("array_type", ["list", "array", "sparse", "dataframe"])
  201. @pytest.mark.parametrize("indices_type", ["list", "tuple", "array", "series", "slice"])
  202. def test_safe_indexing_2d_container_axis_0(array_type, indices_type):
  203. indices = [1, 2]
  204. if indices_type == "slice" and isinstance(indices[1], int):
  205. indices[1] += 1
  206. array = _convert_container([[1, 2, 3], [4, 5, 6], [7, 8, 9]], array_type)
  207. indices = _convert_container(indices, indices_type)
  208. subset = _safe_indexing(array, indices, axis=0)
  209. assert_allclose_dense_sparse(
  210. subset, _convert_container([[4, 5, 6], [7, 8, 9]], array_type)
  211. )
  212. @pytest.mark.parametrize("array_type", ["list", "array", "series"])
  213. @pytest.mark.parametrize("indices_type", ["list", "tuple", "array", "series", "slice"])
  214. def test_safe_indexing_1d_container(array_type, indices_type):
  215. indices = [1, 2]
  216. if indices_type == "slice" and isinstance(indices[1], int):
  217. indices[1] += 1
  218. array = _convert_container([1, 2, 3, 4, 5, 6, 7, 8, 9], array_type)
  219. indices = _convert_container(indices, indices_type)
  220. subset = _safe_indexing(array, indices, axis=0)
  221. assert_allclose_dense_sparse(subset, _convert_container([2, 3], array_type))
  222. @pytest.mark.parametrize("array_type", ["array", "sparse", "dataframe"])
  223. @pytest.mark.parametrize("indices_type", ["list", "tuple", "array", "series", "slice"])
  224. @pytest.mark.parametrize("indices", [[1, 2], ["col_1", "col_2"]])
  225. def test_safe_indexing_2d_container_axis_1(array_type, indices_type, indices):
  226. # validation of the indices
  227. # we make a copy because indices is mutable and shared between tests
  228. indices_converted = copy(indices)
  229. if indices_type == "slice" and isinstance(indices[1], int):
  230. indices_converted[1] += 1
  231. columns_name = ["col_0", "col_1", "col_2"]
  232. array = _convert_container(
  233. [[1, 2, 3], [4, 5, 6], [7, 8, 9]], array_type, columns_name
  234. )
  235. indices_converted = _convert_container(indices_converted, indices_type)
  236. if isinstance(indices[0], str) and array_type != "dataframe":
  237. err_msg = (
  238. "Specifying the columns using strings is only supported "
  239. "for pandas DataFrames"
  240. )
  241. with pytest.raises(ValueError, match=err_msg):
  242. _safe_indexing(array, indices_converted, axis=1)
  243. else:
  244. subset = _safe_indexing(array, indices_converted, axis=1)
  245. assert_allclose_dense_sparse(
  246. subset, _convert_container([[2, 3], [5, 6], [8, 9]], array_type)
  247. )
  248. @pytest.mark.parametrize("array_read_only", [True, False])
  249. @pytest.mark.parametrize("indices_read_only", [True, False])
  250. @pytest.mark.parametrize("array_type", ["array", "sparse", "dataframe"])
  251. @pytest.mark.parametrize("indices_type", ["array", "series"])
  252. @pytest.mark.parametrize(
  253. "axis, expected_array", [(0, [[4, 5, 6], [7, 8, 9]]), (1, [[2, 3], [5, 6], [8, 9]])]
  254. )
  255. def test_safe_indexing_2d_read_only_axis_1(
  256. array_read_only, indices_read_only, array_type, indices_type, axis, expected_array
  257. ):
  258. array = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  259. if array_read_only:
  260. array.setflags(write=False)
  261. array = _convert_container(array, array_type)
  262. indices = np.array([1, 2])
  263. if indices_read_only:
  264. indices.setflags(write=False)
  265. indices = _convert_container(indices, indices_type)
  266. subset = _safe_indexing(array, indices, axis=axis)
  267. assert_allclose_dense_sparse(subset, _convert_container(expected_array, array_type))
  268. @pytest.mark.parametrize("array_type", ["list", "array", "series"])
  269. @pytest.mark.parametrize("indices_type", ["list", "tuple", "array", "series"])
  270. def test_safe_indexing_1d_container_mask(array_type, indices_type):
  271. indices = [False] + [True] * 2 + [False] * 6
  272. array = _convert_container([1, 2, 3, 4, 5, 6, 7, 8, 9], array_type)
  273. indices = _convert_container(indices, indices_type)
  274. subset = _safe_indexing(array, indices, axis=0)
  275. assert_allclose_dense_sparse(subset, _convert_container([2, 3], array_type))
  276. @pytest.mark.parametrize("array_type", ["array", "sparse", "dataframe"])
  277. @pytest.mark.parametrize("indices_type", ["list", "tuple", "array", "series"])
  278. @pytest.mark.parametrize(
  279. "axis, expected_subset",
  280. [(0, [[4, 5, 6], [7, 8, 9]]), (1, [[2, 3], [5, 6], [8, 9]])],
  281. )
  282. def test_safe_indexing_2d_mask(array_type, indices_type, axis, expected_subset):
  283. columns_name = ["col_0", "col_1", "col_2"]
  284. array = _convert_container(
  285. [[1, 2, 3], [4, 5, 6], [7, 8, 9]], array_type, columns_name
  286. )
  287. indices = [False, True, True]
  288. indices = _convert_container(indices, indices_type)
  289. subset = _safe_indexing(array, indices, axis=axis)
  290. assert_allclose_dense_sparse(
  291. subset, _convert_container(expected_subset, array_type)
  292. )
  293. @pytest.mark.parametrize(
  294. "array_type, expected_output_type",
  295. [
  296. ("list", "list"),
  297. ("array", "array"),
  298. ("sparse", "sparse"),
  299. ("dataframe", "series"),
  300. ],
  301. )
  302. def test_safe_indexing_2d_scalar_axis_0(array_type, expected_output_type):
  303. array = _convert_container([[1, 2, 3], [4, 5, 6], [7, 8, 9]], array_type)
  304. indices = 2
  305. subset = _safe_indexing(array, indices, axis=0)
  306. expected_array = _convert_container([7, 8, 9], expected_output_type)
  307. assert_allclose_dense_sparse(subset, expected_array)
  308. @pytest.mark.parametrize("array_type", ["list", "array", "series"])
  309. def test_safe_indexing_1d_scalar(array_type):
  310. array = _convert_container([1, 2, 3, 4, 5, 6, 7, 8, 9], array_type)
  311. indices = 2
  312. subset = _safe_indexing(array, indices, axis=0)
  313. assert subset == 3
  314. @pytest.mark.parametrize(
  315. "array_type, expected_output_type",
  316. [("array", "array"), ("sparse", "sparse"), ("dataframe", "series")],
  317. )
  318. @pytest.mark.parametrize("indices", [2, "col_2"])
  319. def test_safe_indexing_2d_scalar_axis_1(array_type, expected_output_type, indices):
  320. columns_name = ["col_0", "col_1", "col_2"]
  321. array = _convert_container(
  322. [[1, 2, 3], [4, 5, 6], [7, 8, 9]], array_type, columns_name
  323. )
  324. if isinstance(indices, str) and array_type != "dataframe":
  325. err_msg = (
  326. "Specifying the columns using strings is only supported "
  327. "for pandas DataFrames"
  328. )
  329. with pytest.raises(ValueError, match=err_msg):
  330. _safe_indexing(array, indices, axis=1)
  331. else:
  332. subset = _safe_indexing(array, indices, axis=1)
  333. expected_output = [3, 6, 9]
  334. if expected_output_type == "sparse":
  335. # sparse matrix are keeping the 2D shape
  336. expected_output = [[3], [6], [9]]
  337. expected_array = _convert_container(expected_output, expected_output_type)
  338. assert_allclose_dense_sparse(subset, expected_array)
  339. @pytest.mark.parametrize("array_type", ["list", "array", "sparse"])
  340. def test_safe_indexing_None_axis_0(array_type):
  341. X = _convert_container([[1, 2, 3], [4, 5, 6], [7, 8, 9]], array_type)
  342. X_subset = _safe_indexing(X, None, axis=0)
  343. assert_allclose_dense_sparse(X_subset, X)
  344. def test_safe_indexing_pandas_no_matching_cols_error():
  345. pd = pytest.importorskip("pandas")
  346. err_msg = "No valid specification of the columns."
  347. X = pd.DataFrame(X_toy)
  348. with pytest.raises(ValueError, match=err_msg):
  349. _safe_indexing(X, [1.0], axis=1)
  350. @pytest.mark.parametrize("axis", [None, 3])
  351. def test_safe_indexing_error_axis(axis):
  352. with pytest.raises(ValueError, match="'axis' should be either 0"):
  353. _safe_indexing(X_toy, [0, 1], axis=axis)
  354. @pytest.mark.parametrize("X_constructor", ["array", "series"])
  355. def test_safe_indexing_1d_array_error(X_constructor):
  356. # check that we are raising an error if the array-like passed is 1D and
  357. # we try to index on the 2nd dimension
  358. X = list(range(5))
  359. if X_constructor == "array":
  360. X_constructor = np.asarray(X)
  361. elif X_constructor == "series":
  362. pd = pytest.importorskip("pandas")
  363. X_constructor = pd.Series(X)
  364. err_msg = "'X' should be a 2D NumPy array, 2D sparse matrix or pandas"
  365. with pytest.raises(ValueError, match=err_msg):
  366. _safe_indexing(X_constructor, [0, 1], axis=1)
  367. def test_safe_indexing_container_axis_0_unsupported_type():
  368. indices = ["col_1", "col_2"]
  369. array = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
  370. err_msg = "String indexing is not supported with 'axis=0'"
  371. with pytest.raises(ValueError, match=err_msg):
  372. _safe_indexing(array, indices, axis=0)
  373. def test_safe_indexing_pandas_no_settingwithcopy_warning():
  374. # Using safe_indexing with an array-like indexer gives a copy of the
  375. # DataFrame -> ensure it doesn't raise a warning if modified
  376. pd = pytest.importorskip("pandas")
  377. X = pd.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]})
  378. subset = _safe_indexing(X, [0, 1], axis=0)
  379. if hasattr(pd.errors, "SettingWithCopyWarning"):
  380. SettingWithCopyWarning = pd.errors.SettingWithCopyWarning
  381. else:
  382. # backward compatibility for pandas < 1.5
  383. SettingWithCopyWarning = pd.core.common.SettingWithCopyWarning
  384. with warnings.catch_warnings():
  385. warnings.simplefilter("error", SettingWithCopyWarning)
  386. subset.iloc[0, 0] = 10
  387. # The original dataframe is unaffected by the assignment on the subset:
  388. assert X.iloc[0, 0] == 1
  389. @pytest.mark.parametrize(
  390. "key, err_msg",
  391. [
  392. (10, r"all features must be in \[0, 2\]"),
  393. ("whatever", "A given column is not a column of the dataframe"),
  394. ],
  395. )
  396. def test_get_column_indices_error(key, err_msg):
  397. pd = pytest.importorskip("pandas")
  398. X_df = pd.DataFrame(X_toy, columns=["col_0", "col_1", "col_2"])
  399. with pytest.raises(ValueError, match=err_msg):
  400. _get_column_indices(X_df, key)
  401. @pytest.mark.parametrize(
  402. "key", [["col1"], ["col2"], ["col1", "col2"], ["col1", "col3"], ["col2", "col3"]]
  403. )
  404. def test_get_column_indices_pandas_nonunique_columns_error(key):
  405. pd = pytest.importorskip("pandas")
  406. toy = np.zeros((1, 5), dtype=int)
  407. columns = ["col1", "col1", "col2", "col3", "col2"]
  408. X = pd.DataFrame(toy, columns=columns)
  409. err_msg = "Selected columns, {}, are not unique in dataframe".format(key)
  410. with pytest.raises(ValueError) as exc_info:
  411. _get_column_indices(X, key)
  412. assert str(exc_info.value) == err_msg
  413. def test_shuffle_on_ndim_equals_three():
  414. def to_tuple(A): # to make the inner arrays hashable
  415. return tuple(tuple(tuple(C) for C in B) for B in A)
  416. A = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) # A.shape = (2,2,2)
  417. S = set(to_tuple(A))
  418. shuffle(A) # shouldn't raise a ValueError for dim = 3
  419. assert set(to_tuple(A)) == S
  420. def test_shuffle_dont_convert_to_array():
  421. # Check that shuffle does not try to convert to numpy arrays with float
  422. # dtypes can let any indexable datastructure pass-through.
  423. a = ["a", "b", "c"]
  424. b = np.array(["a", "b", "c"], dtype=object)
  425. c = [1, 2, 3]
  426. d = MockDataFrame(np.array([["a", 0], ["b", 1], ["c", 2]], dtype=object))
  427. e = sp.csc_matrix(np.arange(6).reshape(3, 2))
  428. a_s, b_s, c_s, d_s, e_s = shuffle(a, b, c, d, e, random_state=0)
  429. assert a_s == ["c", "b", "a"]
  430. assert type(a_s) == list # noqa: E721
  431. assert_array_equal(b_s, ["c", "b", "a"])
  432. assert b_s.dtype == object
  433. assert c_s == [3, 2, 1]
  434. assert type(c_s) == list # noqa: E721
  435. assert_array_equal(d_s, np.array([["c", 2], ["b", 1], ["a", 0]], dtype=object))
  436. assert type(d_s) == MockDataFrame # noqa: E721
  437. assert_array_equal(e_s.toarray(), np.array([[4, 5], [2, 3], [0, 1]]))
  438. def test_gen_even_slices():
  439. # check that gen_even_slices contains all samples
  440. some_range = range(10)
  441. joined_range = list(chain(*[some_range[slice] for slice in gen_even_slices(10, 3)]))
  442. assert_array_equal(some_range, joined_range)
  443. # check that passing negative n_chunks raises an error
  444. slices = gen_even_slices(10, -1)
  445. with pytest.raises(ValueError, match="gen_even_slices got n_packs=-1, must be >=1"):
  446. next(slices)
  447. @pytest.mark.parametrize(
  448. ("row_bytes", "max_n_rows", "working_memory", "expected"),
  449. [
  450. (1024, None, 1, 1024),
  451. (1024, None, 0.99999999, 1023),
  452. (1023, None, 1, 1025),
  453. (1025, None, 1, 1023),
  454. (1024, None, 2, 2048),
  455. (1024, 7, 1, 7),
  456. (1024 * 1024, None, 1, 1),
  457. ],
  458. )
  459. def test_get_chunk_n_rows(row_bytes, max_n_rows, working_memory, expected):
  460. with warnings.catch_warnings():
  461. warnings.simplefilter("error", UserWarning)
  462. actual = get_chunk_n_rows(
  463. row_bytes=row_bytes,
  464. max_n_rows=max_n_rows,
  465. working_memory=working_memory,
  466. )
  467. assert actual == expected
  468. assert type(actual) is type(expected)
  469. with config_context(working_memory=working_memory):
  470. with warnings.catch_warnings():
  471. warnings.simplefilter("error", UserWarning)
  472. actual = get_chunk_n_rows(row_bytes=row_bytes, max_n_rows=max_n_rows)
  473. assert actual == expected
  474. assert type(actual) is type(expected)
  475. def test_get_chunk_n_rows_warns():
  476. """Check that warning is raised when working_memory is too low."""
  477. row_bytes = 1024 * 1024 + 1
  478. max_n_rows = None
  479. working_memory = 1
  480. expected = 1
  481. warn_msg = (
  482. "Could not adhere to working_memory config. Currently 1MiB, 2MiB required."
  483. )
  484. with pytest.warns(UserWarning, match=warn_msg):
  485. actual = get_chunk_n_rows(
  486. row_bytes=row_bytes,
  487. max_n_rows=max_n_rows,
  488. working_memory=working_memory,
  489. )
  490. assert actual == expected
  491. assert type(actual) is type(expected)
  492. with config_context(working_memory=working_memory):
  493. with pytest.warns(UserWarning, match=warn_msg):
  494. actual = get_chunk_n_rows(row_bytes=row_bytes, max_n_rows=max_n_rows)
  495. assert actual == expected
  496. assert type(actual) is type(expected)
  497. @pytest.mark.parametrize(
  498. ["source", "message", "is_long"],
  499. [
  500. ("ABC", string.ascii_lowercase, False),
  501. ("ABCDEF", string.ascii_lowercase, False),
  502. ("ABC", string.ascii_lowercase * 3, True),
  503. ("ABC" * 10, string.ascii_lowercase, True),
  504. ("ABC", string.ascii_lowercase + "\u1048", False),
  505. ],
  506. )
  507. @pytest.mark.parametrize(
  508. ["time", "time_str"],
  509. [
  510. (0.2, " 0.2s"),
  511. (20, " 20.0s"),
  512. (2000, "33.3min"),
  513. (20000, "333.3min"),
  514. ],
  515. )
  516. def test_message_with_time(source, message, is_long, time, time_str):
  517. out = _message_with_time(source, message, time)
  518. if is_long:
  519. assert len(out) > 70
  520. else:
  521. assert len(out) == 70
  522. assert out.startswith("[" + source + "] ")
  523. out = out[len(source) + 3 :]
  524. assert out.endswith(time_str)
  525. out = out[: -len(time_str)]
  526. assert out.endswith(", total=")
  527. out = out[: -len(", total=")]
  528. assert out.endswith(message)
  529. out = out[: -len(message)]
  530. assert out.endswith(" ")
  531. out = out[:-1]
  532. if is_long:
  533. assert not out
  534. else:
  535. assert list(set(out)) == ["."]
  536. @pytest.mark.parametrize(
  537. ["message", "expected"],
  538. [
  539. ("hello", _message_with_time("ABC", "hello", 0.1) + "\n"),
  540. ("", _message_with_time("ABC", "", 0.1) + "\n"),
  541. (None, ""),
  542. ],
  543. )
  544. def test_print_elapsed_time(message, expected, capsys, monkeypatch):
  545. monkeypatch.setattr(timeit, "default_timer", lambda: 0)
  546. with _print_elapsed_time("ABC", message):
  547. monkeypatch.setattr(timeit, "default_timer", lambda: 0.1)
  548. assert capsys.readouterr().out == expected
  549. @pytest.mark.parametrize(
  550. "value, result",
  551. [
  552. (float("nan"), True),
  553. (np.nan, True),
  554. (float(np.nan), True),
  555. (np.float32(np.nan), True),
  556. (np.float64(np.nan), True),
  557. (0, False),
  558. (0.0, False),
  559. (None, False),
  560. ("", False),
  561. ("nan", False),
  562. ([np.nan], False),
  563. (9867966753463435747313673, False), # Python int that overflows with C type
  564. ],
  565. )
  566. def test_is_scalar_nan(value, result):
  567. assert is_scalar_nan(value) is result
  568. # make sure that we are returning a Python bool
  569. assert isinstance(is_scalar_nan(value), bool)
  570. def test_approximate_mode():
  571. """Make sure sklearn.utils._approximate_mode returns valid
  572. results for cases where "class_counts * n_draws" is enough
  573. to overflow 32-bit signed integer.
  574. Non-regression test for:
  575. https://github.com/scikit-learn/scikit-learn/issues/20774
  576. """
  577. X = np.array([99000, 1000], dtype=np.int32)
  578. ret = _approximate_mode(class_counts=X, n_draws=25000, rng=0)
  579. # Draws 25% of the total population, so in this case a fair draw means:
  580. # 25% * 99.000 = 24.750
  581. # 25% * 1.000 = 250
  582. assert_array_equal(ret, [24750, 250])
  583. def dummy_func():
  584. pass
  585. def test_deprecation_joblib_api(tmpdir):
  586. # Only parallel_backend and register_parallel_backend are not deprecated in
  587. # sklearn.utils
  588. from sklearn.utils import parallel_backend, register_parallel_backend
  589. assert_no_warnings(parallel_backend, "loky", None)
  590. assert_no_warnings(register_parallel_backend, "failing", None)
  591. from sklearn.utils._joblib import joblib
  592. del joblib.parallel.BACKENDS["failing"]
  593. @pytest.mark.parametrize("sequence", [[np.array(1), np.array(2)], [[1, 2], [3, 4]]])
  594. def test_to_object_array(sequence):
  595. out = _to_object_array(sequence)
  596. assert isinstance(out, np.ndarray)
  597. assert out.dtype.kind == "O"
  598. assert out.ndim == 1
  599. @pytest.mark.parametrize("array_type", ["array", "sparse", "dataframe"])
  600. def test_safe_assign(array_type):
  601. """Check that `_safe_assign` works as expected."""
  602. rng = np.random.RandomState(0)
  603. X_array = rng.randn(10, 5)
  604. row_indexer = [1, 2]
  605. values = rng.randn(len(row_indexer), X_array.shape[1])
  606. X = _convert_container(X_array, array_type)
  607. _safe_assign(X, values, row_indexer=row_indexer)
  608. assigned_portion = _safe_indexing(X, row_indexer, axis=0)
  609. assert_allclose_dense_sparse(
  610. assigned_portion, _convert_container(values, array_type)
  611. )
  612. column_indexer = [1, 2]
  613. values = rng.randn(X_array.shape[0], len(column_indexer))
  614. X = _convert_container(X_array, array_type)
  615. _safe_assign(X, values, column_indexer=column_indexer)
  616. assigned_portion = _safe_indexing(X, column_indexer, axis=1)
  617. assert_allclose_dense_sparse(
  618. assigned_portion, _convert_container(values, array_type)
  619. )
  620. row_indexer, column_indexer = None, None
  621. values = rng.randn(*X.shape)
  622. X = _convert_container(X_array, array_type)
  623. _safe_assign(X, values, column_indexer=column_indexer)
  624. assert_allclose_dense_sparse(X, _convert_container(values, array_type))