test_multiclass.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923
  1. from re import escape
  2. import numpy as np
  3. import pytest
  4. import scipy.sparse as sp
  5. from numpy.testing import assert_allclose
  6. from sklearn import datasets, svm
  7. from sklearn.datasets import load_breast_cancer
  8. from sklearn.exceptions import NotFittedError
  9. from sklearn.impute import SimpleImputer
  10. from sklearn.linear_model import (
  11. ElasticNet,
  12. Lasso,
  13. LinearRegression,
  14. LogisticRegression,
  15. Perceptron,
  16. Ridge,
  17. SGDClassifier,
  18. )
  19. from sklearn.metrics import precision_score, recall_score
  20. from sklearn.model_selection import GridSearchCV, cross_val_score
  21. from sklearn.multiclass import (
  22. OneVsOneClassifier,
  23. OneVsRestClassifier,
  24. OutputCodeClassifier,
  25. )
  26. from sklearn.naive_bayes import MultinomialNB
  27. from sklearn.neighbors import KNeighborsClassifier
  28. from sklearn.pipeline import Pipeline, make_pipeline
  29. from sklearn.svm import SVC, LinearSVC
  30. from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
  31. from sklearn.utils import (
  32. check_array,
  33. shuffle,
  34. )
  35. from sklearn.utils._mocking import CheckingClassifier
  36. from sklearn.utils._testing import assert_almost_equal, assert_array_equal
  37. from sklearn.utils.multiclass import check_classification_targets, type_of_target
  38. msg = "The default value for `force_alpha` will change"
  39. pytestmark = pytest.mark.filterwarnings(f"ignore:{msg}:FutureWarning")
  40. iris = datasets.load_iris()
  41. rng = np.random.RandomState(0)
  42. perm = rng.permutation(iris.target.size)
  43. iris.data = iris.data[perm]
  44. iris.target = iris.target[perm]
  45. n_classes = 3
  46. def test_ovr_exceptions():
  47. ovr = OneVsRestClassifier(LinearSVC(dual="auto", random_state=0))
  48. # test predicting without fitting
  49. with pytest.raises(NotFittedError):
  50. ovr.predict([])
  51. # Fail on multioutput data
  52. msg = "Multioutput target data is not supported with label binarization"
  53. with pytest.raises(ValueError, match=msg):
  54. X = np.array([[1, 0], [0, 1]])
  55. y = np.array([[1, 2], [3, 1]])
  56. OneVsRestClassifier(MultinomialNB()).fit(X, y)
  57. with pytest.raises(ValueError, match=msg):
  58. X = np.array([[1, 0], [0, 1]])
  59. y = np.array([[1.5, 2.4], [3.1, 0.8]])
  60. OneVsRestClassifier(MultinomialNB()).fit(X, y)
  61. def test_check_classification_targets():
  62. # Test that check_classification_target return correct type. #5782
  63. y = np.array([0.0, 1.1, 2.0, 3.0])
  64. msg = type_of_target(y)
  65. with pytest.raises(ValueError, match=msg):
  66. check_classification_targets(y)
  67. def test_ovr_fit_predict():
  68. # A classifier which implements decision_function.
  69. ovr = OneVsRestClassifier(LinearSVC(dual="auto", random_state=0))
  70. pred = ovr.fit(iris.data, iris.target).predict(iris.data)
  71. assert len(ovr.estimators_) == n_classes
  72. clf = LinearSVC(dual="auto", random_state=0)
  73. pred2 = clf.fit(iris.data, iris.target).predict(iris.data)
  74. assert np.mean(iris.target == pred) == np.mean(iris.target == pred2)
  75. # A classifier which implements predict_proba.
  76. ovr = OneVsRestClassifier(MultinomialNB())
  77. pred = ovr.fit(iris.data, iris.target).predict(iris.data)
  78. assert np.mean(iris.target == pred) > 0.65
  79. def test_ovr_partial_fit():
  80. # Test if partial_fit is working as intended
  81. X, y = shuffle(iris.data, iris.target, random_state=0)
  82. ovr = OneVsRestClassifier(MultinomialNB())
  83. ovr.partial_fit(X[:100], y[:100], np.unique(y))
  84. ovr.partial_fit(X[100:], y[100:])
  85. pred = ovr.predict(X)
  86. ovr2 = OneVsRestClassifier(MultinomialNB())
  87. pred2 = ovr2.fit(X, y).predict(X)
  88. assert_almost_equal(pred, pred2)
  89. assert len(ovr.estimators_) == len(np.unique(y))
  90. assert np.mean(y == pred) > 0.65
  91. # Test when mini batches doesn't have all classes
  92. # with SGDClassifier
  93. X = np.abs(np.random.randn(14, 2))
  94. y = [1, 1, 1, 1, 2, 3, 3, 0, 0, 2, 3, 1, 2, 3]
  95. ovr = OneVsRestClassifier(
  96. SGDClassifier(max_iter=1, tol=None, shuffle=False, random_state=0)
  97. )
  98. ovr.partial_fit(X[:7], y[:7], np.unique(y))
  99. ovr.partial_fit(X[7:], y[7:])
  100. pred = ovr.predict(X)
  101. ovr1 = OneVsRestClassifier(
  102. SGDClassifier(max_iter=1, tol=None, shuffle=False, random_state=0)
  103. )
  104. pred1 = ovr1.fit(X, y).predict(X)
  105. assert np.mean(pred == y) == np.mean(pred1 == y)
  106. # test partial_fit only exists if estimator has it:
  107. ovr = OneVsRestClassifier(SVC())
  108. assert not hasattr(ovr, "partial_fit")
  109. def test_ovr_partial_fit_exceptions():
  110. ovr = OneVsRestClassifier(MultinomialNB())
  111. X = np.abs(np.random.randn(14, 2))
  112. y = [1, 1, 1, 1, 2, 3, 3, 0, 0, 2, 3, 1, 2, 3]
  113. ovr.partial_fit(X[:7], y[:7], np.unique(y))
  114. # If a new class that was not in the first call of partial fit is seen
  115. # it should raise ValueError
  116. y1 = [5] + y[7:-1]
  117. msg = r"Mini-batch contains \[.+\] while classes must be subset of \[.+\]"
  118. with pytest.raises(ValueError, match=msg):
  119. ovr.partial_fit(X=X[7:], y=y1)
  120. def test_ovr_ovo_regressor():
  121. # test that ovr and ovo work on regressors which don't have a decision_
  122. # function
  123. ovr = OneVsRestClassifier(DecisionTreeRegressor())
  124. pred = ovr.fit(iris.data, iris.target).predict(iris.data)
  125. assert len(ovr.estimators_) == n_classes
  126. assert_array_equal(np.unique(pred), [0, 1, 2])
  127. # we are doing something sensible
  128. assert np.mean(pred == iris.target) > 0.9
  129. ovr = OneVsOneClassifier(DecisionTreeRegressor())
  130. pred = ovr.fit(iris.data, iris.target).predict(iris.data)
  131. assert len(ovr.estimators_) == n_classes * (n_classes - 1) / 2
  132. assert_array_equal(np.unique(pred), [0, 1, 2])
  133. # we are doing something sensible
  134. assert np.mean(pred == iris.target) > 0.9
  135. def test_ovr_fit_predict_sparse():
  136. for sparse in [
  137. sp.csr_matrix,
  138. sp.csc_matrix,
  139. sp.coo_matrix,
  140. sp.dok_matrix,
  141. sp.lil_matrix,
  142. ]:
  143. base_clf = MultinomialNB(alpha=1)
  144. X, Y = datasets.make_multilabel_classification(
  145. n_samples=100,
  146. n_features=20,
  147. n_classes=5,
  148. n_labels=3,
  149. length=50,
  150. allow_unlabeled=True,
  151. random_state=0,
  152. )
  153. X_train, Y_train = X[:80], Y[:80]
  154. X_test = X[80:]
  155. clf = OneVsRestClassifier(base_clf).fit(X_train, Y_train)
  156. Y_pred = clf.predict(X_test)
  157. clf_sprs = OneVsRestClassifier(base_clf).fit(X_train, sparse(Y_train))
  158. Y_pred_sprs = clf_sprs.predict(X_test)
  159. assert clf.multilabel_
  160. assert sp.issparse(Y_pred_sprs)
  161. assert_array_equal(Y_pred_sprs.toarray(), Y_pred)
  162. # Test predict_proba
  163. Y_proba = clf_sprs.predict_proba(X_test)
  164. # predict assigns a label if the probability that the
  165. # sample has the label is greater than 0.5.
  166. pred = Y_proba > 0.5
  167. assert_array_equal(pred, Y_pred_sprs.toarray())
  168. # Test decision_function
  169. clf = svm.SVC()
  170. clf_sprs = OneVsRestClassifier(clf).fit(X_train, sparse(Y_train))
  171. dec_pred = (clf_sprs.decision_function(X_test) > 0).astype(int)
  172. assert_array_equal(dec_pred, clf_sprs.predict(X_test).toarray())
  173. def test_ovr_always_present():
  174. # Test that ovr works with classes that are always present or absent.
  175. # Note: tests is the case where _ConstantPredictor is utilised
  176. X = np.ones((10, 2))
  177. X[:5, :] = 0
  178. # Build an indicator matrix where two features are always on.
  179. # As list of lists, it would be: [[int(i >= 5), 2, 3] for i in range(10)]
  180. y = np.zeros((10, 3))
  181. y[5:, 0] = 1
  182. y[:, 1] = 1
  183. y[:, 2] = 1
  184. ovr = OneVsRestClassifier(LogisticRegression())
  185. msg = r"Label .+ is present in all training examples"
  186. with pytest.warns(UserWarning, match=msg):
  187. ovr.fit(X, y)
  188. y_pred = ovr.predict(X)
  189. assert_array_equal(np.array(y_pred), np.array(y))
  190. y_pred = ovr.decision_function(X)
  191. assert np.unique(y_pred[:, -2:]) == 1
  192. y_pred = ovr.predict_proba(X)
  193. assert_array_equal(y_pred[:, -1], np.ones(X.shape[0]))
  194. # y has a constantly absent label
  195. y = np.zeros((10, 2))
  196. y[5:, 0] = 1 # variable label
  197. ovr = OneVsRestClassifier(LogisticRegression())
  198. msg = r"Label not 1 is present in all training examples"
  199. with pytest.warns(UserWarning, match=msg):
  200. ovr.fit(X, y)
  201. y_pred = ovr.predict_proba(X)
  202. assert_array_equal(y_pred[:, -1], np.zeros(X.shape[0]))
  203. def test_ovr_multiclass():
  204. # Toy dataset where features correspond directly to labels.
  205. X = np.array([[0, 0, 5], [0, 5, 0], [3, 0, 0], [0, 0, 6], [6, 0, 0]])
  206. y = ["eggs", "spam", "ham", "eggs", "ham"]
  207. Y = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]])
  208. classes = set("ham eggs spam".split())
  209. for base_clf in (
  210. MultinomialNB(),
  211. LinearSVC(dual="auto", random_state=0),
  212. LinearRegression(),
  213. Ridge(),
  214. ElasticNet(),
  215. ):
  216. clf = OneVsRestClassifier(base_clf).fit(X, y)
  217. assert set(clf.classes_) == classes
  218. y_pred = clf.predict(np.array([[0, 0, 4]]))[0]
  219. assert_array_equal(y_pred, ["eggs"])
  220. # test input as label indicator matrix
  221. clf = OneVsRestClassifier(base_clf).fit(X, Y)
  222. y_pred = clf.predict([[0, 0, 4]])[0]
  223. assert_array_equal(y_pred, [0, 0, 1])
  224. def test_ovr_binary():
  225. # Toy dataset where features correspond directly to labels.
  226. X = np.array([[0, 0, 5], [0, 5, 0], [3, 0, 0], [0, 0, 6], [6, 0, 0]])
  227. y = ["eggs", "spam", "spam", "eggs", "spam"]
  228. Y = np.array([[0, 1, 1, 0, 1]]).T
  229. classes = set("eggs spam".split())
  230. def conduct_test(base_clf, test_predict_proba=False):
  231. clf = OneVsRestClassifier(base_clf).fit(X, y)
  232. assert set(clf.classes_) == classes
  233. y_pred = clf.predict(np.array([[0, 0, 4]]))[0]
  234. assert_array_equal(y_pred, ["eggs"])
  235. if hasattr(base_clf, "decision_function"):
  236. dec = clf.decision_function(X)
  237. assert dec.shape == (5,)
  238. if test_predict_proba:
  239. X_test = np.array([[0, 0, 4]])
  240. probabilities = clf.predict_proba(X_test)
  241. assert 2 == len(probabilities[0])
  242. assert clf.classes_[np.argmax(probabilities, axis=1)] == clf.predict(X_test)
  243. # test input as label indicator matrix
  244. clf = OneVsRestClassifier(base_clf).fit(X, Y)
  245. y_pred = clf.predict([[3, 0, 0]])[0]
  246. assert y_pred == 1
  247. for base_clf in (
  248. LinearSVC(dual="auto", random_state=0),
  249. LinearRegression(),
  250. Ridge(),
  251. ElasticNet(),
  252. ):
  253. conduct_test(base_clf)
  254. for base_clf in (MultinomialNB(), SVC(probability=True), LogisticRegression()):
  255. conduct_test(base_clf, test_predict_proba=True)
  256. def test_ovr_multilabel():
  257. # Toy dataset where features correspond directly to labels.
  258. X = np.array([[0, 4, 5], [0, 5, 0], [3, 3, 3], [4, 0, 6], [6, 0, 0]])
  259. y = np.array([[0, 1, 1], [0, 1, 0], [1, 1, 1], [1, 0, 1], [1, 0, 0]])
  260. for base_clf in (
  261. MultinomialNB(),
  262. LinearSVC(dual="auto", random_state=0),
  263. LinearRegression(),
  264. Ridge(),
  265. ElasticNet(),
  266. Lasso(alpha=0.5),
  267. ):
  268. clf = OneVsRestClassifier(base_clf).fit(X, y)
  269. y_pred = clf.predict([[0, 4, 4]])[0]
  270. assert_array_equal(y_pred, [0, 1, 1])
  271. assert clf.multilabel_
  272. def test_ovr_fit_predict_svc():
  273. ovr = OneVsRestClassifier(svm.SVC())
  274. ovr.fit(iris.data, iris.target)
  275. assert len(ovr.estimators_) == 3
  276. assert ovr.score(iris.data, iris.target) > 0.9
  277. def test_ovr_multilabel_dataset():
  278. base_clf = MultinomialNB(alpha=1)
  279. for au, prec, recall in zip((True, False), (0.51, 0.66), (0.51, 0.80)):
  280. X, Y = datasets.make_multilabel_classification(
  281. n_samples=100,
  282. n_features=20,
  283. n_classes=5,
  284. n_labels=2,
  285. length=50,
  286. allow_unlabeled=au,
  287. random_state=0,
  288. )
  289. X_train, Y_train = X[:80], Y[:80]
  290. X_test, Y_test = X[80:], Y[80:]
  291. clf = OneVsRestClassifier(base_clf).fit(X_train, Y_train)
  292. Y_pred = clf.predict(X_test)
  293. assert clf.multilabel_
  294. assert_almost_equal(
  295. precision_score(Y_test, Y_pred, average="micro"), prec, decimal=2
  296. )
  297. assert_almost_equal(
  298. recall_score(Y_test, Y_pred, average="micro"), recall, decimal=2
  299. )
  300. def test_ovr_multilabel_predict_proba():
  301. base_clf = MultinomialNB(alpha=1)
  302. for au in (False, True):
  303. X, Y = datasets.make_multilabel_classification(
  304. n_samples=100,
  305. n_features=20,
  306. n_classes=5,
  307. n_labels=3,
  308. length=50,
  309. allow_unlabeled=au,
  310. random_state=0,
  311. )
  312. X_train, Y_train = X[:80], Y[:80]
  313. X_test = X[80:]
  314. clf = OneVsRestClassifier(base_clf).fit(X_train, Y_train)
  315. # Decision function only estimator.
  316. decision_only = OneVsRestClassifier(svm.SVR()).fit(X_train, Y_train)
  317. assert not hasattr(decision_only, "predict_proba")
  318. # Estimator with predict_proba disabled, depending on parameters.
  319. decision_only = OneVsRestClassifier(svm.SVC(probability=False))
  320. assert not hasattr(decision_only, "predict_proba")
  321. decision_only.fit(X_train, Y_train)
  322. assert not hasattr(decision_only, "predict_proba")
  323. assert hasattr(decision_only, "decision_function")
  324. # Estimator which can get predict_proba enabled after fitting
  325. gs = GridSearchCV(
  326. svm.SVC(probability=False), param_grid={"probability": [True]}
  327. )
  328. proba_after_fit = OneVsRestClassifier(gs)
  329. assert not hasattr(proba_after_fit, "predict_proba")
  330. proba_after_fit.fit(X_train, Y_train)
  331. assert hasattr(proba_after_fit, "predict_proba")
  332. Y_pred = clf.predict(X_test)
  333. Y_proba = clf.predict_proba(X_test)
  334. # predict assigns a label if the probability that the
  335. # sample has the label is greater than 0.5.
  336. pred = Y_proba > 0.5
  337. assert_array_equal(pred, Y_pred)
  338. def test_ovr_single_label_predict_proba():
  339. base_clf = MultinomialNB(alpha=1)
  340. X, Y = iris.data, iris.target
  341. X_train, Y_train = X[:80], Y[:80]
  342. X_test = X[80:]
  343. clf = OneVsRestClassifier(base_clf).fit(X_train, Y_train)
  344. # Decision function only estimator.
  345. decision_only = OneVsRestClassifier(svm.SVR()).fit(X_train, Y_train)
  346. assert not hasattr(decision_only, "predict_proba")
  347. Y_pred = clf.predict(X_test)
  348. Y_proba = clf.predict_proba(X_test)
  349. assert_almost_equal(Y_proba.sum(axis=1), 1.0)
  350. # predict assigns a label if the probability that the
  351. # sample has the label with the greatest predictive probability.
  352. pred = Y_proba.argmax(axis=1)
  353. assert not (pred - Y_pred).any()
  354. def test_ovr_multilabel_decision_function():
  355. X, Y = datasets.make_multilabel_classification(
  356. n_samples=100,
  357. n_features=20,
  358. n_classes=5,
  359. n_labels=3,
  360. length=50,
  361. allow_unlabeled=True,
  362. random_state=0,
  363. )
  364. X_train, Y_train = X[:80], Y[:80]
  365. X_test = X[80:]
  366. clf = OneVsRestClassifier(svm.SVC()).fit(X_train, Y_train)
  367. assert_array_equal(
  368. (clf.decision_function(X_test) > 0).astype(int), clf.predict(X_test)
  369. )
  370. def test_ovr_single_label_decision_function():
  371. X, Y = datasets.make_classification(n_samples=100, n_features=20, random_state=0)
  372. X_train, Y_train = X[:80], Y[:80]
  373. X_test = X[80:]
  374. clf = OneVsRestClassifier(svm.SVC()).fit(X_train, Y_train)
  375. assert_array_equal(clf.decision_function(X_test).ravel() > 0, clf.predict(X_test))
  376. def test_ovr_gridsearch():
  377. ovr = OneVsRestClassifier(LinearSVC(dual="auto", random_state=0))
  378. Cs = [0.1, 0.5, 0.8]
  379. cv = GridSearchCV(ovr, {"estimator__C": Cs})
  380. cv.fit(iris.data, iris.target)
  381. best_C = cv.best_estimator_.estimators_[0].C
  382. assert best_C in Cs
  383. def test_ovr_pipeline():
  384. # Test with pipeline of length one
  385. # This test is needed because the multiclass estimators may fail to detect
  386. # the presence of predict_proba or decision_function.
  387. clf = Pipeline([("tree", DecisionTreeClassifier())])
  388. ovr_pipe = OneVsRestClassifier(clf)
  389. ovr_pipe.fit(iris.data, iris.target)
  390. ovr = OneVsRestClassifier(DecisionTreeClassifier())
  391. ovr.fit(iris.data, iris.target)
  392. assert_array_equal(ovr.predict(iris.data), ovr_pipe.predict(iris.data))
  393. def test_ovo_exceptions():
  394. ovo = OneVsOneClassifier(LinearSVC(dual="auto", random_state=0))
  395. with pytest.raises(NotFittedError):
  396. ovo.predict([])
  397. def test_ovo_fit_on_list():
  398. # Test that OneVsOne fitting works with a list of targets and yields the
  399. # same output as predict from an array
  400. ovo = OneVsOneClassifier(LinearSVC(dual="auto", random_state=0))
  401. prediction_from_array = ovo.fit(iris.data, iris.target).predict(iris.data)
  402. iris_data_list = [list(a) for a in iris.data]
  403. prediction_from_list = ovo.fit(iris_data_list, list(iris.target)).predict(
  404. iris_data_list
  405. )
  406. assert_array_equal(prediction_from_array, prediction_from_list)
  407. def test_ovo_fit_predict():
  408. # A classifier which implements decision_function.
  409. ovo = OneVsOneClassifier(LinearSVC(dual="auto", random_state=0))
  410. ovo.fit(iris.data, iris.target).predict(iris.data)
  411. assert len(ovo.estimators_) == n_classes * (n_classes - 1) / 2
  412. # A classifier which implements predict_proba.
  413. ovo = OneVsOneClassifier(MultinomialNB())
  414. ovo.fit(iris.data, iris.target).predict(iris.data)
  415. assert len(ovo.estimators_) == n_classes * (n_classes - 1) / 2
  416. def test_ovo_partial_fit_predict():
  417. temp = datasets.load_iris()
  418. X, y = temp.data, temp.target
  419. ovo1 = OneVsOneClassifier(MultinomialNB())
  420. ovo1.partial_fit(X[:100], y[:100], np.unique(y))
  421. ovo1.partial_fit(X[100:], y[100:])
  422. pred1 = ovo1.predict(X)
  423. ovo2 = OneVsOneClassifier(MultinomialNB())
  424. ovo2.fit(X, y)
  425. pred2 = ovo2.predict(X)
  426. assert len(ovo1.estimators_) == n_classes * (n_classes - 1) / 2
  427. assert np.mean(y == pred1) > 0.65
  428. assert_almost_equal(pred1, pred2)
  429. # Test when mini-batches have binary target classes
  430. ovo1 = OneVsOneClassifier(MultinomialNB())
  431. ovo1.partial_fit(X[:60], y[:60], np.unique(y))
  432. ovo1.partial_fit(X[60:], y[60:])
  433. pred1 = ovo1.predict(X)
  434. ovo2 = OneVsOneClassifier(MultinomialNB())
  435. pred2 = ovo2.fit(X, y).predict(X)
  436. assert_almost_equal(pred1, pred2)
  437. assert len(ovo1.estimators_) == len(np.unique(y))
  438. assert np.mean(y == pred1) > 0.65
  439. ovo = OneVsOneClassifier(MultinomialNB())
  440. X = np.random.rand(14, 2)
  441. y = [1, 1, 2, 3, 3, 0, 0, 4, 4, 4, 4, 4, 2, 2]
  442. ovo.partial_fit(X[:7], y[:7], [0, 1, 2, 3, 4])
  443. ovo.partial_fit(X[7:], y[7:])
  444. pred = ovo.predict(X)
  445. ovo2 = OneVsOneClassifier(MultinomialNB())
  446. pred2 = ovo2.fit(X, y).predict(X)
  447. assert_almost_equal(pred, pred2)
  448. # raises error when mini-batch does not have classes from all_classes
  449. ovo = OneVsOneClassifier(MultinomialNB())
  450. error_y = [0, 1, 2, 3, 4, 5, 2]
  451. message_re = escape(
  452. "Mini-batch contains {0} while it must be subset of {1}".format(
  453. np.unique(error_y), np.unique(y)
  454. )
  455. )
  456. with pytest.raises(ValueError, match=message_re):
  457. ovo.partial_fit(X[:7], error_y, np.unique(y))
  458. # test partial_fit only exists if estimator has it:
  459. ovr = OneVsOneClassifier(SVC())
  460. assert not hasattr(ovr, "partial_fit")
  461. def test_ovo_decision_function():
  462. n_samples = iris.data.shape[0]
  463. ovo_clf = OneVsOneClassifier(LinearSVC(dual="auto", random_state=0))
  464. # first binary
  465. ovo_clf.fit(iris.data, iris.target == 0)
  466. decisions = ovo_clf.decision_function(iris.data)
  467. assert decisions.shape == (n_samples,)
  468. # then multi-class
  469. ovo_clf.fit(iris.data, iris.target)
  470. decisions = ovo_clf.decision_function(iris.data)
  471. assert decisions.shape == (n_samples, n_classes)
  472. assert_array_equal(decisions.argmax(axis=1), ovo_clf.predict(iris.data))
  473. # Compute the votes
  474. votes = np.zeros((n_samples, n_classes))
  475. k = 0
  476. for i in range(n_classes):
  477. for j in range(i + 1, n_classes):
  478. pred = ovo_clf.estimators_[k].predict(iris.data)
  479. votes[pred == 0, i] += 1
  480. votes[pred == 1, j] += 1
  481. k += 1
  482. # Extract votes and verify
  483. assert_array_equal(votes, np.round(decisions))
  484. for class_idx in range(n_classes):
  485. # For each sample and each class, there only 3 possible vote levels
  486. # because they are only 3 distinct class pairs thus 3 distinct
  487. # binary classifiers.
  488. # Therefore, sorting predictions based on votes would yield
  489. # mostly tied predictions:
  490. assert set(votes[:, class_idx]).issubset(set([0.0, 1.0, 2.0]))
  491. # The OVO decision function on the other hand is able to resolve
  492. # most of the ties on this data as it combines both the vote counts
  493. # and the aggregated confidence levels of the binary classifiers
  494. # to compute the aggregate decision function. The iris dataset
  495. # has 150 samples with a couple of duplicates. The OvO decisions
  496. # can resolve most of the ties:
  497. assert len(np.unique(decisions[:, class_idx])) > 146
  498. def test_ovo_gridsearch():
  499. ovo = OneVsOneClassifier(LinearSVC(dual="auto", random_state=0))
  500. Cs = [0.1, 0.5, 0.8]
  501. cv = GridSearchCV(ovo, {"estimator__C": Cs})
  502. cv.fit(iris.data, iris.target)
  503. best_C = cv.best_estimator_.estimators_[0].C
  504. assert best_C in Cs
  505. def test_ovo_ties():
  506. # Test that ties are broken using the decision function,
  507. # not defaulting to the smallest label
  508. X = np.array([[1, 2], [2, 1], [-2, 1], [-2, -1]])
  509. y = np.array([2, 0, 1, 2])
  510. multi_clf = OneVsOneClassifier(Perceptron(shuffle=False, max_iter=4, tol=None))
  511. ovo_prediction = multi_clf.fit(X, y).predict(X)
  512. ovo_decision = multi_clf.decision_function(X)
  513. # Classifiers are in order 0-1, 0-2, 1-2
  514. # Use decision_function to compute the votes and the normalized
  515. # sum_of_confidences, which is used to disambiguate when there is a tie in
  516. # votes.
  517. votes = np.round(ovo_decision)
  518. normalized_confidences = ovo_decision - votes
  519. # For the first point, there is one vote per class
  520. assert_array_equal(votes[0, :], 1)
  521. # For the rest, there is no tie and the prediction is the argmax
  522. assert_array_equal(np.argmax(votes[1:], axis=1), ovo_prediction[1:])
  523. # For the tie, the prediction is the class with the highest score
  524. assert ovo_prediction[0] == normalized_confidences[0].argmax()
  525. def test_ovo_ties2():
  526. # test that ties can not only be won by the first two labels
  527. X = np.array([[1, 2], [2, 1], [-2, 1], [-2, -1]])
  528. y_ref = np.array([2, 0, 1, 2])
  529. # cycle through labels so that each label wins once
  530. for i in range(3):
  531. y = (y_ref + i) % 3
  532. multi_clf = OneVsOneClassifier(Perceptron(shuffle=False, max_iter=4, tol=None))
  533. ovo_prediction = multi_clf.fit(X, y).predict(X)
  534. assert ovo_prediction[0] == i % 3
  535. def test_ovo_string_y():
  536. # Test that the OvO doesn't mess up the encoding of string labels
  537. X = np.eye(4)
  538. y = np.array(["a", "b", "c", "d"])
  539. ovo = OneVsOneClassifier(LinearSVC(dual="auto"))
  540. ovo.fit(X, y)
  541. assert_array_equal(y, ovo.predict(X))
  542. def test_ovo_one_class():
  543. # Test error for OvO with one class
  544. X = np.eye(4)
  545. y = np.array(["a"] * 4)
  546. ovo = OneVsOneClassifier(LinearSVC(dual="auto"))
  547. msg = "when only one class"
  548. with pytest.raises(ValueError, match=msg):
  549. ovo.fit(X, y)
  550. def test_ovo_float_y():
  551. # Test that the OvO errors on float targets
  552. X = iris.data
  553. y = iris.data[:, 0]
  554. ovo = OneVsOneClassifier(LinearSVC(dual="auto"))
  555. msg = "Unknown label type"
  556. with pytest.raises(ValueError, match=msg):
  557. ovo.fit(X, y)
  558. def test_ecoc_exceptions():
  559. ecoc = OutputCodeClassifier(LinearSVC(dual="auto", random_state=0))
  560. with pytest.raises(NotFittedError):
  561. ecoc.predict([])
  562. def test_ecoc_fit_predict():
  563. # A classifier which implements decision_function.
  564. ecoc = OutputCodeClassifier(
  565. LinearSVC(dual="auto", random_state=0), code_size=2, random_state=0
  566. )
  567. ecoc.fit(iris.data, iris.target).predict(iris.data)
  568. assert len(ecoc.estimators_) == n_classes * 2
  569. # A classifier which implements predict_proba.
  570. ecoc = OutputCodeClassifier(MultinomialNB(), code_size=2, random_state=0)
  571. ecoc.fit(iris.data, iris.target).predict(iris.data)
  572. assert len(ecoc.estimators_) == n_classes * 2
  573. def test_ecoc_gridsearch():
  574. ecoc = OutputCodeClassifier(LinearSVC(dual="auto", random_state=0), random_state=0)
  575. Cs = [0.1, 0.5, 0.8]
  576. cv = GridSearchCV(ecoc, {"estimator__C": Cs})
  577. cv.fit(iris.data, iris.target)
  578. best_C = cv.best_estimator_.estimators_[0].C
  579. assert best_C in Cs
  580. def test_ecoc_float_y():
  581. # Test that the OCC errors on float targets
  582. X = iris.data
  583. y = iris.data[:, 0]
  584. ovo = OutputCodeClassifier(LinearSVC(dual="auto"))
  585. msg = "Unknown label type"
  586. with pytest.raises(ValueError, match=msg):
  587. ovo.fit(X, y)
  588. def test_ecoc_delegate_sparse_base_estimator():
  589. # Non-regression test for
  590. # https://github.com/scikit-learn/scikit-learn/issues/17218
  591. X, y = iris.data, iris.target
  592. X_sp = sp.csc_matrix(X)
  593. # create an estimator that does not support sparse input
  594. base_estimator = CheckingClassifier(
  595. check_X=check_array,
  596. check_X_params={"ensure_2d": True, "accept_sparse": False},
  597. )
  598. ecoc = OutputCodeClassifier(base_estimator, random_state=0)
  599. with pytest.raises(TypeError, match="A sparse matrix was passed"):
  600. ecoc.fit(X_sp, y)
  601. ecoc.fit(X, y)
  602. with pytest.raises(TypeError, match="A sparse matrix was passed"):
  603. ecoc.predict(X_sp)
  604. # smoke test to check when sparse input should be supported
  605. ecoc = OutputCodeClassifier(LinearSVC(dual="auto", random_state=0))
  606. ecoc.fit(X_sp, y).predict(X_sp)
  607. assert len(ecoc.estimators_) == 4
  608. def test_pairwise_indices():
  609. clf_precomputed = svm.SVC(kernel="precomputed")
  610. X, y = iris.data, iris.target
  611. ovr_false = OneVsOneClassifier(clf_precomputed)
  612. linear_kernel = np.dot(X, X.T)
  613. ovr_false.fit(linear_kernel, y)
  614. n_estimators = len(ovr_false.estimators_)
  615. precomputed_indices = ovr_false.pairwise_indices_
  616. for idx in precomputed_indices:
  617. assert (
  618. idx.shape[0] * n_estimators / (n_estimators - 1) == linear_kernel.shape[0]
  619. )
  620. def test_pairwise_n_features_in():
  621. """Check the n_features_in_ attributes of the meta and base estimators
  622. When the training data is a regular design matrix, everything is intuitive.
  623. However, when the training data is a precomputed kernel matrix, the
  624. multiclass strategy can resample the kernel matrix of the underlying base
  625. estimator both row-wise and column-wise and this has a non-trivial impact
  626. on the expected value for the n_features_in_ of both the meta and the base
  627. estimators.
  628. """
  629. X, y = iris.data, iris.target
  630. # Remove the last sample to make the classes not exactly balanced and make
  631. # the test more interesting.
  632. assert y[-1] == 0
  633. X = X[:-1]
  634. y = y[:-1]
  635. # Fitting directly on the design matrix:
  636. assert X.shape == (149, 4)
  637. clf_notprecomputed = svm.SVC(kernel="linear").fit(X, y)
  638. assert clf_notprecomputed.n_features_in_ == 4
  639. ovr_notprecomputed = OneVsRestClassifier(clf_notprecomputed).fit(X, y)
  640. assert ovr_notprecomputed.n_features_in_ == 4
  641. for est in ovr_notprecomputed.estimators_:
  642. assert est.n_features_in_ == 4
  643. ovo_notprecomputed = OneVsOneClassifier(clf_notprecomputed).fit(X, y)
  644. assert ovo_notprecomputed.n_features_in_ == 4
  645. assert ovo_notprecomputed.n_classes_ == 3
  646. assert len(ovo_notprecomputed.estimators_) == 3
  647. for est in ovo_notprecomputed.estimators_:
  648. assert est.n_features_in_ == 4
  649. # When working with precomputed kernels we have one "feature" per training
  650. # sample:
  651. K = X @ X.T
  652. assert K.shape == (149, 149)
  653. clf_precomputed = svm.SVC(kernel="precomputed").fit(K, y)
  654. assert clf_precomputed.n_features_in_ == 149
  655. ovr_precomputed = OneVsRestClassifier(clf_precomputed).fit(K, y)
  656. assert ovr_precomputed.n_features_in_ == 149
  657. assert ovr_precomputed.n_classes_ == 3
  658. assert len(ovr_precomputed.estimators_) == 3
  659. for est in ovr_precomputed.estimators_:
  660. assert est.n_features_in_ == 149
  661. # This becomes really interesting with OvO and precomputed kernel together:
  662. # internally, OvO will drop the samples of the classes not part of the pair
  663. # of classes under consideration for a given binary classifier. Since we
  664. # use a precomputed kernel, it will also drop the matching columns of the
  665. # kernel matrix, and therefore we have fewer "features" as result.
  666. #
  667. # Since class 0 has 49 samples, and class 1 and 2 have 50 samples each, a
  668. # single OvO binary classifier works with a sub-kernel matrix of shape
  669. # either (99, 99) or (100, 100).
  670. ovo_precomputed = OneVsOneClassifier(clf_precomputed).fit(K, y)
  671. assert ovo_precomputed.n_features_in_ == 149
  672. assert ovr_precomputed.n_classes_ == 3
  673. assert len(ovr_precomputed.estimators_) == 3
  674. assert ovo_precomputed.estimators_[0].n_features_in_ == 99 # class 0 vs class 1
  675. assert ovo_precomputed.estimators_[1].n_features_in_ == 99 # class 0 vs class 2
  676. assert ovo_precomputed.estimators_[2].n_features_in_ == 100 # class 1 vs class 2
  677. @pytest.mark.parametrize(
  678. "MultiClassClassifier", [OneVsRestClassifier, OneVsOneClassifier]
  679. )
  680. def test_pairwise_tag(MultiClassClassifier):
  681. clf_precomputed = svm.SVC(kernel="precomputed")
  682. clf_notprecomputed = svm.SVC()
  683. ovr_false = MultiClassClassifier(clf_notprecomputed)
  684. assert not ovr_false._get_tags()["pairwise"]
  685. ovr_true = MultiClassClassifier(clf_precomputed)
  686. assert ovr_true._get_tags()["pairwise"]
  687. @pytest.mark.parametrize(
  688. "MultiClassClassifier", [OneVsRestClassifier, OneVsOneClassifier]
  689. )
  690. def test_pairwise_cross_val_score(MultiClassClassifier):
  691. clf_precomputed = svm.SVC(kernel="precomputed")
  692. clf_notprecomputed = svm.SVC(kernel="linear")
  693. X, y = iris.data, iris.target
  694. multiclass_clf_notprecomputed = MultiClassClassifier(clf_notprecomputed)
  695. multiclass_clf_precomputed = MultiClassClassifier(clf_precomputed)
  696. linear_kernel = np.dot(X, X.T)
  697. score_not_precomputed = cross_val_score(
  698. multiclass_clf_notprecomputed, X, y, error_score="raise"
  699. )
  700. score_precomputed = cross_val_score(
  701. multiclass_clf_precomputed, linear_kernel, y, error_score="raise"
  702. )
  703. assert_array_equal(score_precomputed, score_not_precomputed)
  704. @pytest.mark.parametrize(
  705. "MultiClassClassifier", [OneVsRestClassifier, OneVsOneClassifier]
  706. )
  707. # FIXME: we should move this test in `estimator_checks` once we are able
  708. # to construct meta-estimator instances
  709. def test_support_missing_values(MultiClassClassifier):
  710. # smoke test to check that pipeline OvR and OvO classifiers are letting
  711. # the validation of missing values to
  712. # the underlying pipeline or classifiers
  713. rng = np.random.RandomState(42)
  714. X, y = iris.data, iris.target
  715. X = np.copy(X) # Copy to avoid that the original data is modified
  716. mask = rng.choice([1, 0], X.shape, p=[0.1, 0.9]).astype(bool)
  717. X[mask] = np.nan
  718. lr = make_pipeline(SimpleImputer(), LogisticRegression(random_state=rng))
  719. MultiClassClassifier(lr).fit(X, y).score(X, y)
  720. @pytest.mark.parametrize("make_y", [np.ones, np.zeros])
  721. def test_constant_int_target(make_y):
  722. """Check that constant y target does not raise.
  723. Non-regression test for #21869
  724. """
  725. X = np.ones((10, 2))
  726. y = make_y((10, 1), dtype=np.int32)
  727. ovr = OneVsRestClassifier(LogisticRegression())
  728. ovr.fit(X, y)
  729. y_pred = ovr.predict_proba(X)
  730. expected = np.zeros((X.shape[0], 2))
  731. expected[:, 0] = 1
  732. assert_allclose(y_pred, expected)
  733. def test_ovo_consistent_binary_classification():
  734. """Check that ovo is consistent with binary classifier.
  735. Non-regression test for #13617.
  736. """
  737. X, y = load_breast_cancer(return_X_y=True)
  738. clf = KNeighborsClassifier(n_neighbors=8, weights="distance")
  739. ovo = OneVsOneClassifier(clf)
  740. clf.fit(X, y)
  741. ovo.fit(X, y)
  742. assert_array_equal(clf.predict(X), ovo.predict(X))