test_multiclass.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. from itertools import product
  2. import numpy as np
  3. import pytest
  4. import scipy.sparse as sp
  5. from scipy.sparse import (
  6. coo_matrix,
  7. csc_matrix,
  8. csr_matrix,
  9. dok_matrix,
  10. issparse,
  11. lil_matrix,
  12. )
  13. from sklearn import datasets
  14. from sklearn.model_selection import ShuffleSplit
  15. from sklearn.svm import SVC
  16. from sklearn.utils._testing import (
  17. assert_allclose,
  18. assert_array_almost_equal,
  19. assert_array_equal,
  20. )
  21. from sklearn.utils.estimator_checks import _NotAnArray
  22. from sklearn.utils.metaestimators import _safe_split
  23. from sklearn.utils.multiclass import (
  24. _ovr_decision_function,
  25. check_classification_targets,
  26. class_distribution,
  27. is_multilabel,
  28. type_of_target,
  29. unique_labels,
  30. )
  31. sparse_multilable_explicit_zero = csc_matrix(np.array([[0, 1], [1, 0]]))
  32. sparse_multilable_explicit_zero[:, 0] = 0
  33. def _generate_sparse(
  34. matrix,
  35. matrix_types=(csr_matrix, csc_matrix, coo_matrix, dok_matrix, lil_matrix),
  36. dtypes=(bool, int, np.int8, np.uint8, float, np.float32),
  37. ):
  38. return [
  39. matrix_type(matrix, dtype=dtype)
  40. for matrix_type in matrix_types
  41. for dtype in dtypes
  42. ]
  43. EXAMPLES = {
  44. "multilabel-indicator": [
  45. # valid when the data is formatted as sparse or dense, identified
  46. # by CSR format when the testing takes place
  47. csr_matrix(np.random.RandomState(42).randint(2, size=(10, 10))),
  48. [[0, 1], [1, 0]],
  49. [[0, 1]],
  50. sparse_multilable_explicit_zero,
  51. *_generate_sparse([[0, 1], [1, 0]]),
  52. *_generate_sparse([[0, 0], [0, 0]]),
  53. *_generate_sparse([[0, 1]]),
  54. # Only valid when data is dense
  55. [[-1, 1], [1, -1]],
  56. np.array([[-1, 1], [1, -1]]),
  57. np.array([[-3, 3], [3, -3]]),
  58. _NotAnArray(np.array([[-3, 3], [3, -3]])),
  59. ],
  60. "multiclass": [
  61. [1, 0, 2, 2, 1, 4, 2, 4, 4, 4],
  62. np.array([1, 0, 2]),
  63. np.array([1, 0, 2], dtype=np.int8),
  64. np.array([1, 0, 2], dtype=np.uint8),
  65. np.array([1, 0, 2], dtype=float),
  66. np.array([1, 0, 2], dtype=np.float32),
  67. np.array([[1], [0], [2]]),
  68. _NotAnArray(np.array([1, 0, 2])),
  69. [0, 1, 2],
  70. ["a", "b", "c"],
  71. np.array(["a", "b", "c"]),
  72. np.array(["a", "b", "c"], dtype=object),
  73. np.array(["a", "b", "c"], dtype=object),
  74. ],
  75. "multiclass-multioutput": [
  76. [[1, 0, 2, 2], [1, 4, 2, 4]],
  77. [["a", "b"], ["c", "d"]],
  78. np.array([[1, 0, 2, 2], [1, 4, 2, 4]]),
  79. np.array([[1, 0, 2, 2], [1, 4, 2, 4]], dtype=np.int8),
  80. np.array([[1, 0, 2, 2], [1, 4, 2, 4]], dtype=np.uint8),
  81. np.array([[1, 0, 2, 2], [1, 4, 2, 4]], dtype=float),
  82. np.array([[1, 0, 2, 2], [1, 4, 2, 4]], dtype=np.float32),
  83. *_generate_sparse(
  84. [[1, 0, 2, 2], [1, 4, 2, 4]],
  85. matrix_types=(csr_matrix, csc_matrix),
  86. dtypes=(int, np.int8, np.uint8, float, np.float32),
  87. ),
  88. np.array([["a", "b"], ["c", "d"]]),
  89. np.array([["a", "b"], ["c", "d"]]),
  90. np.array([["a", "b"], ["c", "d"]], dtype=object),
  91. np.array([[1, 0, 2]]),
  92. _NotAnArray(np.array([[1, 0, 2]])),
  93. ],
  94. "binary": [
  95. [0, 1],
  96. [1, 1],
  97. [],
  98. [0],
  99. np.array([0, 1, 1, 1, 0, 0, 0, 1, 1, 1]),
  100. np.array([0, 1, 1, 1, 0, 0, 0, 1, 1, 1], dtype=bool),
  101. np.array([0, 1, 1, 1, 0, 0, 0, 1, 1, 1], dtype=np.int8),
  102. np.array([0, 1, 1, 1, 0, 0, 0, 1, 1, 1], dtype=np.uint8),
  103. np.array([0, 1, 1, 1, 0, 0, 0, 1, 1, 1], dtype=float),
  104. np.array([0, 1, 1, 1, 0, 0, 0, 1, 1, 1], dtype=np.float32),
  105. np.array([[0], [1]]),
  106. _NotAnArray(np.array([[0], [1]])),
  107. [1, -1],
  108. [3, 5],
  109. ["a"],
  110. ["a", "b"],
  111. ["abc", "def"],
  112. np.array(["abc", "def"]),
  113. ["a", "b"],
  114. np.array(["abc", "def"], dtype=object),
  115. ],
  116. "continuous": [
  117. [1e-5],
  118. [0, 0.5],
  119. np.array([[0], [0.5]]),
  120. np.array([[0], [0.5]], dtype=np.float32),
  121. ],
  122. "continuous-multioutput": [
  123. np.array([[0, 0.5], [0.5, 0]]),
  124. np.array([[0, 0.5], [0.5, 0]], dtype=np.float32),
  125. np.array([[0, 0.5]]),
  126. *_generate_sparse(
  127. [[0, 0.5], [0.5, 0]],
  128. matrix_types=(csr_matrix, csc_matrix),
  129. dtypes=(float, np.float32),
  130. ),
  131. *_generate_sparse(
  132. [[0, 0.5]],
  133. matrix_types=(csr_matrix, csc_matrix),
  134. dtypes=(float, np.float32),
  135. ),
  136. ],
  137. "unknown": [
  138. [[]],
  139. np.array([[]], dtype=object),
  140. [()],
  141. # sequence of sequences that weren't supported even before deprecation
  142. np.array([np.array([]), np.array([1, 2, 3])], dtype=object),
  143. [np.array([]), np.array([1, 2, 3])],
  144. [{1, 2, 3}, {1, 2}],
  145. [frozenset([1, 2, 3]), frozenset([1, 2])],
  146. # and also confusable as sequences of sequences
  147. [{0: "a", 1: "b"}, {0: "a"}],
  148. # ndim 0
  149. np.array(0),
  150. # empty second dimension
  151. np.array([[], []]),
  152. # 3d
  153. np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]),
  154. ],
  155. }
  156. NON_ARRAY_LIKE_EXAMPLES = [
  157. {1, 2, 3},
  158. {0: "a", 1: "b"},
  159. {0: [5], 1: [5]},
  160. "abc",
  161. frozenset([1, 2, 3]),
  162. None,
  163. ]
  164. MULTILABEL_SEQUENCES = [
  165. [[1], [2], [0, 1]],
  166. [(), (2), (0, 1)],
  167. np.array([[], [1, 2]], dtype="object"),
  168. _NotAnArray(np.array([[], [1, 2]], dtype="object")),
  169. ]
  170. def test_unique_labels():
  171. # Empty iterable
  172. with pytest.raises(ValueError):
  173. unique_labels()
  174. # Multiclass problem
  175. assert_array_equal(unique_labels(range(10)), np.arange(10))
  176. assert_array_equal(unique_labels(np.arange(10)), np.arange(10))
  177. assert_array_equal(unique_labels([4, 0, 2]), np.array([0, 2, 4]))
  178. # Multilabel indicator
  179. assert_array_equal(
  180. unique_labels(np.array([[0, 0, 1], [1, 0, 1], [0, 0, 0]])), np.arange(3)
  181. )
  182. assert_array_equal(unique_labels(np.array([[0, 0, 1], [0, 0, 0]])), np.arange(3))
  183. # Several arrays passed
  184. assert_array_equal(unique_labels([4, 0, 2], range(5)), np.arange(5))
  185. assert_array_equal(unique_labels((0, 1, 2), (0,), (2, 1)), np.arange(3))
  186. # Border line case with binary indicator matrix
  187. with pytest.raises(ValueError):
  188. unique_labels([4, 0, 2], np.ones((5, 5)))
  189. with pytest.raises(ValueError):
  190. unique_labels(np.ones((5, 4)), np.ones((5, 5)))
  191. assert_array_equal(unique_labels(np.ones((4, 5)), np.ones((5, 5))), np.arange(5))
  192. def test_unique_labels_non_specific():
  193. # Test unique_labels with a variety of collected examples
  194. # Smoke test for all supported format
  195. for format in ["binary", "multiclass", "multilabel-indicator"]:
  196. for y in EXAMPLES[format]:
  197. unique_labels(y)
  198. # We don't support those format at the moment
  199. for example in NON_ARRAY_LIKE_EXAMPLES:
  200. with pytest.raises(ValueError):
  201. unique_labels(example)
  202. for y_type in [
  203. "unknown",
  204. "continuous",
  205. "continuous-multioutput",
  206. "multiclass-multioutput",
  207. ]:
  208. for example in EXAMPLES[y_type]:
  209. with pytest.raises(ValueError):
  210. unique_labels(example)
  211. def test_unique_labels_mixed_types():
  212. # Mix with binary or multiclass and multilabel
  213. mix_clf_format = product(
  214. EXAMPLES["multilabel-indicator"], EXAMPLES["multiclass"] + EXAMPLES["binary"]
  215. )
  216. for y_multilabel, y_multiclass in mix_clf_format:
  217. with pytest.raises(ValueError):
  218. unique_labels(y_multiclass, y_multilabel)
  219. with pytest.raises(ValueError):
  220. unique_labels(y_multilabel, y_multiclass)
  221. with pytest.raises(ValueError):
  222. unique_labels([[1, 2]], [["a", "d"]])
  223. with pytest.raises(ValueError):
  224. unique_labels(["1", 2])
  225. with pytest.raises(ValueError):
  226. unique_labels([["1", 2], [1, 3]])
  227. with pytest.raises(ValueError):
  228. unique_labels([["1", "2"], [2, 3]])
  229. def test_is_multilabel():
  230. for group, group_examples in EXAMPLES.items():
  231. if group in ["multilabel-indicator"]:
  232. dense_exp = True
  233. else:
  234. dense_exp = False
  235. for example in group_examples:
  236. # Only mark explicitly defined sparse examples as valid sparse
  237. # multilabel-indicators
  238. if group == "multilabel-indicator" and issparse(example):
  239. sparse_exp = True
  240. else:
  241. sparse_exp = False
  242. if issparse(example) or (
  243. hasattr(example, "__array__")
  244. and np.asarray(example).ndim == 2
  245. and np.asarray(example).dtype.kind in "biuf"
  246. and np.asarray(example).shape[1] > 0
  247. ):
  248. examples_sparse = [
  249. sparse_matrix(example)
  250. for sparse_matrix in [
  251. coo_matrix,
  252. csc_matrix,
  253. csr_matrix,
  254. dok_matrix,
  255. lil_matrix,
  256. ]
  257. ]
  258. for exmpl_sparse in examples_sparse:
  259. assert sparse_exp == is_multilabel(
  260. exmpl_sparse
  261. ), "is_multilabel(%r) should be %s" % (exmpl_sparse, sparse_exp)
  262. # Densify sparse examples before testing
  263. if issparse(example):
  264. example = example.toarray()
  265. assert dense_exp == is_multilabel(
  266. example
  267. ), "is_multilabel(%r) should be %s" % (example, dense_exp)
  268. def test_check_classification_targets():
  269. for y_type in EXAMPLES.keys():
  270. if y_type in ["unknown", "continuous", "continuous-multioutput"]:
  271. for example in EXAMPLES[y_type]:
  272. msg = "Unknown label type: "
  273. with pytest.raises(ValueError, match=msg):
  274. check_classification_targets(example)
  275. else:
  276. for example in EXAMPLES[y_type]:
  277. check_classification_targets(example)
  278. # @ignore_warnings
  279. def test_type_of_target():
  280. for group, group_examples in EXAMPLES.items():
  281. for example in group_examples:
  282. assert (
  283. type_of_target(example) == group
  284. ), "type_of_target(%r) should be %r, got %r" % (
  285. example,
  286. group,
  287. type_of_target(example),
  288. )
  289. for example in NON_ARRAY_LIKE_EXAMPLES:
  290. msg_regex = r"Expected array-like \(array or non-string sequence\).*"
  291. with pytest.raises(ValueError, match=msg_regex):
  292. type_of_target(example)
  293. for example in MULTILABEL_SEQUENCES:
  294. msg = (
  295. "You appear to be using a legacy multi-label data "
  296. "representation. Sequence of sequences are no longer supported;"
  297. " use a binary array or sparse matrix instead."
  298. )
  299. with pytest.raises(ValueError, match=msg):
  300. type_of_target(example)
  301. def test_type_of_target_pandas_sparse():
  302. pd = pytest.importorskip("pandas")
  303. y = pd.arrays.SparseArray([1, np.nan, np.nan, 1, np.nan])
  304. msg = "y cannot be class 'SparseSeries' or 'SparseArray'"
  305. with pytest.raises(ValueError, match=msg):
  306. type_of_target(y)
  307. def test_type_of_target_pandas_nullable():
  308. """Check that type_of_target works with pandas nullable dtypes."""
  309. pd = pytest.importorskip("pandas")
  310. for dtype in ["Int32", "Float32"]:
  311. y_true = pd.Series([1, 0, 2, 3, 4], dtype=dtype)
  312. assert type_of_target(y_true) == "multiclass"
  313. y_true = pd.Series([1, 0, 1, 0], dtype=dtype)
  314. assert type_of_target(y_true) == "binary"
  315. y_true = pd.DataFrame([[1.4, 3.1], [3.1, 1.4]], dtype="Float32")
  316. assert type_of_target(y_true) == "continuous-multioutput"
  317. y_true = pd.DataFrame([[0, 1], [1, 1]], dtype="Int32")
  318. assert type_of_target(y_true) == "multilabel-indicator"
  319. y_true = pd.DataFrame([[1, 2], [3, 1]], dtype="Int32")
  320. assert type_of_target(y_true) == "multiclass-multioutput"
  321. @pytest.mark.parametrize("dtype", ["Int64", "Float64", "boolean"])
  322. def test_unique_labels_pandas_nullable(dtype):
  323. """Checks that unique_labels work with pandas nullable dtypes.
  324. Non-regression test for gh-25634.
  325. """
  326. pd = pytest.importorskip("pandas")
  327. y_true = pd.Series([1, 0, 0, 1, 0, 1, 1, 0, 1], dtype=dtype)
  328. y_predicted = pd.Series([0, 0, 1, 1, 0, 1, 1, 1, 1], dtype="int64")
  329. labels = unique_labels(y_true, y_predicted)
  330. assert_array_equal(labels, [0, 1])
  331. def test_class_distribution():
  332. y = np.array(
  333. [
  334. [1, 0, 0, 1],
  335. [2, 2, 0, 1],
  336. [1, 3, 0, 1],
  337. [4, 2, 0, 1],
  338. [2, 0, 0, 1],
  339. [1, 3, 0, 1],
  340. ]
  341. )
  342. # Define the sparse matrix with a mix of implicit and explicit zeros
  343. data = np.array([1, 2, 1, 4, 2, 1, 0, 2, 3, 2, 3, 1, 1, 1, 1, 1, 1])
  344. indices = np.array([0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 5, 0, 1, 2, 3, 4, 5])
  345. indptr = np.array([0, 6, 11, 11, 17])
  346. y_sp = sp.csc_matrix((data, indices, indptr), shape=(6, 4))
  347. classes, n_classes, class_prior = class_distribution(y)
  348. classes_sp, n_classes_sp, class_prior_sp = class_distribution(y_sp)
  349. classes_expected = [[1, 2, 4], [0, 2, 3], [0], [1]]
  350. n_classes_expected = [3, 3, 1, 1]
  351. class_prior_expected = [[3 / 6, 2 / 6, 1 / 6], [1 / 3, 1 / 3, 1 / 3], [1.0], [1.0]]
  352. for k in range(y.shape[1]):
  353. assert_array_almost_equal(classes[k], classes_expected[k])
  354. assert_array_almost_equal(n_classes[k], n_classes_expected[k])
  355. assert_array_almost_equal(class_prior[k], class_prior_expected[k])
  356. assert_array_almost_equal(classes_sp[k], classes_expected[k])
  357. assert_array_almost_equal(n_classes_sp[k], n_classes_expected[k])
  358. assert_array_almost_equal(class_prior_sp[k], class_prior_expected[k])
  359. # Test again with explicit sample weights
  360. (classes, n_classes, class_prior) = class_distribution(
  361. y, [1.0, 2.0, 1.0, 2.0, 1.0, 2.0]
  362. )
  363. (classes_sp, n_classes_sp, class_prior_sp) = class_distribution(
  364. y, [1.0, 2.0, 1.0, 2.0, 1.0, 2.0]
  365. )
  366. class_prior_expected = [[4 / 9, 3 / 9, 2 / 9], [2 / 9, 4 / 9, 3 / 9], [1.0], [1.0]]
  367. for k in range(y.shape[1]):
  368. assert_array_almost_equal(classes[k], classes_expected[k])
  369. assert_array_almost_equal(n_classes[k], n_classes_expected[k])
  370. assert_array_almost_equal(class_prior[k], class_prior_expected[k])
  371. assert_array_almost_equal(classes_sp[k], classes_expected[k])
  372. assert_array_almost_equal(n_classes_sp[k], n_classes_expected[k])
  373. assert_array_almost_equal(class_prior_sp[k], class_prior_expected[k])
  374. def test_safe_split_with_precomputed_kernel():
  375. clf = SVC()
  376. clfp = SVC(kernel="precomputed")
  377. iris = datasets.load_iris()
  378. X, y = iris.data, iris.target
  379. K = np.dot(X, X.T)
  380. cv = ShuffleSplit(test_size=0.25, random_state=0)
  381. train, test = list(cv.split(X))[0]
  382. X_train, y_train = _safe_split(clf, X, y, train)
  383. K_train, y_train2 = _safe_split(clfp, K, y, train)
  384. assert_array_almost_equal(K_train, np.dot(X_train, X_train.T))
  385. assert_array_almost_equal(y_train, y_train2)
  386. X_test, y_test = _safe_split(clf, X, y, test, train)
  387. K_test, y_test2 = _safe_split(clfp, K, y, test, train)
  388. assert_array_almost_equal(K_test, np.dot(X_test, X_train.T))
  389. assert_array_almost_equal(y_test, y_test2)
  390. def test_ovr_decision_function():
  391. # test properties for ovr decision function
  392. predictions = np.array([[0, 1, 1], [0, 1, 0], [0, 1, 1], [0, 1, 1]])
  393. confidences = np.array(
  394. [[-1e16, 0, -1e16], [1.0, 2.0, -3.0], [-5.0, 2.0, 5.0], [-0.5, 0.2, 0.5]]
  395. )
  396. n_classes = 3
  397. dec_values = _ovr_decision_function(predictions, confidences, n_classes)
  398. # check that the decision values are within 0.5 range of the votes
  399. votes = np.array([[1, 0, 2], [1, 1, 1], [1, 0, 2], [1, 0, 2]])
  400. assert_allclose(votes, dec_values, atol=0.5)
  401. # check that the prediction are what we expect
  402. # highest vote or highest confidence if there is a tie.
  403. # for the second sample we have a tie (should be won by 1)
  404. expected_prediction = np.array([2, 1, 2, 2])
  405. assert_array_equal(np.argmax(dec_values, axis=1), expected_prediction)
  406. # third and fourth sample have the same vote but third sample
  407. # has higher confidence, this should reflect on the decision values
  408. assert dec_values[2, 2] > dec_values[3, 2]
  409. # assert subset invariance.
  410. dec_values_one = [
  411. _ovr_decision_function(
  412. np.array([predictions[i]]), np.array([confidences[i]]), n_classes
  413. )[0]
  414. for i in range(4)
  415. ]
  416. assert_allclose(dec_values, dec_values_one, atol=1e-6)