_label.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951
  1. # Authors: Alexandre Gramfort <alexandre.gramfort@inria.fr>
  2. # Mathieu Blondel <mathieu@mblondel.org>
  3. # Olivier Grisel <olivier.grisel@ensta.org>
  4. # Andreas Mueller <amueller@ais.uni-bonn.de>
  5. # Joel Nothman <joel.nothman@gmail.com>
  6. # Hamzeh Alsalhi <ha258@cornell.edu>
  7. # License: BSD 3 clause
  8. import array
  9. import itertools
  10. import warnings
  11. from collections import defaultdict
  12. from numbers import Integral
  13. import numpy as np
  14. import scipy.sparse as sp
  15. from ..base import BaseEstimator, TransformerMixin, _fit_context
  16. from ..utils import column_or_1d
  17. from ..utils._encode import _encode, _unique
  18. from ..utils._param_validation import Interval, validate_params
  19. from ..utils.multiclass import type_of_target, unique_labels
  20. from ..utils.sparsefuncs import min_max_axis
  21. from ..utils.validation import _num_samples, check_array, check_is_fitted
  22. __all__ = [
  23. "label_binarize",
  24. "LabelBinarizer",
  25. "LabelEncoder",
  26. "MultiLabelBinarizer",
  27. ]
  28. class LabelEncoder(TransformerMixin, BaseEstimator, auto_wrap_output_keys=None):
  29. """Encode target labels with value between 0 and n_classes-1.
  30. This transformer should be used to encode target values, *i.e.* `y`, and
  31. not the input `X`.
  32. Read more in the :ref:`User Guide <preprocessing_targets>`.
  33. .. versionadded:: 0.12
  34. Attributes
  35. ----------
  36. classes_ : ndarray of shape (n_classes,)
  37. Holds the label for each class.
  38. See Also
  39. --------
  40. OrdinalEncoder : Encode categorical features using an ordinal encoding
  41. scheme.
  42. OneHotEncoder : Encode categorical features as a one-hot numeric array.
  43. Examples
  44. --------
  45. `LabelEncoder` can be used to normalize labels.
  46. >>> from sklearn.preprocessing import LabelEncoder
  47. >>> le = LabelEncoder()
  48. >>> le.fit([1, 2, 2, 6])
  49. LabelEncoder()
  50. >>> le.classes_
  51. array([1, 2, 6])
  52. >>> le.transform([1, 1, 2, 6])
  53. array([0, 0, 1, 2]...)
  54. >>> le.inverse_transform([0, 0, 1, 2])
  55. array([1, 1, 2, 6])
  56. It can also be used to transform non-numerical labels (as long as they are
  57. hashable and comparable) to numerical labels.
  58. >>> le = LabelEncoder()
  59. >>> le.fit(["paris", "paris", "tokyo", "amsterdam"])
  60. LabelEncoder()
  61. >>> list(le.classes_)
  62. ['amsterdam', 'paris', 'tokyo']
  63. >>> le.transform(["tokyo", "tokyo", "paris"])
  64. array([2, 2, 1]...)
  65. >>> list(le.inverse_transform([2, 2, 1]))
  66. ['tokyo', 'tokyo', 'paris']
  67. """
  68. def fit(self, y):
  69. """Fit label encoder.
  70. Parameters
  71. ----------
  72. y : array-like of shape (n_samples,)
  73. Target values.
  74. Returns
  75. -------
  76. self : returns an instance of self.
  77. Fitted label encoder.
  78. """
  79. y = column_or_1d(y, warn=True)
  80. self.classes_ = _unique(y)
  81. return self
  82. def fit_transform(self, y):
  83. """Fit label encoder and return encoded labels.
  84. Parameters
  85. ----------
  86. y : array-like of shape (n_samples,)
  87. Target values.
  88. Returns
  89. -------
  90. y : array-like of shape (n_samples,)
  91. Encoded labels.
  92. """
  93. y = column_or_1d(y, warn=True)
  94. self.classes_, y = _unique(y, return_inverse=True)
  95. return y
  96. def transform(self, y):
  97. """Transform labels to normalized encoding.
  98. Parameters
  99. ----------
  100. y : array-like of shape (n_samples,)
  101. Target values.
  102. Returns
  103. -------
  104. y : array-like of shape (n_samples,)
  105. Labels as normalized encodings.
  106. """
  107. check_is_fitted(self)
  108. y = column_or_1d(y, dtype=self.classes_.dtype, warn=True)
  109. # transform of empty array is empty array
  110. if _num_samples(y) == 0:
  111. return np.array([])
  112. return _encode(y, uniques=self.classes_)
  113. def inverse_transform(self, y):
  114. """Transform labels back to original encoding.
  115. Parameters
  116. ----------
  117. y : ndarray of shape (n_samples,)
  118. Target values.
  119. Returns
  120. -------
  121. y : ndarray of shape (n_samples,)
  122. Original encoding.
  123. """
  124. check_is_fitted(self)
  125. y = column_or_1d(y, warn=True)
  126. # inverse transform of empty array is empty array
  127. if _num_samples(y) == 0:
  128. return np.array([])
  129. diff = np.setdiff1d(y, np.arange(len(self.classes_)))
  130. if len(diff):
  131. raise ValueError("y contains previously unseen labels: %s" % str(diff))
  132. y = np.asarray(y)
  133. return self.classes_[y]
  134. def _more_tags(self):
  135. return {"X_types": ["1dlabels"]}
  136. class LabelBinarizer(TransformerMixin, BaseEstimator, auto_wrap_output_keys=None):
  137. """Binarize labels in a one-vs-all fashion.
  138. Several regression and binary classification algorithms are
  139. available in scikit-learn. A simple way to extend these algorithms
  140. to the multi-class classification case is to use the so-called
  141. one-vs-all scheme.
  142. At learning time, this simply consists in learning one regressor
  143. or binary classifier per class. In doing so, one needs to convert
  144. multi-class labels to binary labels (belong or does not belong
  145. to the class). `LabelBinarizer` makes this process easy with the
  146. transform method.
  147. At prediction time, one assigns the class for which the corresponding
  148. model gave the greatest confidence. `LabelBinarizer` makes this easy
  149. with the :meth:`inverse_transform` method.
  150. Read more in the :ref:`User Guide <preprocessing_targets>`.
  151. Parameters
  152. ----------
  153. neg_label : int, default=0
  154. Value with which negative labels must be encoded.
  155. pos_label : int, default=1
  156. Value with which positive labels must be encoded.
  157. sparse_output : bool, default=False
  158. True if the returned array from transform is desired to be in sparse
  159. CSR format.
  160. Attributes
  161. ----------
  162. classes_ : ndarray of shape (n_classes,)
  163. Holds the label for each class.
  164. y_type_ : str
  165. Represents the type of the target data as evaluated by
  166. :func:`~sklearn.utils.multiclass.type_of_target`. Possible type are
  167. 'continuous', 'continuous-multioutput', 'binary', 'multiclass',
  168. 'multiclass-multioutput', 'multilabel-indicator', and 'unknown'.
  169. sparse_input_ : bool
  170. `True` if the input data to transform is given as a sparse matrix,
  171. `False` otherwise.
  172. See Also
  173. --------
  174. label_binarize : Function to perform the transform operation of
  175. LabelBinarizer with fixed classes.
  176. OneHotEncoder : Encode categorical features using a one-hot aka one-of-K
  177. scheme.
  178. Examples
  179. --------
  180. >>> from sklearn.preprocessing import LabelBinarizer
  181. >>> lb = LabelBinarizer()
  182. >>> lb.fit([1, 2, 6, 4, 2])
  183. LabelBinarizer()
  184. >>> lb.classes_
  185. array([1, 2, 4, 6])
  186. >>> lb.transform([1, 6])
  187. array([[1, 0, 0, 0],
  188. [0, 0, 0, 1]])
  189. Binary targets transform to a column vector
  190. >>> lb = LabelBinarizer()
  191. >>> lb.fit_transform(['yes', 'no', 'no', 'yes'])
  192. array([[1],
  193. [0],
  194. [0],
  195. [1]])
  196. Passing a 2D matrix for multilabel classification
  197. >>> import numpy as np
  198. >>> lb.fit(np.array([[0, 1, 1], [1, 0, 0]]))
  199. LabelBinarizer()
  200. >>> lb.classes_
  201. array([0, 1, 2])
  202. >>> lb.transform([0, 1, 2, 1])
  203. array([[1, 0, 0],
  204. [0, 1, 0],
  205. [0, 0, 1],
  206. [0, 1, 0]])
  207. """
  208. _parameter_constraints: dict = {
  209. "neg_label": [Integral],
  210. "pos_label": [Integral],
  211. "sparse_output": ["boolean"],
  212. }
  213. def __init__(self, *, neg_label=0, pos_label=1, sparse_output=False):
  214. self.neg_label = neg_label
  215. self.pos_label = pos_label
  216. self.sparse_output = sparse_output
  217. @_fit_context(prefer_skip_nested_validation=True)
  218. def fit(self, y):
  219. """Fit label binarizer.
  220. Parameters
  221. ----------
  222. y : ndarray of shape (n_samples,) or (n_samples, n_classes)
  223. Target values. The 2-d matrix should only contain 0 and 1,
  224. represents multilabel classification.
  225. Returns
  226. -------
  227. self : object
  228. Returns the instance itself.
  229. """
  230. if self.neg_label >= self.pos_label:
  231. raise ValueError(
  232. f"neg_label={self.neg_label} must be strictly less than "
  233. f"pos_label={self.pos_label}."
  234. )
  235. if self.sparse_output and (self.pos_label == 0 or self.neg_label != 0):
  236. raise ValueError(
  237. "Sparse binarization is only supported with non "
  238. "zero pos_label and zero neg_label, got "
  239. f"pos_label={self.pos_label} and neg_label={self.neg_label}"
  240. )
  241. self.y_type_ = type_of_target(y, input_name="y")
  242. if "multioutput" in self.y_type_:
  243. raise ValueError(
  244. "Multioutput target data is not supported with label binarization"
  245. )
  246. if _num_samples(y) == 0:
  247. raise ValueError("y has 0 samples: %r" % y)
  248. self.sparse_input_ = sp.issparse(y)
  249. self.classes_ = unique_labels(y)
  250. return self
  251. def fit_transform(self, y):
  252. """Fit label binarizer/transform multi-class labels to binary labels.
  253. The output of transform is sometimes referred to as
  254. the 1-of-K coding scheme.
  255. Parameters
  256. ----------
  257. y : {ndarray, sparse matrix} of shape (n_samples,) or \
  258. (n_samples, n_classes)
  259. Target values. The 2-d matrix should only contain 0 and 1,
  260. represents multilabel classification. Sparse matrix can be
  261. CSR, CSC, COO, DOK, or LIL.
  262. Returns
  263. -------
  264. Y : {ndarray, sparse matrix} of shape (n_samples, n_classes)
  265. Shape will be (n_samples, 1) for binary problems. Sparse matrix
  266. will be of CSR format.
  267. """
  268. return self.fit(y).transform(y)
  269. def transform(self, y):
  270. """Transform multi-class labels to binary labels.
  271. The output of transform is sometimes referred to by some authors as
  272. the 1-of-K coding scheme.
  273. Parameters
  274. ----------
  275. y : {array, sparse matrix} of shape (n_samples,) or \
  276. (n_samples, n_classes)
  277. Target values. The 2-d matrix should only contain 0 and 1,
  278. represents multilabel classification. Sparse matrix can be
  279. CSR, CSC, COO, DOK, or LIL.
  280. Returns
  281. -------
  282. Y : {ndarray, sparse matrix} of shape (n_samples, n_classes)
  283. Shape will be (n_samples, 1) for binary problems. Sparse matrix
  284. will be of CSR format.
  285. """
  286. check_is_fitted(self)
  287. y_is_multilabel = type_of_target(y).startswith("multilabel")
  288. if y_is_multilabel and not self.y_type_.startswith("multilabel"):
  289. raise ValueError("The object was not fitted with multilabel input.")
  290. return label_binarize(
  291. y,
  292. classes=self.classes_,
  293. pos_label=self.pos_label,
  294. neg_label=self.neg_label,
  295. sparse_output=self.sparse_output,
  296. )
  297. def inverse_transform(self, Y, threshold=None):
  298. """Transform binary labels back to multi-class labels.
  299. Parameters
  300. ----------
  301. Y : {ndarray, sparse matrix} of shape (n_samples, n_classes)
  302. Target values. All sparse matrices are converted to CSR before
  303. inverse transformation.
  304. threshold : float, default=None
  305. Threshold used in the binary and multi-label cases.
  306. Use 0 when ``Y`` contains the output of :term:`decision_function`
  307. (classifier).
  308. Use 0.5 when ``Y`` contains the output of :term:`predict_proba`.
  309. If None, the threshold is assumed to be half way between
  310. neg_label and pos_label.
  311. Returns
  312. -------
  313. y : {ndarray, sparse matrix} of shape (n_samples,)
  314. Target values. Sparse matrix will be of CSR format.
  315. Notes
  316. -----
  317. In the case when the binary labels are fractional
  318. (probabilistic), :meth:`inverse_transform` chooses the class with the
  319. greatest value. Typically, this allows to use the output of a
  320. linear model's :term:`decision_function` method directly as the input
  321. of :meth:`inverse_transform`.
  322. """
  323. check_is_fitted(self)
  324. if threshold is None:
  325. threshold = (self.pos_label + self.neg_label) / 2.0
  326. if self.y_type_ == "multiclass":
  327. y_inv = _inverse_binarize_multiclass(Y, self.classes_)
  328. else:
  329. y_inv = _inverse_binarize_thresholding(
  330. Y, self.y_type_, self.classes_, threshold
  331. )
  332. if self.sparse_input_:
  333. y_inv = sp.csr_matrix(y_inv)
  334. elif sp.issparse(y_inv):
  335. y_inv = y_inv.toarray()
  336. return y_inv
  337. def _more_tags(self):
  338. return {"X_types": ["1dlabels"]}
  339. @validate_params(
  340. {
  341. "y": ["array-like"],
  342. "classes": ["array-like"],
  343. "neg_label": [Interval(Integral, None, None, closed="neither")],
  344. "pos_label": [Interval(Integral, None, None, closed="neither")],
  345. "sparse_output": ["boolean"],
  346. },
  347. prefer_skip_nested_validation=True,
  348. )
  349. def label_binarize(y, *, classes, neg_label=0, pos_label=1, sparse_output=False):
  350. """Binarize labels in a one-vs-all fashion.
  351. Several regression and binary classification algorithms are
  352. available in scikit-learn. A simple way to extend these algorithms
  353. to the multi-class classification case is to use the so-called
  354. one-vs-all scheme.
  355. This function makes it possible to compute this transformation for a
  356. fixed set of class labels known ahead of time.
  357. Parameters
  358. ----------
  359. y : array-like
  360. Sequence of integer labels or multilabel data to encode.
  361. classes : array-like of shape (n_classes,)
  362. Uniquely holds the label for each class.
  363. neg_label : int, default=0
  364. Value with which negative labels must be encoded.
  365. pos_label : int, default=1
  366. Value with which positive labels must be encoded.
  367. sparse_output : bool, default=False,
  368. Set to true if output binary array is desired in CSR sparse format.
  369. Returns
  370. -------
  371. Y : {ndarray, sparse matrix} of shape (n_samples, n_classes)
  372. Shape will be (n_samples, 1) for binary problems. Sparse matrix will
  373. be of CSR format.
  374. See Also
  375. --------
  376. LabelBinarizer : Class used to wrap the functionality of label_binarize and
  377. allow for fitting to classes independently of the transform operation.
  378. Examples
  379. --------
  380. >>> from sklearn.preprocessing import label_binarize
  381. >>> label_binarize([1, 6], classes=[1, 2, 4, 6])
  382. array([[1, 0, 0, 0],
  383. [0, 0, 0, 1]])
  384. The class ordering is preserved:
  385. >>> label_binarize([1, 6], classes=[1, 6, 4, 2])
  386. array([[1, 0, 0, 0],
  387. [0, 1, 0, 0]])
  388. Binary targets transform to a column vector
  389. >>> label_binarize(['yes', 'no', 'no', 'yes'], classes=['no', 'yes'])
  390. array([[1],
  391. [0],
  392. [0],
  393. [1]])
  394. """
  395. if not isinstance(y, list):
  396. # XXX Workaround that will be removed when list of list format is
  397. # dropped
  398. y = check_array(
  399. y, input_name="y", accept_sparse="csr", ensure_2d=False, dtype=None
  400. )
  401. else:
  402. if _num_samples(y) == 0:
  403. raise ValueError("y has 0 samples: %r" % y)
  404. if neg_label >= pos_label:
  405. raise ValueError(
  406. "neg_label={0} must be strictly less than pos_label={1}.".format(
  407. neg_label, pos_label
  408. )
  409. )
  410. if sparse_output and (pos_label == 0 or neg_label != 0):
  411. raise ValueError(
  412. "Sparse binarization is only supported with non "
  413. "zero pos_label and zero neg_label, got "
  414. "pos_label={0} and neg_label={1}"
  415. "".format(pos_label, neg_label)
  416. )
  417. # To account for pos_label == 0 in the dense case
  418. pos_switch = pos_label == 0
  419. if pos_switch:
  420. pos_label = -neg_label
  421. y_type = type_of_target(y)
  422. if "multioutput" in y_type:
  423. raise ValueError(
  424. "Multioutput target data is not supported with label binarization"
  425. )
  426. if y_type == "unknown":
  427. raise ValueError("The type of target data is not known")
  428. n_samples = y.shape[0] if sp.issparse(y) else len(y)
  429. n_classes = len(classes)
  430. classes = np.asarray(classes)
  431. if y_type == "binary":
  432. if n_classes == 1:
  433. if sparse_output:
  434. return sp.csr_matrix((n_samples, 1), dtype=int)
  435. else:
  436. Y = np.zeros((len(y), 1), dtype=int)
  437. Y += neg_label
  438. return Y
  439. elif len(classes) >= 3:
  440. y_type = "multiclass"
  441. sorted_class = np.sort(classes)
  442. if y_type == "multilabel-indicator":
  443. y_n_classes = y.shape[1] if hasattr(y, "shape") else len(y[0])
  444. if classes.size != y_n_classes:
  445. raise ValueError(
  446. "classes {0} mismatch with the labels {1} found in the data".format(
  447. classes, unique_labels(y)
  448. )
  449. )
  450. if y_type in ("binary", "multiclass"):
  451. y = column_or_1d(y)
  452. # pick out the known labels from y
  453. y_in_classes = np.isin(y, classes)
  454. y_seen = y[y_in_classes]
  455. indices = np.searchsorted(sorted_class, y_seen)
  456. indptr = np.hstack((0, np.cumsum(y_in_classes)))
  457. data = np.empty_like(indices)
  458. data.fill(pos_label)
  459. Y = sp.csr_matrix((data, indices, indptr), shape=(n_samples, n_classes))
  460. elif y_type == "multilabel-indicator":
  461. Y = sp.csr_matrix(y)
  462. if pos_label != 1:
  463. data = np.empty_like(Y.data)
  464. data.fill(pos_label)
  465. Y.data = data
  466. else:
  467. raise ValueError(
  468. "%s target data is not supported with label binarization" % y_type
  469. )
  470. if not sparse_output:
  471. Y = Y.toarray()
  472. Y = Y.astype(int, copy=False)
  473. if neg_label != 0:
  474. Y[Y == 0] = neg_label
  475. if pos_switch:
  476. Y[Y == pos_label] = 0
  477. else:
  478. Y.data = Y.data.astype(int, copy=False)
  479. # preserve label ordering
  480. if np.any(classes != sorted_class):
  481. indices = np.searchsorted(sorted_class, classes)
  482. Y = Y[:, indices]
  483. if y_type == "binary":
  484. if sparse_output:
  485. Y = Y.getcol(-1)
  486. else:
  487. Y = Y[:, -1].reshape((-1, 1))
  488. return Y
  489. def _inverse_binarize_multiclass(y, classes):
  490. """Inverse label binarization transformation for multiclass.
  491. Multiclass uses the maximal score instead of a threshold.
  492. """
  493. classes = np.asarray(classes)
  494. if sp.issparse(y):
  495. # Find the argmax for each row in y where y is a CSR matrix
  496. y = y.tocsr()
  497. n_samples, n_outputs = y.shape
  498. outputs = np.arange(n_outputs)
  499. row_max = min_max_axis(y, 1)[1]
  500. row_nnz = np.diff(y.indptr)
  501. y_data_repeated_max = np.repeat(row_max, row_nnz)
  502. # picks out all indices obtaining the maximum per row
  503. y_i_all_argmax = np.flatnonzero(y_data_repeated_max == y.data)
  504. # For corner case where last row has a max of 0
  505. if row_max[-1] == 0:
  506. y_i_all_argmax = np.append(y_i_all_argmax, [len(y.data)])
  507. # Gets the index of the first argmax in each row from y_i_all_argmax
  508. index_first_argmax = np.searchsorted(y_i_all_argmax, y.indptr[:-1])
  509. # first argmax of each row
  510. y_ind_ext = np.append(y.indices, [0])
  511. y_i_argmax = y_ind_ext[y_i_all_argmax[index_first_argmax]]
  512. # Handle rows of all 0
  513. y_i_argmax[np.where(row_nnz == 0)[0]] = 0
  514. # Handles rows with max of 0 that contain negative numbers
  515. samples = np.arange(n_samples)[(row_nnz > 0) & (row_max.ravel() == 0)]
  516. for i in samples:
  517. ind = y.indices[y.indptr[i] : y.indptr[i + 1]]
  518. y_i_argmax[i] = classes[np.setdiff1d(outputs, ind)][0]
  519. return classes[y_i_argmax]
  520. else:
  521. return classes.take(y.argmax(axis=1), mode="clip")
  522. def _inverse_binarize_thresholding(y, output_type, classes, threshold):
  523. """Inverse label binarization transformation using thresholding."""
  524. if output_type == "binary" and y.ndim == 2 and y.shape[1] > 2:
  525. raise ValueError("output_type='binary', but y.shape = {0}".format(y.shape))
  526. if output_type != "binary" and y.shape[1] != len(classes):
  527. raise ValueError(
  528. "The number of class is not equal to the number of dimension of y."
  529. )
  530. classes = np.asarray(classes)
  531. # Perform thresholding
  532. if sp.issparse(y):
  533. if threshold > 0:
  534. if y.format not in ("csr", "csc"):
  535. y = y.tocsr()
  536. y.data = np.array(y.data > threshold, dtype=int)
  537. y.eliminate_zeros()
  538. else:
  539. y = np.array(y.toarray() > threshold, dtype=int)
  540. else:
  541. y = np.array(y > threshold, dtype=int)
  542. # Inverse transform data
  543. if output_type == "binary":
  544. if sp.issparse(y):
  545. y = y.toarray()
  546. if y.ndim == 2 and y.shape[1] == 2:
  547. return classes[y[:, 1]]
  548. else:
  549. if len(classes) == 1:
  550. return np.repeat(classes[0], len(y))
  551. else:
  552. return classes[y.ravel()]
  553. elif output_type == "multilabel-indicator":
  554. return y
  555. else:
  556. raise ValueError("{0} format is not supported".format(output_type))
  557. class MultiLabelBinarizer(TransformerMixin, BaseEstimator, auto_wrap_output_keys=None):
  558. """Transform between iterable of iterables and a multilabel format.
  559. Although a list of sets or tuples is a very intuitive format for multilabel
  560. data, it is unwieldy to process. This transformer converts between this
  561. intuitive format and the supported multilabel format: a (samples x classes)
  562. binary matrix indicating the presence of a class label.
  563. Parameters
  564. ----------
  565. classes : array-like of shape (n_classes,), default=None
  566. Indicates an ordering for the class labels.
  567. All entries should be unique (cannot contain duplicate classes).
  568. sparse_output : bool, default=False
  569. Set to True if output binary array is desired in CSR sparse format.
  570. Attributes
  571. ----------
  572. classes_ : ndarray of shape (n_classes,)
  573. A copy of the `classes` parameter when provided.
  574. Otherwise it corresponds to the sorted set of classes found
  575. when fitting.
  576. See Also
  577. --------
  578. OneHotEncoder : Encode categorical features using a one-hot aka one-of-K
  579. scheme.
  580. Examples
  581. --------
  582. >>> from sklearn.preprocessing import MultiLabelBinarizer
  583. >>> mlb = MultiLabelBinarizer()
  584. >>> mlb.fit_transform([(1, 2), (3,)])
  585. array([[1, 1, 0],
  586. [0, 0, 1]])
  587. >>> mlb.classes_
  588. array([1, 2, 3])
  589. >>> mlb.fit_transform([{'sci-fi', 'thriller'}, {'comedy'}])
  590. array([[0, 1, 1],
  591. [1, 0, 0]])
  592. >>> list(mlb.classes_)
  593. ['comedy', 'sci-fi', 'thriller']
  594. A common mistake is to pass in a list, which leads to the following issue:
  595. >>> mlb = MultiLabelBinarizer()
  596. >>> mlb.fit(['sci-fi', 'thriller', 'comedy'])
  597. MultiLabelBinarizer()
  598. >>> mlb.classes_
  599. array(['-', 'c', 'd', 'e', 'f', 'h', 'i', 'l', 'm', 'o', 'r', 's', 't',
  600. 'y'], dtype=object)
  601. To correct this, the list of labels should be passed in as:
  602. >>> mlb = MultiLabelBinarizer()
  603. >>> mlb.fit([['sci-fi', 'thriller', 'comedy']])
  604. MultiLabelBinarizer()
  605. >>> mlb.classes_
  606. array(['comedy', 'sci-fi', 'thriller'], dtype=object)
  607. """
  608. _parameter_constraints: dict = {
  609. "classes": ["array-like", None],
  610. "sparse_output": ["boolean"],
  611. }
  612. def __init__(self, *, classes=None, sparse_output=False):
  613. self.classes = classes
  614. self.sparse_output = sparse_output
  615. @_fit_context(prefer_skip_nested_validation=True)
  616. def fit(self, y):
  617. """Fit the label sets binarizer, storing :term:`classes_`.
  618. Parameters
  619. ----------
  620. y : iterable of iterables
  621. A set of labels (any orderable and hashable object) for each
  622. sample. If the `classes` parameter is set, `y` will not be
  623. iterated.
  624. Returns
  625. -------
  626. self : object
  627. Fitted estimator.
  628. """
  629. self._cached_dict = None
  630. if self.classes is None:
  631. classes = sorted(set(itertools.chain.from_iterable(y)))
  632. elif len(set(self.classes)) < len(self.classes):
  633. raise ValueError(
  634. "The classes argument contains duplicate "
  635. "classes. Remove these duplicates before passing "
  636. "them to MultiLabelBinarizer."
  637. )
  638. else:
  639. classes = self.classes
  640. dtype = int if all(isinstance(c, int) for c in classes) else object
  641. self.classes_ = np.empty(len(classes), dtype=dtype)
  642. self.classes_[:] = classes
  643. return self
  644. @_fit_context(prefer_skip_nested_validation=True)
  645. def fit_transform(self, y):
  646. """Fit the label sets binarizer and transform the given label sets.
  647. Parameters
  648. ----------
  649. y : iterable of iterables
  650. A set of labels (any orderable and hashable object) for each
  651. sample. If the `classes` parameter is set, `y` will not be
  652. iterated.
  653. Returns
  654. -------
  655. y_indicator : {ndarray, sparse matrix} of shape (n_samples, n_classes)
  656. A matrix such that `y_indicator[i, j] = 1` iff `classes_[j]`
  657. is in `y[i]`, and 0 otherwise. Sparse matrix will be of CSR
  658. format.
  659. """
  660. if self.classes is not None:
  661. return self.fit(y).transform(y)
  662. self._cached_dict = None
  663. # Automatically increment on new class
  664. class_mapping = defaultdict(int)
  665. class_mapping.default_factory = class_mapping.__len__
  666. yt = self._transform(y, class_mapping)
  667. # sort classes and reorder columns
  668. tmp = sorted(class_mapping, key=class_mapping.get)
  669. # (make safe for tuples)
  670. dtype = int if all(isinstance(c, int) for c in tmp) else object
  671. class_mapping = np.empty(len(tmp), dtype=dtype)
  672. class_mapping[:] = tmp
  673. self.classes_, inverse = np.unique(class_mapping, return_inverse=True)
  674. # ensure yt.indices keeps its current dtype
  675. yt.indices = np.array(inverse[yt.indices], dtype=yt.indices.dtype, copy=False)
  676. if not self.sparse_output:
  677. yt = yt.toarray()
  678. return yt
  679. def transform(self, y):
  680. """Transform the given label sets.
  681. Parameters
  682. ----------
  683. y : iterable of iterables
  684. A set of labels (any orderable and hashable object) for each
  685. sample. If the `classes` parameter is set, `y` will not be
  686. iterated.
  687. Returns
  688. -------
  689. y_indicator : array or CSR matrix, shape (n_samples, n_classes)
  690. A matrix such that `y_indicator[i, j] = 1` iff `classes_[j]` is in
  691. `y[i]`, and 0 otherwise.
  692. """
  693. check_is_fitted(self)
  694. class_to_index = self._build_cache()
  695. yt = self._transform(y, class_to_index)
  696. if not self.sparse_output:
  697. yt = yt.toarray()
  698. return yt
  699. def _build_cache(self):
  700. if self._cached_dict is None:
  701. self._cached_dict = dict(zip(self.classes_, range(len(self.classes_))))
  702. return self._cached_dict
  703. def _transform(self, y, class_mapping):
  704. """Transforms the label sets with a given mapping.
  705. Parameters
  706. ----------
  707. y : iterable of iterables
  708. A set of labels (any orderable and hashable object) for each
  709. sample. If the `classes` parameter is set, `y` will not be
  710. iterated.
  711. class_mapping : Mapping
  712. Maps from label to column index in label indicator matrix.
  713. Returns
  714. -------
  715. y_indicator : sparse matrix of shape (n_samples, n_classes)
  716. Label indicator matrix. Will be of CSR format.
  717. """
  718. indices = array.array("i")
  719. indptr = array.array("i", [0])
  720. unknown = set()
  721. for labels in y:
  722. index = set()
  723. for label in labels:
  724. try:
  725. index.add(class_mapping[label])
  726. except KeyError:
  727. unknown.add(label)
  728. indices.extend(index)
  729. indptr.append(len(indices))
  730. if unknown:
  731. warnings.warn(
  732. "unknown class(es) {0} will be ignored".format(sorted(unknown, key=str))
  733. )
  734. data = np.ones(len(indices), dtype=int)
  735. return sp.csr_matrix(
  736. (data, indices, indptr), shape=(len(indptr) - 1, len(class_mapping))
  737. )
  738. def inverse_transform(self, yt):
  739. """Transform the given indicator matrix into label sets.
  740. Parameters
  741. ----------
  742. yt : {ndarray, sparse matrix} of shape (n_samples, n_classes)
  743. A matrix containing only 1s ands 0s.
  744. Returns
  745. -------
  746. y : list of tuples
  747. The set of labels for each sample such that `y[i]` consists of
  748. `classes_[j]` for each `yt[i, j] == 1`.
  749. """
  750. check_is_fitted(self)
  751. if yt.shape[1] != len(self.classes_):
  752. raise ValueError(
  753. "Expected indicator for {0} classes, but got {1}".format(
  754. len(self.classes_), yt.shape[1]
  755. )
  756. )
  757. if sp.issparse(yt):
  758. yt = yt.tocsr()
  759. if len(yt.data) != 0 and len(np.setdiff1d(yt.data, [0, 1])) > 0:
  760. raise ValueError("Expected only 0s and 1s in label indicator.")
  761. return [
  762. tuple(self.classes_.take(yt.indices[start:end]))
  763. for start, end in zip(yt.indptr[:-1], yt.indptr[1:])
  764. ]
  765. else:
  766. unexpected = np.setdiff1d(yt, [0, 1])
  767. if len(unexpected) > 0:
  768. raise ValueError(
  769. "Expected only 0s and 1s in label indicator. Also got {0}".format(
  770. unexpected
  771. )
  772. )
  773. return [tuple(self.classes_.compress(indicators)) for indicators in yt]
  774. def _more_tags(self):
  775. return {"X_types": ["2dlabels"]}