test_param_validation.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760
  1. from numbers import Integral, Real
  2. import numpy as np
  3. import pytest
  4. from scipy.sparse import csr_matrix
  5. from sklearn._config import config_context, get_config
  6. from sklearn.base import BaseEstimator, _fit_context
  7. from sklearn.model_selection import LeaveOneOut
  8. from sklearn.utils import deprecated
  9. from sklearn.utils._param_validation import (
  10. HasMethods,
  11. Hidden,
  12. Interval,
  13. InvalidParameterError,
  14. MissingValues,
  15. Options,
  16. RealNotInt,
  17. StrOptions,
  18. _ArrayLikes,
  19. _Booleans,
  20. _Callables,
  21. _CVObjects,
  22. _InstancesOf,
  23. _IterablesNotString,
  24. _NanConstraint,
  25. _NoneConstraint,
  26. _PandasNAConstraint,
  27. _RandomStates,
  28. _SparseMatrices,
  29. _VerboseHelper,
  30. generate_invalid_param_val,
  31. generate_valid_param,
  32. make_constraint,
  33. validate_params,
  34. )
  35. # Some helpers for the tests
  36. @validate_params(
  37. {"a": [Real], "b": [Real], "c": [Real], "d": [Real]},
  38. prefer_skip_nested_validation=True,
  39. )
  40. def _func(a, b=0, *args, c, d=0, **kwargs):
  41. """A function to test the validation of functions."""
  42. class _Class:
  43. """A class to test the _InstancesOf constraint and the validation of methods."""
  44. @validate_params({"a": [Real]}, prefer_skip_nested_validation=True)
  45. def _method(self, a):
  46. """A validated method"""
  47. @deprecated()
  48. @validate_params({"a": [Real]}, prefer_skip_nested_validation=True)
  49. def _deprecated_method(self, a):
  50. """A deprecated validated method"""
  51. class _Estimator(BaseEstimator):
  52. """An estimator to test the validation of estimator parameters."""
  53. _parameter_constraints: dict = {"a": [Real]}
  54. def __init__(self, a):
  55. self.a = a
  56. @_fit_context(prefer_skip_nested_validation=True)
  57. def fit(self, X=None, y=None):
  58. pass
  59. @pytest.mark.parametrize("interval_type", [Integral, Real])
  60. def test_interval_range(interval_type):
  61. """Check the range of values depending on closed."""
  62. interval = Interval(interval_type, -2, 2, closed="left")
  63. assert -2 in interval and 2 not in interval
  64. interval = Interval(interval_type, -2, 2, closed="right")
  65. assert -2 not in interval and 2 in interval
  66. interval = Interval(interval_type, -2, 2, closed="both")
  67. assert -2 in interval and 2 in interval
  68. interval = Interval(interval_type, -2, 2, closed="neither")
  69. assert -2 not in interval and 2 not in interval
  70. def test_interval_inf_in_bounds():
  71. """Check that inf is included iff a bound is closed and set to None.
  72. Only valid for real intervals.
  73. """
  74. interval = Interval(Real, 0, None, closed="right")
  75. assert np.inf in interval
  76. interval = Interval(Real, None, 0, closed="left")
  77. assert -np.inf in interval
  78. interval = Interval(Real, None, None, closed="neither")
  79. assert np.inf not in interval
  80. assert -np.inf not in interval
  81. @pytest.mark.parametrize(
  82. "interval",
  83. [Interval(Real, 0, 1, closed="left"), Interval(Real, None, None, closed="both")],
  84. )
  85. def test_nan_not_in_interval(interval):
  86. """Check that np.nan is not in any interval."""
  87. assert np.nan not in interval
  88. @pytest.mark.parametrize(
  89. "params, error, match",
  90. [
  91. (
  92. {"type": Integral, "left": 1.0, "right": 2, "closed": "both"},
  93. TypeError,
  94. r"Expecting left to be an int for an interval over the integers",
  95. ),
  96. (
  97. {"type": Integral, "left": 1, "right": 2.0, "closed": "neither"},
  98. TypeError,
  99. "Expecting right to be an int for an interval over the integers",
  100. ),
  101. (
  102. {"type": Integral, "left": None, "right": 0, "closed": "left"},
  103. ValueError,
  104. r"left can't be None when closed == left",
  105. ),
  106. (
  107. {"type": Integral, "left": 0, "right": None, "closed": "right"},
  108. ValueError,
  109. r"right can't be None when closed == right",
  110. ),
  111. (
  112. {"type": Integral, "left": 1, "right": -1, "closed": "both"},
  113. ValueError,
  114. r"right can't be less than left",
  115. ),
  116. ],
  117. )
  118. def test_interval_errors(params, error, match):
  119. """Check that informative errors are raised for invalid combination of parameters"""
  120. with pytest.raises(error, match=match):
  121. Interval(**params)
  122. def test_stroptions():
  123. """Sanity check for the StrOptions constraint"""
  124. options = StrOptions({"a", "b", "c"}, deprecated={"c"})
  125. assert options.is_satisfied_by("a")
  126. assert options.is_satisfied_by("c")
  127. assert not options.is_satisfied_by("d")
  128. assert "'c' (deprecated)" in str(options)
  129. def test_options():
  130. """Sanity check for the Options constraint"""
  131. options = Options(Real, {-0.5, 0.5, np.inf}, deprecated={-0.5})
  132. assert options.is_satisfied_by(-0.5)
  133. assert options.is_satisfied_by(np.inf)
  134. assert not options.is_satisfied_by(1.23)
  135. assert "-0.5 (deprecated)" in str(options)
  136. @pytest.mark.parametrize(
  137. "type, expected_type_name",
  138. [
  139. (int, "int"),
  140. (Integral, "int"),
  141. (Real, "float"),
  142. (np.ndarray, "numpy.ndarray"),
  143. ],
  144. )
  145. def test_instances_of_type_human_readable(type, expected_type_name):
  146. """Check the string representation of the _InstancesOf constraint."""
  147. constraint = _InstancesOf(type)
  148. assert str(constraint) == f"an instance of '{expected_type_name}'"
  149. def test_hasmethods():
  150. """Check the HasMethods constraint."""
  151. constraint = HasMethods(["a", "b"])
  152. class _Good:
  153. def a(self):
  154. pass # pragma: no cover
  155. def b(self):
  156. pass # pragma: no cover
  157. class _Bad:
  158. def a(self):
  159. pass # pragma: no cover
  160. assert constraint.is_satisfied_by(_Good())
  161. assert not constraint.is_satisfied_by(_Bad())
  162. assert str(constraint) == "an object implementing 'a' and 'b'"
  163. @pytest.mark.parametrize(
  164. "constraint",
  165. [
  166. Interval(Real, None, 0, closed="left"),
  167. Interval(Real, 0, None, closed="left"),
  168. Interval(Real, None, None, closed="neither"),
  169. StrOptions({"a", "b", "c"}),
  170. MissingValues(),
  171. MissingValues(numeric_only=True),
  172. _VerboseHelper(),
  173. HasMethods("fit"),
  174. _IterablesNotString(),
  175. _CVObjects(),
  176. ],
  177. )
  178. def test_generate_invalid_param_val(constraint):
  179. """Check that the value generated does not satisfy the constraint"""
  180. bad_value = generate_invalid_param_val(constraint)
  181. assert not constraint.is_satisfied_by(bad_value)
  182. @pytest.mark.parametrize(
  183. "integer_interval, real_interval",
  184. [
  185. (
  186. Interval(Integral, None, 3, closed="right"),
  187. Interval(RealNotInt, -5, 5, closed="both"),
  188. ),
  189. (
  190. Interval(Integral, None, 3, closed="right"),
  191. Interval(RealNotInt, -5, 5, closed="neither"),
  192. ),
  193. (
  194. Interval(Integral, None, 3, closed="right"),
  195. Interval(RealNotInt, 4, 5, closed="both"),
  196. ),
  197. (
  198. Interval(Integral, None, 3, closed="right"),
  199. Interval(RealNotInt, 5, None, closed="left"),
  200. ),
  201. (
  202. Interval(Integral, None, 3, closed="right"),
  203. Interval(RealNotInt, 4, None, closed="neither"),
  204. ),
  205. (
  206. Interval(Integral, 3, None, closed="left"),
  207. Interval(RealNotInt, -5, 5, closed="both"),
  208. ),
  209. (
  210. Interval(Integral, 3, None, closed="left"),
  211. Interval(RealNotInt, -5, 5, closed="neither"),
  212. ),
  213. (
  214. Interval(Integral, 3, None, closed="left"),
  215. Interval(RealNotInt, 1, 2, closed="both"),
  216. ),
  217. (
  218. Interval(Integral, 3, None, closed="left"),
  219. Interval(RealNotInt, None, -5, closed="left"),
  220. ),
  221. (
  222. Interval(Integral, 3, None, closed="left"),
  223. Interval(RealNotInt, None, -4, closed="neither"),
  224. ),
  225. (
  226. Interval(Integral, -5, 5, closed="both"),
  227. Interval(RealNotInt, None, 1, closed="right"),
  228. ),
  229. (
  230. Interval(Integral, -5, 5, closed="both"),
  231. Interval(RealNotInt, 1, None, closed="left"),
  232. ),
  233. (
  234. Interval(Integral, -5, 5, closed="both"),
  235. Interval(RealNotInt, -10, -4, closed="neither"),
  236. ),
  237. (
  238. Interval(Integral, -5, 5, closed="both"),
  239. Interval(RealNotInt, -10, -4, closed="right"),
  240. ),
  241. (
  242. Interval(Integral, -5, 5, closed="neither"),
  243. Interval(RealNotInt, 6, 10, closed="neither"),
  244. ),
  245. (
  246. Interval(Integral, -5, 5, closed="neither"),
  247. Interval(RealNotInt, 6, 10, closed="left"),
  248. ),
  249. (
  250. Interval(Integral, 2, None, closed="left"),
  251. Interval(RealNotInt, 0, 1, closed="both"),
  252. ),
  253. (
  254. Interval(Integral, 1, None, closed="left"),
  255. Interval(RealNotInt, 0, 1, closed="both"),
  256. ),
  257. ],
  258. )
  259. def test_generate_invalid_param_val_2_intervals(integer_interval, real_interval):
  260. """Check that the value generated for an interval constraint does not satisfy any of
  261. the interval constraints.
  262. """
  263. bad_value = generate_invalid_param_val(constraint=real_interval)
  264. assert not real_interval.is_satisfied_by(bad_value)
  265. assert not integer_interval.is_satisfied_by(bad_value)
  266. bad_value = generate_invalid_param_val(constraint=integer_interval)
  267. assert not real_interval.is_satisfied_by(bad_value)
  268. assert not integer_interval.is_satisfied_by(bad_value)
  269. @pytest.mark.parametrize(
  270. "constraint",
  271. [
  272. _ArrayLikes(),
  273. _InstancesOf(list),
  274. _Callables(),
  275. _NoneConstraint(),
  276. _RandomStates(),
  277. _SparseMatrices(),
  278. _Booleans(),
  279. Interval(Integral, None, None, closed="neither"),
  280. ],
  281. )
  282. def test_generate_invalid_param_val_all_valid(constraint):
  283. """Check that the function raises NotImplementedError when there's no invalid value
  284. for the constraint.
  285. """
  286. with pytest.raises(NotImplementedError):
  287. generate_invalid_param_val(constraint)
  288. @pytest.mark.parametrize(
  289. "constraint",
  290. [
  291. _ArrayLikes(),
  292. _Callables(),
  293. _InstancesOf(list),
  294. _NoneConstraint(),
  295. _RandomStates(),
  296. _SparseMatrices(),
  297. _Booleans(),
  298. _VerboseHelper(),
  299. MissingValues(),
  300. MissingValues(numeric_only=True),
  301. StrOptions({"a", "b", "c"}),
  302. Options(Integral, {1, 2, 3}),
  303. Interval(Integral, None, None, closed="neither"),
  304. Interval(Integral, 0, 10, closed="neither"),
  305. Interval(Integral, 0, None, closed="neither"),
  306. Interval(Integral, None, 0, closed="neither"),
  307. Interval(Real, 0, 1, closed="neither"),
  308. Interval(Real, 0, None, closed="both"),
  309. Interval(Real, None, 0, closed="right"),
  310. HasMethods("fit"),
  311. _IterablesNotString(),
  312. _CVObjects(),
  313. ],
  314. )
  315. def test_generate_valid_param(constraint):
  316. """Check that the value generated does satisfy the constraint."""
  317. value = generate_valid_param(constraint)
  318. assert constraint.is_satisfied_by(value)
  319. @pytest.mark.parametrize(
  320. "constraint_declaration, value",
  321. [
  322. (Interval(Real, 0, 1, closed="both"), 0.42),
  323. (Interval(Integral, 0, None, closed="neither"), 42),
  324. (StrOptions({"a", "b", "c"}), "b"),
  325. (Options(type, {np.float32, np.float64}), np.float64),
  326. (callable, lambda x: x + 1),
  327. (None, None),
  328. ("array-like", [[1, 2], [3, 4]]),
  329. ("array-like", np.array([[1, 2], [3, 4]])),
  330. ("sparse matrix", csr_matrix([[1, 2], [3, 4]])),
  331. ("random_state", 0),
  332. ("random_state", np.random.RandomState(0)),
  333. ("random_state", None),
  334. (_Class, _Class()),
  335. (int, 1),
  336. (Real, 0.5),
  337. ("boolean", False),
  338. ("verbose", 1),
  339. ("nan", np.nan),
  340. (MissingValues(), -1),
  341. (MissingValues(), -1.0),
  342. (MissingValues(), None),
  343. (MissingValues(), float("nan")),
  344. (MissingValues(), np.nan),
  345. (MissingValues(), "missing"),
  346. (HasMethods("fit"), _Estimator(a=0)),
  347. ("cv_object", 5),
  348. ],
  349. )
  350. def test_is_satisfied_by(constraint_declaration, value):
  351. """Sanity check for the is_satisfied_by method"""
  352. constraint = make_constraint(constraint_declaration)
  353. assert constraint.is_satisfied_by(value)
  354. @pytest.mark.parametrize(
  355. "constraint_declaration, expected_constraint_class",
  356. [
  357. (Interval(Real, 0, 1, closed="both"), Interval),
  358. (StrOptions({"option1", "option2"}), StrOptions),
  359. (Options(Real, {0.42, 1.23}), Options),
  360. ("array-like", _ArrayLikes),
  361. ("sparse matrix", _SparseMatrices),
  362. ("random_state", _RandomStates),
  363. (None, _NoneConstraint),
  364. (callable, _Callables),
  365. (int, _InstancesOf),
  366. ("boolean", _Booleans),
  367. ("verbose", _VerboseHelper),
  368. (MissingValues(numeric_only=True), MissingValues),
  369. (HasMethods("fit"), HasMethods),
  370. ("cv_object", _CVObjects),
  371. ("nan", _NanConstraint),
  372. ],
  373. )
  374. def test_make_constraint(constraint_declaration, expected_constraint_class):
  375. """Check that make_constraint dispaches to the appropriate constraint class"""
  376. constraint = make_constraint(constraint_declaration)
  377. assert constraint.__class__ is expected_constraint_class
  378. def test_make_constraint_unknown():
  379. """Check that an informative error is raised when an unknown constraint is passed"""
  380. with pytest.raises(ValueError, match="Unknown constraint"):
  381. make_constraint("not a valid constraint")
  382. def test_validate_params():
  383. """Check that validate_params works no matter how the arguments are passed"""
  384. with pytest.raises(
  385. InvalidParameterError, match="The 'a' parameter of _func must be"
  386. ):
  387. _func("wrong", c=1)
  388. with pytest.raises(
  389. InvalidParameterError, match="The 'b' parameter of _func must be"
  390. ):
  391. _func(*[1, "wrong"], c=1)
  392. with pytest.raises(
  393. InvalidParameterError, match="The 'c' parameter of _func must be"
  394. ):
  395. _func(1, **{"c": "wrong"})
  396. with pytest.raises(
  397. InvalidParameterError, match="The 'd' parameter of _func must be"
  398. ):
  399. _func(1, c=1, d="wrong")
  400. # check in the presence of extra positional and keyword args
  401. with pytest.raises(
  402. InvalidParameterError, match="The 'b' parameter of _func must be"
  403. ):
  404. _func(0, *["wrong", 2, 3], c=4, **{"e": 5})
  405. with pytest.raises(
  406. InvalidParameterError, match="The 'c' parameter of _func must be"
  407. ):
  408. _func(0, *[1, 2, 3], c="four", **{"e": 5})
  409. def test_validate_params_missing_params():
  410. """Check that no error is raised when there are parameters without
  411. constraints
  412. """
  413. @validate_params({"a": [int]}, prefer_skip_nested_validation=True)
  414. def func(a, b):
  415. pass
  416. func(1, 2)
  417. def test_decorate_validated_function():
  418. """Check that validate_params functions can be decorated"""
  419. decorated_function = deprecated()(_func)
  420. with pytest.warns(FutureWarning, match="Function _func is deprecated"):
  421. decorated_function(1, 2, c=3)
  422. # outer decorator does not interfere with validation
  423. with pytest.warns(FutureWarning, match="Function _func is deprecated"):
  424. with pytest.raises(
  425. InvalidParameterError, match=r"The 'c' parameter of _func must be"
  426. ):
  427. decorated_function(1, 2, c="wrong")
  428. def test_validate_params_method():
  429. """Check that validate_params works with methods"""
  430. with pytest.raises(
  431. InvalidParameterError, match="The 'a' parameter of _Class._method must be"
  432. ):
  433. _Class()._method("wrong")
  434. # validated method can be decorated
  435. with pytest.warns(FutureWarning, match="Function _deprecated_method is deprecated"):
  436. with pytest.raises(
  437. InvalidParameterError,
  438. match="The 'a' parameter of _Class._deprecated_method must be",
  439. ):
  440. _Class()._deprecated_method("wrong")
  441. def test_validate_params_estimator():
  442. """Check that validate_params works with Estimator instances"""
  443. # no validation in init
  444. est = _Estimator("wrong")
  445. with pytest.raises(
  446. InvalidParameterError, match="The 'a' parameter of _Estimator must be"
  447. ):
  448. est.fit()
  449. def test_stroptions_deprecated_subset():
  450. """Check that the deprecated parameter must be a subset of options."""
  451. with pytest.raises(ValueError, match="deprecated options must be a subset"):
  452. StrOptions({"a", "b", "c"}, deprecated={"a", "d"})
  453. def test_hidden_constraint():
  454. """Check that internal constraints are not exposed in the error message."""
  455. @validate_params(
  456. {"param": [Hidden(list), dict]}, prefer_skip_nested_validation=True
  457. )
  458. def f(param):
  459. pass
  460. # list and dict are valid params
  461. f({"a": 1, "b": 2, "c": 3})
  462. f([1, 2, 3])
  463. with pytest.raises(
  464. InvalidParameterError, match="The 'param' parameter"
  465. ) as exc_info:
  466. f(param="bad")
  467. # the list option is not exposed in the error message
  468. err_msg = str(exc_info.value)
  469. assert "an instance of 'dict'" in err_msg
  470. assert "an instance of 'list'" not in err_msg
  471. def test_hidden_stroptions():
  472. """Check that we can have 2 StrOptions constraints, one being hidden."""
  473. @validate_params(
  474. {"param": [StrOptions({"auto"}), Hidden(StrOptions({"warn"}))]},
  475. prefer_skip_nested_validation=True,
  476. )
  477. def f(param):
  478. pass
  479. # "auto" and "warn" are valid params
  480. f("auto")
  481. f("warn")
  482. with pytest.raises(
  483. InvalidParameterError, match="The 'param' parameter"
  484. ) as exc_info:
  485. f(param="bad")
  486. # the "warn" option is not exposed in the error message
  487. err_msg = str(exc_info.value)
  488. assert "auto" in err_msg
  489. assert "warn" not in err_msg
  490. def test_validate_params_set_param_constraints_attribute():
  491. """Check that the validate_params decorator properly sets the parameter constraints
  492. as attribute of the decorated function/method.
  493. """
  494. assert hasattr(_func, "_skl_parameter_constraints")
  495. assert hasattr(_Class()._method, "_skl_parameter_constraints")
  496. def test_boolean_constraint_deprecated_int():
  497. """Check that validate_params raise a deprecation message but still passes
  498. validation when using an int for a parameter accepting a boolean.
  499. """
  500. @validate_params({"param": ["boolean"]}, prefer_skip_nested_validation=True)
  501. def f(param):
  502. pass
  503. # True/False and np.bool_(True/False) are valid params
  504. f(True)
  505. f(np.bool_(False))
  506. # an int is also valid but deprecated
  507. with pytest.warns(
  508. FutureWarning, match="Passing an int for a boolean parameter is deprecated"
  509. ):
  510. f(1)
  511. def test_no_validation():
  512. """Check that validation can be skipped for a parameter."""
  513. @validate_params(
  514. {"param1": [int, None], "param2": "no_validation"},
  515. prefer_skip_nested_validation=True,
  516. )
  517. def f(param1=None, param2=None):
  518. pass
  519. # param1 is validated
  520. with pytest.raises(InvalidParameterError, match="The 'param1' parameter"):
  521. f(param1="wrong")
  522. # param2 is not validated: any type is valid.
  523. class SomeType:
  524. pass
  525. f(param2=SomeType)
  526. f(param2=SomeType())
  527. def test_pandas_na_constraint_with_pd_na():
  528. """Add a specific test for checking support for `pandas.NA`."""
  529. pd = pytest.importorskip("pandas")
  530. na_constraint = _PandasNAConstraint()
  531. assert na_constraint.is_satisfied_by(pd.NA)
  532. assert not na_constraint.is_satisfied_by(np.array([1, 2, 3]))
  533. def test_iterable_not_string():
  534. """Check that a string does not satisfy the _IterableNotString constraint."""
  535. constraint = _IterablesNotString()
  536. assert constraint.is_satisfied_by([1, 2, 3])
  537. assert constraint.is_satisfied_by(range(10))
  538. assert not constraint.is_satisfied_by("some string")
  539. def test_cv_objects():
  540. """Check that the _CVObjects constraint accepts all current ways
  541. to pass cv objects."""
  542. constraint = _CVObjects()
  543. assert constraint.is_satisfied_by(5)
  544. assert constraint.is_satisfied_by(LeaveOneOut())
  545. assert constraint.is_satisfied_by([([1, 2], [3, 4]), ([3, 4], [1, 2])])
  546. assert constraint.is_satisfied_by(None)
  547. assert not constraint.is_satisfied_by("not a CV object")
  548. def test_third_party_estimator():
  549. """Check that the validation from a scikit-learn estimator inherited by a third
  550. party estimator does not impose a match between the dict of constraints and the
  551. parameters of the estimator.
  552. """
  553. class ThirdPartyEstimator(_Estimator):
  554. def __init__(self, b):
  555. self.b = b
  556. super().__init__(a=0)
  557. def fit(self, X=None, y=None):
  558. super().fit(X, y)
  559. # does not raise, even though "b" is not in the constraints dict and "a" is not
  560. # a parameter of the estimator.
  561. ThirdPartyEstimator(b=0).fit()
  562. def test_interval_real_not_int():
  563. """Check for the type RealNotInt in the Interval constraint."""
  564. constraint = Interval(RealNotInt, 0, 1, closed="both")
  565. assert constraint.is_satisfied_by(1.0)
  566. assert not constraint.is_satisfied_by(1)
  567. def test_real_not_int():
  568. """Check for the RealNotInt type."""
  569. assert isinstance(1.0, RealNotInt)
  570. assert not isinstance(1, RealNotInt)
  571. assert isinstance(np.float64(1), RealNotInt)
  572. assert not isinstance(np.int64(1), RealNotInt)
  573. def test_skip_param_validation():
  574. """Check that param validation can be skipped using config_context."""
  575. @validate_params({"a": [int]}, prefer_skip_nested_validation=True)
  576. def f(a):
  577. pass
  578. with pytest.raises(InvalidParameterError, match="The 'a' parameter"):
  579. f(a="1")
  580. # does not raise
  581. with config_context(skip_parameter_validation=True):
  582. f(a="1")
  583. @pytest.mark.parametrize("prefer_skip_nested_validation", [True, False])
  584. def test_skip_nested_validation(prefer_skip_nested_validation):
  585. """Check that nested validation can be skipped."""
  586. @validate_params({"a": [int]}, prefer_skip_nested_validation=True)
  587. def f(a):
  588. pass
  589. @validate_params(
  590. {"b": [int]},
  591. prefer_skip_nested_validation=prefer_skip_nested_validation,
  592. )
  593. def g(b):
  594. # calls f with a bad parameter type
  595. return f(a="invalid_param_value")
  596. # Validation for g is never skipped.
  597. with pytest.raises(InvalidParameterError, match="The 'b' parameter"):
  598. g(b="invalid_param_value")
  599. if prefer_skip_nested_validation:
  600. g(b=1) # does not raise because inner f is not validated
  601. else:
  602. with pytest.raises(InvalidParameterError, match="The 'a' parameter"):
  603. g(b=1)
  604. @pytest.mark.parametrize(
  605. "skip_parameter_validation, prefer_skip_nested_validation, expected_skipped",
  606. [
  607. (True, True, True),
  608. (True, False, True),
  609. (False, True, True),
  610. (False, False, False),
  611. ],
  612. )
  613. def test_skip_nested_validation_and_config_context(
  614. skip_parameter_validation, prefer_skip_nested_validation, expected_skipped
  615. ):
  616. """Check interaction between global skip and local skip."""
  617. @validate_params(
  618. {"a": [int]}, prefer_skip_nested_validation=prefer_skip_nested_validation
  619. )
  620. def g(a):
  621. return get_config()["skip_parameter_validation"]
  622. with config_context(skip_parameter_validation=skip_parameter_validation):
  623. actual_skipped = g(1)
  624. assert actual_skipped == expected_skipped