test_testing.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780
  1. import atexit
  2. import os
  3. import unittest
  4. import warnings
  5. import numpy as np
  6. import pytest
  7. from scipy import sparse
  8. from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
  9. from sklearn.tree import DecisionTreeClassifier
  10. from sklearn.utils._testing import (
  11. TempMemmap,
  12. _convert_container,
  13. _delete_folder,
  14. assert_allclose,
  15. assert_allclose_dense_sparse,
  16. assert_no_warnings,
  17. assert_raise_message,
  18. assert_raises,
  19. assert_raises_regex,
  20. check_docstring_parameters,
  21. create_memmap_backed_data,
  22. ignore_warnings,
  23. raises,
  24. set_random_state,
  25. )
  26. from sklearn.utils.deprecation import deprecated
  27. from sklearn.utils.metaestimators import available_if
  28. def test_set_random_state():
  29. lda = LinearDiscriminantAnalysis()
  30. tree = DecisionTreeClassifier()
  31. # Linear Discriminant Analysis doesn't have random state: smoke test
  32. set_random_state(lda, 3)
  33. set_random_state(tree, 3)
  34. assert tree.random_state == 3
  35. def test_assert_allclose_dense_sparse():
  36. x = np.arange(9).reshape(3, 3)
  37. msg = "Not equal to tolerance "
  38. y = sparse.csc_matrix(x)
  39. for X in [x, y]:
  40. # basic compare
  41. with pytest.raises(AssertionError, match=msg):
  42. assert_allclose_dense_sparse(X, X * 2)
  43. assert_allclose_dense_sparse(X, X)
  44. with pytest.raises(ValueError, match="Can only compare two sparse"):
  45. assert_allclose_dense_sparse(x, y)
  46. A = sparse.diags(np.ones(5), offsets=0).tocsr()
  47. B = sparse.csr_matrix(np.ones((1, 5)))
  48. with pytest.raises(AssertionError, match="Arrays are not equal"):
  49. assert_allclose_dense_sparse(B, A)
  50. def test_assert_raises_msg():
  51. with assert_raises_regex(AssertionError, "Hello world"):
  52. with assert_raises(ValueError, msg="Hello world"):
  53. pass
  54. def test_assert_raise_message():
  55. def _raise_ValueError(message):
  56. raise ValueError(message)
  57. def _no_raise():
  58. pass
  59. assert_raise_message(ValueError, "test", _raise_ValueError, "test")
  60. assert_raises(
  61. AssertionError,
  62. assert_raise_message,
  63. ValueError,
  64. "something else",
  65. _raise_ValueError,
  66. "test",
  67. )
  68. assert_raises(
  69. ValueError,
  70. assert_raise_message,
  71. TypeError,
  72. "something else",
  73. _raise_ValueError,
  74. "test",
  75. )
  76. assert_raises(AssertionError, assert_raise_message, ValueError, "test", _no_raise)
  77. # multiple exceptions in a tuple
  78. assert_raises(
  79. AssertionError,
  80. assert_raise_message,
  81. (ValueError, AttributeError),
  82. "test",
  83. _no_raise,
  84. )
  85. def test_ignore_warning():
  86. # This check that ignore_warning decorator and context manager are working
  87. # as expected
  88. def _warning_function():
  89. warnings.warn("deprecation warning", DeprecationWarning)
  90. def _multiple_warning_function():
  91. warnings.warn("deprecation warning", DeprecationWarning)
  92. warnings.warn("deprecation warning")
  93. # Check the function directly
  94. assert_no_warnings(ignore_warnings(_warning_function))
  95. assert_no_warnings(ignore_warnings(_warning_function, category=DeprecationWarning))
  96. with pytest.warns(DeprecationWarning):
  97. ignore_warnings(_warning_function, category=UserWarning)()
  98. with pytest.warns(UserWarning):
  99. ignore_warnings(_multiple_warning_function, category=FutureWarning)()
  100. with pytest.warns(DeprecationWarning):
  101. ignore_warnings(_multiple_warning_function, category=UserWarning)()
  102. assert_no_warnings(
  103. ignore_warnings(_warning_function, category=(DeprecationWarning, UserWarning))
  104. )
  105. # Check the decorator
  106. @ignore_warnings
  107. def decorator_no_warning():
  108. _warning_function()
  109. _multiple_warning_function()
  110. @ignore_warnings(category=(DeprecationWarning, UserWarning))
  111. def decorator_no_warning_multiple():
  112. _multiple_warning_function()
  113. @ignore_warnings(category=DeprecationWarning)
  114. def decorator_no_deprecation_warning():
  115. _warning_function()
  116. @ignore_warnings(category=UserWarning)
  117. def decorator_no_user_warning():
  118. _warning_function()
  119. @ignore_warnings(category=DeprecationWarning)
  120. def decorator_no_deprecation_multiple_warning():
  121. _multiple_warning_function()
  122. @ignore_warnings(category=UserWarning)
  123. def decorator_no_user_multiple_warning():
  124. _multiple_warning_function()
  125. assert_no_warnings(decorator_no_warning)
  126. assert_no_warnings(decorator_no_warning_multiple)
  127. assert_no_warnings(decorator_no_deprecation_warning)
  128. with pytest.warns(DeprecationWarning):
  129. decorator_no_user_warning()
  130. with pytest.warns(UserWarning):
  131. decorator_no_deprecation_multiple_warning()
  132. with pytest.warns(DeprecationWarning):
  133. decorator_no_user_multiple_warning()
  134. # Check the context manager
  135. def context_manager_no_warning():
  136. with ignore_warnings():
  137. _warning_function()
  138. def context_manager_no_warning_multiple():
  139. with ignore_warnings(category=(DeprecationWarning, UserWarning)):
  140. _multiple_warning_function()
  141. def context_manager_no_deprecation_warning():
  142. with ignore_warnings(category=DeprecationWarning):
  143. _warning_function()
  144. def context_manager_no_user_warning():
  145. with ignore_warnings(category=UserWarning):
  146. _warning_function()
  147. def context_manager_no_deprecation_multiple_warning():
  148. with ignore_warnings(category=DeprecationWarning):
  149. _multiple_warning_function()
  150. def context_manager_no_user_multiple_warning():
  151. with ignore_warnings(category=UserWarning):
  152. _multiple_warning_function()
  153. assert_no_warnings(context_manager_no_warning)
  154. assert_no_warnings(context_manager_no_warning_multiple)
  155. assert_no_warnings(context_manager_no_deprecation_warning)
  156. with pytest.warns(DeprecationWarning):
  157. context_manager_no_user_warning()
  158. with pytest.warns(UserWarning):
  159. context_manager_no_deprecation_multiple_warning()
  160. with pytest.warns(DeprecationWarning):
  161. context_manager_no_user_multiple_warning()
  162. # Check that passing warning class as first positional argument
  163. warning_class = UserWarning
  164. match = "'obj' should be a callable.+you should use 'category=UserWarning'"
  165. with pytest.raises(ValueError, match=match):
  166. silence_warnings_func = ignore_warnings(warning_class)(_warning_function)
  167. silence_warnings_func()
  168. with pytest.raises(ValueError, match=match):
  169. @ignore_warnings(warning_class)
  170. def test():
  171. pass
  172. class TestWarns(unittest.TestCase):
  173. def test_warn(self):
  174. def f():
  175. warnings.warn("yo")
  176. return 3
  177. with pytest.raises(AssertionError):
  178. assert_no_warnings(f)
  179. assert assert_no_warnings(lambda x: x, 1) == 1
  180. # Tests for docstrings:
  181. def f_ok(a, b):
  182. """Function f
  183. Parameters
  184. ----------
  185. a : int
  186. Parameter a
  187. b : float
  188. Parameter b
  189. Returns
  190. -------
  191. c : list
  192. Parameter c
  193. """
  194. c = a + b
  195. return c
  196. def f_bad_sections(a, b):
  197. """Function f
  198. Parameters
  199. ----------
  200. a : int
  201. Parameter a
  202. b : float
  203. Parameter b
  204. Results
  205. -------
  206. c : list
  207. Parameter c
  208. """
  209. c = a + b
  210. return c
  211. def f_bad_order(b, a):
  212. """Function f
  213. Parameters
  214. ----------
  215. a : int
  216. Parameter a
  217. b : float
  218. Parameter b
  219. Returns
  220. -------
  221. c : list
  222. Parameter c
  223. """
  224. c = a + b
  225. return c
  226. def f_too_many_param_docstring(a, b):
  227. """Function f
  228. Parameters
  229. ----------
  230. a : int
  231. Parameter a
  232. b : int
  233. Parameter b
  234. c : int
  235. Parameter c
  236. Returns
  237. -------
  238. d : list
  239. Parameter c
  240. """
  241. d = a + b
  242. return d
  243. def f_missing(a, b):
  244. """Function f
  245. Parameters
  246. ----------
  247. a : int
  248. Parameter a
  249. Returns
  250. -------
  251. c : list
  252. Parameter c
  253. """
  254. c = a + b
  255. return c
  256. def f_check_param_definition(a, b, c, d, e):
  257. """Function f
  258. Parameters
  259. ----------
  260. a: int
  261. Parameter a
  262. b:
  263. Parameter b
  264. c :
  265. This is parsed correctly in numpydoc 1.2
  266. d:int
  267. Parameter d
  268. e
  269. No typespec is allowed without colon
  270. """
  271. return a + b + c + d
  272. class Klass:
  273. def f_missing(self, X, y):
  274. pass
  275. def f_bad_sections(self, X, y):
  276. """Function f
  277. Parameter
  278. ---------
  279. a : int
  280. Parameter a
  281. b : float
  282. Parameter b
  283. Results
  284. -------
  285. c : list
  286. Parameter c
  287. """
  288. pass
  289. class MockEst:
  290. def __init__(self):
  291. """MockEstimator"""
  292. def fit(self, X, y):
  293. return X
  294. def predict(self, X):
  295. return X
  296. def predict_proba(self, X):
  297. return X
  298. def score(self, X):
  299. return 1.0
  300. class MockMetaEstimator:
  301. def __init__(self, delegate):
  302. """MetaEstimator to check if doctest on delegated methods work.
  303. Parameters
  304. ---------
  305. delegate : estimator
  306. Delegated estimator.
  307. """
  308. self.delegate = delegate
  309. @available_if(lambda self: hasattr(self.delegate, "predict"))
  310. def predict(self, X):
  311. """This is available only if delegate has predict.
  312. Parameters
  313. ----------
  314. y : ndarray
  315. Parameter y
  316. """
  317. return self.delegate.predict(X)
  318. @available_if(lambda self: hasattr(self.delegate, "score"))
  319. @deprecated("Testing a deprecated delegated method")
  320. def score(self, X):
  321. """This is available only if delegate has score.
  322. Parameters
  323. ---------
  324. y : ndarray
  325. Parameter y
  326. """
  327. @available_if(lambda self: hasattr(self.delegate, "predict_proba"))
  328. def predict_proba(self, X):
  329. """This is available only if delegate has predict_proba.
  330. Parameters
  331. ---------
  332. X : ndarray
  333. Parameter X
  334. """
  335. return X
  336. @deprecated("Testing deprecated function with wrong params")
  337. def fit(self, X, y):
  338. """Incorrect docstring but should not be tested"""
  339. def test_check_docstring_parameters():
  340. pytest.importorskip(
  341. "numpydoc",
  342. reason="numpydoc is required to test the docstrings",
  343. minversion="1.2.0",
  344. )
  345. incorrect = check_docstring_parameters(f_ok)
  346. assert incorrect == []
  347. incorrect = check_docstring_parameters(f_ok, ignore=["b"])
  348. assert incorrect == []
  349. incorrect = check_docstring_parameters(f_missing, ignore=["b"])
  350. assert incorrect == []
  351. with pytest.raises(RuntimeError, match="Unknown section Results"):
  352. check_docstring_parameters(f_bad_sections)
  353. with pytest.raises(RuntimeError, match="Unknown section Parameter"):
  354. check_docstring_parameters(Klass.f_bad_sections)
  355. incorrect = check_docstring_parameters(f_check_param_definition)
  356. mock_meta = MockMetaEstimator(delegate=MockEst())
  357. mock_meta_name = mock_meta.__class__.__name__
  358. assert incorrect == [
  359. (
  360. "sklearn.utils.tests.test_testing.f_check_param_definition There "
  361. "was no space between the param name and colon ('a: int')"
  362. ),
  363. (
  364. "sklearn.utils.tests.test_testing.f_check_param_definition There "
  365. "was no space between the param name and colon ('b:')"
  366. ),
  367. (
  368. "sklearn.utils.tests.test_testing.f_check_param_definition There "
  369. "was no space between the param name and colon ('d:int')"
  370. ),
  371. ]
  372. messages = [
  373. [
  374. "In function: sklearn.utils.tests.test_testing.f_bad_order",
  375. (
  376. "There's a parameter name mismatch in function docstring w.r.t."
  377. " function signature, at index 0 diff: 'b' != 'a'"
  378. ),
  379. "Full diff:",
  380. "- ['b', 'a']",
  381. "+ ['a', 'b']",
  382. ],
  383. [
  384. "In function: "
  385. + "sklearn.utils.tests.test_testing.f_too_many_param_docstring",
  386. (
  387. "Parameters in function docstring have more items w.r.t. function"
  388. " signature, first extra item: c"
  389. ),
  390. "Full diff:",
  391. "- ['a', 'b']",
  392. "+ ['a', 'b', 'c']",
  393. "? +++++",
  394. ],
  395. [
  396. "In function: sklearn.utils.tests.test_testing.f_missing",
  397. (
  398. "Parameters in function docstring have less items w.r.t. function"
  399. " signature, first missing item: b"
  400. ),
  401. "Full diff:",
  402. "- ['a', 'b']",
  403. "+ ['a']",
  404. ],
  405. [
  406. "In function: sklearn.utils.tests.test_testing.Klass.f_missing",
  407. (
  408. "Parameters in function docstring have less items w.r.t. function"
  409. " signature, first missing item: X"
  410. ),
  411. "Full diff:",
  412. "- ['X', 'y']",
  413. "+ []",
  414. ],
  415. [
  416. "In function: "
  417. + f"sklearn.utils.tests.test_testing.{mock_meta_name}.predict",
  418. (
  419. "There's a parameter name mismatch in function docstring w.r.t."
  420. " function signature, at index 0 diff: 'X' != 'y'"
  421. ),
  422. "Full diff:",
  423. "- ['X']",
  424. "? ^",
  425. "+ ['y']",
  426. "? ^",
  427. ],
  428. [
  429. "In function: "
  430. + f"sklearn.utils.tests.test_testing.{mock_meta_name}."
  431. + "predict_proba",
  432. "potentially wrong underline length... ",
  433. "Parameters ",
  434. "--------- in ",
  435. ],
  436. [
  437. "In function: "
  438. + f"sklearn.utils.tests.test_testing.{mock_meta_name}.score",
  439. "potentially wrong underline length... ",
  440. "Parameters ",
  441. "--------- in ",
  442. ],
  443. [
  444. "In function: " + f"sklearn.utils.tests.test_testing.{mock_meta_name}.fit",
  445. (
  446. "Parameters in function docstring have less items w.r.t. function"
  447. " signature, first missing item: X"
  448. ),
  449. "Full diff:",
  450. "- ['X', 'y']",
  451. "+ []",
  452. ],
  453. ]
  454. for msg, f in zip(
  455. messages,
  456. [
  457. f_bad_order,
  458. f_too_many_param_docstring,
  459. f_missing,
  460. Klass.f_missing,
  461. mock_meta.predict,
  462. mock_meta.predict_proba,
  463. mock_meta.score,
  464. mock_meta.fit,
  465. ],
  466. ):
  467. incorrect = check_docstring_parameters(f)
  468. assert msg == incorrect, '\n"%s"\n not in \n"%s"' % (msg, incorrect)
  469. class RegistrationCounter:
  470. def __init__(self):
  471. self.nb_calls = 0
  472. def __call__(self, to_register_func):
  473. self.nb_calls += 1
  474. assert to_register_func.func is _delete_folder
  475. def check_memmap(input_array, mmap_data, mmap_mode="r"):
  476. assert isinstance(mmap_data, np.memmap)
  477. writeable = mmap_mode != "r"
  478. assert mmap_data.flags.writeable is writeable
  479. np.testing.assert_array_equal(input_array, mmap_data)
  480. def test_tempmemmap(monkeypatch):
  481. registration_counter = RegistrationCounter()
  482. monkeypatch.setattr(atexit, "register", registration_counter)
  483. input_array = np.ones(3)
  484. with TempMemmap(input_array) as data:
  485. check_memmap(input_array, data)
  486. temp_folder = os.path.dirname(data.filename)
  487. if os.name != "nt":
  488. assert not os.path.exists(temp_folder)
  489. assert registration_counter.nb_calls == 1
  490. mmap_mode = "r+"
  491. with TempMemmap(input_array, mmap_mode=mmap_mode) as data:
  492. check_memmap(input_array, data, mmap_mode=mmap_mode)
  493. temp_folder = os.path.dirname(data.filename)
  494. if os.name != "nt":
  495. assert not os.path.exists(temp_folder)
  496. assert registration_counter.nb_calls == 2
  497. @pytest.mark.parametrize("aligned", [False, True])
  498. def test_create_memmap_backed_data(monkeypatch, aligned):
  499. registration_counter = RegistrationCounter()
  500. monkeypatch.setattr(atexit, "register", registration_counter)
  501. input_array = np.ones(3)
  502. data = create_memmap_backed_data(input_array, aligned=aligned)
  503. check_memmap(input_array, data)
  504. assert registration_counter.nb_calls == 1
  505. data, folder = create_memmap_backed_data(
  506. input_array, return_folder=True, aligned=aligned
  507. )
  508. check_memmap(input_array, data)
  509. assert folder == os.path.dirname(data.filename)
  510. assert registration_counter.nb_calls == 2
  511. mmap_mode = "r+"
  512. data = create_memmap_backed_data(input_array, mmap_mode=mmap_mode, aligned=aligned)
  513. check_memmap(input_array, data, mmap_mode)
  514. assert registration_counter.nb_calls == 3
  515. input_list = [input_array, input_array + 1, input_array + 2]
  516. mmap_data_list = create_memmap_backed_data(input_list, aligned=aligned)
  517. for input_array, data in zip(input_list, mmap_data_list):
  518. check_memmap(input_array, data)
  519. assert registration_counter.nb_calls == 4
  520. with pytest.raises(
  521. ValueError,
  522. match=(
  523. "When creating aligned memmap-backed arrays, input must be a single array"
  524. " or a sequence of arrays"
  525. ),
  526. ):
  527. create_memmap_backed_data([input_array, "not-an-array"], aligned=True)
  528. @pytest.mark.parametrize(
  529. "constructor_name, container_type",
  530. [
  531. ("list", list),
  532. ("tuple", tuple),
  533. ("array", np.ndarray),
  534. ("sparse", sparse.csr_matrix),
  535. ("sparse_csr", sparse.csr_matrix),
  536. ("sparse_csc", sparse.csc_matrix),
  537. ("dataframe", lambda: pytest.importorskip("pandas").DataFrame),
  538. ("series", lambda: pytest.importorskip("pandas").Series),
  539. ("index", lambda: pytest.importorskip("pandas").Index),
  540. ("slice", slice),
  541. ],
  542. )
  543. @pytest.mark.parametrize(
  544. "dtype, superdtype",
  545. [
  546. (np.int32, np.integer),
  547. (np.int64, np.integer),
  548. (np.float32, np.floating),
  549. (np.float64, np.floating),
  550. ],
  551. )
  552. def test_convert_container(
  553. constructor_name,
  554. container_type,
  555. dtype,
  556. superdtype,
  557. ):
  558. """Check that we convert the container to the right type of array with the
  559. right data type."""
  560. if constructor_name in ("dataframe", "series", "index"):
  561. # delay the import of pandas within the function to only skip this test
  562. # instead of the whole file
  563. container_type = container_type()
  564. container = [0, 1]
  565. container_converted = _convert_container(
  566. container,
  567. constructor_name,
  568. dtype=dtype,
  569. )
  570. assert isinstance(container_converted, container_type)
  571. if constructor_name in ("list", "tuple", "index"):
  572. # list and tuple will use Python class dtype: int, float
  573. # pandas index will always use high precision: np.int64 and np.float64
  574. assert np.issubdtype(type(container_converted[0]), superdtype)
  575. elif hasattr(container_converted, "dtype"):
  576. assert container_converted.dtype == dtype
  577. elif hasattr(container_converted, "dtypes"):
  578. assert container_converted.dtypes[0] == dtype
  579. def test_raises():
  580. # Tests for the raises context manager
  581. # Proper type, no match
  582. with raises(TypeError):
  583. raise TypeError()
  584. # Proper type, proper match
  585. with raises(TypeError, match="how are you") as cm:
  586. raise TypeError("hello how are you")
  587. assert cm.raised_and_matched
  588. # Proper type, proper match with multiple patterns
  589. with raises(TypeError, match=["not this one", "how are you"]) as cm:
  590. raise TypeError("hello how are you")
  591. assert cm.raised_and_matched
  592. # bad type, no match
  593. with pytest.raises(ValueError, match="this will be raised"):
  594. with raises(TypeError) as cm:
  595. raise ValueError("this will be raised")
  596. assert not cm.raised_and_matched
  597. # Bad type, no match, with a err_msg
  598. with pytest.raises(AssertionError, match="the failure message"):
  599. with raises(TypeError, err_msg="the failure message") as cm:
  600. raise ValueError()
  601. assert not cm.raised_and_matched
  602. # bad type, with match (is ignored anyway)
  603. with pytest.raises(ValueError, match="this will be raised"):
  604. with raises(TypeError, match="this is ignored") as cm:
  605. raise ValueError("this will be raised")
  606. assert not cm.raised_and_matched
  607. # proper type but bad match
  608. with pytest.raises(
  609. AssertionError, match="should contain one of the following patterns"
  610. ):
  611. with raises(TypeError, match="hello") as cm:
  612. raise TypeError("Bad message")
  613. assert not cm.raised_and_matched
  614. # proper type but bad match, with err_msg
  615. with pytest.raises(AssertionError, match="the failure message"):
  616. with raises(TypeError, match="hello", err_msg="the failure message") as cm:
  617. raise TypeError("Bad message")
  618. assert not cm.raised_and_matched
  619. # no raise with default may_pass=False
  620. with pytest.raises(AssertionError, match="Did not raise"):
  621. with raises(TypeError) as cm:
  622. pass
  623. assert not cm.raised_and_matched
  624. # no raise with may_pass=True
  625. with raises(TypeError, match="hello", may_pass=True) as cm:
  626. pass # still OK
  627. assert not cm.raised_and_matched
  628. # Multiple exception types:
  629. with raises((TypeError, ValueError)):
  630. raise TypeError()
  631. with raises((TypeError, ValueError)):
  632. raise ValueError()
  633. with pytest.raises(AssertionError):
  634. with raises((TypeError, ValueError)):
  635. pass
  636. def test_float32_aware_assert_allclose():
  637. # The relative tolerance for float32 inputs is 1e-4
  638. assert_allclose(np.array([1.0 + 2e-5], dtype=np.float32), 1.0)
  639. with pytest.raises(AssertionError):
  640. assert_allclose(np.array([1.0 + 2e-4], dtype=np.float32), 1.0)
  641. # The relative tolerance for other inputs is left to 1e-7 as in
  642. # the original numpy version.
  643. assert_allclose(np.array([1.0 + 2e-8], dtype=np.float64), 1.0)
  644. with pytest.raises(AssertionError):
  645. assert_allclose(np.array([1.0 + 2e-7], dtype=np.float64), 1.0)
  646. # atol is left to 0.0 by default, even for float32
  647. with pytest.raises(AssertionError):
  648. assert_allclose(np.array([1e-5], dtype=np.float32), 0.0)
  649. assert_allclose(np.array([1e-5], dtype=np.float32), 0.0, atol=2e-5)