_ransac.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614
  1. # Author: Johannes Schönberger
  2. #
  3. # License: BSD 3 clause
  4. import warnings
  5. from numbers import Integral, Real
  6. import numpy as np
  7. from ..base import (
  8. BaseEstimator,
  9. MetaEstimatorMixin,
  10. MultiOutputMixin,
  11. RegressorMixin,
  12. _fit_context,
  13. clone,
  14. )
  15. from ..exceptions import ConvergenceWarning
  16. from ..utils import check_consistent_length, check_random_state
  17. from ..utils._param_validation import (
  18. HasMethods,
  19. Interval,
  20. Options,
  21. RealNotInt,
  22. StrOptions,
  23. )
  24. from ..utils.random import sample_without_replacement
  25. from ..utils.validation import _check_sample_weight, check_is_fitted, has_fit_parameter
  26. from ._base import LinearRegression
  27. _EPSILON = np.spacing(1)
  28. def _dynamic_max_trials(n_inliers, n_samples, min_samples, probability):
  29. """Determine number trials such that at least one outlier-free subset is
  30. sampled for the given inlier/outlier ratio.
  31. Parameters
  32. ----------
  33. n_inliers : int
  34. Number of inliers in the data.
  35. n_samples : int
  36. Total number of samples in the data.
  37. min_samples : int
  38. Minimum number of samples chosen randomly from original data.
  39. probability : float
  40. Probability (confidence) that one outlier-free sample is generated.
  41. Returns
  42. -------
  43. trials : int
  44. Number of trials.
  45. """
  46. inlier_ratio = n_inliers / float(n_samples)
  47. nom = max(_EPSILON, 1 - probability)
  48. denom = max(_EPSILON, 1 - inlier_ratio**min_samples)
  49. if nom == 1:
  50. return 0
  51. if denom == 1:
  52. return float("inf")
  53. return abs(float(np.ceil(np.log(nom) / np.log(denom))))
  54. class RANSACRegressor(
  55. MetaEstimatorMixin, RegressorMixin, MultiOutputMixin, BaseEstimator
  56. ):
  57. """RANSAC (RANdom SAmple Consensus) algorithm.
  58. RANSAC is an iterative algorithm for the robust estimation of parameters
  59. from a subset of inliers from the complete data set.
  60. Read more in the :ref:`User Guide <ransac_regression>`.
  61. Parameters
  62. ----------
  63. estimator : object, default=None
  64. Base estimator object which implements the following methods:
  65. * `fit(X, y)`: Fit model to given training data and target values.
  66. * `score(X, y)`: Returns the mean accuracy on the given test data,
  67. which is used for the stop criterion defined by `stop_score`.
  68. Additionally, the score is used to decide which of two equally
  69. large consensus sets is chosen as the better one.
  70. * `predict(X)`: Returns predicted values using the linear model,
  71. which is used to compute residual error using loss function.
  72. If `estimator` is None, then
  73. :class:`~sklearn.linear_model.LinearRegression` is used for
  74. target values of dtype float.
  75. Note that the current implementation only supports regression
  76. estimators.
  77. min_samples : int (>= 1) or float ([0, 1]), default=None
  78. Minimum number of samples chosen randomly from original data. Treated
  79. as an absolute number of samples for `min_samples >= 1`, treated as a
  80. relative number `ceil(min_samples * X.shape[0])` for
  81. `min_samples < 1`. This is typically chosen as the minimal number of
  82. samples necessary to estimate the given `estimator`. By default a
  83. :class:`~sklearn.linear_model.LinearRegression` estimator is assumed and
  84. `min_samples` is chosen as ``X.shape[1] + 1``. This parameter is highly
  85. dependent upon the model, so if a `estimator` other than
  86. :class:`~sklearn.linear_model.LinearRegression` is used, the user must
  87. provide a value.
  88. residual_threshold : float, default=None
  89. Maximum residual for a data sample to be classified as an inlier.
  90. By default the threshold is chosen as the MAD (median absolute
  91. deviation) of the target values `y`. Points whose residuals are
  92. strictly equal to the threshold are considered as inliers.
  93. is_data_valid : callable, default=None
  94. This function is called with the randomly selected data before the
  95. model is fitted to it: `is_data_valid(X, y)`. If its return value is
  96. False the current randomly chosen sub-sample is skipped.
  97. is_model_valid : callable, default=None
  98. This function is called with the estimated model and the randomly
  99. selected data: `is_model_valid(model, X, y)`. If its return value is
  100. False the current randomly chosen sub-sample is skipped.
  101. Rejecting samples with this function is computationally costlier than
  102. with `is_data_valid`. `is_model_valid` should therefore only be used if
  103. the estimated model is needed for making the rejection decision.
  104. max_trials : int, default=100
  105. Maximum number of iterations for random sample selection.
  106. max_skips : int, default=np.inf
  107. Maximum number of iterations that can be skipped due to finding zero
  108. inliers or invalid data defined by ``is_data_valid`` or invalid models
  109. defined by ``is_model_valid``.
  110. .. versionadded:: 0.19
  111. stop_n_inliers : int, default=np.inf
  112. Stop iteration if at least this number of inliers are found.
  113. stop_score : float, default=np.inf
  114. Stop iteration if score is greater equal than this threshold.
  115. stop_probability : float in range [0, 1], default=0.99
  116. RANSAC iteration stops if at least one outlier-free set of the training
  117. data is sampled in RANSAC. This requires to generate at least N
  118. samples (iterations)::
  119. N >= log(1 - probability) / log(1 - e**m)
  120. where the probability (confidence) is typically set to high value such
  121. as 0.99 (the default) and e is the current fraction of inliers w.r.t.
  122. the total number of samples.
  123. loss : str, callable, default='absolute_error'
  124. String inputs, 'absolute_error' and 'squared_error' are supported which
  125. find the absolute error and squared error per sample respectively.
  126. If ``loss`` is a callable, then it should be a function that takes
  127. two arrays as inputs, the true and predicted value and returns a 1-D
  128. array with the i-th value of the array corresponding to the loss
  129. on ``X[i]``.
  130. If the loss on a sample is greater than the ``residual_threshold``,
  131. then this sample is classified as an outlier.
  132. .. versionadded:: 0.18
  133. random_state : int, RandomState instance, default=None
  134. The generator used to initialize the centers.
  135. Pass an int for reproducible output across multiple function calls.
  136. See :term:`Glossary <random_state>`.
  137. Attributes
  138. ----------
  139. estimator_ : object
  140. Best fitted model (copy of the `estimator` object).
  141. n_trials_ : int
  142. Number of random selection trials until one of the stop criteria is
  143. met. It is always ``<= max_trials``.
  144. inlier_mask_ : bool array of shape [n_samples]
  145. Boolean mask of inliers classified as ``True``.
  146. n_skips_no_inliers_ : int
  147. Number of iterations skipped due to finding zero inliers.
  148. .. versionadded:: 0.19
  149. n_skips_invalid_data_ : int
  150. Number of iterations skipped due to invalid data defined by
  151. ``is_data_valid``.
  152. .. versionadded:: 0.19
  153. n_skips_invalid_model_ : int
  154. Number of iterations skipped due to an invalid model defined by
  155. ``is_model_valid``.
  156. .. versionadded:: 0.19
  157. n_features_in_ : int
  158. Number of features seen during :term:`fit`.
  159. .. versionadded:: 0.24
  160. feature_names_in_ : ndarray of shape (`n_features_in_`,)
  161. Names of features seen during :term:`fit`. Defined only when `X`
  162. has feature names that are all strings.
  163. .. versionadded:: 1.0
  164. See Also
  165. --------
  166. HuberRegressor : Linear regression model that is robust to outliers.
  167. TheilSenRegressor : Theil-Sen Estimator robust multivariate regression model.
  168. SGDRegressor : Fitted by minimizing a regularized empirical loss with SGD.
  169. References
  170. ----------
  171. .. [1] https://en.wikipedia.org/wiki/RANSAC
  172. .. [2] https://www.sri.com/wp-content/uploads/2021/12/ransac-publication.pdf
  173. .. [3] http://www.bmva.org/bmvc/2009/Papers/Paper355/Paper355.pdf
  174. Examples
  175. --------
  176. >>> from sklearn.linear_model import RANSACRegressor
  177. >>> from sklearn.datasets import make_regression
  178. >>> X, y = make_regression(
  179. ... n_samples=200, n_features=2, noise=4.0, random_state=0)
  180. >>> reg = RANSACRegressor(random_state=0).fit(X, y)
  181. >>> reg.score(X, y)
  182. 0.9885...
  183. >>> reg.predict(X[:1,])
  184. array([-31.9417...])
  185. """ # noqa: E501
  186. _parameter_constraints: dict = {
  187. "estimator": [HasMethods(["fit", "score", "predict"]), None],
  188. "min_samples": [
  189. Interval(Integral, 1, None, closed="left"),
  190. Interval(RealNotInt, 0, 1, closed="both"),
  191. None,
  192. ],
  193. "residual_threshold": [Interval(Real, 0, None, closed="left"), None],
  194. "is_data_valid": [callable, None],
  195. "is_model_valid": [callable, None],
  196. "max_trials": [
  197. Interval(Integral, 0, None, closed="left"),
  198. Options(Real, {np.inf}),
  199. ],
  200. "max_skips": [
  201. Interval(Integral, 0, None, closed="left"),
  202. Options(Real, {np.inf}),
  203. ],
  204. "stop_n_inliers": [
  205. Interval(Integral, 0, None, closed="left"),
  206. Options(Real, {np.inf}),
  207. ],
  208. "stop_score": [Interval(Real, None, None, closed="both")],
  209. "stop_probability": [Interval(Real, 0, 1, closed="both")],
  210. "loss": [StrOptions({"absolute_error", "squared_error"}), callable],
  211. "random_state": ["random_state"],
  212. }
  213. def __init__(
  214. self,
  215. estimator=None,
  216. *,
  217. min_samples=None,
  218. residual_threshold=None,
  219. is_data_valid=None,
  220. is_model_valid=None,
  221. max_trials=100,
  222. max_skips=np.inf,
  223. stop_n_inliers=np.inf,
  224. stop_score=np.inf,
  225. stop_probability=0.99,
  226. loss="absolute_error",
  227. random_state=None,
  228. ):
  229. self.estimator = estimator
  230. self.min_samples = min_samples
  231. self.residual_threshold = residual_threshold
  232. self.is_data_valid = is_data_valid
  233. self.is_model_valid = is_model_valid
  234. self.max_trials = max_trials
  235. self.max_skips = max_skips
  236. self.stop_n_inliers = stop_n_inliers
  237. self.stop_score = stop_score
  238. self.stop_probability = stop_probability
  239. self.random_state = random_state
  240. self.loss = loss
  241. @_fit_context(
  242. # RansacRegressor.estimator is not validated yet
  243. prefer_skip_nested_validation=False
  244. )
  245. def fit(self, X, y, sample_weight=None):
  246. """Fit estimator using RANSAC algorithm.
  247. Parameters
  248. ----------
  249. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  250. Training data.
  251. y : array-like of shape (n_samples,) or (n_samples, n_targets)
  252. Target values.
  253. sample_weight : array-like of shape (n_samples,), default=None
  254. Individual weights for each sample
  255. raises error if sample_weight is passed and estimator
  256. fit method does not support it.
  257. .. versionadded:: 0.18
  258. Returns
  259. -------
  260. self : object
  261. Fitted `RANSACRegressor` estimator.
  262. Raises
  263. ------
  264. ValueError
  265. If no valid consensus set could be found. This occurs if
  266. `is_data_valid` and `is_model_valid` return False for all
  267. `max_trials` randomly chosen sub-samples.
  268. """
  269. # Need to validate separately here. We can't pass multi_output=True
  270. # because that would allow y to be csr. Delay expensive finiteness
  271. # check to the estimator's own input validation.
  272. check_X_params = dict(accept_sparse="csr", force_all_finite=False)
  273. check_y_params = dict(ensure_2d=False)
  274. X, y = self._validate_data(
  275. X, y, validate_separately=(check_X_params, check_y_params)
  276. )
  277. check_consistent_length(X, y)
  278. if self.estimator is not None:
  279. estimator = clone(self.estimator)
  280. else:
  281. estimator = LinearRegression()
  282. if self.min_samples is None:
  283. if not isinstance(estimator, LinearRegression):
  284. raise ValueError(
  285. "`min_samples` needs to be explicitly set when estimator "
  286. "is not a LinearRegression."
  287. )
  288. min_samples = X.shape[1] + 1
  289. elif 0 < self.min_samples < 1:
  290. min_samples = np.ceil(self.min_samples * X.shape[0])
  291. elif self.min_samples >= 1:
  292. min_samples = self.min_samples
  293. if min_samples > X.shape[0]:
  294. raise ValueError(
  295. "`min_samples` may not be larger than number "
  296. "of samples: n_samples = %d." % (X.shape[0])
  297. )
  298. if self.residual_threshold is None:
  299. # MAD (median absolute deviation)
  300. residual_threshold = np.median(np.abs(y - np.median(y)))
  301. else:
  302. residual_threshold = self.residual_threshold
  303. if self.loss == "absolute_error":
  304. if y.ndim == 1:
  305. loss_function = lambda y_true, y_pred: np.abs(y_true - y_pred)
  306. else:
  307. loss_function = lambda y_true, y_pred: np.sum(
  308. np.abs(y_true - y_pred), axis=1
  309. )
  310. elif self.loss == "squared_error":
  311. if y.ndim == 1:
  312. loss_function = lambda y_true, y_pred: (y_true - y_pred) ** 2
  313. else:
  314. loss_function = lambda y_true, y_pred: np.sum(
  315. (y_true - y_pred) ** 2, axis=1
  316. )
  317. elif callable(self.loss):
  318. loss_function = self.loss
  319. random_state = check_random_state(self.random_state)
  320. try: # Not all estimator accept a random_state
  321. estimator.set_params(random_state=random_state)
  322. except ValueError:
  323. pass
  324. estimator_fit_has_sample_weight = has_fit_parameter(estimator, "sample_weight")
  325. estimator_name = type(estimator).__name__
  326. if sample_weight is not None and not estimator_fit_has_sample_weight:
  327. raise ValueError(
  328. "%s does not support sample_weight. Samples"
  329. " weights are only used for the calibration"
  330. " itself." % estimator_name
  331. )
  332. if sample_weight is not None:
  333. sample_weight = _check_sample_weight(sample_weight, X)
  334. n_inliers_best = 1
  335. score_best = -np.inf
  336. inlier_mask_best = None
  337. X_inlier_best = None
  338. y_inlier_best = None
  339. inlier_best_idxs_subset = None
  340. self.n_skips_no_inliers_ = 0
  341. self.n_skips_invalid_data_ = 0
  342. self.n_skips_invalid_model_ = 0
  343. # number of data samples
  344. n_samples = X.shape[0]
  345. sample_idxs = np.arange(n_samples)
  346. self.n_trials_ = 0
  347. max_trials = self.max_trials
  348. while self.n_trials_ < max_trials:
  349. self.n_trials_ += 1
  350. if (
  351. self.n_skips_no_inliers_
  352. + self.n_skips_invalid_data_
  353. + self.n_skips_invalid_model_
  354. ) > self.max_skips:
  355. break
  356. # choose random sample set
  357. subset_idxs = sample_without_replacement(
  358. n_samples, min_samples, random_state=random_state
  359. )
  360. X_subset = X[subset_idxs]
  361. y_subset = y[subset_idxs]
  362. # check if random sample set is valid
  363. if self.is_data_valid is not None and not self.is_data_valid(
  364. X_subset, y_subset
  365. ):
  366. self.n_skips_invalid_data_ += 1
  367. continue
  368. # fit model for current random sample set
  369. if sample_weight is None:
  370. estimator.fit(X_subset, y_subset)
  371. else:
  372. estimator.fit(
  373. X_subset, y_subset, sample_weight=sample_weight[subset_idxs]
  374. )
  375. # check if estimated model is valid
  376. if self.is_model_valid is not None and not self.is_model_valid(
  377. estimator, X_subset, y_subset
  378. ):
  379. self.n_skips_invalid_model_ += 1
  380. continue
  381. # residuals of all data for current random sample model
  382. y_pred = estimator.predict(X)
  383. residuals_subset = loss_function(y, y_pred)
  384. # classify data into inliers and outliers
  385. inlier_mask_subset = residuals_subset <= residual_threshold
  386. n_inliers_subset = np.sum(inlier_mask_subset)
  387. # less inliers -> skip current random sample
  388. if n_inliers_subset < n_inliers_best:
  389. self.n_skips_no_inliers_ += 1
  390. continue
  391. # extract inlier data set
  392. inlier_idxs_subset = sample_idxs[inlier_mask_subset]
  393. X_inlier_subset = X[inlier_idxs_subset]
  394. y_inlier_subset = y[inlier_idxs_subset]
  395. # score of inlier data set
  396. score_subset = estimator.score(X_inlier_subset, y_inlier_subset)
  397. # same number of inliers but worse score -> skip current random
  398. # sample
  399. if n_inliers_subset == n_inliers_best and score_subset < score_best:
  400. continue
  401. # save current random sample as best sample
  402. n_inliers_best = n_inliers_subset
  403. score_best = score_subset
  404. inlier_mask_best = inlier_mask_subset
  405. X_inlier_best = X_inlier_subset
  406. y_inlier_best = y_inlier_subset
  407. inlier_best_idxs_subset = inlier_idxs_subset
  408. max_trials = min(
  409. max_trials,
  410. _dynamic_max_trials(
  411. n_inliers_best, n_samples, min_samples, self.stop_probability
  412. ),
  413. )
  414. # break if sufficient number of inliers or score is reached
  415. if n_inliers_best >= self.stop_n_inliers or score_best >= self.stop_score:
  416. break
  417. # if none of the iterations met the required criteria
  418. if inlier_mask_best is None:
  419. if (
  420. self.n_skips_no_inliers_
  421. + self.n_skips_invalid_data_
  422. + self.n_skips_invalid_model_
  423. ) > self.max_skips:
  424. raise ValueError(
  425. "RANSAC skipped more iterations than `max_skips` without"
  426. " finding a valid consensus set. Iterations were skipped"
  427. " because each randomly chosen sub-sample failed the"
  428. " passing criteria. See estimator attributes for"
  429. " diagnostics (n_skips*)."
  430. )
  431. else:
  432. raise ValueError(
  433. "RANSAC could not find a valid consensus set. All"
  434. " `max_trials` iterations were skipped because each"
  435. " randomly chosen sub-sample failed the passing criteria."
  436. " See estimator attributes for diagnostics (n_skips*)."
  437. )
  438. else:
  439. if (
  440. self.n_skips_no_inliers_
  441. + self.n_skips_invalid_data_
  442. + self.n_skips_invalid_model_
  443. ) > self.max_skips:
  444. warnings.warn(
  445. (
  446. "RANSAC found a valid consensus set but exited"
  447. " early due to skipping more iterations than"
  448. " `max_skips`. See estimator attributes for"
  449. " diagnostics (n_skips*)."
  450. ),
  451. ConvergenceWarning,
  452. )
  453. # estimate final model using all inliers
  454. if sample_weight is None:
  455. estimator.fit(X_inlier_best, y_inlier_best)
  456. else:
  457. estimator.fit(
  458. X_inlier_best,
  459. y_inlier_best,
  460. sample_weight=sample_weight[inlier_best_idxs_subset],
  461. )
  462. self.estimator_ = estimator
  463. self.inlier_mask_ = inlier_mask_best
  464. return self
  465. def predict(self, X):
  466. """Predict using the estimated model.
  467. This is a wrapper for `estimator_.predict(X)`.
  468. Parameters
  469. ----------
  470. X : {array-like or sparse matrix} of shape (n_samples, n_features)
  471. Input data.
  472. Returns
  473. -------
  474. y : array, shape = [n_samples] or [n_samples, n_targets]
  475. Returns predicted values.
  476. """
  477. check_is_fitted(self)
  478. X = self._validate_data(
  479. X,
  480. force_all_finite=False,
  481. accept_sparse=True,
  482. reset=False,
  483. )
  484. return self.estimator_.predict(X)
  485. def score(self, X, y):
  486. """Return the score of the prediction.
  487. This is a wrapper for `estimator_.score(X, y)`.
  488. Parameters
  489. ----------
  490. X : (array-like or sparse matrix} of shape (n_samples, n_features)
  491. Training data.
  492. y : array-like of shape (n_samples,) or (n_samples, n_targets)
  493. Target values.
  494. Returns
  495. -------
  496. z : float
  497. Score of the prediction.
  498. """
  499. check_is_fitted(self)
  500. X = self._validate_data(
  501. X,
  502. force_all_finite=False,
  503. accept_sparse=True,
  504. reset=False,
  505. )
  506. return self.estimator_.score(X, y)
  507. def _more_tags(self):
  508. return {
  509. "_xfail_checks": {
  510. "check_sample_weights_invariance": (
  511. "zero sample_weight is not equivalent to removing samples"
  512. ),
  513. }
  514. }