multiclass.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546
  1. # Author: Arnaud Joly, Joel Nothman, Hamzeh Alsalhi
  2. #
  3. # License: BSD 3 clause
  4. """
  5. Multi-class / multi-label utility function
  6. ==========================================
  7. """
  8. import warnings
  9. from collections.abc import Sequence
  10. from itertools import chain
  11. import numpy as np
  12. from scipy.sparse import issparse
  13. from ..utils._array_api import get_namespace
  14. from ..utils.fixes import VisibleDeprecationWarning
  15. from .validation import _assert_all_finite, check_array
  16. def _unique_multiclass(y):
  17. xp, is_array_api_compliant = get_namespace(y)
  18. if hasattr(y, "__array__") or is_array_api_compliant:
  19. return xp.unique_values(xp.asarray(y))
  20. else:
  21. return set(y)
  22. def _unique_indicator(y):
  23. return np.arange(
  24. check_array(y, input_name="y", accept_sparse=["csr", "csc", "coo"]).shape[1]
  25. )
  26. _FN_UNIQUE_LABELS = {
  27. "binary": _unique_multiclass,
  28. "multiclass": _unique_multiclass,
  29. "multilabel-indicator": _unique_indicator,
  30. }
  31. def unique_labels(*ys):
  32. """Extract an ordered array of unique labels.
  33. We don't allow:
  34. - mix of multilabel and multiclass (single label) targets
  35. - mix of label indicator matrix and anything else,
  36. because there are no explicit labels)
  37. - mix of label indicator matrices of different sizes
  38. - mix of string and integer labels
  39. At the moment, we also don't allow "multiclass-multioutput" input type.
  40. Parameters
  41. ----------
  42. *ys : array-likes
  43. Label values.
  44. Returns
  45. -------
  46. out : ndarray of shape (n_unique_labels,)
  47. An ordered array of unique labels.
  48. Examples
  49. --------
  50. >>> from sklearn.utils.multiclass import unique_labels
  51. >>> unique_labels([3, 5, 5, 5, 7, 7])
  52. array([3, 5, 7])
  53. >>> unique_labels([1, 2, 3, 4], [2, 2, 3, 4])
  54. array([1, 2, 3, 4])
  55. >>> unique_labels([1, 2, 10], [5, 11])
  56. array([ 1, 2, 5, 10, 11])
  57. """
  58. xp, is_array_api_compliant = get_namespace(*ys)
  59. if not ys:
  60. raise ValueError("No argument has been passed.")
  61. # Check that we don't mix label format
  62. ys_types = set(type_of_target(x) for x in ys)
  63. if ys_types == {"binary", "multiclass"}:
  64. ys_types = {"multiclass"}
  65. if len(ys_types) > 1:
  66. raise ValueError("Mix type of y not allowed, got types %s" % ys_types)
  67. label_type = ys_types.pop()
  68. # Check consistency for the indicator format
  69. if (
  70. label_type == "multilabel-indicator"
  71. and len(
  72. set(
  73. check_array(y, accept_sparse=["csr", "csc", "coo"]).shape[1] for y in ys
  74. )
  75. )
  76. > 1
  77. ):
  78. raise ValueError(
  79. "Multi-label binary indicator input with different numbers of labels"
  80. )
  81. # Get the unique set of labels
  82. _unique_labels = _FN_UNIQUE_LABELS.get(label_type, None)
  83. if not _unique_labels:
  84. raise ValueError("Unknown label type: %s" % repr(ys))
  85. if is_array_api_compliant:
  86. # array_api does not allow for mixed dtypes
  87. unique_ys = xp.concat([_unique_labels(y) for y in ys])
  88. return xp.unique_values(unique_ys)
  89. ys_labels = set(chain.from_iterable((i for i in _unique_labels(y)) for y in ys))
  90. # Check that we don't mix string type with number type
  91. if len(set(isinstance(label, str) for label in ys_labels)) > 1:
  92. raise ValueError("Mix of label input types (string and number)")
  93. return xp.asarray(sorted(ys_labels))
  94. def _is_integral_float(y):
  95. return y.dtype.kind == "f" and np.all(y.astype(int) == y)
  96. def is_multilabel(y):
  97. """Check if ``y`` is in a multilabel format.
  98. Parameters
  99. ----------
  100. y : ndarray of shape (n_samples,)
  101. Target values.
  102. Returns
  103. -------
  104. out : bool
  105. Return ``True``, if ``y`` is in a multilabel format, else ```False``.
  106. Examples
  107. --------
  108. >>> import numpy as np
  109. >>> from sklearn.utils.multiclass import is_multilabel
  110. >>> is_multilabel([0, 1, 0, 1])
  111. False
  112. >>> is_multilabel([[1], [0, 2], []])
  113. False
  114. >>> is_multilabel(np.array([[1, 0], [0, 0]]))
  115. True
  116. >>> is_multilabel(np.array([[1], [0], [0]]))
  117. False
  118. >>> is_multilabel(np.array([[1, 0, 0]]))
  119. True
  120. """
  121. xp, is_array_api_compliant = get_namespace(y)
  122. if hasattr(y, "__array__") or isinstance(y, Sequence) or is_array_api_compliant:
  123. # DeprecationWarning will be replaced by ValueError, see NEP 34
  124. # https://numpy.org/neps/nep-0034-infer-dtype-is-object.html
  125. check_y_kwargs = dict(
  126. accept_sparse=True,
  127. allow_nd=True,
  128. force_all_finite=False,
  129. ensure_2d=False,
  130. ensure_min_samples=0,
  131. ensure_min_features=0,
  132. )
  133. with warnings.catch_warnings():
  134. warnings.simplefilter("error", VisibleDeprecationWarning)
  135. try:
  136. y = check_array(y, dtype=None, **check_y_kwargs)
  137. except (VisibleDeprecationWarning, ValueError) as e:
  138. if str(e).startswith("Complex data not supported"):
  139. raise
  140. # dtype=object should be provided explicitly for ragged arrays,
  141. # see NEP 34
  142. y = check_array(y, dtype=object, **check_y_kwargs)
  143. if not (hasattr(y, "shape") and y.ndim == 2 and y.shape[1] > 1):
  144. return False
  145. if issparse(y):
  146. if y.format in ("dok", "lil"):
  147. y = y.tocsr()
  148. labels = xp.unique_values(y.data)
  149. return (
  150. len(y.data) == 0
  151. or (labels.size == 1 or (labels.size == 2) and (0 in labels))
  152. and (y.dtype.kind in "biu" or _is_integral_float(labels)) # bool, int, uint
  153. )
  154. else:
  155. labels = xp.unique_values(y)
  156. return len(labels) < 3 and (
  157. y.dtype.kind in "biu" or _is_integral_float(labels) # bool, int, uint
  158. )
  159. def check_classification_targets(y):
  160. """Ensure that target y is of a non-regression type.
  161. Only the following target types (as defined in type_of_target) are allowed:
  162. 'binary', 'multiclass', 'multiclass-multioutput',
  163. 'multilabel-indicator', 'multilabel-sequences'
  164. Parameters
  165. ----------
  166. y : array-like
  167. Target values.
  168. """
  169. y_type = type_of_target(y, input_name="y")
  170. if y_type not in [
  171. "binary",
  172. "multiclass",
  173. "multiclass-multioutput",
  174. "multilabel-indicator",
  175. "multilabel-sequences",
  176. ]:
  177. raise ValueError(
  178. f"Unknown label type: {y_type}. Maybe you are trying to fit a "
  179. "classifier, which expects discrete classes on a "
  180. "regression target with continuous values."
  181. )
  182. def type_of_target(y, input_name=""):
  183. """Determine the type of data indicated by the target.
  184. Note that this type is the most specific type that can be inferred.
  185. For example:
  186. * ``binary`` is more specific but compatible with ``multiclass``.
  187. * ``multiclass`` of integers is more specific but compatible with
  188. ``continuous``.
  189. * ``multilabel-indicator`` is more specific but compatible with
  190. ``multiclass-multioutput``.
  191. Parameters
  192. ----------
  193. y : {array-like, sparse matrix}
  194. Target values. If a sparse matrix, `y` is expected to be a
  195. CSR/CSC matrix.
  196. input_name : str, default=""
  197. The data name used to construct the error message.
  198. .. versionadded:: 1.1.0
  199. Returns
  200. -------
  201. target_type : str
  202. One of:
  203. * 'continuous': `y` is an array-like of floats that are not all
  204. integers, and is 1d or a column vector.
  205. * 'continuous-multioutput': `y` is a 2d array of floats that are
  206. not all integers, and both dimensions are of size > 1.
  207. * 'binary': `y` contains <= 2 discrete values and is 1d or a column
  208. vector.
  209. * 'multiclass': `y` contains more than two discrete values, is not a
  210. sequence of sequences, and is 1d or a column vector.
  211. * 'multiclass-multioutput': `y` is a 2d array that contains more
  212. than two discrete values, is not a sequence of sequences, and both
  213. dimensions are of size > 1.
  214. * 'multilabel-indicator': `y` is a label indicator matrix, an array
  215. of two dimensions with at least two columns, and at most 2 unique
  216. values.
  217. * 'unknown': `y` is array-like but none of the above, such as a 3d
  218. array, sequence of sequences, or an array of non-sequence objects.
  219. Examples
  220. --------
  221. >>> from sklearn.utils.multiclass import type_of_target
  222. >>> import numpy as np
  223. >>> type_of_target([0.1, 0.6])
  224. 'continuous'
  225. >>> type_of_target([1, -1, -1, 1])
  226. 'binary'
  227. >>> type_of_target(['a', 'b', 'a'])
  228. 'binary'
  229. >>> type_of_target([1.0, 2.0])
  230. 'binary'
  231. >>> type_of_target([1, 0, 2])
  232. 'multiclass'
  233. >>> type_of_target([1.0, 0.0, 3.0])
  234. 'multiclass'
  235. >>> type_of_target(['a', 'b', 'c'])
  236. 'multiclass'
  237. >>> type_of_target(np.array([[1, 2], [3, 1]]))
  238. 'multiclass-multioutput'
  239. >>> type_of_target([[1, 2]])
  240. 'multilabel-indicator'
  241. >>> type_of_target(np.array([[1.5, 2.0], [3.0, 1.6]]))
  242. 'continuous-multioutput'
  243. >>> type_of_target(np.array([[0, 1], [1, 1]]))
  244. 'multilabel-indicator'
  245. """
  246. xp, is_array_api_compliant = get_namespace(y)
  247. valid = (
  248. (isinstance(y, Sequence) or issparse(y) or hasattr(y, "__array__"))
  249. and not isinstance(y, str)
  250. or is_array_api_compliant
  251. )
  252. if not valid:
  253. raise ValueError(
  254. "Expected array-like (array or non-string sequence), got %r" % y
  255. )
  256. sparse_pandas = y.__class__.__name__ in ["SparseSeries", "SparseArray"]
  257. if sparse_pandas:
  258. raise ValueError("y cannot be class 'SparseSeries' or 'SparseArray'")
  259. if is_multilabel(y):
  260. return "multilabel-indicator"
  261. # DeprecationWarning will be replaced by ValueError, see NEP 34
  262. # https://numpy.org/neps/nep-0034-infer-dtype-is-object.html
  263. # We therefore catch both deprecation (NumPy < 1.24) warning and
  264. # value error (NumPy >= 1.24).
  265. check_y_kwargs = dict(
  266. accept_sparse=True,
  267. allow_nd=True,
  268. force_all_finite=False,
  269. ensure_2d=False,
  270. ensure_min_samples=0,
  271. ensure_min_features=0,
  272. )
  273. with warnings.catch_warnings():
  274. warnings.simplefilter("error", VisibleDeprecationWarning)
  275. if not issparse(y):
  276. try:
  277. y = check_array(y, dtype=None, **check_y_kwargs)
  278. except (VisibleDeprecationWarning, ValueError) as e:
  279. if str(e).startswith("Complex data not supported"):
  280. raise
  281. # dtype=object should be provided explicitly for ragged arrays,
  282. # see NEP 34
  283. y = check_array(y, dtype=object, **check_y_kwargs)
  284. # The old sequence of sequences format
  285. try:
  286. if (
  287. not hasattr(y[0], "__array__")
  288. and isinstance(y[0], Sequence)
  289. and not isinstance(y[0], str)
  290. ):
  291. raise ValueError(
  292. "You appear to be using a legacy multi-label data"
  293. " representation. Sequence of sequences are no"
  294. " longer supported; use a binary array or sparse"
  295. " matrix instead - the MultiLabelBinarizer"
  296. " transformer can convert to this format."
  297. )
  298. except IndexError:
  299. pass
  300. # Invalid inputs
  301. if y.ndim not in (1, 2):
  302. # Number of dimension greater than 2: [[[1, 2]]]
  303. return "unknown"
  304. if not min(y.shape):
  305. # Empty ndarray: []/[[]]
  306. if y.ndim == 1:
  307. # 1-D empty array: []
  308. return "binary" # []
  309. # 2-D empty array: [[]]
  310. return "unknown"
  311. if not issparse(y) and y.dtype == object and not isinstance(y.flat[0], str):
  312. # [obj_1] and not ["label_1"]
  313. return "unknown"
  314. # Check if multioutput
  315. if y.ndim == 2 and y.shape[1] > 1:
  316. suffix = "-multioutput" # [[1, 2], [1, 2]]
  317. else:
  318. suffix = "" # [1, 2, 3] or [[1], [2], [3]]
  319. # Check float and contains non-integer float values
  320. if xp.isdtype(y.dtype, "real floating"):
  321. # [.1, .2, 3] or [[.1, .2, 3]] or [[1., .2]] and not [1., 2., 3.]
  322. data = y.data if issparse(y) else y
  323. if xp.any(data != xp.astype(data, int)):
  324. _assert_all_finite(data, input_name=input_name)
  325. return "continuous" + suffix
  326. # Check multiclass
  327. first_row = y[0] if not issparse(y) else y.getrow(0).data
  328. if xp.unique_values(y).shape[0] > 2 or (y.ndim == 2 and len(first_row) > 1):
  329. # [1, 2, 3] or [[1., 2., 3]] or [[1, 2]]
  330. return "multiclass" + suffix
  331. else:
  332. return "binary" # [1, 2] or [["a"], ["b"]]
  333. def _check_partial_fit_first_call(clf, classes=None):
  334. """Private helper function for factorizing common classes param logic.
  335. Estimators that implement the ``partial_fit`` API need to be provided with
  336. the list of possible classes at the first call to partial_fit.
  337. Subsequent calls to partial_fit should check that ``classes`` is still
  338. consistent with a previous value of ``clf.classes_`` when provided.
  339. This function returns True if it detects that this was the first call to
  340. ``partial_fit`` on ``clf``. In that case the ``classes_`` attribute is also
  341. set on ``clf``.
  342. """
  343. if getattr(clf, "classes_", None) is None and classes is None:
  344. raise ValueError("classes must be passed on the first call to partial_fit.")
  345. elif classes is not None:
  346. if getattr(clf, "classes_", None) is not None:
  347. if not np.array_equal(clf.classes_, unique_labels(classes)):
  348. raise ValueError(
  349. "`classes=%r` is not the same as on last call "
  350. "to partial_fit, was: %r" % (classes, clf.classes_)
  351. )
  352. else:
  353. # This is the first call to partial_fit
  354. clf.classes_ = unique_labels(classes)
  355. return True
  356. # classes is None and clf.classes_ has already previously been set:
  357. # nothing to do
  358. return False
  359. def class_distribution(y, sample_weight=None):
  360. """Compute class priors from multioutput-multiclass target data.
  361. Parameters
  362. ----------
  363. y : {array-like, sparse matrix} of size (n_samples, n_outputs)
  364. The labels for each example.
  365. sample_weight : array-like of shape (n_samples,), default=None
  366. Sample weights.
  367. Returns
  368. -------
  369. classes : list of size n_outputs of ndarray of size (n_classes,)
  370. List of classes for each column.
  371. n_classes : list of int of size n_outputs
  372. Number of classes in each column.
  373. class_prior : list of size n_outputs of ndarray of size (n_classes,)
  374. Class distribution of each column.
  375. """
  376. classes = []
  377. n_classes = []
  378. class_prior = []
  379. n_samples, n_outputs = y.shape
  380. if sample_weight is not None:
  381. sample_weight = np.asarray(sample_weight)
  382. if issparse(y):
  383. y = y.tocsc()
  384. y_nnz = np.diff(y.indptr)
  385. for k in range(n_outputs):
  386. col_nonzero = y.indices[y.indptr[k] : y.indptr[k + 1]]
  387. # separate sample weights for zero and non-zero elements
  388. if sample_weight is not None:
  389. nz_samp_weight = sample_weight[col_nonzero]
  390. zeros_samp_weight_sum = np.sum(sample_weight) - np.sum(nz_samp_weight)
  391. else:
  392. nz_samp_weight = None
  393. zeros_samp_weight_sum = y.shape[0] - y_nnz[k]
  394. classes_k, y_k = np.unique(
  395. y.data[y.indptr[k] : y.indptr[k + 1]], return_inverse=True
  396. )
  397. class_prior_k = np.bincount(y_k, weights=nz_samp_weight)
  398. # An explicit zero was found, combine its weight with the weight
  399. # of the implicit zeros
  400. if 0 in classes_k:
  401. class_prior_k[classes_k == 0] += zeros_samp_weight_sum
  402. # If an there is an implicit zero and it is not in classes and
  403. # class_prior, make an entry for it
  404. if 0 not in classes_k and y_nnz[k] < y.shape[0]:
  405. classes_k = np.insert(classes_k, 0, 0)
  406. class_prior_k = np.insert(class_prior_k, 0, zeros_samp_weight_sum)
  407. classes.append(classes_k)
  408. n_classes.append(classes_k.shape[0])
  409. class_prior.append(class_prior_k / class_prior_k.sum())
  410. else:
  411. for k in range(n_outputs):
  412. classes_k, y_k = np.unique(y[:, k], return_inverse=True)
  413. classes.append(classes_k)
  414. n_classes.append(classes_k.shape[0])
  415. class_prior_k = np.bincount(y_k, weights=sample_weight)
  416. class_prior.append(class_prior_k / class_prior_k.sum())
  417. return (classes, n_classes, class_prior)
  418. def _ovr_decision_function(predictions, confidences, n_classes):
  419. """Compute a continuous, tie-breaking OvR decision function from OvO.
  420. It is important to include a continuous value, not only votes,
  421. to make computing AUC or calibration meaningful.
  422. Parameters
  423. ----------
  424. predictions : array-like of shape (n_samples, n_classifiers)
  425. Predicted classes for each binary classifier.
  426. confidences : array-like of shape (n_samples, n_classifiers)
  427. Decision functions or predicted probabilities for positive class
  428. for each binary classifier.
  429. n_classes : int
  430. Number of classes. n_classifiers must be
  431. ``n_classes * (n_classes - 1 ) / 2``.
  432. """
  433. n_samples = predictions.shape[0]
  434. votes = np.zeros((n_samples, n_classes))
  435. sum_of_confidences = np.zeros((n_samples, n_classes))
  436. k = 0
  437. for i in range(n_classes):
  438. for j in range(i + 1, n_classes):
  439. sum_of_confidences[:, i] -= confidences[:, k]
  440. sum_of_confidences[:, j] += confidences[:, k]
  441. votes[predictions[:, k] == 0, i] += 1
  442. votes[predictions[:, k] == 1, j] += 1
  443. k += 1
  444. # Monotonically transform the sum_of_confidences to (-1/3, 1/3)
  445. # and add it with votes. The monotonic transformation is
  446. # f: x -> x / (3 * (|x| + 1)), it uses 1/3 instead of 1/2
  447. # to ensure that we won't reach the limits and change vote order.
  448. # The motivation is to use confidence levels as a way to break ties in
  449. # the votes without switching any decision made based on a difference
  450. # of 1 vote.
  451. transformed_confidences = sum_of_confidences / (
  452. 3 * (np.abs(sum_of_confidences) + 1)
  453. )
  454. return votes + transformed_confidences