class_weight.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. # Authors: Andreas Mueller
  2. # Manoj Kumar
  3. # License: BSD 3 clause
  4. import numpy as np
  5. from scipy import sparse
  6. def compute_class_weight(class_weight, *, classes, y):
  7. """Estimate class weights for unbalanced datasets.
  8. Parameters
  9. ----------
  10. class_weight : dict, 'balanced' or None
  11. If 'balanced', class weights will be given by
  12. ``n_samples / (n_classes * np.bincount(y))``.
  13. If a dictionary is given, keys are classes and values
  14. are corresponding class weights.
  15. If None is given, the class weights will be uniform.
  16. classes : ndarray
  17. Array of the classes occurring in the data, as given by
  18. ``np.unique(y_org)`` with ``y_org`` the original class labels.
  19. y : array-like of shape (n_samples,)
  20. Array of original class labels per sample.
  21. Returns
  22. -------
  23. class_weight_vect : ndarray of shape (n_classes,)
  24. Array with class_weight_vect[i] the weight for i-th class.
  25. References
  26. ----------
  27. The "balanced" heuristic is inspired by
  28. Logistic Regression in Rare Events Data, King, Zen, 2001.
  29. """
  30. # Import error caused by circular imports.
  31. from ..preprocessing import LabelEncoder
  32. if set(y) - set(classes):
  33. raise ValueError("classes should include all valid labels that can be in y")
  34. if class_weight is None or len(class_weight) == 0:
  35. # uniform class weights
  36. weight = np.ones(classes.shape[0], dtype=np.float64, order="C")
  37. elif class_weight == "balanced":
  38. # Find the weight of each class as present in y.
  39. le = LabelEncoder()
  40. y_ind = le.fit_transform(y)
  41. if not all(np.isin(classes, le.classes_)):
  42. raise ValueError("classes should have valid labels that are in y")
  43. recip_freq = len(y) / (len(le.classes_) * np.bincount(y_ind).astype(np.float64))
  44. weight = recip_freq[le.transform(classes)]
  45. else:
  46. # user-defined dictionary
  47. weight = np.ones(classes.shape[0], dtype=np.float64, order="C")
  48. if not isinstance(class_weight, dict):
  49. raise ValueError(
  50. "class_weight must be dict, 'balanced', or None, got: %r" % class_weight
  51. )
  52. unweighted_classes = []
  53. for i, c in enumerate(classes):
  54. if c in class_weight:
  55. weight[i] = class_weight[c]
  56. else:
  57. unweighted_classes.append(c)
  58. n_weighted_classes = len(classes) - len(unweighted_classes)
  59. if unweighted_classes and n_weighted_classes != len(class_weight):
  60. unweighted_classes_user_friendly_str = np.array(unweighted_classes).tolist()
  61. raise ValueError(
  62. f"The classes, {unweighted_classes_user_friendly_str}, are not in"
  63. " class_weight"
  64. )
  65. return weight
  66. def compute_sample_weight(class_weight, y, *, indices=None):
  67. """Estimate sample weights by class for unbalanced datasets.
  68. Parameters
  69. ----------
  70. class_weight : dict, list of dicts, "balanced", or None
  71. Weights associated with classes in the form ``{class_label: weight}``.
  72. If not given, all classes are supposed to have weight one. For
  73. multi-output problems, a list of dicts can be provided in the same
  74. order as the columns of y.
  75. Note that for multioutput (including multilabel) weights should be
  76. defined for each class of every column in its own dict. For example,
  77. for four-class multilabel classification weights should be
  78. [{0: 1, 1: 1}, {0: 1, 1: 5}, {0: 1, 1: 1}, {0: 1, 1: 1}] instead of
  79. [{1:1}, {2:5}, {3:1}, {4:1}].
  80. The "balanced" mode uses the values of y to automatically adjust
  81. weights inversely proportional to class frequencies in the input data:
  82. ``n_samples / (n_classes * np.bincount(y))``.
  83. For multi-output, the weights of each column of y will be multiplied.
  84. y : {array-like, sparse matrix} of shape (n_samples,) or (n_samples, n_outputs)
  85. Array of original class labels per sample.
  86. indices : array-like of shape (n_subsample,), default=None
  87. Array of indices to be used in a subsample. Can be of length less than
  88. n_samples in the case of a subsample, or equal to n_samples in the
  89. case of a bootstrap subsample with repeated indices. If None, the
  90. sample weight will be calculated over the full sample. Only "balanced"
  91. is supported for class_weight if this is provided.
  92. Returns
  93. -------
  94. sample_weight_vect : ndarray of shape (n_samples,)
  95. Array with sample weights as applied to the original y.
  96. """
  97. # Ensure y is 2D. Sparse matrices are already 2D.
  98. if not sparse.issparse(y):
  99. y = np.atleast_1d(y)
  100. if y.ndim == 1:
  101. y = np.reshape(y, (-1, 1))
  102. n_outputs = y.shape[1]
  103. if isinstance(class_weight, str):
  104. if class_weight not in ["balanced"]:
  105. raise ValueError(
  106. 'The only valid preset for class_weight is "balanced". Given "%s".'
  107. % class_weight
  108. )
  109. elif indices is not None and not isinstance(class_weight, str):
  110. raise ValueError(
  111. 'The only valid class_weight for subsampling is "balanced". Given "%s".'
  112. % class_weight
  113. )
  114. elif n_outputs > 1:
  115. if not hasattr(class_weight, "__iter__") or isinstance(class_weight, dict):
  116. raise ValueError(
  117. "For multi-output, class_weight should be a "
  118. "list of dicts, or a valid string."
  119. )
  120. if len(class_weight) != n_outputs:
  121. raise ValueError(
  122. "For multi-output, number of elements in "
  123. "class_weight should match number of outputs."
  124. )
  125. expanded_class_weight = []
  126. for k in range(n_outputs):
  127. y_full = y[:, k]
  128. if sparse.issparse(y_full):
  129. # Ok to densify a single column at a time
  130. y_full = y_full.toarray().flatten()
  131. classes_full = np.unique(y_full)
  132. classes_missing = None
  133. if class_weight == "balanced" or n_outputs == 1:
  134. class_weight_k = class_weight
  135. else:
  136. class_weight_k = class_weight[k]
  137. if indices is not None:
  138. # Get class weights for the subsample, covering all classes in
  139. # case some labels that were present in the original data are
  140. # missing from the sample.
  141. y_subsample = y_full[indices]
  142. classes_subsample = np.unique(y_subsample)
  143. weight_k = np.take(
  144. compute_class_weight(
  145. class_weight_k, classes=classes_subsample, y=y_subsample
  146. ),
  147. np.searchsorted(classes_subsample, classes_full),
  148. mode="clip",
  149. )
  150. classes_missing = set(classes_full) - set(classes_subsample)
  151. else:
  152. weight_k = compute_class_weight(
  153. class_weight_k, classes=classes_full, y=y_full
  154. )
  155. weight_k = weight_k[np.searchsorted(classes_full, y_full)]
  156. if classes_missing:
  157. # Make missing classes' weight zero
  158. weight_k[np.isin(y_full, list(classes_missing))] = 0.0
  159. expanded_class_weight.append(weight_k)
  160. expanded_class_weight = np.prod(expanded_class_weight, axis=0, dtype=np.float64)
  161. return expanded_class_weight