test_samples_generator.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688
  1. import re
  2. from collections import defaultdict
  3. from functools import partial
  4. import numpy as np
  5. import pytest
  6. import scipy.sparse as sp
  7. from sklearn.datasets import (
  8. make_biclusters,
  9. make_blobs,
  10. make_checkerboard,
  11. make_circles,
  12. make_classification,
  13. make_friedman1,
  14. make_friedman2,
  15. make_friedman3,
  16. make_hastie_10_2,
  17. make_low_rank_matrix,
  18. make_moons,
  19. make_multilabel_classification,
  20. make_regression,
  21. make_s_curve,
  22. make_sparse_coded_signal,
  23. make_sparse_uncorrelated,
  24. make_spd_matrix,
  25. make_swiss_roll,
  26. )
  27. from sklearn.utils._testing import (
  28. assert_allclose,
  29. assert_almost_equal,
  30. assert_array_almost_equal,
  31. assert_array_equal,
  32. ignore_warnings,
  33. )
  34. from sklearn.utils.validation import assert_all_finite
  35. def test_make_classification():
  36. weights = [0.1, 0.25]
  37. X, y = make_classification(
  38. n_samples=100,
  39. n_features=20,
  40. n_informative=5,
  41. n_redundant=1,
  42. n_repeated=1,
  43. n_classes=3,
  44. n_clusters_per_class=1,
  45. hypercube=False,
  46. shift=None,
  47. scale=None,
  48. weights=weights,
  49. random_state=0,
  50. )
  51. assert weights == [0.1, 0.25]
  52. assert X.shape == (100, 20), "X shape mismatch"
  53. assert y.shape == (100,), "y shape mismatch"
  54. assert np.unique(y).shape == (3,), "Unexpected number of classes"
  55. assert sum(y == 0) == 10, "Unexpected number of samples in class #0"
  56. assert sum(y == 1) == 25, "Unexpected number of samples in class #1"
  57. assert sum(y == 2) == 65, "Unexpected number of samples in class #2"
  58. # Test for n_features > 30
  59. X, y = make_classification(
  60. n_samples=2000,
  61. n_features=31,
  62. n_informative=31,
  63. n_redundant=0,
  64. n_repeated=0,
  65. hypercube=True,
  66. scale=0.5,
  67. random_state=0,
  68. )
  69. assert X.shape == (2000, 31), "X shape mismatch"
  70. assert y.shape == (2000,), "y shape mismatch"
  71. assert (
  72. np.unique(X.view([("", X.dtype)] * X.shape[1]))
  73. .view(X.dtype)
  74. .reshape(-1, X.shape[1])
  75. .shape[0]
  76. == 2000
  77. ), "Unexpected number of unique rows"
  78. def test_make_classification_informative_features():
  79. """Test the construction of informative features in make_classification
  80. Also tests `n_clusters_per_class`, `n_classes`, `hypercube` and
  81. fully-specified `weights`.
  82. """
  83. # Create very separate clusters; check that vertices are unique and
  84. # correspond to classes
  85. class_sep = 1e6
  86. make = partial(
  87. make_classification,
  88. class_sep=class_sep,
  89. n_redundant=0,
  90. n_repeated=0,
  91. flip_y=0,
  92. shift=0,
  93. scale=1,
  94. shuffle=False,
  95. )
  96. for n_informative, weights, n_clusters_per_class in [
  97. (2, [1], 1),
  98. (2, [1 / 3] * 3, 1),
  99. (2, [1 / 4] * 4, 1),
  100. (2, [1 / 2] * 2, 2),
  101. (2, [3 / 4, 1 / 4], 2),
  102. (10, [1 / 3] * 3, 10),
  103. (int(64), [1], 1),
  104. ]:
  105. n_classes = len(weights)
  106. n_clusters = n_classes * n_clusters_per_class
  107. n_samples = n_clusters * 50
  108. for hypercube in (False, True):
  109. X, y = make(
  110. n_samples=n_samples,
  111. n_classes=n_classes,
  112. weights=weights,
  113. n_features=n_informative,
  114. n_informative=n_informative,
  115. n_clusters_per_class=n_clusters_per_class,
  116. hypercube=hypercube,
  117. random_state=0,
  118. )
  119. assert X.shape == (n_samples, n_informative)
  120. assert y.shape == (n_samples,)
  121. # Cluster by sign, viewed as strings to allow uniquing
  122. signs = np.sign(X)
  123. signs = signs.view(dtype="|S{0}".format(signs.strides[0]))
  124. unique_signs, cluster_index = np.unique(signs, return_inverse=True)
  125. assert (
  126. len(unique_signs) == n_clusters
  127. ), "Wrong number of clusters, or not in distinct quadrants"
  128. clusters_by_class = defaultdict(set)
  129. for cluster, cls in zip(cluster_index, y):
  130. clusters_by_class[cls].add(cluster)
  131. for clusters in clusters_by_class.values():
  132. assert (
  133. len(clusters) == n_clusters_per_class
  134. ), "Wrong number of clusters per class"
  135. assert len(clusters_by_class) == n_classes, "Wrong number of classes"
  136. assert_array_almost_equal(
  137. np.bincount(y) / len(y) // weights,
  138. [1] * n_classes,
  139. err_msg="Wrong number of samples per class",
  140. )
  141. # Ensure on vertices of hypercube
  142. for cluster in range(len(unique_signs)):
  143. centroid = X[cluster_index == cluster].mean(axis=0)
  144. if hypercube:
  145. assert_array_almost_equal(
  146. np.abs(centroid) / class_sep,
  147. np.ones(n_informative),
  148. decimal=5,
  149. err_msg="Clusters are not centered on hypercube vertices",
  150. )
  151. else:
  152. with pytest.raises(AssertionError):
  153. assert_array_almost_equal(
  154. np.abs(centroid) / class_sep,
  155. np.ones(n_informative),
  156. decimal=5,
  157. err_msg=(
  158. "Clusters should not be centered on hypercube vertices"
  159. ),
  160. )
  161. with pytest.raises(ValueError):
  162. make(n_features=2, n_informative=2, n_classes=5, n_clusters_per_class=1)
  163. with pytest.raises(ValueError):
  164. make(n_features=2, n_informative=2, n_classes=3, n_clusters_per_class=2)
  165. @pytest.mark.parametrize(
  166. "weights, err_type, err_msg",
  167. [
  168. ([], ValueError, "Weights specified but incompatible with number of classes."),
  169. (
  170. [0.25, 0.75, 0.1],
  171. ValueError,
  172. "Weights specified but incompatible with number of classes.",
  173. ),
  174. (
  175. np.array([]),
  176. ValueError,
  177. "Weights specified but incompatible with number of classes.",
  178. ),
  179. (
  180. np.array([0.25, 0.75, 0.1]),
  181. ValueError,
  182. "Weights specified but incompatible with number of classes.",
  183. ),
  184. (
  185. np.random.random(3),
  186. ValueError,
  187. "Weights specified but incompatible with number of classes.",
  188. ),
  189. ],
  190. )
  191. def test_make_classification_weights_type(weights, err_type, err_msg):
  192. with pytest.raises(err_type, match=err_msg):
  193. make_classification(weights=weights)
  194. @pytest.mark.parametrize("kwargs", [{}, {"n_classes": 3, "n_informative": 3}])
  195. def test_make_classification_weights_array_or_list_ok(kwargs):
  196. X1, y1 = make_classification(weights=[0.1, 0.9], random_state=0, **kwargs)
  197. X2, y2 = make_classification(weights=np.array([0.1, 0.9]), random_state=0, **kwargs)
  198. assert_almost_equal(X1, X2)
  199. assert_almost_equal(y1, y2)
  200. def test_make_multilabel_classification_return_sequences():
  201. for allow_unlabeled, min_length in zip((True, False), (0, 1)):
  202. X, Y = make_multilabel_classification(
  203. n_samples=100,
  204. n_features=20,
  205. n_classes=3,
  206. random_state=0,
  207. return_indicator=False,
  208. allow_unlabeled=allow_unlabeled,
  209. )
  210. assert X.shape == (100, 20), "X shape mismatch"
  211. if not allow_unlabeled:
  212. assert max([max(y) for y in Y]) == 2
  213. assert min([len(y) for y in Y]) == min_length
  214. assert max([len(y) for y in Y]) <= 3
  215. def test_make_multilabel_classification_return_indicator():
  216. for allow_unlabeled, min_length in zip((True, False), (0, 1)):
  217. X, Y = make_multilabel_classification(
  218. n_samples=25,
  219. n_features=20,
  220. n_classes=3,
  221. random_state=0,
  222. allow_unlabeled=allow_unlabeled,
  223. )
  224. assert X.shape == (25, 20), "X shape mismatch"
  225. assert Y.shape == (25, 3), "Y shape mismatch"
  226. assert np.all(np.sum(Y, axis=0) > min_length)
  227. # Also test return_distributions and return_indicator with True
  228. X2, Y2, p_c, p_w_c = make_multilabel_classification(
  229. n_samples=25,
  230. n_features=20,
  231. n_classes=3,
  232. random_state=0,
  233. allow_unlabeled=allow_unlabeled,
  234. return_distributions=True,
  235. )
  236. assert_array_almost_equal(X, X2)
  237. assert_array_equal(Y, Y2)
  238. assert p_c.shape == (3,)
  239. assert_almost_equal(p_c.sum(), 1)
  240. assert p_w_c.shape == (20, 3)
  241. assert_almost_equal(p_w_c.sum(axis=0), [1] * 3)
  242. def test_make_multilabel_classification_return_indicator_sparse():
  243. for allow_unlabeled, min_length in zip((True, False), (0, 1)):
  244. X, Y = make_multilabel_classification(
  245. n_samples=25,
  246. n_features=20,
  247. n_classes=3,
  248. random_state=0,
  249. return_indicator="sparse",
  250. allow_unlabeled=allow_unlabeled,
  251. )
  252. assert X.shape == (25, 20), "X shape mismatch"
  253. assert Y.shape == (25, 3), "Y shape mismatch"
  254. assert sp.issparse(Y)
  255. def test_make_hastie_10_2():
  256. X, y = make_hastie_10_2(n_samples=100, random_state=0)
  257. assert X.shape == (100, 10), "X shape mismatch"
  258. assert y.shape == (100,), "y shape mismatch"
  259. assert np.unique(y).shape == (2,), "Unexpected number of classes"
  260. def test_make_regression():
  261. X, y, c = make_regression(
  262. n_samples=100,
  263. n_features=10,
  264. n_informative=3,
  265. effective_rank=5,
  266. coef=True,
  267. bias=0.0,
  268. noise=1.0,
  269. random_state=0,
  270. )
  271. assert X.shape == (100, 10), "X shape mismatch"
  272. assert y.shape == (100,), "y shape mismatch"
  273. assert c.shape == (10,), "coef shape mismatch"
  274. assert sum(c != 0.0) == 3, "Unexpected number of informative features"
  275. # Test that y ~= np.dot(X, c) + bias + N(0, 1.0).
  276. assert_almost_equal(np.std(y - np.dot(X, c)), 1.0, decimal=1)
  277. # Test with small number of features.
  278. X, y = make_regression(n_samples=100, n_features=1) # n_informative=3
  279. assert X.shape == (100, 1)
  280. def test_make_regression_multitarget():
  281. X, y, c = make_regression(
  282. n_samples=100,
  283. n_features=10,
  284. n_informative=3,
  285. n_targets=3,
  286. coef=True,
  287. noise=1.0,
  288. random_state=0,
  289. )
  290. assert X.shape == (100, 10), "X shape mismatch"
  291. assert y.shape == (100, 3), "y shape mismatch"
  292. assert c.shape == (10, 3), "coef shape mismatch"
  293. assert_array_equal(sum(c != 0.0), 3, "Unexpected number of informative features")
  294. # Test that y ~= np.dot(X, c) + bias + N(0, 1.0)
  295. assert_almost_equal(np.std(y - np.dot(X, c)), 1.0, decimal=1)
  296. def test_make_blobs():
  297. cluster_stds = np.array([0.05, 0.2, 0.4])
  298. cluster_centers = np.array([[0.0, 0.0], [1.0, 1.0], [0.0, 1.0]])
  299. X, y = make_blobs(
  300. random_state=0,
  301. n_samples=50,
  302. n_features=2,
  303. centers=cluster_centers,
  304. cluster_std=cluster_stds,
  305. )
  306. assert X.shape == (50, 2), "X shape mismatch"
  307. assert y.shape == (50,), "y shape mismatch"
  308. assert np.unique(y).shape == (3,), "Unexpected number of blobs"
  309. for i, (ctr, std) in enumerate(zip(cluster_centers, cluster_stds)):
  310. assert_almost_equal((X[y == i] - ctr).std(), std, 1, "Unexpected std")
  311. def test_make_blobs_n_samples_list():
  312. n_samples = [50, 30, 20]
  313. X, y = make_blobs(n_samples=n_samples, n_features=2, random_state=0)
  314. assert X.shape == (sum(n_samples), 2), "X shape mismatch"
  315. assert all(
  316. np.bincount(y, minlength=len(n_samples)) == n_samples
  317. ), "Incorrect number of samples per blob"
  318. def test_make_blobs_n_samples_list_with_centers():
  319. n_samples = [20, 20, 20]
  320. centers = np.array([[0.0, 0.0], [1.0, 1.0], [0.0, 1.0]])
  321. cluster_stds = np.array([0.05, 0.2, 0.4])
  322. X, y = make_blobs(
  323. n_samples=n_samples, centers=centers, cluster_std=cluster_stds, random_state=0
  324. )
  325. assert X.shape == (sum(n_samples), 2), "X shape mismatch"
  326. assert all(
  327. np.bincount(y, minlength=len(n_samples)) == n_samples
  328. ), "Incorrect number of samples per blob"
  329. for i, (ctr, std) in enumerate(zip(centers, cluster_stds)):
  330. assert_almost_equal((X[y == i] - ctr).std(), std, 1, "Unexpected std")
  331. @pytest.mark.parametrize(
  332. "n_samples", [[5, 3, 0], np.array([5, 3, 0]), tuple([5, 3, 0])]
  333. )
  334. def test_make_blobs_n_samples_centers_none(n_samples):
  335. centers = None
  336. X, y = make_blobs(n_samples=n_samples, centers=centers, random_state=0)
  337. assert X.shape == (sum(n_samples), 2), "X shape mismatch"
  338. assert all(
  339. np.bincount(y, minlength=len(n_samples)) == n_samples
  340. ), "Incorrect number of samples per blob"
  341. def test_make_blobs_return_centers():
  342. n_samples = [10, 20]
  343. n_features = 3
  344. X, y, centers = make_blobs(
  345. n_samples=n_samples, n_features=n_features, return_centers=True, random_state=0
  346. )
  347. assert centers.shape == (len(n_samples), n_features)
  348. def test_make_blobs_error():
  349. n_samples = [20, 20, 20]
  350. centers = np.array([[0.0, 0.0], [1.0, 1.0], [0.0, 1.0]])
  351. cluster_stds = np.array([0.05, 0.2, 0.4])
  352. wrong_centers_msg = re.escape(
  353. "Length of `n_samples` not consistent with number of centers. "
  354. f"Got n_samples = {n_samples} and centers = {centers[:-1]}"
  355. )
  356. with pytest.raises(ValueError, match=wrong_centers_msg):
  357. make_blobs(n_samples, centers=centers[:-1])
  358. wrong_std_msg = re.escape(
  359. "Length of `clusters_std` not consistent with number of centers. "
  360. f"Got centers = {centers} and cluster_std = {cluster_stds[:-1]}"
  361. )
  362. with pytest.raises(ValueError, match=wrong_std_msg):
  363. make_blobs(n_samples, centers=centers, cluster_std=cluster_stds[:-1])
  364. wrong_type_msg = "Parameter `centers` must be array-like. Got {!r} instead".format(
  365. 3
  366. )
  367. with pytest.raises(ValueError, match=wrong_type_msg):
  368. make_blobs(n_samples, centers=3)
  369. def test_make_friedman1():
  370. X, y = make_friedman1(n_samples=5, n_features=10, noise=0.0, random_state=0)
  371. assert X.shape == (5, 10), "X shape mismatch"
  372. assert y.shape == (5,), "y shape mismatch"
  373. assert_array_almost_equal(
  374. y,
  375. 10 * np.sin(np.pi * X[:, 0] * X[:, 1])
  376. + 20 * (X[:, 2] - 0.5) ** 2
  377. + 10 * X[:, 3]
  378. + 5 * X[:, 4],
  379. )
  380. def test_make_friedman2():
  381. X, y = make_friedman2(n_samples=5, noise=0.0, random_state=0)
  382. assert X.shape == (5, 4), "X shape mismatch"
  383. assert y.shape == (5,), "y shape mismatch"
  384. assert_array_almost_equal(
  385. y, (X[:, 0] ** 2 + (X[:, 1] * X[:, 2] - 1 / (X[:, 1] * X[:, 3])) ** 2) ** 0.5
  386. )
  387. def test_make_friedman3():
  388. X, y = make_friedman3(n_samples=5, noise=0.0, random_state=0)
  389. assert X.shape == (5, 4), "X shape mismatch"
  390. assert y.shape == (5,), "y shape mismatch"
  391. assert_array_almost_equal(
  392. y, np.arctan((X[:, 1] * X[:, 2] - 1 / (X[:, 1] * X[:, 3])) / X[:, 0])
  393. )
  394. def test_make_low_rank_matrix():
  395. X = make_low_rank_matrix(
  396. n_samples=50,
  397. n_features=25,
  398. effective_rank=5,
  399. tail_strength=0.01,
  400. random_state=0,
  401. )
  402. assert X.shape == (50, 25), "X shape mismatch"
  403. from numpy.linalg import svd
  404. u, s, v = svd(X)
  405. assert sum(s) - 5 < 0.1, "X rank is not approximately 5"
  406. def test_make_sparse_coded_signal():
  407. Y, D, X = make_sparse_coded_signal(
  408. n_samples=5,
  409. n_components=8,
  410. n_features=10,
  411. n_nonzero_coefs=3,
  412. random_state=0,
  413. )
  414. assert Y.shape == (5, 10), "Y shape mismatch"
  415. assert D.shape == (8, 10), "D shape mismatch"
  416. assert X.shape == (5, 8), "X shape mismatch"
  417. for row in X:
  418. assert len(np.flatnonzero(row)) == 3, "Non-zero coefs mismatch"
  419. assert_allclose(Y, X @ D)
  420. assert_allclose(np.sqrt((D**2).sum(axis=1)), np.ones(D.shape[0]))
  421. # TODO(1.5): remove
  422. @ignore_warnings(category=FutureWarning)
  423. def test_make_sparse_coded_signal_transposed():
  424. Y, D, X = make_sparse_coded_signal(
  425. n_samples=5,
  426. n_components=8,
  427. n_features=10,
  428. n_nonzero_coefs=3,
  429. random_state=0,
  430. data_transposed=True,
  431. )
  432. assert Y.shape == (10, 5), "Y shape mismatch"
  433. assert D.shape == (10, 8), "D shape mismatch"
  434. assert X.shape == (8, 5), "X shape mismatch"
  435. for col in X.T:
  436. assert len(np.flatnonzero(col)) == 3, "Non-zero coefs mismatch"
  437. assert_allclose(Y, D @ X)
  438. assert_allclose(np.sqrt((D**2).sum(axis=0)), np.ones(D.shape[1]))
  439. # TODO(1.5): remove
  440. def test_make_sparse_code_signal_deprecation_warning():
  441. """Check the message for future deprecation."""
  442. warn_msg = "data_transposed was deprecated in version 1.3"
  443. with pytest.warns(FutureWarning, match=warn_msg):
  444. make_sparse_coded_signal(
  445. n_samples=1,
  446. n_components=1,
  447. n_features=1,
  448. n_nonzero_coefs=1,
  449. random_state=0,
  450. data_transposed=True,
  451. )
  452. def test_make_sparse_uncorrelated():
  453. X, y = make_sparse_uncorrelated(n_samples=5, n_features=10, random_state=0)
  454. assert X.shape == (5, 10), "X shape mismatch"
  455. assert y.shape == (5,), "y shape mismatch"
  456. def test_make_spd_matrix():
  457. X = make_spd_matrix(n_dim=5, random_state=0)
  458. assert X.shape == (5, 5), "X shape mismatch"
  459. assert_array_almost_equal(X, X.T)
  460. from numpy.linalg import eig
  461. eigenvalues, _ = eig(X)
  462. assert_array_equal(
  463. eigenvalues > 0, np.array([True] * 5), "X is not positive-definite"
  464. )
  465. @pytest.mark.parametrize("hole", [False, True])
  466. def test_make_swiss_roll(hole):
  467. X, t = make_swiss_roll(n_samples=5, noise=0.0, random_state=0, hole=hole)
  468. assert X.shape == (5, 3)
  469. assert t.shape == (5,)
  470. assert_array_almost_equal(X[:, 0], t * np.cos(t))
  471. assert_array_almost_equal(X[:, 2], t * np.sin(t))
  472. def test_make_s_curve():
  473. X, t = make_s_curve(n_samples=5, noise=0.0, random_state=0)
  474. assert X.shape == (5, 3), "X shape mismatch"
  475. assert t.shape == (5,), "t shape mismatch"
  476. assert_array_almost_equal(X[:, 0], np.sin(t))
  477. assert_array_almost_equal(X[:, 2], np.sign(t) * (np.cos(t) - 1))
  478. def test_make_biclusters():
  479. X, rows, cols = make_biclusters(
  480. shape=(100, 100), n_clusters=4, shuffle=True, random_state=0
  481. )
  482. assert X.shape == (100, 100), "X shape mismatch"
  483. assert rows.shape == (4, 100), "rows shape mismatch"
  484. assert cols.shape == (
  485. 4,
  486. 100,
  487. ), "columns shape mismatch"
  488. assert_all_finite(X)
  489. assert_all_finite(rows)
  490. assert_all_finite(cols)
  491. X2, _, _ = make_biclusters(
  492. shape=(100, 100), n_clusters=4, shuffle=True, random_state=0
  493. )
  494. assert_array_almost_equal(X, X2)
  495. def test_make_checkerboard():
  496. X, rows, cols = make_checkerboard(
  497. shape=(100, 100), n_clusters=(20, 5), shuffle=True, random_state=0
  498. )
  499. assert X.shape == (100, 100), "X shape mismatch"
  500. assert rows.shape == (100, 100), "rows shape mismatch"
  501. assert cols.shape == (
  502. 100,
  503. 100,
  504. ), "columns shape mismatch"
  505. X, rows, cols = make_checkerboard(
  506. shape=(100, 100), n_clusters=2, shuffle=True, random_state=0
  507. )
  508. assert_all_finite(X)
  509. assert_all_finite(rows)
  510. assert_all_finite(cols)
  511. X1, _, _ = make_checkerboard(
  512. shape=(100, 100), n_clusters=2, shuffle=True, random_state=0
  513. )
  514. X2, _, _ = make_checkerboard(
  515. shape=(100, 100), n_clusters=2, shuffle=True, random_state=0
  516. )
  517. assert_array_almost_equal(X1, X2)
  518. def test_make_moons():
  519. X, y = make_moons(3, shuffle=False)
  520. for x, label in zip(X, y):
  521. center = [0.0, 0.0] if label == 0 else [1.0, 0.5]
  522. dist_sqr = ((x - center) ** 2).sum()
  523. assert_almost_equal(
  524. dist_sqr, 1.0, err_msg="Point is not on expected unit circle"
  525. )
  526. def test_make_moons_unbalanced():
  527. X, y = make_moons(n_samples=(7, 5))
  528. assert (
  529. np.sum(y == 0) == 7 and np.sum(y == 1) == 5
  530. ), "Number of samples in a moon is wrong"
  531. assert X.shape == (12, 2), "X shape mismatch"
  532. assert y.shape == (12,), "y shape mismatch"
  533. with pytest.raises(
  534. ValueError,
  535. match=r"`n_samples` can be either an int " r"or a two-element tuple.",
  536. ):
  537. make_moons(n_samples=(10,))
  538. def test_make_circles():
  539. factor = 0.3
  540. for n_samples, n_outer, n_inner in [(7, 3, 4), (8, 4, 4)]:
  541. # Testing odd and even case, because in the past make_circles always
  542. # created an even number of samples.
  543. X, y = make_circles(n_samples, shuffle=False, noise=None, factor=factor)
  544. assert X.shape == (n_samples, 2), "X shape mismatch"
  545. assert y.shape == (n_samples,), "y shape mismatch"
  546. center = [0.0, 0.0]
  547. for x, label in zip(X, y):
  548. dist_sqr = ((x - center) ** 2).sum()
  549. dist_exp = 1.0 if label == 0 else factor**2
  550. dist_exp = 1.0 if label == 0 else factor**2
  551. assert_almost_equal(
  552. dist_sqr, dist_exp, err_msg="Point is not on expected circle"
  553. )
  554. assert X[y == 0].shape == (
  555. n_outer,
  556. 2,
  557. ), "Samples not correctly distributed across circles."
  558. assert X[y == 1].shape == (
  559. n_inner,
  560. 2,
  561. ), "Samples not correctly distributed across circles."
  562. def test_make_circles_unbalanced():
  563. X, y = make_circles(n_samples=(2, 8))
  564. assert np.sum(y == 0) == 2, "Number of samples in inner circle is wrong"
  565. assert np.sum(y == 1) == 8, "Number of samples in outer circle is wrong"
  566. assert X.shape == (10, 2), "X shape mismatch"
  567. assert y.shape == (10,), "y shape mismatch"
  568. with pytest.raises(
  569. ValueError,
  570. match="When a tuple, n_samples must have exactly two elements.",
  571. ):
  572. make_circles(n_samples=(10,))