__init__.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import typing
  2. from ._plot import LearningCurveDisplay, ValidationCurveDisplay
  3. from ._search import GridSearchCV, ParameterGrid, ParameterSampler, RandomizedSearchCV
  4. from ._split import (
  5. BaseCrossValidator,
  6. BaseShuffleSplit,
  7. GroupKFold,
  8. GroupShuffleSplit,
  9. KFold,
  10. LeaveOneGroupOut,
  11. LeaveOneOut,
  12. LeavePGroupsOut,
  13. LeavePOut,
  14. PredefinedSplit,
  15. RepeatedKFold,
  16. RepeatedStratifiedKFold,
  17. ShuffleSplit,
  18. StratifiedGroupKFold,
  19. StratifiedKFold,
  20. StratifiedShuffleSplit,
  21. TimeSeriesSplit,
  22. check_cv,
  23. train_test_split,
  24. )
  25. from ._validation import (
  26. cross_val_predict,
  27. cross_val_score,
  28. cross_validate,
  29. learning_curve,
  30. permutation_test_score,
  31. validation_curve,
  32. )
  33. if typing.TYPE_CHECKING:
  34. # Avoid errors in type checkers (e.g. mypy) for experimental estimators.
  35. # TODO: remove this check once the estimator is no longer experimental.
  36. from ._search_successive_halving import ( # noqa
  37. HalvingGridSearchCV,
  38. HalvingRandomSearchCV,
  39. )
  40. __all__ = [
  41. "BaseCrossValidator",
  42. "BaseShuffleSplit",
  43. "GridSearchCV",
  44. "TimeSeriesSplit",
  45. "KFold",
  46. "GroupKFold",
  47. "GroupShuffleSplit",
  48. "LeaveOneGroupOut",
  49. "LeaveOneOut",
  50. "LeavePGroupsOut",
  51. "LeavePOut",
  52. "RepeatedKFold",
  53. "RepeatedStratifiedKFold",
  54. "ParameterGrid",
  55. "ParameterSampler",
  56. "PredefinedSplit",
  57. "RandomizedSearchCV",
  58. "ShuffleSplit",
  59. "StratifiedKFold",
  60. "StratifiedGroupKFold",
  61. "StratifiedShuffleSplit",
  62. "check_cv",
  63. "cross_val_predict",
  64. "cross_val_score",
  65. "cross_validate",
  66. "learning_curve",
  67. "LearningCurveDisplay",
  68. "permutation_test_score",
  69. "train_test_split",
  70. "validation_curve",
  71. "ValidationCurveDisplay",
  72. ]
  73. # TODO: remove this check once the estimator is no longer experimental.
  74. def __getattr__(name):
  75. if name in {"HalvingGridSearchCV", "HalvingRandomSearchCV"}:
  76. raise ImportError(
  77. f"{name} is experimental and the API might change without any "
  78. "deprecation cycle. To use it, you need to explicitly import "
  79. "enable_halving_search_cv:\n"
  80. "from sklearn.experimental import enable_halving_search_cv"
  81. )
  82. raise AttributeError(f"module {__name__} has no attribute {name}")