test_public_functions.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. from importlib import import_module
  2. from inspect import signature
  3. from numbers import Integral, Real
  4. import pytest
  5. from sklearn.utils._param_validation import (
  6. Interval,
  7. InvalidParameterError,
  8. generate_invalid_param_val,
  9. generate_valid_param,
  10. make_constraint,
  11. )
  12. def _get_func_info(func_module):
  13. module_name, func_name = func_module.rsplit(".", 1)
  14. module = import_module(module_name)
  15. func = getattr(module, func_name)
  16. func_sig = signature(func)
  17. func_params = [
  18. p.name
  19. for p in func_sig.parameters.values()
  20. if p.kind not in (p.VAR_POSITIONAL, p.VAR_KEYWORD)
  21. ]
  22. # The parameters `*args` and `**kwargs` are ignored since we cannot generate
  23. # constraints.
  24. required_params = [
  25. p.name
  26. for p in func_sig.parameters.values()
  27. if p.default is p.empty and p.kind not in (p.VAR_POSITIONAL, p.VAR_KEYWORD)
  28. ]
  29. return func, func_name, func_params, required_params
  30. def _check_function_param_validation(
  31. func, func_name, func_params, required_params, parameter_constraints
  32. ):
  33. """Check that an informative error is raised when the value of a parameter does not
  34. have an appropriate type or value.
  35. """
  36. # generate valid values for the required parameters
  37. valid_required_params = {}
  38. for param_name in required_params:
  39. if parameter_constraints[param_name] == "no_validation":
  40. valid_required_params[param_name] = 1
  41. else:
  42. valid_required_params[param_name] = generate_valid_param(
  43. make_constraint(parameter_constraints[param_name][0])
  44. )
  45. # check that there is a constraint for each parameter
  46. if func_params:
  47. validation_params = parameter_constraints.keys()
  48. unexpected_params = set(validation_params) - set(func_params)
  49. missing_params = set(func_params) - set(validation_params)
  50. err_msg = (
  51. "Mismatch between _parameter_constraints and the parameters of"
  52. f" {func_name}.\nConsider the unexpected parameters {unexpected_params} and"
  53. f" expected but missing parameters {missing_params}\n"
  54. )
  55. assert set(validation_params) == set(func_params), err_msg
  56. # this object does not have a valid type for sure for all params
  57. param_with_bad_type = type("BadType", (), {})()
  58. for param_name in func_params:
  59. constraints = parameter_constraints[param_name]
  60. if constraints == "no_validation":
  61. # This parameter is not validated
  62. continue
  63. # Mixing an interval of reals and an interval of integers must be avoided.
  64. if any(
  65. isinstance(constraint, Interval) and constraint.type == Integral
  66. for constraint in constraints
  67. ) and any(
  68. isinstance(constraint, Interval) and constraint.type == Real
  69. for constraint in constraints
  70. ):
  71. raise ValueError(
  72. f"The constraint for parameter {param_name} of {func_name} can't have a"
  73. " mix of intervals of Integral and Real types. Use the type"
  74. " RealNotInt instead of Real."
  75. )
  76. match = (
  77. rf"The '{param_name}' parameter of {func_name} must be .* Got .* instead."
  78. )
  79. # First, check that the error is raised if param doesn't match any valid type.
  80. with pytest.raises(InvalidParameterError, match=match):
  81. func(**{**valid_required_params, param_name: param_with_bad_type})
  82. # Then, for constraints that are more than a type constraint, check that the
  83. # error is raised if param does match a valid type but does not match any valid
  84. # value for this type.
  85. constraints = [make_constraint(constraint) for constraint in constraints]
  86. for constraint in constraints:
  87. try:
  88. bad_value = generate_invalid_param_val(constraint)
  89. except NotImplementedError:
  90. continue
  91. with pytest.raises(InvalidParameterError, match=match):
  92. func(**{**valid_required_params, param_name: bad_value})
  93. PARAM_VALIDATION_FUNCTION_LIST = [
  94. "sklearn.calibration.calibration_curve",
  95. "sklearn.cluster.cluster_optics_dbscan",
  96. "sklearn.cluster.compute_optics_graph",
  97. "sklearn.cluster.estimate_bandwidth",
  98. "sklearn.cluster.kmeans_plusplus",
  99. "sklearn.cluster.cluster_optics_xi",
  100. "sklearn.cluster.ward_tree",
  101. "sklearn.covariance.empirical_covariance",
  102. "sklearn.covariance.ledoit_wolf_shrinkage",
  103. "sklearn.covariance.shrunk_covariance",
  104. "sklearn.datasets.clear_data_home",
  105. "sklearn.datasets.dump_svmlight_file",
  106. "sklearn.datasets.fetch_20newsgroups",
  107. "sklearn.datasets.fetch_20newsgroups_vectorized",
  108. "sklearn.datasets.fetch_california_housing",
  109. "sklearn.datasets.fetch_covtype",
  110. "sklearn.datasets.fetch_kddcup99",
  111. "sklearn.datasets.fetch_lfw_pairs",
  112. "sklearn.datasets.fetch_lfw_people",
  113. "sklearn.datasets.fetch_olivetti_faces",
  114. "sklearn.datasets.fetch_rcv1",
  115. "sklearn.datasets.fetch_openml",
  116. "sklearn.datasets.fetch_species_distributions",
  117. "sklearn.datasets.get_data_home",
  118. "sklearn.datasets.load_breast_cancer",
  119. "sklearn.datasets.load_diabetes",
  120. "sklearn.datasets.load_digits",
  121. "sklearn.datasets.load_files",
  122. "sklearn.datasets.load_iris",
  123. "sklearn.datasets.load_linnerud",
  124. "sklearn.datasets.load_sample_image",
  125. "sklearn.datasets.load_svmlight_file",
  126. "sklearn.datasets.load_svmlight_files",
  127. "sklearn.datasets.load_wine",
  128. "sklearn.datasets.make_biclusters",
  129. "sklearn.datasets.make_blobs",
  130. "sklearn.datasets.make_checkerboard",
  131. "sklearn.datasets.make_circles",
  132. "sklearn.datasets.make_classification",
  133. "sklearn.datasets.make_friedman1",
  134. "sklearn.datasets.make_friedman2",
  135. "sklearn.datasets.make_friedman3",
  136. "sklearn.datasets.make_gaussian_quantiles",
  137. "sklearn.datasets.make_hastie_10_2",
  138. "sklearn.datasets.make_low_rank_matrix",
  139. "sklearn.datasets.make_moons",
  140. "sklearn.datasets.make_multilabel_classification",
  141. "sklearn.datasets.make_regression",
  142. "sklearn.datasets.make_s_curve",
  143. "sklearn.datasets.make_sparse_coded_signal",
  144. "sklearn.datasets.make_sparse_spd_matrix",
  145. "sklearn.datasets.make_sparse_uncorrelated",
  146. "sklearn.datasets.make_spd_matrix",
  147. "sklearn.datasets.make_swiss_roll",
  148. "sklearn.decomposition.sparse_encode",
  149. "sklearn.feature_extraction.grid_to_graph",
  150. "sklearn.feature_extraction.img_to_graph",
  151. "sklearn.feature_extraction.image.extract_patches_2d",
  152. "sklearn.feature_extraction.image.reconstruct_from_patches_2d",
  153. "sklearn.feature_selection.chi2",
  154. "sklearn.feature_selection.f_classif",
  155. "sklearn.feature_selection.f_regression",
  156. "sklearn.feature_selection.mutual_info_classif",
  157. "sklearn.feature_selection.mutual_info_regression",
  158. "sklearn.feature_selection.r_regression",
  159. "sklearn.inspection.partial_dependence",
  160. "sklearn.inspection.permutation_importance",
  161. "sklearn.linear_model.orthogonal_mp",
  162. "sklearn.metrics.accuracy_score",
  163. "sklearn.metrics.auc",
  164. "sklearn.metrics.average_precision_score",
  165. "sklearn.metrics.balanced_accuracy_score",
  166. "sklearn.metrics.brier_score_loss",
  167. "sklearn.metrics.calinski_harabasz_score",
  168. "sklearn.metrics.check_scoring",
  169. "sklearn.metrics.completeness_score",
  170. "sklearn.metrics.class_likelihood_ratios",
  171. "sklearn.metrics.classification_report",
  172. "sklearn.metrics.cluster.adjusted_mutual_info_score",
  173. "sklearn.metrics.cluster.contingency_matrix",
  174. "sklearn.metrics.cluster.entropy",
  175. "sklearn.metrics.cluster.fowlkes_mallows_score",
  176. "sklearn.metrics.cluster.homogeneity_completeness_v_measure",
  177. "sklearn.metrics.cluster.normalized_mutual_info_score",
  178. "sklearn.metrics.cluster.silhouette_samples",
  179. "sklearn.metrics.cluster.silhouette_score",
  180. "sklearn.metrics.cohen_kappa_score",
  181. "sklearn.metrics.confusion_matrix",
  182. "sklearn.metrics.coverage_error",
  183. "sklearn.metrics.d2_absolute_error_score",
  184. "sklearn.metrics.d2_pinball_score",
  185. "sklearn.metrics.d2_tweedie_score",
  186. "sklearn.metrics.davies_bouldin_score",
  187. "sklearn.metrics.dcg_score",
  188. "sklearn.metrics.det_curve",
  189. "sklearn.metrics.explained_variance_score",
  190. "sklearn.metrics.f1_score",
  191. "sklearn.metrics.fbeta_score",
  192. "sklearn.metrics.get_scorer",
  193. "sklearn.metrics.hamming_loss",
  194. "sklearn.metrics.hinge_loss",
  195. "sklearn.metrics.homogeneity_score",
  196. "sklearn.metrics.jaccard_score",
  197. "sklearn.metrics.label_ranking_average_precision_score",
  198. "sklearn.metrics.label_ranking_loss",
  199. "sklearn.metrics.log_loss",
  200. "sklearn.metrics.make_scorer",
  201. "sklearn.metrics.matthews_corrcoef",
  202. "sklearn.metrics.max_error",
  203. "sklearn.metrics.mean_absolute_error",
  204. "sklearn.metrics.mean_absolute_percentage_error",
  205. "sklearn.metrics.mean_gamma_deviance",
  206. "sklearn.metrics.mean_pinball_loss",
  207. "sklearn.metrics.mean_poisson_deviance",
  208. "sklearn.metrics.mean_squared_error",
  209. "sklearn.metrics.mean_squared_log_error",
  210. "sklearn.metrics.mean_tweedie_deviance",
  211. "sklearn.metrics.median_absolute_error",
  212. "sklearn.metrics.multilabel_confusion_matrix",
  213. "sklearn.metrics.mutual_info_score",
  214. "sklearn.metrics.ndcg_score",
  215. "sklearn.metrics.pair_confusion_matrix",
  216. "sklearn.metrics.adjusted_rand_score",
  217. "sklearn.metrics.pairwise.additive_chi2_kernel",
  218. "sklearn.metrics.pairwise.cosine_distances",
  219. "sklearn.metrics.pairwise.cosine_similarity",
  220. "sklearn.metrics.pairwise.haversine_distances",
  221. "sklearn.metrics.pairwise.laplacian_kernel",
  222. "sklearn.metrics.pairwise.linear_kernel",
  223. "sklearn.metrics.pairwise.manhattan_distances",
  224. "sklearn.metrics.pairwise.nan_euclidean_distances",
  225. "sklearn.metrics.pairwise.paired_cosine_distances",
  226. "sklearn.metrics.pairwise.paired_euclidean_distances",
  227. "sklearn.metrics.pairwise.paired_manhattan_distances",
  228. "sklearn.metrics.pairwise.polynomial_kernel",
  229. "sklearn.metrics.pairwise.rbf_kernel",
  230. "sklearn.metrics.pairwise.sigmoid_kernel",
  231. "sklearn.metrics.pairwise_distances_argmin",
  232. "sklearn.metrics.precision_recall_curve",
  233. "sklearn.metrics.precision_recall_fscore_support",
  234. "sklearn.metrics.precision_score",
  235. "sklearn.metrics.r2_score",
  236. "sklearn.metrics.rand_score",
  237. "sklearn.metrics.recall_score",
  238. "sklearn.metrics.roc_auc_score",
  239. "sklearn.metrics.roc_curve",
  240. "sklearn.metrics.top_k_accuracy_score",
  241. "sklearn.metrics.v_measure_score",
  242. "sklearn.metrics.zero_one_loss",
  243. "sklearn.model_selection.cross_validate",
  244. "sklearn.model_selection.learning_curve",
  245. "sklearn.model_selection.permutation_test_score",
  246. "sklearn.model_selection.train_test_split",
  247. "sklearn.model_selection.validation_curve",
  248. "sklearn.neighbors.sort_graph_by_row_values",
  249. "sklearn.preprocessing.add_dummy_feature",
  250. "sklearn.preprocessing.binarize",
  251. "sklearn.preprocessing.label_binarize",
  252. "sklearn.preprocessing.normalize",
  253. "sklearn.preprocessing.scale",
  254. "sklearn.random_projection.johnson_lindenstrauss_min_dim",
  255. "sklearn.svm.l1_min_c",
  256. "sklearn.tree.export_graphviz",
  257. "sklearn.tree.export_text",
  258. "sklearn.tree.plot_tree",
  259. "sklearn.utils.gen_batches",
  260. "sklearn.utils.resample",
  261. ]
  262. @pytest.mark.parametrize("func_module", PARAM_VALIDATION_FUNCTION_LIST)
  263. def test_function_param_validation(func_module):
  264. """Check param validation for public functions that are not wrappers around
  265. estimators.
  266. """
  267. func, func_name, func_params, required_params = _get_func_info(func_module)
  268. parameter_constraints = getattr(func, "_skl_parameter_constraints")
  269. _check_function_param_validation(
  270. func, func_name, func_params, required_params, parameter_constraints
  271. )
  272. PARAM_VALIDATION_CLASS_WRAPPER_LIST = [
  273. ("sklearn.cluster.affinity_propagation", "sklearn.cluster.AffinityPropagation"),
  274. ("sklearn.cluster.k_means", "sklearn.cluster.KMeans"),
  275. ("sklearn.cluster.mean_shift", "sklearn.cluster.MeanShift"),
  276. ("sklearn.cluster.spectral_clustering", "sklearn.cluster.SpectralClustering"),
  277. ("sklearn.covariance.graphical_lasso", "sklearn.covariance.GraphicalLasso"),
  278. ("sklearn.covariance.ledoit_wolf", "sklearn.covariance.LedoitWolf"),
  279. ("sklearn.covariance.oas", "sklearn.covariance.OAS"),
  280. ("sklearn.decomposition.dict_learning", "sklearn.decomposition.DictionaryLearning"),
  281. ("sklearn.decomposition.fastica", "sklearn.decomposition.FastICA"),
  282. ("sklearn.decomposition.non_negative_factorization", "sklearn.decomposition.NMF"),
  283. ("sklearn.preprocessing.maxabs_scale", "sklearn.preprocessing.MaxAbsScaler"),
  284. ("sklearn.preprocessing.minmax_scale", "sklearn.preprocessing.MinMaxScaler"),
  285. ("sklearn.preprocessing.power_transform", "sklearn.preprocessing.PowerTransformer"),
  286. (
  287. "sklearn.preprocessing.quantile_transform",
  288. "sklearn.preprocessing.QuantileTransformer",
  289. ),
  290. ("sklearn.preprocessing.robust_scale", "sklearn.preprocessing.RobustScaler"),
  291. ]
  292. @pytest.mark.parametrize(
  293. "func_module, class_module", PARAM_VALIDATION_CLASS_WRAPPER_LIST
  294. )
  295. def test_class_wrapper_param_validation(func_module, class_module):
  296. """Check param validation for public functions that are wrappers around
  297. estimators.
  298. """
  299. func, func_name, func_params, required_params = _get_func_info(func_module)
  300. module_name, class_name = class_module.rsplit(".", 1)
  301. module = import_module(module_name)
  302. klass = getattr(module, class_name)
  303. parameter_constraints_func = getattr(func, "_skl_parameter_constraints")
  304. parameter_constraints_class = getattr(klass, "_parameter_constraints")
  305. parameter_constraints = {
  306. **parameter_constraints_class,
  307. **parameter_constraints_func,
  308. }
  309. parameter_constraints = {
  310. k: v for k, v in parameter_constraints.items() if k in func_params
  311. }
  312. _check_function_param_validation(
  313. func, func_name, func_params, required_params, parameter_constraints
  314. )