test_self_training.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. from math import ceil
  2. import numpy as np
  3. import pytest
  4. from numpy.testing import assert_array_equal
  5. from sklearn.datasets import load_iris, make_blobs
  6. from sklearn.ensemble import StackingClassifier
  7. from sklearn.exceptions import NotFittedError
  8. from sklearn.metrics import accuracy_score
  9. from sklearn.model_selection import train_test_split
  10. from sklearn.neighbors import KNeighborsClassifier
  11. from sklearn.semi_supervised import SelfTrainingClassifier
  12. from sklearn.svm import SVC
  13. # Author: Oliver Rausch <rauscho@ethz.ch>
  14. # License: BSD 3 clause
  15. # load the iris dataset and randomly permute it
  16. iris = load_iris()
  17. X_train, X_test, y_train, y_test = train_test_split(
  18. iris.data, iris.target, random_state=0
  19. )
  20. n_labeled_samples = 50
  21. y_train_missing_labels = y_train.copy()
  22. y_train_missing_labels[n_labeled_samples:] = -1
  23. mapping = {0: "A", 1: "B", 2: "C", -1: "-1"}
  24. y_train_missing_strings = np.vectorize(mapping.get)(y_train_missing_labels).astype(
  25. object
  26. )
  27. y_train_missing_strings[y_train_missing_labels == -1] = -1
  28. def test_warns_k_best():
  29. st = SelfTrainingClassifier(KNeighborsClassifier(), criterion="k_best", k_best=1000)
  30. with pytest.warns(UserWarning, match="k_best is larger than"):
  31. st.fit(X_train, y_train_missing_labels)
  32. assert st.termination_condition_ == "all_labeled"
  33. @pytest.mark.parametrize(
  34. "base_estimator",
  35. [KNeighborsClassifier(), SVC(gamma="scale", probability=True, random_state=0)],
  36. )
  37. @pytest.mark.parametrize("selection_crit", ["threshold", "k_best"])
  38. def test_classification(base_estimator, selection_crit):
  39. # Check classification for various parameter settings.
  40. # Also assert that predictions for strings and numerical labels are equal.
  41. # Also test for multioutput classification
  42. threshold = 0.75
  43. max_iter = 10
  44. st = SelfTrainingClassifier(
  45. base_estimator, max_iter=max_iter, threshold=threshold, criterion=selection_crit
  46. )
  47. st.fit(X_train, y_train_missing_labels)
  48. pred = st.predict(X_test)
  49. proba = st.predict_proba(X_test)
  50. st_string = SelfTrainingClassifier(
  51. base_estimator, max_iter=max_iter, criterion=selection_crit, threshold=threshold
  52. )
  53. st_string.fit(X_train, y_train_missing_strings)
  54. pred_string = st_string.predict(X_test)
  55. proba_string = st_string.predict_proba(X_test)
  56. assert_array_equal(np.vectorize(mapping.get)(pred), pred_string)
  57. assert_array_equal(proba, proba_string)
  58. assert st.termination_condition_ == st_string.termination_condition_
  59. # Check consistency between labeled_iter, n_iter and max_iter
  60. labeled = y_train_missing_labels != -1
  61. # assert that labeled samples have labeled_iter = 0
  62. assert_array_equal(st.labeled_iter_ == 0, labeled)
  63. # assert that labeled samples do not change label during training
  64. assert_array_equal(y_train_missing_labels[labeled], st.transduction_[labeled])
  65. # assert that the max of the iterations is less than the total amount of
  66. # iterations
  67. assert np.max(st.labeled_iter_) <= st.n_iter_ <= max_iter
  68. assert np.max(st_string.labeled_iter_) <= st_string.n_iter_ <= max_iter
  69. # check shapes
  70. assert st.labeled_iter_.shape == st.transduction_.shape
  71. assert st_string.labeled_iter_.shape == st_string.transduction_.shape
  72. def test_k_best():
  73. st = SelfTrainingClassifier(
  74. KNeighborsClassifier(n_neighbors=1),
  75. criterion="k_best",
  76. k_best=10,
  77. max_iter=None,
  78. )
  79. y_train_only_one_label = np.copy(y_train)
  80. y_train_only_one_label[1:] = -1
  81. n_samples = y_train.shape[0]
  82. n_expected_iter = ceil((n_samples - 1) / 10)
  83. st.fit(X_train, y_train_only_one_label)
  84. assert st.n_iter_ == n_expected_iter
  85. # Check labeled_iter_
  86. assert np.sum(st.labeled_iter_ == 0) == 1
  87. for i in range(1, n_expected_iter):
  88. assert np.sum(st.labeled_iter_ == i) == 10
  89. assert np.sum(st.labeled_iter_ == n_expected_iter) == (n_samples - 1) % 10
  90. assert st.termination_condition_ == "all_labeled"
  91. def test_sanity_classification():
  92. base_estimator = SVC(gamma="scale", probability=True)
  93. base_estimator.fit(X_train[n_labeled_samples:], y_train[n_labeled_samples:])
  94. st = SelfTrainingClassifier(base_estimator)
  95. st.fit(X_train, y_train_missing_labels)
  96. pred1, pred2 = base_estimator.predict(X_test), st.predict(X_test)
  97. assert not np.array_equal(pred1, pred2)
  98. score_supervised = accuracy_score(base_estimator.predict(X_test), y_test)
  99. score_self_training = accuracy_score(st.predict(X_test), y_test)
  100. assert score_self_training > score_supervised
  101. def test_none_iter():
  102. # Check that the all samples were labeled after a 'reasonable' number of
  103. # iterations.
  104. st = SelfTrainingClassifier(KNeighborsClassifier(), threshold=0.55, max_iter=None)
  105. st.fit(X_train, y_train_missing_labels)
  106. assert st.n_iter_ < 10
  107. assert st.termination_condition_ == "all_labeled"
  108. @pytest.mark.parametrize(
  109. "base_estimator",
  110. [KNeighborsClassifier(), SVC(gamma="scale", probability=True, random_state=0)],
  111. )
  112. @pytest.mark.parametrize("y", [y_train_missing_labels, y_train_missing_strings])
  113. def test_zero_iterations(base_estimator, y):
  114. # Check classification for zero iterations.
  115. # Fitting a SelfTrainingClassifier with zero iterations should give the
  116. # same results as fitting a supervised classifier.
  117. # This also asserts that string arrays work as expected.
  118. clf1 = SelfTrainingClassifier(base_estimator, max_iter=0)
  119. clf1.fit(X_train, y)
  120. clf2 = base_estimator.fit(X_train[:n_labeled_samples], y[:n_labeled_samples])
  121. assert_array_equal(clf1.predict(X_test), clf2.predict(X_test))
  122. assert clf1.termination_condition_ == "max_iter"
  123. def test_prefitted_throws_error():
  124. # Test that passing a pre-fitted classifier and calling predict throws an
  125. # error
  126. knn = KNeighborsClassifier()
  127. knn.fit(X_train, y_train)
  128. st = SelfTrainingClassifier(knn)
  129. with pytest.raises(
  130. NotFittedError,
  131. match="This SelfTrainingClassifier instance is not fitted yet",
  132. ):
  133. st.predict(X_train)
  134. @pytest.mark.parametrize("max_iter", range(1, 5))
  135. def test_labeled_iter(max_iter):
  136. # Check that the amount of datapoints labeled in iteration 0 is equal to
  137. # the amount of labeled datapoints we passed.
  138. st = SelfTrainingClassifier(KNeighborsClassifier(), max_iter=max_iter)
  139. st.fit(X_train, y_train_missing_labels)
  140. amount_iter_0 = len(st.labeled_iter_[st.labeled_iter_ == 0])
  141. assert amount_iter_0 == n_labeled_samples
  142. # Check that the max of the iterations is less than the total amount of
  143. # iterations
  144. assert np.max(st.labeled_iter_) <= st.n_iter_ <= max_iter
  145. def test_no_unlabeled():
  146. # Test that training on a fully labeled dataset produces the same results
  147. # as training the classifier by itself.
  148. knn = KNeighborsClassifier()
  149. knn.fit(X_train, y_train)
  150. st = SelfTrainingClassifier(knn)
  151. with pytest.warns(UserWarning, match="y contains no unlabeled samples"):
  152. st.fit(X_train, y_train)
  153. assert_array_equal(knn.predict(X_test), st.predict(X_test))
  154. # Assert that all samples were labeled in iteration 0 (since there were no
  155. # unlabeled samples).
  156. assert np.all(st.labeled_iter_ == 0)
  157. assert st.termination_condition_ == "all_labeled"
  158. def test_early_stopping():
  159. svc = SVC(gamma="scale", probability=True)
  160. st = SelfTrainingClassifier(svc)
  161. X_train_easy = [[1], [0], [1], [0.5]]
  162. y_train_easy = [1, 0, -1, -1]
  163. # X = [[0.5]] cannot be predicted on with a high confidence, so training
  164. # stops early
  165. st.fit(X_train_easy, y_train_easy)
  166. assert st.n_iter_ == 1
  167. assert st.termination_condition_ == "no_change"
  168. def test_strings_dtype():
  169. clf = SelfTrainingClassifier(KNeighborsClassifier())
  170. X, y = make_blobs(n_samples=30, random_state=0, cluster_std=0.1)
  171. labels_multiclass = ["one", "two", "three"]
  172. y_strings = np.take(labels_multiclass, y)
  173. with pytest.raises(ValueError, match="dtype"):
  174. clf.fit(X, y_strings)
  175. @pytest.mark.parametrize("verbose", [True, False])
  176. def test_verbose(capsys, verbose):
  177. clf = SelfTrainingClassifier(KNeighborsClassifier(), verbose=verbose)
  178. clf.fit(X_train, y_train_missing_labels)
  179. captured = capsys.readouterr()
  180. if verbose:
  181. assert "iteration" in captured.out
  182. else:
  183. assert "iteration" not in captured.out
  184. def test_verbose_k_best(capsys):
  185. st = SelfTrainingClassifier(
  186. KNeighborsClassifier(n_neighbors=1),
  187. criterion="k_best",
  188. k_best=10,
  189. verbose=True,
  190. max_iter=None,
  191. )
  192. y_train_only_one_label = np.copy(y_train)
  193. y_train_only_one_label[1:] = -1
  194. n_samples = y_train.shape[0]
  195. n_expected_iter = ceil((n_samples - 1) / 10)
  196. st.fit(X_train, y_train_only_one_label)
  197. captured = capsys.readouterr()
  198. msg = "End of iteration {}, added {} new labels."
  199. for i in range(1, n_expected_iter):
  200. assert msg.format(i, 10) in captured.out
  201. assert msg.format(n_expected_iter, (n_samples - 1) % 10) in captured.out
  202. def test_k_best_selects_best():
  203. # Tests that the labels added by st really are the 10 best labels.
  204. svc = SVC(gamma="scale", probability=True, random_state=0)
  205. st = SelfTrainingClassifier(svc, criterion="k_best", max_iter=1, k_best=10)
  206. has_label = y_train_missing_labels != -1
  207. st.fit(X_train, y_train_missing_labels)
  208. got_label = ~has_label & (st.transduction_ != -1)
  209. svc.fit(X_train[has_label], y_train_missing_labels[has_label])
  210. pred = svc.predict_proba(X_train[~has_label])
  211. max_proba = np.max(pred, axis=1)
  212. most_confident_svc = X_train[~has_label][np.argsort(max_proba)[-10:]]
  213. added_by_st = X_train[np.where(got_label)].tolist()
  214. for row in most_confident_svc.tolist():
  215. assert row in added_by_st
  216. def test_base_estimator_meta_estimator():
  217. # Check that a meta-estimator relying on an estimator implementing
  218. # `predict_proba` will work even if it does not expose this method before being
  219. # fitted.
  220. # Non-regression test for:
  221. # https://github.com/scikit-learn/scikit-learn/issues/19119
  222. base_estimator = StackingClassifier(
  223. estimators=[
  224. ("svc_1", SVC(probability=True)),
  225. ("svc_2", SVC(probability=True)),
  226. ],
  227. final_estimator=SVC(probability=True),
  228. cv=2,
  229. )
  230. assert hasattr(base_estimator, "predict_proba")
  231. clf = SelfTrainingClassifier(base_estimator=base_estimator)
  232. clf.fit(X_train, y_train_missing_labels)
  233. clf.predict_proba(X_test)
  234. base_estimator = StackingClassifier(
  235. estimators=[
  236. ("svc_1", SVC(probability=False)),
  237. ("svc_2", SVC(probability=False)),
  238. ],
  239. final_estimator=SVC(probability=False),
  240. cv=2,
  241. )
  242. assert not hasattr(base_estimator, "predict_proba")
  243. clf = SelfTrainingClassifier(base_estimator=base_estimator)
  244. with pytest.raises(AttributeError):
  245. clf.fit(X_train, y_train_missing_labels)
  246. def test_missing_predict_proba():
  247. # Check that an error is thrown if predict_proba is not implemented
  248. base_estimator = SVC(probability=False, gamma="scale")
  249. self_training = SelfTrainingClassifier(base_estimator)
  250. with pytest.raises(AttributeError, match="predict_proba is not available"):
  251. self_training.fit(X_train, y_train_missing_labels)