test_ransac.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. import numpy as np
  2. import pytest
  3. from numpy.testing import assert_array_almost_equal, assert_array_equal
  4. from scipy import sparse
  5. from sklearn.datasets import make_regression
  6. from sklearn.exceptions import ConvergenceWarning
  7. from sklearn.linear_model import (
  8. LinearRegression,
  9. OrthogonalMatchingPursuit,
  10. RANSACRegressor,
  11. Ridge,
  12. )
  13. from sklearn.linear_model._ransac import _dynamic_max_trials
  14. from sklearn.utils import check_random_state
  15. from sklearn.utils._testing import assert_allclose
  16. # Generate coordinates of line
  17. X = np.arange(-200, 200)
  18. y = 0.2 * X + 20
  19. data = np.column_stack([X, y])
  20. # Add some faulty data
  21. rng = np.random.RandomState(1000)
  22. outliers = np.unique(rng.randint(len(X), size=200))
  23. data[outliers, :] += 50 + rng.rand(len(outliers), 2) * 10
  24. X = data[:, 0][:, np.newaxis]
  25. y = data[:, 1]
  26. def test_ransac_inliers_outliers():
  27. estimator = LinearRegression()
  28. ransac_estimator = RANSACRegressor(
  29. estimator, min_samples=2, residual_threshold=5, random_state=0
  30. )
  31. # Estimate parameters of corrupted data
  32. ransac_estimator.fit(X, y)
  33. # Ground truth / reference inlier mask
  34. ref_inlier_mask = np.ones_like(ransac_estimator.inlier_mask_).astype(np.bool_)
  35. ref_inlier_mask[outliers] = False
  36. assert_array_equal(ransac_estimator.inlier_mask_, ref_inlier_mask)
  37. def test_ransac_is_data_valid():
  38. def is_data_valid(X, y):
  39. assert X.shape[0] == 2
  40. assert y.shape[0] == 2
  41. return False
  42. rng = np.random.RandomState(0)
  43. X = rng.rand(10, 2)
  44. y = rng.rand(10, 1)
  45. estimator = LinearRegression()
  46. ransac_estimator = RANSACRegressor(
  47. estimator,
  48. min_samples=2,
  49. residual_threshold=5,
  50. is_data_valid=is_data_valid,
  51. random_state=0,
  52. )
  53. with pytest.raises(ValueError):
  54. ransac_estimator.fit(X, y)
  55. def test_ransac_is_model_valid():
  56. def is_model_valid(estimator, X, y):
  57. assert X.shape[0] == 2
  58. assert y.shape[0] == 2
  59. return False
  60. estimator = LinearRegression()
  61. ransac_estimator = RANSACRegressor(
  62. estimator,
  63. min_samples=2,
  64. residual_threshold=5,
  65. is_model_valid=is_model_valid,
  66. random_state=0,
  67. )
  68. with pytest.raises(ValueError):
  69. ransac_estimator.fit(X, y)
  70. def test_ransac_max_trials():
  71. estimator = LinearRegression()
  72. ransac_estimator = RANSACRegressor(
  73. estimator,
  74. min_samples=2,
  75. residual_threshold=5,
  76. max_trials=0,
  77. random_state=0,
  78. )
  79. with pytest.raises(ValueError):
  80. ransac_estimator.fit(X, y)
  81. # there is a 1e-9 chance it will take these many trials. No good reason
  82. # 1e-2 isn't enough, can still happen
  83. # 2 is the what ransac defines as min_samples = X.shape[1] + 1
  84. max_trials = _dynamic_max_trials(len(X) - len(outliers), X.shape[0], 2, 1 - 1e-9)
  85. ransac_estimator = RANSACRegressor(estimator, min_samples=2)
  86. for i in range(50):
  87. ransac_estimator.set_params(min_samples=2, random_state=i)
  88. ransac_estimator.fit(X, y)
  89. assert ransac_estimator.n_trials_ < max_trials + 1
  90. def test_ransac_stop_n_inliers():
  91. estimator = LinearRegression()
  92. ransac_estimator = RANSACRegressor(
  93. estimator,
  94. min_samples=2,
  95. residual_threshold=5,
  96. stop_n_inliers=2,
  97. random_state=0,
  98. )
  99. ransac_estimator.fit(X, y)
  100. assert ransac_estimator.n_trials_ == 1
  101. def test_ransac_stop_score():
  102. estimator = LinearRegression()
  103. ransac_estimator = RANSACRegressor(
  104. estimator,
  105. min_samples=2,
  106. residual_threshold=5,
  107. stop_score=0,
  108. random_state=0,
  109. )
  110. ransac_estimator.fit(X, y)
  111. assert ransac_estimator.n_trials_ == 1
  112. def test_ransac_score():
  113. X = np.arange(100)[:, None]
  114. y = np.zeros((100,))
  115. y[0] = 1
  116. y[1] = 100
  117. estimator = LinearRegression()
  118. ransac_estimator = RANSACRegressor(
  119. estimator, min_samples=2, residual_threshold=0.5, random_state=0
  120. )
  121. ransac_estimator.fit(X, y)
  122. assert ransac_estimator.score(X[2:], y[2:]) == 1
  123. assert ransac_estimator.score(X[:2], y[:2]) < 1
  124. def test_ransac_predict():
  125. X = np.arange(100)[:, None]
  126. y = np.zeros((100,))
  127. y[0] = 1
  128. y[1] = 100
  129. estimator = LinearRegression()
  130. ransac_estimator = RANSACRegressor(
  131. estimator, min_samples=2, residual_threshold=0.5, random_state=0
  132. )
  133. ransac_estimator.fit(X, y)
  134. assert_array_equal(ransac_estimator.predict(X), np.zeros(100))
  135. def test_ransac_no_valid_data():
  136. def is_data_valid(X, y):
  137. return False
  138. estimator = LinearRegression()
  139. ransac_estimator = RANSACRegressor(
  140. estimator, is_data_valid=is_data_valid, max_trials=5
  141. )
  142. msg = "RANSAC could not find a valid consensus set"
  143. with pytest.raises(ValueError, match=msg):
  144. ransac_estimator.fit(X, y)
  145. assert ransac_estimator.n_skips_no_inliers_ == 0
  146. assert ransac_estimator.n_skips_invalid_data_ == 5
  147. assert ransac_estimator.n_skips_invalid_model_ == 0
  148. def test_ransac_no_valid_model():
  149. def is_model_valid(estimator, X, y):
  150. return False
  151. estimator = LinearRegression()
  152. ransac_estimator = RANSACRegressor(
  153. estimator, is_model_valid=is_model_valid, max_trials=5
  154. )
  155. msg = "RANSAC could not find a valid consensus set"
  156. with pytest.raises(ValueError, match=msg):
  157. ransac_estimator.fit(X, y)
  158. assert ransac_estimator.n_skips_no_inliers_ == 0
  159. assert ransac_estimator.n_skips_invalid_data_ == 0
  160. assert ransac_estimator.n_skips_invalid_model_ == 5
  161. def test_ransac_exceed_max_skips():
  162. def is_data_valid(X, y):
  163. return False
  164. estimator = LinearRegression()
  165. ransac_estimator = RANSACRegressor(
  166. estimator, is_data_valid=is_data_valid, max_trials=5, max_skips=3
  167. )
  168. msg = "RANSAC skipped more iterations than `max_skips`"
  169. with pytest.raises(ValueError, match=msg):
  170. ransac_estimator.fit(X, y)
  171. assert ransac_estimator.n_skips_no_inliers_ == 0
  172. assert ransac_estimator.n_skips_invalid_data_ == 4
  173. assert ransac_estimator.n_skips_invalid_model_ == 0
  174. def test_ransac_warn_exceed_max_skips():
  175. global cause_skip
  176. cause_skip = False
  177. def is_data_valid(X, y):
  178. global cause_skip
  179. if not cause_skip:
  180. cause_skip = True
  181. return True
  182. else:
  183. return False
  184. estimator = LinearRegression()
  185. ransac_estimator = RANSACRegressor(
  186. estimator, is_data_valid=is_data_valid, max_skips=3, max_trials=5
  187. )
  188. warning_message = (
  189. "RANSAC found a valid consensus set but exited "
  190. "early due to skipping more iterations than "
  191. "`max_skips`. See estimator attributes for "
  192. "diagnostics."
  193. )
  194. with pytest.warns(ConvergenceWarning, match=warning_message):
  195. ransac_estimator.fit(X, y)
  196. assert ransac_estimator.n_skips_no_inliers_ == 0
  197. assert ransac_estimator.n_skips_invalid_data_ == 4
  198. assert ransac_estimator.n_skips_invalid_model_ == 0
  199. def test_ransac_sparse_coo():
  200. X_sparse = sparse.coo_matrix(X)
  201. estimator = LinearRegression()
  202. ransac_estimator = RANSACRegressor(
  203. estimator, min_samples=2, residual_threshold=5, random_state=0
  204. )
  205. ransac_estimator.fit(X_sparse, y)
  206. ref_inlier_mask = np.ones_like(ransac_estimator.inlier_mask_).astype(np.bool_)
  207. ref_inlier_mask[outliers] = False
  208. assert_array_equal(ransac_estimator.inlier_mask_, ref_inlier_mask)
  209. def test_ransac_sparse_csr():
  210. X_sparse = sparse.csr_matrix(X)
  211. estimator = LinearRegression()
  212. ransac_estimator = RANSACRegressor(
  213. estimator, min_samples=2, residual_threshold=5, random_state=0
  214. )
  215. ransac_estimator.fit(X_sparse, y)
  216. ref_inlier_mask = np.ones_like(ransac_estimator.inlier_mask_).astype(np.bool_)
  217. ref_inlier_mask[outliers] = False
  218. assert_array_equal(ransac_estimator.inlier_mask_, ref_inlier_mask)
  219. def test_ransac_sparse_csc():
  220. X_sparse = sparse.csc_matrix(X)
  221. estimator = LinearRegression()
  222. ransac_estimator = RANSACRegressor(
  223. estimator, min_samples=2, residual_threshold=5, random_state=0
  224. )
  225. ransac_estimator.fit(X_sparse, y)
  226. ref_inlier_mask = np.ones_like(ransac_estimator.inlier_mask_).astype(np.bool_)
  227. ref_inlier_mask[outliers] = False
  228. assert_array_equal(ransac_estimator.inlier_mask_, ref_inlier_mask)
  229. def test_ransac_none_estimator():
  230. estimator = LinearRegression()
  231. ransac_estimator = RANSACRegressor(
  232. estimator, min_samples=2, residual_threshold=5, random_state=0
  233. )
  234. ransac_none_estimator = RANSACRegressor(
  235. None, min_samples=2, residual_threshold=5, random_state=0
  236. )
  237. ransac_estimator.fit(X, y)
  238. ransac_none_estimator.fit(X, y)
  239. assert_array_almost_equal(
  240. ransac_estimator.predict(X), ransac_none_estimator.predict(X)
  241. )
  242. def test_ransac_min_n_samples():
  243. estimator = LinearRegression()
  244. ransac_estimator1 = RANSACRegressor(
  245. estimator, min_samples=2, residual_threshold=5, random_state=0
  246. )
  247. ransac_estimator2 = RANSACRegressor(
  248. estimator,
  249. min_samples=2.0 / X.shape[0],
  250. residual_threshold=5,
  251. random_state=0,
  252. )
  253. ransac_estimator5 = RANSACRegressor(
  254. estimator, min_samples=2, residual_threshold=5, random_state=0
  255. )
  256. ransac_estimator6 = RANSACRegressor(estimator, residual_threshold=5, random_state=0)
  257. ransac_estimator7 = RANSACRegressor(
  258. estimator, min_samples=X.shape[0] + 1, residual_threshold=5, random_state=0
  259. )
  260. # GH #19390
  261. ransac_estimator8 = RANSACRegressor(
  262. Ridge(), min_samples=None, residual_threshold=5, random_state=0
  263. )
  264. ransac_estimator1.fit(X, y)
  265. ransac_estimator2.fit(X, y)
  266. ransac_estimator5.fit(X, y)
  267. ransac_estimator6.fit(X, y)
  268. assert_array_almost_equal(
  269. ransac_estimator1.predict(X), ransac_estimator2.predict(X)
  270. )
  271. assert_array_almost_equal(
  272. ransac_estimator1.predict(X), ransac_estimator5.predict(X)
  273. )
  274. assert_array_almost_equal(
  275. ransac_estimator1.predict(X), ransac_estimator6.predict(X)
  276. )
  277. with pytest.raises(ValueError):
  278. ransac_estimator7.fit(X, y)
  279. err_msg = "`min_samples` needs to be explicitly set"
  280. with pytest.raises(ValueError, match=err_msg):
  281. ransac_estimator8.fit(X, y)
  282. def test_ransac_multi_dimensional_targets():
  283. estimator = LinearRegression()
  284. ransac_estimator = RANSACRegressor(
  285. estimator, min_samples=2, residual_threshold=5, random_state=0
  286. )
  287. # 3-D target values
  288. yyy = np.column_stack([y, y, y])
  289. # Estimate parameters of corrupted data
  290. ransac_estimator.fit(X, yyy)
  291. # Ground truth / reference inlier mask
  292. ref_inlier_mask = np.ones_like(ransac_estimator.inlier_mask_).astype(np.bool_)
  293. ref_inlier_mask[outliers] = False
  294. assert_array_equal(ransac_estimator.inlier_mask_, ref_inlier_mask)
  295. def test_ransac_residual_loss():
  296. def loss_multi1(y_true, y_pred):
  297. return np.sum(np.abs(y_true - y_pred), axis=1)
  298. def loss_multi2(y_true, y_pred):
  299. return np.sum((y_true - y_pred) ** 2, axis=1)
  300. def loss_mono(y_true, y_pred):
  301. return np.abs(y_true - y_pred)
  302. yyy = np.column_stack([y, y, y])
  303. estimator = LinearRegression()
  304. ransac_estimator0 = RANSACRegressor(
  305. estimator, min_samples=2, residual_threshold=5, random_state=0
  306. )
  307. ransac_estimator1 = RANSACRegressor(
  308. estimator,
  309. min_samples=2,
  310. residual_threshold=5,
  311. random_state=0,
  312. loss=loss_multi1,
  313. )
  314. ransac_estimator2 = RANSACRegressor(
  315. estimator,
  316. min_samples=2,
  317. residual_threshold=5,
  318. random_state=0,
  319. loss=loss_multi2,
  320. )
  321. # multi-dimensional
  322. ransac_estimator0.fit(X, yyy)
  323. ransac_estimator1.fit(X, yyy)
  324. ransac_estimator2.fit(X, yyy)
  325. assert_array_almost_equal(
  326. ransac_estimator0.predict(X), ransac_estimator1.predict(X)
  327. )
  328. assert_array_almost_equal(
  329. ransac_estimator0.predict(X), ransac_estimator2.predict(X)
  330. )
  331. # one-dimensional
  332. ransac_estimator0.fit(X, y)
  333. ransac_estimator2.loss = loss_mono
  334. ransac_estimator2.fit(X, y)
  335. assert_array_almost_equal(
  336. ransac_estimator0.predict(X), ransac_estimator2.predict(X)
  337. )
  338. ransac_estimator3 = RANSACRegressor(
  339. estimator,
  340. min_samples=2,
  341. residual_threshold=5,
  342. random_state=0,
  343. loss="squared_error",
  344. )
  345. ransac_estimator3.fit(X, y)
  346. assert_array_almost_equal(
  347. ransac_estimator0.predict(X), ransac_estimator2.predict(X)
  348. )
  349. def test_ransac_default_residual_threshold():
  350. estimator = LinearRegression()
  351. ransac_estimator = RANSACRegressor(estimator, min_samples=2, random_state=0)
  352. # Estimate parameters of corrupted data
  353. ransac_estimator.fit(X, y)
  354. # Ground truth / reference inlier mask
  355. ref_inlier_mask = np.ones_like(ransac_estimator.inlier_mask_).astype(np.bool_)
  356. ref_inlier_mask[outliers] = False
  357. assert_array_equal(ransac_estimator.inlier_mask_, ref_inlier_mask)
  358. def test_ransac_dynamic_max_trials():
  359. # Numbers hand-calculated and confirmed on page 119 (Table 4.3) in
  360. # Hartley, R.~I. and Zisserman, A., 2004,
  361. # Multiple View Geometry in Computer Vision, Second Edition,
  362. # Cambridge University Press, ISBN: 0521540518
  363. # e = 0%, min_samples = X
  364. assert _dynamic_max_trials(100, 100, 2, 0.99) == 1
  365. # e = 5%, min_samples = 2
  366. assert _dynamic_max_trials(95, 100, 2, 0.99) == 2
  367. # e = 10%, min_samples = 2
  368. assert _dynamic_max_trials(90, 100, 2, 0.99) == 3
  369. # e = 30%, min_samples = 2
  370. assert _dynamic_max_trials(70, 100, 2, 0.99) == 7
  371. # e = 50%, min_samples = 2
  372. assert _dynamic_max_trials(50, 100, 2, 0.99) == 17
  373. # e = 5%, min_samples = 8
  374. assert _dynamic_max_trials(95, 100, 8, 0.99) == 5
  375. # e = 10%, min_samples = 8
  376. assert _dynamic_max_trials(90, 100, 8, 0.99) == 9
  377. # e = 30%, min_samples = 8
  378. assert _dynamic_max_trials(70, 100, 8, 0.99) == 78
  379. # e = 50%, min_samples = 8
  380. assert _dynamic_max_trials(50, 100, 8, 0.99) == 1177
  381. # e = 0%, min_samples = 10
  382. assert _dynamic_max_trials(1, 100, 10, 0) == 0
  383. assert _dynamic_max_trials(1, 100, 10, 1) == float("inf")
  384. def test_ransac_fit_sample_weight():
  385. ransac_estimator = RANSACRegressor(random_state=0)
  386. n_samples = y.shape[0]
  387. weights = np.ones(n_samples)
  388. ransac_estimator.fit(X, y, weights)
  389. # sanity check
  390. assert ransac_estimator.inlier_mask_.shape[0] == n_samples
  391. ref_inlier_mask = np.ones_like(ransac_estimator.inlier_mask_).astype(np.bool_)
  392. ref_inlier_mask[outliers] = False
  393. # check that mask is correct
  394. assert_array_equal(ransac_estimator.inlier_mask_, ref_inlier_mask)
  395. # check that fit(X) = fit([X1, X2, X3],sample_weight = [n1, n2, n3]) where
  396. # X = X1 repeated n1 times, X2 repeated n2 times and so forth
  397. random_state = check_random_state(0)
  398. X_ = random_state.randint(0, 200, [10, 1])
  399. y_ = np.ndarray.flatten(0.2 * X_ + 2)
  400. sample_weight = random_state.randint(0, 10, 10)
  401. outlier_X = random_state.randint(0, 1000, [1, 1])
  402. outlier_weight = random_state.randint(0, 10, 1)
  403. outlier_y = random_state.randint(-1000, 0, 1)
  404. X_flat = np.append(
  405. np.repeat(X_, sample_weight, axis=0),
  406. np.repeat(outlier_X, outlier_weight, axis=0),
  407. axis=0,
  408. )
  409. y_flat = np.ndarray.flatten(
  410. np.append(
  411. np.repeat(y_, sample_weight, axis=0),
  412. np.repeat(outlier_y, outlier_weight, axis=0),
  413. axis=0,
  414. )
  415. )
  416. ransac_estimator.fit(X_flat, y_flat)
  417. ref_coef_ = ransac_estimator.estimator_.coef_
  418. sample_weight = np.append(sample_weight, outlier_weight)
  419. X_ = np.append(X_, outlier_X, axis=0)
  420. y_ = np.append(y_, outlier_y)
  421. ransac_estimator.fit(X_, y_, sample_weight)
  422. assert_allclose(ransac_estimator.estimator_.coef_, ref_coef_)
  423. # check that if estimator.fit doesn't support
  424. # sample_weight, raises error
  425. estimator = OrthogonalMatchingPursuit()
  426. ransac_estimator = RANSACRegressor(estimator, min_samples=10)
  427. err_msg = f"{estimator.__class__.__name__} does not support sample_weight."
  428. with pytest.raises(ValueError, match=err_msg):
  429. ransac_estimator.fit(X, y, weights)
  430. def test_ransac_final_model_fit_sample_weight():
  431. X, y = make_regression(n_samples=1000, random_state=10)
  432. rng = check_random_state(42)
  433. sample_weight = rng.randint(1, 4, size=y.shape[0])
  434. sample_weight = sample_weight / sample_weight.sum()
  435. ransac = RANSACRegressor(estimator=LinearRegression(), random_state=0)
  436. ransac.fit(X, y, sample_weight=sample_weight)
  437. final_model = LinearRegression()
  438. mask_samples = ransac.inlier_mask_
  439. final_model.fit(
  440. X[mask_samples], y[mask_samples], sample_weight=sample_weight[mask_samples]
  441. )
  442. assert_allclose(ransac.estimator_.coef_, final_model.coef_, atol=1e-12)
  443. def test_perfect_horizontal_line():
  444. """Check that we can fit a line where all samples are inliers.
  445. Non-regression test for:
  446. https://github.com/scikit-learn/scikit-learn/issues/19497
  447. """
  448. X = np.arange(100)[:, None]
  449. y = np.zeros((100,))
  450. estimator = LinearRegression()
  451. ransac_estimator = RANSACRegressor(estimator, random_state=0)
  452. ransac_estimator.fit(X, y)
  453. assert_allclose(ransac_estimator.estimator_.coef_, 0.0)
  454. assert_allclose(ransac_estimator.estimator_.intercept_, 0.0)