_encode.py 11 KB


  1. from collections import Counter
  2. from contextlib import suppress
  3. from typing import NamedTuple
  4. import numpy as np
  5. from . import is_scalar_nan
  6. def _unique(values, *, return_inverse=False, return_counts=False):
  7. """Helper function to find unique values with support for python objects.
  8. Uses pure python method for object dtype, and numpy method for
  9. all other dtypes.
  10. Parameters
  11. ----------
  12. values : ndarray
  13. Values to check for unknowns.
  14. return_inverse : bool, default=False
  15. If True, also return the indices of the unique values.
  16. return_counts : bool, default=False
  17. If True, also return the number of times each unique item appears in
  18. values.
  19. Returns
  20. -------
  21. unique : ndarray
  22. The sorted unique values.
  23. unique_inverse : ndarray
  24. The indices to reconstruct the original array from the unique array.
  25. Only provided if `return_inverse` is True.
  26. unique_counts : ndarray
  27. The number of times each of the unique values comes up in the original
  28. array. Only provided if `return_counts` is True.
  29. """
  30. if values.dtype == object:
  31. return _unique_python(
  32. values, return_inverse=return_inverse, return_counts=return_counts
  33. )
  34. # numerical
  35. return _unique_np(
  36. values, return_inverse=return_inverse, return_counts=return_counts
  37. )
  38. def _unique_np(values, return_inverse=False, return_counts=False):
  39. """Helper function to find unique values for numpy arrays that correctly
  40. accounts for nans. See `_unique` documentation for details."""
  41. uniques = np.unique(
  42. values, return_inverse=return_inverse, return_counts=return_counts
  43. )
  44. inverse, counts = None, None
  45. if return_counts:
  46. *uniques, counts = uniques
  47. if return_inverse:
  48. *uniques, inverse = uniques
  49. if return_counts or return_inverse:
  50. uniques = uniques[0]
  51. # np.unique will have duplicate missing values at the end of `uniques`
  52. # here we clip the nans and remove it from uniques
  53. if uniques.size and is_scalar_nan(uniques[-1]):
  54. nan_idx = np.searchsorted(uniques, np.nan)
  55. uniques = uniques[: nan_idx + 1]
  56. if return_inverse:
  57. inverse[inverse > nan_idx] = nan_idx
  58. if return_counts:
  59. counts[nan_idx] = np.sum(counts[nan_idx:])
  60. counts = counts[: nan_idx + 1]
  61. ret = (uniques,)
  62. if return_inverse:
  63. ret += (inverse,)
  64. if return_counts:
  65. ret += (counts,)
  66. return ret[0] if len(ret) == 1 else ret
  67. class MissingValues(NamedTuple):
  68. """Data class for missing data information"""
  69. nan: bool
  70. none: bool
  71. def to_list(self):
  72. """Convert tuple to a list where None is always first."""
  73. output = []
  74. if self.none:
  75. output.append(None)
  76. if self.nan:
  77. output.append(np.nan)
  78. return output
  79. def _extract_missing(values):
  80. """Extract missing values from `values`.
  81. Parameters
  82. ----------
  83. values: set
  84. Set of values to extract missing from.
  85. Returns
  86. -------
  87. output: set
  88. Set with missing values extracted.
  89. missing_values: MissingValues
  90. Object with missing value information.
  91. """
  92. missing_values_set = {
  93. value for value in values if value is None or is_scalar_nan(value)
  94. }
  95. if not missing_values_set:
  96. return values, MissingValues(nan=False, none=False)
  97. if None in missing_values_set:
  98. if len(missing_values_set) == 1:
  99. output_missing_values = MissingValues(nan=False, none=True)
  100. else:
  101. # If there is more than one missing value, then it has to be
  102. # float('nan') or np.nan
  103. output_missing_values = MissingValues(nan=True, none=True)
  104. else:
  105. output_missing_values = MissingValues(nan=True, none=False)
  106. # create set without the missing values
  107. output = values - missing_values_set
  108. return output, output_missing_values
  109. class _nandict(dict):
  110. """Dictionary with support for nans."""
  111. def __init__(self, mapping):
  112. super().__init__(mapping)
  113. for key, value in mapping.items():
  114. if is_scalar_nan(key):
  115. self.nan_value = value
  116. break
  117. def __missing__(self, key):
  118. if hasattr(self, "nan_value") and is_scalar_nan(key):
  119. return self.nan_value
  120. raise KeyError(key)
  121. def _map_to_integer(values, uniques):
  122. """Map values based on its position in uniques."""
  123. table = _nandict({val: i for i, val in enumerate(uniques)})
  124. return np.array([table[v] for v in values])
  125. def _unique_python(values, *, return_inverse, return_counts):
  126. # Only used in `_uniques`, see docstring there for details
  127. try:
  128. uniques_set = set(values)
  129. uniques_set, missing_values = _extract_missing(uniques_set)
  130. uniques = sorted(uniques_set)
  131. uniques.extend(missing_values.to_list())
  132. uniques = np.array(uniques, dtype=values.dtype)
  133. except TypeError:
  134. types = sorted(t.__qualname__ for t in set(type(v) for v in values))
  135. raise TypeError(
  136. "Encoders require their input argument must be uniformly "
  137. f"strings or numbers. Got {types}"
  138. )
  139. ret = (uniques,)
  140. if return_inverse:
  141. ret += (_map_to_integer(values, uniques),)
  142. if return_counts:
  143. ret += (_get_counts(values, uniques),)
  144. return ret[0] if len(ret) == 1 else ret
  145. def _encode(values, *, uniques, check_unknown=True):
  146. """Helper function to encode values into [0, n_uniques - 1].
  147. Uses pure python method for object dtype, and numpy method for
  148. all other dtypes.
  149. The numpy method has the limitation that the `uniques` need to
  150. be sorted. Importantly, this is not checked but assumed to already be
  151. the case. The calling method needs to ensure this for all non-object
  152. values.
  153. Parameters
  154. ----------
  155. values : ndarray
  156. Values to encode.
  157. uniques : ndarray
  158. The unique values in `values`. If the dtype is not object, then
  159. `uniques` needs to be sorted.
  160. check_unknown : bool, default=True
  161. If True, check for values in `values` that are not in `unique`
  162. and raise an error. This is ignored for object dtype, and treated as
  163. True in this case. This parameter is useful for
  164. _BaseEncoder._transform() to avoid calling _check_unknown()
  165. twice.
  166. Returns
  167. -------
  168. encoded : ndarray
  169. Encoded values
  170. """
  171. if values.dtype.kind in "OUS":
  172. try:
  173. return _map_to_integer(values, uniques)
  174. except KeyError as e:
  175. raise ValueError(f"y contains previously unseen labels: {str(e)}")
  176. else:
  177. if check_unknown:
  178. diff = _check_unknown(values, uniques)
  179. if diff:
  180. raise ValueError(f"y contains previously unseen labels: {str(diff)}")
  181. return np.searchsorted(uniques, values)
  182. def _check_unknown(values, known_values, return_mask=False):
  183. """
  184. Helper function to check for unknowns in values to be encoded.
  185. Uses pure python method for object dtype, and numpy method for
  186. all other dtypes.
  187. Parameters
  188. ----------
  189. values : array
  190. Values to check for unknowns.
  191. known_values : array
  192. Known values. Must be unique.
  193. return_mask : bool, default=False
  194. If True, return a mask of the same shape as `values` indicating
  195. the valid values.
  196. Returns
  197. -------
  198. diff : list
  199. The unique values present in `values` and not in `know_values`.
  200. valid_mask : boolean array
  201. Additionally returned if ``return_mask=True``.
  202. """
  203. valid_mask = None
  204. if values.dtype.kind in "OUS":
  205. values_set = set(values)
  206. values_set, missing_in_values = _extract_missing(values_set)
  207. uniques_set = set(known_values)
  208. uniques_set, missing_in_uniques = _extract_missing(uniques_set)
  209. diff = values_set - uniques_set
  210. nan_in_diff = missing_in_values.nan and not missing_in_uniques.nan
  211. none_in_diff = missing_in_values.none and not missing_in_uniques.none
  212. def is_valid(value):
  213. return (
  214. value in uniques_set
  215. or missing_in_uniques.none
  216. and value is None
  217. or missing_in_uniques.nan
  218. and is_scalar_nan(value)
  219. )
  220. if return_mask:
  221. if diff or nan_in_diff or none_in_diff:
  222. valid_mask = np.array([is_valid(value) for value in values])
  223. else:
  224. valid_mask = np.ones(len(values), dtype=bool)
  225. diff = list(diff)
  226. if none_in_diff:
  227. diff.append(None)
  228. if nan_in_diff:
  229. diff.append(np.nan)
  230. else:
  231. unique_values = np.unique(values)
  232. diff = np.setdiff1d(unique_values, known_values, assume_unique=True)
  233. if return_mask:
  234. if diff.size:
  235. valid_mask = np.isin(values, known_values)
  236. else:
  237. valid_mask = np.ones(len(values), dtype=bool)
  238. # check for nans in the known_values
  239. if np.isnan(known_values).any():
  240. diff_is_nan = np.isnan(diff)
  241. if diff_is_nan.any():
  242. # removes nan from valid_mask
  243. if diff.size and return_mask:
  244. is_nan = np.isnan(values)
  245. valid_mask[is_nan] = 1
  246. # remove nan from diff
  247. diff = diff[~diff_is_nan]
  248. diff = list(diff)
  249. if return_mask:
  250. return diff, valid_mask
  251. return diff
  252. class _NaNCounter(Counter):
  253. """Counter with support for nan values."""
  254. def __init__(self, items):
  255. super().__init__(self._generate_items(items))
  256. def _generate_items(self, items):
  257. """Generate items without nans. Stores the nan counts separately."""
  258. for item in items:
  259. if not is_scalar_nan(item):
  260. yield item
  261. continue
  262. if not hasattr(self, "nan_count"):
  263. self.nan_count = 0
  264. self.nan_count += 1
  265. def __missing__(self, key):
  266. if hasattr(self, "nan_count") and is_scalar_nan(key):
  267. return self.nan_count
  268. raise KeyError(key)
  269. def _get_counts(values, uniques):
  270. """Get the count of each of the `uniques` in `values`.
  271. The counts will use the order passed in by `uniques`. For non-object dtypes,
  272. `uniques` is assumed to be sorted and `np.nan` is at the end.
  273. """
  274. if values.dtype.kind in "OU":
  275. counter = _NaNCounter(values)
  276. output = np.zeros(len(uniques), dtype=np.int64)
  277. for i, item in enumerate(uniques):
  278. with suppress(KeyError):
  279. output[i] = counter[item]
  280. return output
  281. unique_values, counts = _unique_np(values, return_counts=True)
  282. # Recorder unique_values based on input: `uniques`
  283. uniques_in_values = np.isin(uniques, unique_values, assume_unique=True)
  284. if np.isnan(unique_values[-1]) and np.isnan(uniques[-1]):
  285. uniques_in_values[-1] = True
  286. unique_valid_indices = np.searchsorted(unique_values, uniques[uniques_in_values])
  287. output = np.zeros_like(uniques, dtype=np.int64)
  288. output[uniques_in_values] = counts[unique_valid_indices]
  289. return output