_self_training.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. import warnings
  2. from numbers import Integral, Real
  3. import numpy as np
  4. from ..base import BaseEstimator, MetaEstimatorMixin, _fit_context, clone
  5. from ..utils import safe_mask
  6. from ..utils._param_validation import HasMethods, Interval, StrOptions
  7. from ..utils.metaestimators import available_if
  8. from ..utils.validation import check_is_fitted
  9. __all__ = ["SelfTrainingClassifier"]
  10. # Authors: Oliver Rausch <rauscho@ethz.ch>
  11. # Patrice Becker <beckerp@ethz.ch>
  12. # License: BSD 3 clause
  13. def _estimator_has(attr):
  14. """Check if `self.base_estimator_ `or `self.base_estimator_` has `attr`."""
  15. return lambda self: (
  16. hasattr(self.base_estimator_, attr)
  17. if hasattr(self, "base_estimator_")
  18. else hasattr(self.base_estimator, attr)
  19. )
  20. class SelfTrainingClassifier(MetaEstimatorMixin, BaseEstimator):
  21. """Self-training classifier.
  22. This :term:`metaestimator` allows a given supervised classifier to function as a
  23. semi-supervised classifier, allowing it to learn from unlabeled data. It
  24. does this by iteratively predicting pseudo-labels for the unlabeled data
  25. and adding them to the training set.
  26. The classifier will continue iterating until either max_iter is reached, or
  27. no pseudo-labels were added to the training set in the previous iteration.
  28. Read more in the :ref:`User Guide <self_training>`.
  29. Parameters
  30. ----------
  31. base_estimator : estimator object
  32. An estimator object implementing `fit` and `predict_proba`.
  33. Invoking the `fit` method will fit a clone of the passed estimator,
  34. which will be stored in the `base_estimator_` attribute.
  35. threshold : float, default=0.75
  36. The decision threshold for use with `criterion='threshold'`.
  37. Should be in [0, 1). When using the `'threshold'` criterion, a
  38. :ref:`well calibrated classifier <calibration>` should be used.
  39. criterion : {'threshold', 'k_best'}, default='threshold'
  40. The selection criterion used to select which labels to add to the
  41. training set. If `'threshold'`, pseudo-labels with prediction
  42. probabilities above `threshold` are added to the dataset. If `'k_best'`,
  43. the `k_best` pseudo-labels with highest prediction probabilities are
  44. added to the dataset. When using the 'threshold' criterion, a
  45. :ref:`well calibrated classifier <calibration>` should be used.
  46. k_best : int, default=10
  47. The amount of samples to add in each iteration. Only used when
  48. `criterion='k_best'`.
  49. max_iter : int or None, default=10
  50. Maximum number of iterations allowed. Should be greater than or equal
  51. to 0. If it is `None`, the classifier will continue to predict labels
  52. until no new pseudo-labels are added, or all unlabeled samples have
  53. been labeled.
  54. verbose : bool, default=False
  55. Enable verbose output.
  56. Attributes
  57. ----------
  58. base_estimator_ : estimator object
  59. The fitted estimator.
  60. classes_ : ndarray or list of ndarray of shape (n_classes,)
  61. Class labels for each output. (Taken from the trained
  62. `base_estimator_`).
  63. transduction_ : ndarray of shape (n_samples,)
  64. The labels used for the final fit of the classifier, including
  65. pseudo-labels added during fit.
  66. labeled_iter_ : ndarray of shape (n_samples,)
  67. The iteration in which each sample was labeled. When a sample has
  68. iteration 0, the sample was already labeled in the original dataset.
  69. When a sample has iteration -1, the sample was not labeled in any
  70. iteration.
  71. n_features_in_ : int
  72. Number of features seen during :term:`fit`.
  73. .. versionadded:: 0.24
  74. feature_names_in_ : ndarray of shape (`n_features_in_`,)
  75. Names of features seen during :term:`fit`. Defined only when `X`
  76. has feature names that are all strings.
  77. .. versionadded:: 1.0
  78. n_iter_ : int
  79. The number of rounds of self-training, that is the number of times the
  80. base estimator is fitted on relabeled variants of the training set.
  81. termination_condition_ : {'max_iter', 'no_change', 'all_labeled'}
  82. The reason that fitting was stopped.
  83. - `'max_iter'`: `n_iter_` reached `max_iter`.
  84. - `'no_change'`: no new labels were predicted.
  85. - `'all_labeled'`: all unlabeled samples were labeled before `max_iter`
  86. was reached.
  87. See Also
  88. --------
  89. LabelPropagation : Label propagation classifier.
  90. LabelSpreading : Label spreading model for semi-supervised learning.
  91. References
  92. ----------
  93. :doi:`David Yarowsky. 1995. Unsupervised word sense disambiguation rivaling
  94. supervised methods. In Proceedings of the 33rd annual meeting on
  95. Association for Computational Linguistics (ACL '95). Association for
  96. Computational Linguistics, Stroudsburg, PA, USA, 189-196.
  97. <10.3115/981658.981684>`
  98. Examples
  99. --------
  100. >>> import numpy as np
  101. >>> from sklearn import datasets
  102. >>> from sklearn.semi_supervised import SelfTrainingClassifier
  103. >>> from sklearn.svm import SVC
  104. >>> rng = np.random.RandomState(42)
  105. >>> iris = datasets.load_iris()
  106. >>> random_unlabeled_points = rng.rand(iris.target.shape[0]) < 0.3
  107. >>> iris.target[random_unlabeled_points] = -1
  108. >>> svc = SVC(probability=True, gamma="auto")
  109. >>> self_training_model = SelfTrainingClassifier(svc)
  110. >>> self_training_model.fit(iris.data, iris.target)
  111. SelfTrainingClassifier(...)
  112. """
  113. _estimator_type = "classifier"
  114. _parameter_constraints: dict = {
  115. # We don't require `predic_proba` here to allow passing a meta-estimator
  116. # that only exposes `predict_proba` after fitting.
  117. "base_estimator": [HasMethods(["fit"])],
  118. "threshold": [Interval(Real, 0.0, 1.0, closed="left")],
  119. "criterion": [StrOptions({"threshold", "k_best"})],
  120. "k_best": [Interval(Integral, 1, None, closed="left")],
  121. "max_iter": [Interval(Integral, 0, None, closed="left"), None],
  122. "verbose": ["verbose"],
  123. }
  124. def __init__(
  125. self,
  126. base_estimator,
  127. threshold=0.75,
  128. criterion="threshold",
  129. k_best=10,
  130. max_iter=10,
  131. verbose=False,
  132. ):
  133. self.base_estimator = base_estimator
  134. self.threshold = threshold
  135. self.criterion = criterion
  136. self.k_best = k_best
  137. self.max_iter = max_iter
  138. self.verbose = verbose
  139. @_fit_context(
  140. # SelfTrainingClassifier.base_estimator is not validated yet
  141. prefer_skip_nested_validation=False
  142. )
  143. def fit(self, X, y):
  144. """
  145. Fit self-training classifier using `X`, `y` as training data.
  146. Parameters
  147. ----------
  148. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  149. Array representing the data.
  150. y : {array-like, sparse matrix} of shape (n_samples,)
  151. Array representing the labels. Unlabeled samples should have the
  152. label -1.
  153. Returns
  154. -------
  155. self : object
  156. Fitted estimator.
  157. """
  158. # we need row slicing support for sparce matrices, but costly finiteness check
  159. # can be delegated to the base estimator.
  160. X, y = self._validate_data(
  161. X, y, accept_sparse=["csr", "csc", "lil", "dok"], force_all_finite=False
  162. )
  163. self.base_estimator_ = clone(self.base_estimator)
  164. if y.dtype.kind in ["U", "S"]:
  165. raise ValueError(
  166. "y has dtype string. If you wish to predict on "
  167. "string targets, use dtype object, and use -1"
  168. " as the label for unlabeled samples."
  169. )
  170. has_label = y != -1
  171. if np.all(has_label):
  172. warnings.warn("y contains no unlabeled samples", UserWarning)
  173. if self.criterion == "k_best" and (
  174. self.k_best > X.shape[0] - np.sum(has_label)
  175. ):
  176. warnings.warn(
  177. (
  178. "k_best is larger than the amount of unlabeled "
  179. "samples. All unlabeled samples will be labeled in "
  180. "the first iteration"
  181. ),
  182. UserWarning,
  183. )
  184. self.transduction_ = np.copy(y)
  185. self.labeled_iter_ = np.full_like(y, -1)
  186. self.labeled_iter_[has_label] = 0
  187. self.n_iter_ = 0
  188. while not np.all(has_label) and (
  189. self.max_iter is None or self.n_iter_ < self.max_iter
  190. ):
  191. self.n_iter_ += 1
  192. self.base_estimator_.fit(
  193. X[safe_mask(X, has_label)], self.transduction_[has_label]
  194. )
  195. # Predict on the unlabeled samples
  196. prob = self.base_estimator_.predict_proba(X[safe_mask(X, ~has_label)])
  197. pred = self.base_estimator_.classes_[np.argmax(prob, axis=1)]
  198. max_proba = np.max(prob, axis=1)
  199. # Select new labeled samples
  200. if self.criterion == "threshold":
  201. selected = max_proba > self.threshold
  202. else:
  203. n_to_select = min(self.k_best, max_proba.shape[0])
  204. if n_to_select == max_proba.shape[0]:
  205. selected = np.ones_like(max_proba, dtype=bool)
  206. else:
  207. # NB these are indices, not a mask
  208. selected = np.argpartition(-max_proba, n_to_select)[:n_to_select]
  209. # Map selected indices into original array
  210. selected_full = np.nonzero(~has_label)[0][selected]
  211. # Add newly labeled confident predictions to the dataset
  212. self.transduction_[selected_full] = pred[selected]
  213. has_label[selected_full] = True
  214. self.labeled_iter_[selected_full] = self.n_iter_
  215. if selected_full.shape[0] == 0:
  216. # no changed labels
  217. self.termination_condition_ = "no_change"
  218. break
  219. if self.verbose:
  220. print(
  221. f"End of iteration {self.n_iter_},"
  222. f" added {selected_full.shape[0]} new labels."
  223. )
  224. if self.n_iter_ == self.max_iter:
  225. self.termination_condition_ = "max_iter"
  226. if np.all(has_label):
  227. self.termination_condition_ = "all_labeled"
  228. self.base_estimator_.fit(
  229. X[safe_mask(X, has_label)], self.transduction_[has_label]
  230. )
  231. self.classes_ = self.base_estimator_.classes_
  232. return self
  233. @available_if(_estimator_has("predict"))
  234. def predict(self, X):
  235. """Predict the classes of `X`.
  236. Parameters
  237. ----------
  238. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  239. Array representing the data.
  240. Returns
  241. -------
  242. y : ndarray of shape (n_samples,)
  243. Array with predicted labels.
  244. """
  245. check_is_fitted(self)
  246. X = self._validate_data(
  247. X,
  248. accept_sparse=True,
  249. force_all_finite=False,
  250. reset=False,
  251. )
  252. return self.base_estimator_.predict(X)
  253. @available_if(_estimator_has("predict_proba"))
  254. def predict_proba(self, X):
  255. """Predict probability for each possible outcome.
  256. Parameters
  257. ----------
  258. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  259. Array representing the data.
  260. Returns
  261. -------
  262. y : ndarray of shape (n_samples, n_features)
  263. Array with prediction probabilities.
  264. """
  265. check_is_fitted(self)
  266. X = self._validate_data(
  267. X,
  268. accept_sparse=True,
  269. force_all_finite=False,
  270. reset=False,
  271. )
  272. return self.base_estimator_.predict_proba(X)
  273. @available_if(_estimator_has("decision_function"))
  274. def decision_function(self, X):
  275. """Call decision function of the `base_estimator`.
  276. Parameters
  277. ----------
  278. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  279. Array representing the data.
  280. Returns
  281. -------
  282. y : ndarray of shape (n_samples, n_features)
  283. Result of the decision function of the `base_estimator`.
  284. """
  285. check_is_fitted(self)
  286. X = self._validate_data(
  287. X,
  288. accept_sparse=True,
  289. force_all_finite=False,
  290. reset=False,
  291. )
  292. return self.base_estimator_.decision_function(X)
  293. @available_if(_estimator_has("predict_log_proba"))
  294. def predict_log_proba(self, X):
  295. """Predict log probability for each possible outcome.
  296. Parameters
  297. ----------
  298. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  299. Array representing the data.
  300. Returns
  301. -------
  302. y : ndarray of shape (n_samples, n_features)
  303. Array with log prediction probabilities.
  304. """
  305. check_is_fitted(self)
  306. X = self._validate_data(
  307. X,
  308. accept_sparse=True,
  309. force_all_finite=False,
  310. reset=False,
  311. )
  312. return self.base_estimator_.predict_log_proba(X)
  313. @available_if(_estimator_has("score"))
  314. def score(self, X, y):
  315. """Call score on the `base_estimator`.
  316. Parameters
  317. ----------
  318. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  319. Array representing the data.
  320. y : array-like of shape (n_samples,)
  321. Array representing the labels.
  322. Returns
  323. -------
  324. score : float
  325. Result of calling score on the `base_estimator`.
  326. """
  327. check_is_fitted(self)
  328. X = self._validate_data(
  329. X,
  330. accept_sparse=True,
  331. force_all_finite=False,
  332. reset=False,
  333. )
  334. return self.base_estimator_.score(X, y)