test_base.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. """
  2. Testing for the base module (sklearn.ensemble.base).
  3. """
  4. # Authors: Gilles Louppe
  5. # License: BSD 3 clause
  6. from collections import OrderedDict
  7. import numpy as np
  8. import pytest
  9. from sklearn import ensemble
  10. from sklearn.datasets import load_iris
  11. from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
  12. from sklearn.ensemble import BaggingClassifier
  13. from sklearn.ensemble._base import _set_random_states
  14. from sklearn.feature_selection import SelectFromModel
  15. from sklearn.linear_model import LogisticRegression, Perceptron, Ridge
  16. from sklearn.pipeline import Pipeline
  17. def test_base():
  18. # Check BaseEnsemble methods.
  19. ensemble = BaggingClassifier(
  20. estimator=Perceptron(random_state=None), n_estimators=3
  21. )
  22. iris = load_iris()
  23. ensemble.fit(iris.data, iris.target)
  24. ensemble.estimators_ = [] # empty the list and create estimators manually
  25. ensemble._make_estimator()
  26. random_state = np.random.RandomState(3)
  27. ensemble._make_estimator(random_state=random_state)
  28. ensemble._make_estimator(random_state=random_state)
  29. ensemble._make_estimator(append=False)
  30. assert 3 == len(ensemble)
  31. assert 3 == len(ensemble.estimators_)
  32. assert isinstance(ensemble[0], Perceptron)
  33. assert ensemble[0].random_state is None
  34. assert isinstance(ensemble[1].random_state, int)
  35. assert isinstance(ensemble[2].random_state, int)
  36. assert ensemble[1].random_state != ensemble[2].random_state
  37. np_int_ensemble = BaggingClassifier(
  38. estimator=Perceptron(), n_estimators=np.int32(3)
  39. )
  40. np_int_ensemble.fit(iris.data, iris.target)
  41. def test_set_random_states():
  42. # Linear Discriminant Analysis doesn't have random state: smoke test
  43. _set_random_states(LinearDiscriminantAnalysis(), random_state=17)
  44. clf1 = Perceptron(random_state=None)
  45. assert clf1.random_state is None
  46. # check random_state is None still sets
  47. _set_random_states(clf1, None)
  48. assert isinstance(clf1.random_state, int)
  49. # check random_state fixes results in consistent initialisation
  50. _set_random_states(clf1, 3)
  51. assert isinstance(clf1.random_state, int)
  52. clf2 = Perceptron(random_state=None)
  53. _set_random_states(clf2, 3)
  54. assert clf1.random_state == clf2.random_state
  55. # nested random_state
  56. def make_steps():
  57. return [
  58. ("sel", SelectFromModel(Perceptron(random_state=None))),
  59. ("clf", Perceptron(random_state=None)),
  60. ]
  61. est1 = Pipeline(make_steps())
  62. _set_random_states(est1, 3)
  63. assert isinstance(est1.steps[0][1].estimator.random_state, int)
  64. assert isinstance(est1.steps[1][1].random_state, int)
  65. assert (
  66. est1.get_params()["sel__estimator__random_state"]
  67. != est1.get_params()["clf__random_state"]
  68. )
  69. # ensure multiple random_state parameters are invariant to get_params()
  70. # iteration order
  71. class AlphaParamPipeline(Pipeline):
  72. def get_params(self, *args, **kwargs):
  73. params = Pipeline.get_params(self, *args, **kwargs).items()
  74. return OrderedDict(sorted(params))
  75. class RevParamPipeline(Pipeline):
  76. def get_params(self, *args, **kwargs):
  77. params = Pipeline.get_params(self, *args, **kwargs).items()
  78. return OrderedDict(sorted(params, reverse=True))
  79. for cls in [AlphaParamPipeline, RevParamPipeline]:
  80. est2 = cls(make_steps())
  81. _set_random_states(est2, 3)
  82. assert (
  83. est1.get_params()["sel__estimator__random_state"]
  84. == est2.get_params()["sel__estimator__random_state"]
  85. )
  86. assert (
  87. est1.get_params()["clf__random_state"]
  88. == est2.get_params()["clf__random_state"]
  89. )
  90. # TODO(1.4): remove
  91. def test_validate_estimator_value_error():
  92. X = np.array([[1, 2], [3, 4]])
  93. y = np.array([1, 0])
  94. model = BaggingClassifier(estimator=Perceptron(), base_estimator=Perceptron())
  95. err_msg = "Both `estimator` and `base_estimator` were set. Only set `estimator`."
  96. with pytest.raises(ValueError, match=err_msg):
  97. model.fit(X, y)
  98. # TODO(1.4): remove
  99. @pytest.mark.parametrize(
  100. "model",
  101. [
  102. ensemble.GradientBoostingClassifier(),
  103. ensemble.GradientBoostingRegressor(),
  104. ensemble.HistGradientBoostingClassifier(),
  105. ensemble.HistGradientBoostingRegressor(),
  106. ensemble.VotingClassifier(
  107. [("a", LogisticRegression()), ("b", LogisticRegression())]
  108. ),
  109. ensemble.VotingRegressor([("a", Ridge()), ("b", Ridge())]),
  110. ],
  111. )
  112. def test_estimator_attribute_error(model):
  113. X = [[1], [2]]
  114. y = [0, 1]
  115. model.fit(X, y)
  116. assert not hasattr(model, "estimator_")