_metadata_requests.py 44 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324
  1. """
  2. Metadata Routing Utility
  3. In order to better understand the components implemented in this file, one
  4. needs to understand their relationship to one another.
  5. The only relevant public API for end users are the ``set_{method}_request``,
  6. e.g. ``estimator.set_fit_request(sample_weight=True)``. However, third-party
  7. developers and users who implement custom meta-estimators, need to deal with
  8. the objects implemented in this file.
  9. All estimators (should) implement a ``get_metadata_routing`` method, returning
  10. the routing requests set for the estimator. This method is automatically
  11. implemented via ``BaseEstimator`` for all simple estimators, but needs a custom
  12. implementation for meta-estimators.
  13. In non-routing consumers, i.e. the simplest case, e.g. ``SVM``,
  14. ``get_metadata_routing`` returns a ``MetadataRequest`` object.
  15. In routers, e.g. meta-estimators and a multi metric scorer,
  16. ``get_metadata_routing`` returns a ``MetadataRouter`` object.
  17. An object which is both a router and a consumer, e.g. a meta-estimator which
  18. consumes ``sample_weight`` and routes ``sample_weight`` to its sub-estimators,
  19. routing information includes both information about the object itself (added
  20. via ``MetadataRouter.add_self_request``), as well as the routing information
  21. for its sub-estimators.
  22. A ``MetadataRequest`` instance includes one ``MethodMetadataRequest`` per
  23. method in ``METHODS``, which includes ``fit``, ``score``, etc.
  24. Request values are added to the routing mechanism by adding them to
  25. ``MethodMetadataRequest`` instances, e.g.
  26. ``metadatarequest.fit.add(param="sample_weight", alias="my_weights")``. This is
  27. used in ``set_{method}_request`` which are automatically generated, so users
  28. and developers almost never need to directly call methods on a
  29. ``MethodMetadataRequest``.
  30. The ``alias`` above in the ``add`` method has to be either a string (an alias),
  31. or a {True (requested), False (unrequested), None (error if passed)}``. There
  32. are some other special values such as ``UNUSED`` and ``WARN`` which are used
  33. for purposes such as warning of removing a metadata in a child class, but not
  34. used by the end users.
  35. ``MetadataRouter`` includes information about sub-objects' routing and how
  36. methods are mapped together. For instance, the information about which methods
  37. of a sub-estimator are called in which methods of the meta-estimator are all
  38. stored here. Conceptually, this information looks like:
  39. ```
  40. {
  41. "sub_estimator1": (
  42. mapping=[(caller="fit", callee="transform"), ...],
  43. router=MetadataRequest(...), # or another MetadataRouter
  44. ),
  45. ...
  46. }
  47. ```
  48. To give the above representation some structure, we use the following objects:
  49. - ``(caller, callee)`` is a namedtuple called ``MethodPair``
  50. - The list of ``MethodPair`` stored in the ``mapping`` field is a
  51. ``MethodMapping`` object
  52. - ``(mapping=..., router=...)`` is a namedtuple called ``RouterMappingPair``
  53. The ``set_{method}_request`` methods are dynamically generated for estimators
  54. which inherit from the ``BaseEstimator``. This is done by attaching instances
  55. of the ``RequestMethod`` descriptor to classes, which is done in the
  56. ``_MetadataRequester`` class, and ``BaseEstimator`` inherits from this mixin.
  57. This mixin also implements the ``get_metadata_routing``, which meta-estimators
  58. need to override, but it works for simple consumers as is.
  59. """
  60. # Author: Adrin Jalali <adrin.jalali@gmail.com>
  61. # License: BSD 3 clause
  62. import inspect
  63. from collections import namedtuple
  64. from copy import deepcopy
  65. from typing import Optional, Union
  66. from warnings import warn
  67. from .. import get_config
  68. from ..exceptions import UnsetMetadataPassedError
  69. from ._bunch import Bunch
  70. # Only the following methods are supported in the routing mechanism. Adding new
  71. # methods at the moment involves monkeypatching this list.
  72. METHODS = [
  73. "fit",
  74. "partial_fit",
  75. "predict",
  76. "predict_proba",
  77. "predict_log_proba",
  78. "decision_function",
  79. "score",
  80. "split",
  81. "transform",
  82. "inverse_transform",
  83. ]
  84. def _routing_enabled():
  85. """Return whether metadata routing is enabled.
  86. .. versionadded:: 1.3
  87. Returns
  88. -------
  89. enabled : bool
  90. Whether metadata routing is enabled. If the config is not set, it
  91. defaults to False.
  92. """
  93. return get_config().get("enable_metadata_routing", False)
  94. # Request values
  95. # ==============
  96. # Each request value needs to be one of the following values, or an alias.
  97. # this is used in `__metadata_request__*` attributes to indicate that a
  98. # metadata is not present even though it may be present in the
  99. # corresponding method's signature.
  100. UNUSED = "$UNUSED$"
  101. # this is used whenever a default value is changed, and therefore the user
  102. # should explicitly set the value, otherwise a warning is shown. An example
  103. # is when a meta-estimator is only a router, but then becomes also a
  104. # consumer in a new release.
  105. WARN = "$WARN$"
  106. # this is the default used in `set_{method}_request` methods to indicate no
  107. # change requested by the user.
  108. UNCHANGED = "$UNCHANGED$"
  109. VALID_REQUEST_VALUES = [False, True, None, UNUSED, WARN]
  110. def request_is_alias(item):
  111. """Check if an item is a valid alias.
  112. Values in ``VALID_REQUEST_VALUES`` are not considered aliases in this
  113. context. Only a string which is a valid identifier is.
  114. Parameters
  115. ----------
  116. item : object
  117. The given item to be checked if it can be an alias.
  118. Returns
  119. -------
  120. result : bool
  121. Whether the given item is a valid alias.
  122. """
  123. if item in VALID_REQUEST_VALUES:
  124. return False
  125. # item is only an alias if it's a valid identifier
  126. return isinstance(item, str) and item.isidentifier()
  127. def request_is_valid(item):
  128. """Check if an item is a valid request value (and not an alias).
  129. Parameters
  130. ----------
  131. item : object
  132. The given item to be checked.
  133. Returns
  134. -------
  135. result : bool
  136. Whether the given item is valid.
  137. """
  138. return item in VALID_REQUEST_VALUES
  139. # Metadata Request for Simple Consumers
  140. # =====================================
  141. # This section includes MethodMetadataRequest and MetadataRequest which are
  142. # used in simple consumers.
  143. class MethodMetadataRequest:
  144. """A prescription of how metadata is to be passed to a single method.
  145. Refer to :class:`MetadataRequest` for how this class is used.
  146. .. versionadded:: 1.3
  147. Parameters
  148. ----------
  149. owner : str
  150. A display name for the object owning these requests.
  151. method : str
  152. The name of the method to which these requests belong.
  153. """
  154. def __init__(self, owner, method):
  155. self._requests = dict()
  156. self.owner = owner
  157. self.method = method
  158. @property
  159. def requests(self):
  160. """Dictionary of the form: ``{key: alias}``."""
  161. return self._requests
  162. def add_request(
  163. self,
  164. *,
  165. param,
  166. alias,
  167. ):
  168. """Add request info for a metadata.
  169. Parameters
  170. ----------
  171. param : str
  172. The property for which a request is set.
  173. alias : str, or {True, False, None}
  174. Specifies which metadata should be routed to `param`
  175. - str: the name (or alias) of metadata given to a meta-estimator that
  176. should be routed to this parameter.
  177. - True: requested
  178. - False: not requested
  179. - None: error if passed
  180. """
  181. if not request_is_alias(alias) and not request_is_valid(alias):
  182. raise ValueError(
  183. f"The alias you're setting for `{param}` should be either a "
  184. "valid identifier or one of {None, True, False}, but given "
  185. f"value is: `{alias}`"
  186. )
  187. if alias == param:
  188. alias = True
  189. if alias == UNUSED:
  190. if param in self._requests:
  191. del self._requests[param]
  192. else:
  193. raise ValueError(
  194. f"Trying to remove parameter {param} with UNUSED which doesn't"
  195. " exist."
  196. )
  197. else:
  198. self._requests[param] = alias
  199. return self
  200. def _get_param_names(self, return_alias):
  201. """Get names of all metadata that can be consumed or routed by this method.
  202. This method returns the names of all metadata, even the ``False``
  203. ones.
  204. Parameters
  205. ----------
  206. return_alias : bool
  207. Controls whether original or aliased names should be returned. If
  208. ``False``, aliases are ignored and original names are returned.
  209. Returns
  210. -------
  211. names : set of str
  212. A set of strings with the names of all parameters.
  213. """
  214. return set(
  215. alias if return_alias and not request_is_valid(alias) else prop
  216. for prop, alias in self._requests.items()
  217. if not request_is_valid(alias) or alias is not False
  218. )
  219. def _check_warnings(self, *, params):
  220. """Check whether metadata is passed which is marked as WARN.
  221. If any metadata is passed which is marked as WARN, a warning is raised.
  222. Parameters
  223. ----------
  224. params : dict
  225. The metadata passed to a method.
  226. """
  227. params = {} if params is None else params
  228. warn_params = {
  229. prop
  230. for prop, alias in self._requests.items()
  231. if alias == WARN and prop in params
  232. }
  233. for param in warn_params:
  234. warn(
  235. f"Support for {param} has recently been added to this class. "
  236. "To maintain backward compatibility, it is ignored now. "
  237. "You can set the request value to False to silence this "
  238. "warning, or to True to consume and use the metadata."
  239. )
  240. def _route_params(self, params):
  241. """Prepare the given parameters to be passed to the method.
  242. The output of this method can be used directly as the input to the
  243. corresponding method as extra props.
  244. Parameters
  245. ----------
  246. params : dict
  247. A dictionary of provided metadata.
  248. Returns
  249. -------
  250. params : Bunch
  251. A :class:`~utils.Bunch` of {prop: value} which can be given to the
  252. corresponding method.
  253. """
  254. self._check_warnings(params=params)
  255. unrequested = dict()
  256. args = {arg: value for arg, value in params.items() if value is not None}
  257. res = Bunch()
  258. for prop, alias in self._requests.items():
  259. if alias is False or alias == WARN:
  260. continue
  261. elif alias is True and prop in args:
  262. res[prop] = args[prop]
  263. elif alias is None and prop in args:
  264. unrequested[prop] = args[prop]
  265. elif alias in args:
  266. res[prop] = args[alias]
  267. if unrequested:
  268. raise UnsetMetadataPassedError(
  269. message=(
  270. f"[{', '.join([key for key in unrequested])}] are passed but are"
  271. " not explicitly set as requested or not for"
  272. f" {self.owner}.{self.method}"
  273. ),
  274. unrequested_params=unrequested,
  275. routed_params=res,
  276. )
  277. return res
  278. def _serialize(self):
  279. """Serialize the object.
  280. Returns
  281. -------
  282. obj : dict
  283. A serialized version of the instance in the form of a dictionary.
  284. """
  285. return self._requests
  286. def __repr__(self):
  287. return str(self._serialize())
  288. def __str__(self):
  289. return str(repr(self))
  290. class MetadataRequest:
  291. """Contains the metadata request info of a consumer.
  292. Instances of `MethodMetadataRequest` are used in this class for each
  293. available method under `metadatarequest.{method}`.
  294. Consumer-only classes such as simple estimators return a serialized
  295. version of this class as the output of `get_metadata_routing()`.
  296. .. versionadded:: 1.3
  297. Parameters
  298. ----------
  299. owner : str
  300. The name of the object to which these requests belong.
  301. """
  302. # this is here for us to use this attribute's value instead of doing
  303. # `isinstance` in our checks, so that we avoid issues when people vendor
  304. # this file instead of using it directly from scikit-learn.
  305. _type = "metadata_request"
  306. def __init__(self, owner):
  307. for method in METHODS:
  308. setattr(
  309. self,
  310. method,
  311. MethodMetadataRequest(owner=owner, method=method),
  312. )
  313. def _get_param_names(self, method, return_alias, ignore_self_request=None):
  314. """Get names of all metadata that can be consumed or routed by specified \
  315. method.
  316. This method returns the names of all metadata, even the ``False``
  317. ones.
  318. Parameters
  319. ----------
  320. method : str
  321. The name of the method for which metadata names are requested.
  322. return_alias : bool
  323. Controls whether original or aliased names should be returned. If
  324. ``False``, aliases are ignored and original names are returned.
  325. ignore_self_request : bool
  326. Ignored. Present for API compatibility.
  327. Returns
  328. -------
  329. names : set of str
  330. A set of strings with the names of all parameters.
  331. """
  332. return getattr(self, method)._get_param_names(return_alias=return_alias)
  333. def _route_params(self, *, method, params):
  334. """Prepare the given parameters to be passed to the method.
  335. The output of this method can be used directly as the input to the
  336. corresponding method as extra keyword arguments to pass metadata.
  337. Parameters
  338. ----------
  339. method : str
  340. The name of the method for which the parameters are requested and
  341. routed.
  342. params : dict
  343. A dictionary of provided metadata.
  344. Returns
  345. -------
  346. params : Bunch
  347. A :class:`~utils.Bunch` of {prop: value} which can be given to the
  348. corresponding method.
  349. """
  350. return getattr(self, method)._route_params(params=params)
  351. def _check_warnings(self, *, method, params):
  352. """Check whether metadata is passed which is marked as WARN.
  353. If any metadata is passed which is marked as WARN, a warning is raised.
  354. Parameters
  355. ----------
  356. method : str
  357. The name of the method for which the warnings should be checked.
  358. params : dict
  359. The metadata passed to a method.
  360. """
  361. getattr(self, method)._check_warnings(params=params)
  362. def _serialize(self):
  363. """Serialize the object.
  364. Returns
  365. -------
  366. obj : dict
  367. A serialized version of the instance in the form of a dictionary.
  368. """
  369. output = dict()
  370. for method in METHODS:
  371. mmr = getattr(self, method)
  372. if len(mmr.requests):
  373. output[method] = mmr._serialize()
  374. return output
  375. def __repr__(self):
  376. return str(self._serialize())
  377. def __str__(self):
  378. return str(repr(self))
  379. # Metadata Request for Routers
  380. # ============================
  381. # This section includes all objects required for MetadataRouter which is used
  382. # in routers, returned by their ``get_metadata_routing``.
  383. # This namedtuple is used to store a (mapping, routing) pair. Mapping is a
  384. # MethodMapping object, and routing is the output of `get_metadata_routing`.
  385. # MetadataRouter stores a collection of these namedtuples.
  386. RouterMappingPair = namedtuple("RouterMappingPair", ["mapping", "router"])
  387. # A namedtuple storing a single method route. A collection of these namedtuples
  388. # is stored in a MetadataRouter.
  389. MethodPair = namedtuple("MethodPair", ["callee", "caller"])
  390. class MethodMapping:
  391. """Stores the mapping between callee and caller methods for a router.
  392. This class is primarily used in a ``get_metadata_routing()`` of a router
  393. object when defining the mapping between a sub-object (a sub-estimator or a
  394. scorer) to the router's methods. It stores a collection of ``Route``
  395. namedtuples.
  396. Iterating through an instance of this class will yield named
  397. ``MethodPair(callee, caller)`` tuples.
  398. .. versionadded:: 1.3
  399. """
  400. def __init__(self):
  401. self._routes = []
  402. def __iter__(self):
  403. return iter(self._routes)
  404. def add(self, *, callee, caller):
  405. """Add a method mapping.
  406. Parameters
  407. ----------
  408. callee : str
  409. Child object's method name. This method is called in ``caller``.
  410. caller : str
  411. Parent estimator's method name in which the ``callee`` is called.
  412. Returns
  413. -------
  414. self : MethodMapping
  415. Returns self.
  416. """
  417. if callee not in METHODS:
  418. raise ValueError(
  419. f"Given callee:{callee} is not a valid method. Valid methods are:"
  420. f" {METHODS}"
  421. )
  422. if caller not in METHODS:
  423. raise ValueError(
  424. f"Given caller:{caller} is not a valid method. Valid methods are:"
  425. f" {METHODS}"
  426. )
  427. self._routes.append(MethodPair(callee=callee, caller=caller))
  428. return self
  429. def _serialize(self):
  430. """Serialize the object.
  431. Returns
  432. -------
  433. obj : list
  434. A serialized version of the instance in the form of a list.
  435. """
  436. result = list()
  437. for route in self._routes:
  438. result.append({"callee": route.callee, "caller": route.caller})
  439. return result
  440. @classmethod
  441. def from_str(cls, route):
  442. """Construct an instance from a string.
  443. Parameters
  444. ----------
  445. route : str
  446. A string representing the mapping, it can be:
  447. - `"one-to-one"`: a one to one mapping for all methods.
  448. - `"method"`: the name of a single method, such as ``fit``,
  449. ``transform``, ``score``, etc.
  450. Returns
  451. -------
  452. obj : MethodMapping
  453. A :class:`~sklearn.utils.metadata_routing.MethodMapping` instance
  454. constructed from the given string.
  455. """
  456. routing = cls()
  457. if route == "one-to-one":
  458. for method in METHODS:
  459. routing.add(callee=method, caller=method)
  460. elif route in METHODS:
  461. routing.add(callee=route, caller=route)
  462. else:
  463. raise ValueError("route should be 'one-to-one' or a single method!")
  464. return routing
  465. def __repr__(self):
  466. return str(self._serialize())
  467. def __str__(self):
  468. return str(repr(self))
  469. class MetadataRouter:
  470. """Stores and handles metadata routing for a router object.
  471. This class is used by router objects to store and handle metadata routing.
  472. Routing information is stored as a dictionary of the form ``{"object_name":
  473. RouteMappingPair(method_mapping, routing_info)}``, where ``method_mapping``
  474. is an instance of :class:`~sklearn.utils.metadata_routing.MethodMapping` and
  475. ``routing_info`` is either a
  476. :class:`~utils.metadata_routing.MetadataRequest` or a
  477. :class:`~utils.metadata_routing.MetadataRouter` instance.
  478. .. versionadded:: 1.3
  479. Parameters
  480. ----------
  481. owner : str
  482. The name of the object to which these requests belong.
  483. """
  484. # this is here for us to use this attribute's value instead of doing
  485. # `isinstance`` in our checks, so that we avoid issues when people vendor
  486. # this file instead of using it directly from scikit-learn.
  487. _type = "metadata_router"
  488. def __init__(self, owner):
  489. self._route_mappings = dict()
  490. # `_self_request` is used if the router is also a consumer.
  491. # _self_request, (added using `add_self_request()`) is treated
  492. # differently from the other objects which are stored in
  493. # _route_mappings.
  494. self._self_request = None
  495. self.owner = owner
  496. def add_self_request(self, obj):
  497. """Add `self` (as a consumer) to the routing.
  498. This method is used if the router is also a consumer, and hence the
  499. router itself needs to be included in the routing. The passed object
  500. can be an estimator or a
  501. :class:`~utils.metadata_routing.MetadataRequest`.
  502. A router should add itself using this method instead of `add` since it
  503. should be treated differently than the other objects to which metadata
  504. is routed by the router.
  505. Parameters
  506. ----------
  507. obj : object
  508. This is typically the router instance, i.e. `self` in a
  509. ``get_metadata_routing()`` implementation. It can also be a
  510. ``MetadataRequest`` instance.
  511. Returns
  512. -------
  513. self : MetadataRouter
  514. Returns `self`.
  515. """
  516. if getattr(obj, "_type", None) == "metadata_request":
  517. self._self_request = deepcopy(obj)
  518. elif hasattr(obj, "_get_metadata_request"):
  519. self._self_request = deepcopy(obj._get_metadata_request())
  520. else:
  521. raise ValueError(
  522. "Given `obj` is neither a `MetadataRequest` nor does it implement the"
  523. " required API. Inheriting from `BaseEstimator` implements the required"
  524. " API."
  525. )
  526. return self
  527. def add(self, *, method_mapping, **objs):
  528. """Add named objects with their corresponding method mapping.
  529. Parameters
  530. ----------
  531. method_mapping : MethodMapping or str
  532. The mapping between the child and the parent's methods. If str, the
  533. output of :func:`~sklearn.utils.metadata_routing.MethodMapping.from_str`
  534. is used.
  535. **objs : dict
  536. A dictionary of objects from which metadata is extracted by calling
  537. :func:`~sklearn.utils.metadata_routing.get_routing_for_object` on them.
  538. Returns
  539. -------
  540. self : MetadataRouter
  541. Returns `self`.
  542. """
  543. if isinstance(method_mapping, str):
  544. method_mapping = MethodMapping.from_str(method_mapping)
  545. else:
  546. method_mapping = deepcopy(method_mapping)
  547. for name, obj in objs.items():
  548. self._route_mappings[name] = RouterMappingPair(
  549. mapping=method_mapping, router=get_routing_for_object(obj)
  550. )
  551. return self
  552. def _get_param_names(self, *, method, return_alias, ignore_self_request):
  553. """Get names of all metadata that can be consumed or routed by specified \
  554. method.
  555. This method returns the names of all metadata, even the ``False``
  556. ones.
  557. Parameters
  558. ----------
  559. method : str
  560. The name of the method for which metadata names are requested.
  561. return_alias : bool
  562. Controls whether original or aliased names should be returned,
  563. which only applies to the stored `self`. If no `self` routing
  564. object is stored, this parameter has no effect.
  565. ignore_self_request : bool
  566. If `self._self_request` should be ignored. This is used in `_route_params`.
  567. If ``True``, ``return_alias`` has no effect.
  568. Returns
  569. -------
  570. names : set of str
  571. A set of strings with the names of all parameters.
  572. """
  573. res = set()
  574. if self._self_request and not ignore_self_request:
  575. res = res.union(
  576. self._self_request._get_param_names(
  577. method=method, return_alias=return_alias
  578. )
  579. )
  580. for name, route_mapping in self._route_mappings.items():
  581. for callee, caller in route_mapping.mapping:
  582. if caller == method:
  583. res = res.union(
  584. route_mapping.router._get_param_names(
  585. method=callee, return_alias=True, ignore_self_request=False
  586. )
  587. )
  588. return res
  589. def _route_params(self, *, params, method):
  590. """Prepare the given parameters to be passed to the method.
  591. This is used when a router is used as a child object of another router.
  592. The parent router then passes all parameters understood by the child
  593. object to it and delegates their validation to the child.
  594. The output of this method can be used directly as the input to the
  595. corresponding method as extra props.
  596. Parameters
  597. ----------
  598. method : str
  599. The name of the method for which the parameters are requested and
  600. routed.
  601. params : dict
  602. A dictionary of provided metadata.
  603. Returns
  604. -------
  605. params : Bunch
  606. A :class:`~sklearn.utils.Bunch` of {prop: value} which can be given to the
  607. corresponding method.
  608. """
  609. res = Bunch()
  610. if self._self_request:
  611. res.update(self._self_request._route_params(params=params, method=method))
  612. param_names = self._get_param_names(
  613. method=method, return_alias=True, ignore_self_request=True
  614. )
  615. child_params = {
  616. key: value for key, value in params.items() if key in param_names
  617. }
  618. for key in set(res.keys()).intersection(child_params.keys()):
  619. # conflicts are okay if the passed objects are the same, but it's
  620. # an issue if they're different objects.
  621. if child_params[key] is not res[key]:
  622. raise ValueError(
  623. f"In {self.owner}, there is a conflict on {key} between what is"
  624. " requested for this estimator and what is requested by its"
  625. " children. You can resolve this conflict by using an alias for"
  626. " the child estimator(s) requested metadata."
  627. )
  628. res.update(child_params)
  629. return res
  630. def route_params(self, *, caller, params):
  631. """Return the input parameters requested by child objects.
  632. The output of this method is a bunch, which includes the inputs for all
  633. methods of each child object that are used in the router's `caller`
  634. method.
  635. If the router is also a consumer, it also checks for warnings of
  636. `self`'s/consumer's requested metadata.
  637. Parameters
  638. ----------
  639. caller : str
  640. The name of the method for which the parameters are requested and
  641. routed. If called inside the :term:`fit` method of a router, it
  642. would be `"fit"`.
  643. params : dict
  644. A dictionary of provided metadata.
  645. Returns
  646. -------
  647. params : Bunch
  648. A :class:`~utils.Bunch` of the form
  649. ``{"object_name": {"method_name": {prop: value}}}`` which can be
  650. used to pass the required metadata to corresponding methods or
  651. corresponding child objects.
  652. """
  653. if self._self_request:
  654. self._self_request._check_warnings(params=params, method=caller)
  655. res = Bunch()
  656. for name, route_mapping in self._route_mappings.items():
  657. router, mapping = route_mapping.router, route_mapping.mapping
  658. res[name] = Bunch()
  659. for _callee, _caller in mapping:
  660. if _caller == caller:
  661. res[name][_callee] = router._route_params(
  662. params=params, method=_callee
  663. )
  664. return res
  665. def validate_metadata(self, *, method, params):
  666. """Validate given metadata for a method.
  667. This raises a ``ValueError`` if some of the passed metadata are not
  668. understood by child objects.
  669. Parameters
  670. ----------
  671. method : str
  672. The name of the method for which the parameters are requested and
  673. routed. If called inside the :term:`fit` method of a router, it
  674. would be `"fit"`.
  675. params : dict
  676. A dictionary of provided metadata.
  677. """
  678. param_names = self._get_param_names(
  679. method=method, return_alias=False, ignore_self_request=False
  680. )
  681. if self._self_request:
  682. self_params = self._self_request._get_param_names(
  683. method=method, return_alias=False
  684. )
  685. else:
  686. self_params = set()
  687. extra_keys = set(params.keys()) - param_names - self_params
  688. if extra_keys:
  689. raise TypeError(
  690. f"{method} got unexpected argument(s) {extra_keys}, which are "
  691. "not requested metadata in any object."
  692. )
  693. def _serialize(self):
  694. """Serialize the object.
  695. Returns
  696. -------
  697. obj : dict
  698. A serialized version of the instance in the form of a dictionary.
  699. """
  700. res = dict()
  701. if self._self_request:
  702. res["$self_request"] = self._self_request._serialize()
  703. for name, route_mapping in self._route_mappings.items():
  704. res[name] = dict()
  705. res[name]["mapping"] = route_mapping.mapping._serialize()
  706. res[name]["router"] = route_mapping.router._serialize()
  707. return res
  708. def __iter__(self):
  709. if self._self_request:
  710. yield "$self_request", RouterMappingPair(
  711. mapping=MethodMapping.from_str("one-to-one"), router=self._self_request
  712. )
  713. for name, route_mapping in self._route_mappings.items():
  714. yield (name, route_mapping)
  715. def __repr__(self):
  716. return str(self._serialize())
  717. def __str__(self):
  718. return str(repr(self))
  719. def get_routing_for_object(obj=None):
  720. """Get a ``Metadata{Router, Request}`` instance from the given object.
  721. This function returns a
  722. :class:`~sklearn.utils.metadata_routing.MetadataRouter` or a
  723. :class:`~sklearn.utils.metadata_routing.MetadataRequest` from the given input.
  724. This function always returns a copy or an instance constructed from the
  725. input, such that changing the output of this function will not change the
  726. original object.
  727. .. versionadded:: 1.3
  728. Parameters
  729. ----------
  730. obj : object
  731. - If the object is already a
  732. :class:`~sklearn.utils.metadata_routing.MetadataRequest` or a
  733. :class:`~sklearn.utils.metadata_routing.MetadataRouter`, return a copy
  734. of that.
  735. - If the object provides a `get_metadata_routing` method, return a copy
  736. of the output of that method.
  737. - Returns an empty :class:`~sklearn.utils.metadata_routing.MetadataRequest`
  738. otherwise.
  739. Returns
  740. -------
  741. obj : MetadataRequest or MetadataRouting
  742. A ``MetadataRequest`` or a ``MetadataRouting`` taken or created from
  743. the given object.
  744. """
  745. # doing this instead of a try/except since an AttributeError could be raised
  746. # for other reasons.
  747. if hasattr(obj, "get_metadata_routing"):
  748. return deepcopy(obj.get_metadata_routing())
  749. elif getattr(obj, "_type", None) in ["metadata_request", "metadata_router"]:
  750. return deepcopy(obj)
  751. return MetadataRequest(owner=None)
  752. # Request method
  753. # ==============
  754. # This section includes what's needed for the request method descriptor and
  755. # their dynamic generation in a meta class.
  756. # These strings are used to dynamically generate the docstrings for
  757. # set_{method}_request methods.
  758. REQUESTER_DOC = """ Request metadata passed to the ``{method}`` method.
  759. Note that this method is only relevant if
  760. ``enable_metadata_routing=True`` (see :func:`sklearn.set_config`).
  761. Please see :ref:`User Guide <metadata_routing>` on how the routing
  762. mechanism works.
  763. The options for each parameter are:
  764. - ``True``: metadata is requested, and \
  765. passed to ``{method}`` if provided. The request is ignored if \
  766. metadata is not provided.
  767. - ``False``: metadata is not requested and the meta-estimator \
  768. will not pass it to ``{method}``.
  769. - ``None``: metadata is not requested, and the meta-estimator \
  770. will raise an error if the user provides it.
  771. - ``str``: metadata should be passed to the meta-estimator with \
  772. this given alias instead of the original name.
  773. The default (``sklearn.utils.metadata_routing.UNCHANGED``) retains the
  774. existing request. This allows you to change the request for some
  775. parameters and not others.
  776. .. versionadded:: 1.3
  777. .. note::
  778. This method is only relevant if this estimator is used as a
  779. sub-estimator of a meta-estimator, e.g. used inside a
  780. :class:`~sklearn.pipeline.Pipeline`. Otherwise it has no effect.
  781. Parameters
  782. ----------
  783. """
  784. REQUESTER_DOC_PARAM = """ {metadata} : str, True, False, or None, \
  785. default=sklearn.utils.metadata_routing.UNCHANGED
  786. Metadata routing for ``{metadata}`` parameter in ``{method}``.
  787. """
  788. REQUESTER_DOC_RETURN = """ Returns
  789. -------
  790. self : object
  791. The updated object.
  792. """
  793. class RequestMethod:
  794. """
  795. A descriptor for request methods.
  796. .. versionadded:: 1.3
  797. Parameters
  798. ----------
  799. name : str
  800. The name of the method for which the request function should be
  801. created, e.g. ``"fit"`` would create a ``set_fit_request`` function.
  802. keys : list of str
  803. A list of strings which are accepted parameters by the created
  804. function, e.g. ``["sample_weight"]`` if the corresponding method
  805. accepts it as a metadata.
  806. validate_keys : bool, default=True
  807. Whether to check if the requested parameters fit the actual parameters
  808. of the method.
  809. Notes
  810. -----
  811. This class is a descriptor [1]_ and uses PEP-362 to set the signature of
  812. the returned function [2]_.
  813. References
  814. ----------
  815. .. [1] https://docs.python.org/3/howto/descriptor.html
  816. .. [2] https://www.python.org/dev/peps/pep-0362/
  817. """
  818. def __init__(self, name, keys, validate_keys=True):
  819. self.name = name
  820. self.keys = keys
  821. self.validate_keys = validate_keys
  822. def __get__(self, instance, owner):
  823. # we would want to have a method which accepts only the expected args
  824. def func(**kw):
  825. """Updates the request for provided parameters
  826. This docstring is overwritten below.
  827. See REQUESTER_DOC for expected functionality
  828. """
  829. if not _routing_enabled():
  830. raise RuntimeError(
  831. "This method is only available when metadata routing is enabled."
  832. " You can enable it using"
  833. " sklearn.set_config(enable_metadata_routing=True)."
  834. )
  835. if self.validate_keys and (set(kw) - set(self.keys)):
  836. raise TypeError(
  837. f"Unexpected args: {set(kw) - set(self.keys)}. Accepted arguments"
  838. f" are: {set(self.keys)}"
  839. )
  840. requests = instance._get_metadata_request()
  841. method_metadata_request = getattr(requests, self.name)
  842. for prop, alias in kw.items():
  843. if alias is not UNCHANGED:
  844. method_metadata_request.add_request(param=prop, alias=alias)
  845. instance._metadata_request = requests
  846. return instance
  847. # Now we set the relevant attributes of the function so that it seems
  848. # like a normal method to the end user, with known expected arguments.
  849. func.__name__ = f"set_{self.name}_request"
  850. params = [
  851. inspect.Parameter(
  852. name="self",
  853. kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
  854. annotation=owner,
  855. )
  856. ]
  857. params.extend(
  858. [
  859. inspect.Parameter(
  860. k,
  861. inspect.Parameter.KEYWORD_ONLY,
  862. default=UNCHANGED,
  863. annotation=Optional[Union[bool, None, str]],
  864. )
  865. for k in self.keys
  866. ]
  867. )
  868. func.__signature__ = inspect.Signature(
  869. params,
  870. return_annotation=owner,
  871. )
  872. doc = REQUESTER_DOC.format(method=self.name)
  873. for metadata in self.keys:
  874. doc += REQUESTER_DOC_PARAM.format(metadata=metadata, method=self.name)
  875. doc += REQUESTER_DOC_RETURN
  876. func.__doc__ = doc
  877. return func
  878. class _MetadataRequester:
  879. """Mixin class for adding metadata request functionality.
  880. ``BaseEstimator`` inherits from this Mixin.
  881. .. versionadded:: 1.3
  882. """
  883. def __init_subclass__(cls, **kwargs):
  884. """Set the ``set_{method}_request`` methods.
  885. This uses PEP-487 [1]_ to set the ``set_{method}_request`` methods. It
  886. looks for the information available in the set default values which are
  887. set using ``__metadata_request__*`` class attributes, or inferred
  888. from method signatures.
  889. The ``__metadata_request__*`` class attributes are used when a method
  890. does not explicitly accept a metadata through its arguments or if the
  891. developer would like to specify a request value for those metadata
  892. which are different from the default ``None``.
  893. References
  894. ----------
  895. .. [1] https://www.python.org/dev/peps/pep-0487
  896. """
  897. try:
  898. requests = cls._get_default_requests()
  899. except Exception:
  900. # if there are any issues in the default values, it will be raised
  901. # when ``get_metadata_routing`` is called. Here we are going to
  902. # ignore all the issues such as bad defaults etc.
  903. super().__init_subclass__(**kwargs)
  904. return
  905. for method in METHODS:
  906. mmr = getattr(requests, method)
  907. # set ``set_{method}_request``` methods
  908. if not len(mmr.requests):
  909. continue
  910. setattr(
  911. cls,
  912. f"set_{method}_request",
  913. RequestMethod(method, sorted(mmr.requests.keys())),
  914. )
  915. super().__init_subclass__(**kwargs)
  916. @classmethod
  917. def _build_request_for_signature(cls, router, method):
  918. """Build the `MethodMetadataRequest` for a method using its signature.
  919. This method takes all arguments from the method signature and uses
  920. ``None`` as their default request value, except ``X``, ``y``, ``Y``,
  921. ``Xt``, ``yt``, ``*args``, and ``**kwargs``.
  922. Parameters
  923. ----------
  924. router : MetadataRequest
  925. The parent object for the created `MethodMetadataRequest`.
  926. method : str
  927. The name of the method.
  928. Returns
  929. -------
  930. method_request : MethodMetadataRequest
  931. The prepared request using the method's signature.
  932. """
  933. mmr = MethodMetadataRequest(owner=cls.__name__, method=method)
  934. # Here we use `isfunction` instead of `ismethod` because calling `getattr`
  935. # on a class instead of an instance returns an unbound function.
  936. if not hasattr(cls, method) or not inspect.isfunction(getattr(cls, method)):
  937. return mmr
  938. # ignore the first parameter of the method, which is usually "self"
  939. params = list(inspect.signature(getattr(cls, method)).parameters.items())[1:]
  940. for pname, param in params:
  941. if pname in {"X", "y", "Y", "Xt", "yt"}:
  942. continue
  943. if param.kind in {param.VAR_POSITIONAL, param.VAR_KEYWORD}:
  944. continue
  945. mmr.add_request(
  946. param=pname,
  947. alias=None,
  948. )
  949. return mmr
  950. @classmethod
  951. def _get_default_requests(cls):
  952. """Collect default request values.
  953. This method combines the information present in ``__metadata_request__*``
  954. class attributes, as well as determining request keys from method
  955. signatures.
  956. """
  957. requests = MetadataRequest(owner=cls.__name__)
  958. for method in METHODS:
  959. setattr(
  960. requests,
  961. method,
  962. cls._build_request_for_signature(router=requests, method=method),
  963. )
  964. # Then overwrite those defaults with the ones provided in
  965. # __metadata_request__* attributes. Defaults set in
  966. # __metadata_request__* attributes take precedence over signature
  967. # sniffing.
  968. # need to go through the MRO since this is a class attribute and
  969. # ``vars`` doesn't report the parent class attributes. We go through
  970. # the reverse of the MRO so that child classes have precedence over
  971. # their parents.
  972. defaults = dict()
  973. for base_class in reversed(inspect.getmro(cls)):
  974. base_defaults = {
  975. attr: value
  976. for attr, value in vars(base_class).items()
  977. if "__metadata_request__" in attr
  978. }
  979. defaults.update(base_defaults)
  980. defaults = dict(sorted(defaults.items()))
  981. for attr, value in defaults.items():
  982. # we don't check for attr.startswith() since python prefixes attrs
  983. # starting with __ with the `_ClassName`.
  984. substr = "__metadata_request__"
  985. method = attr[attr.index(substr) + len(substr) :]
  986. for prop, alias in value.items():
  987. getattr(requests, method).add_request(param=prop, alias=alias)
  988. return requests
  989. def _get_metadata_request(self):
  990. """Get requested data properties.
  991. Please check :ref:`User Guide <metadata_routing>` on how the routing
  992. mechanism works.
  993. Returns
  994. -------
  995. request : MetadataRequest
  996. A :class:`~sklearn.utils.metadata_routing.MetadataRequest` instance.
  997. """
  998. if hasattr(self, "_metadata_request"):
  999. requests = get_routing_for_object(self._metadata_request)
  1000. else:
  1001. requests = self._get_default_requests()
  1002. return requests
  1003. def get_metadata_routing(self):
  1004. """Get metadata routing of this object.
  1005. Please check :ref:`User Guide <metadata_routing>` on how the routing
  1006. mechanism works.
  1007. Returns
  1008. -------
  1009. routing : MetadataRequest
  1010. A :class:`~sklearn.utils.metadata_routing.MetadataRequest` encapsulating
  1011. routing information.
  1012. """
  1013. return self._get_metadata_request()
  1014. # Process Routing in Routers
  1015. # ==========================
  1016. # This is almost always the only method used in routers to process and route
  1017. # given metadata. This is to minimize the boilerplate required in routers.
  1018. def process_routing(obj, method, other_params, **kwargs):
  1019. """Validate and route input parameters.
  1020. This function is used inside a router's method, e.g. :term:`fit`,
  1021. to validate the metadata and handle the routing.
  1022. Assuming this signature: ``fit(self, X, y, sample_weight=None, **fit_params)``,
  1023. a call to this function would be:
  1024. ``process_routing(self, fit_params, sample_weight=sample_weight)``.
  1025. .. versionadded:: 1.3
  1026. Parameters
  1027. ----------
  1028. obj : object
  1029. An object implementing ``get_metadata_routing``. Typically a
  1030. meta-estimator.
  1031. method : str
  1032. The name of the router's method in which this function is called.
  1033. other_params : dict
  1034. A dictionary of extra parameters passed to the router's method,
  1035. e.g. ``**fit_params`` passed to a meta-estimator's :term:`fit`.
  1036. **kwargs : dict
  1037. Parameters explicitly accepted and included in the router's method
  1038. signature.
  1039. Returns
  1040. -------
  1041. routed_params : Bunch
  1042. A :class:`~utils.Bunch` of the form ``{"object_name": {"method_name":
  1043. {prop: value}}}`` which can be used to pass the required metadata to
  1044. corresponding methods or corresponding child objects. The object names
  1045. are those defined in `obj.get_metadata_routing()`.
  1046. """
  1047. if not hasattr(obj, "get_metadata_routing"):
  1048. raise AttributeError(
  1049. f"This {repr(obj.__class__.__name__)} has not implemented the routing"
  1050. " method `get_metadata_routing`."
  1051. )
  1052. if method not in METHODS:
  1053. raise TypeError(
  1054. f"Can only route and process input on these methods: {METHODS}, "
  1055. f"while the passed method is: {method}."
  1056. )
  1057. # We take the extra params (**fit_params) which is passed as `other_params`
  1058. # and add the explicitly passed parameters (passed as **kwargs) to it. This
  1059. # is equivalent to a code such as this in a router:
  1060. # if sample_weight is not None:
  1061. # fit_params["sample_weight"] = sample_weight
  1062. all_params = other_params if other_params is not None else dict()
  1063. all_params.update(kwargs)
  1064. request_routing = get_routing_for_object(obj)
  1065. request_routing.validate_metadata(params=all_params, method=method)
  1066. routed_params = request_routing.route_params(params=all_params, caller=method)
  1067. return routed_params