| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856 |
- # Copyright 2022 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- Generic utilities
- """
- import inspect
- import tempfile
- import warnings
- from collections import OrderedDict, UserDict
- from collections.abc import MutableMapping
- from contextlib import ExitStack, contextmanager
- from dataclasses import fields, is_dataclass
- from enum import Enum
- from functools import partial, wraps
- from typing import Any, ContextManager, Iterable, List, Optional, Tuple
- import numpy as np
- from packaging import version
- from .import_utils import (
- get_torch_version,
- is_flax_available,
- is_mlx_available,
- is_tf_available,
- is_torch_available,
- is_torch_fx_proxy,
- )
- class cached_property(property):
- """
- Descriptor that mimics @property but caches output in member variable.
- From tensorflow_datasets
- Built-in in functools from Python 3.8.
- """
- def __get__(self, obj, objtype=None):
- # See docs.python.org/3/howto/descriptor.html#properties
- if obj is None:
- return self
- if self.fget is None:
- raise AttributeError("unreadable attribute")
- attr = "__cached_" + self.fget.__name__
- cached = getattr(obj, attr, None)
- if cached is None:
- cached = self.fget(obj)
- setattr(obj, attr, cached)
- return cached
- # vendored from distutils.util
- def strtobool(val):
- """Convert a string representation of truth to true (1) or false (0).
- True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'.
- Raises ValueError if 'val' is anything else.
- """
- val = val.lower()
- if val in {"y", "yes", "t", "true", "on", "1"}:
- return 1
- if val in {"n", "no", "f", "false", "off", "0"}:
- return 0
- raise ValueError(f"invalid truth value {val!r}")
- def infer_framework_from_repr(x):
- """
- Tries to guess the framework of an object `x` from its repr (brittle but will help in `is_tensor` to try the
- frameworks in a smart order, without the need to import the frameworks).
- """
- representation = str(type(x))
- if representation.startswith("<class 'torch."):
- return "pt"
- elif representation.startswith("<class 'tensorflow."):
- return "tf"
- elif representation.startswith("<class 'jax"):
- return "jax"
- elif representation.startswith("<class 'numpy."):
- return "np"
- elif representation.startswith("<class 'mlx."):
- return "mlx"
- def _get_frameworks_and_test_func(x):
- """
- Returns an (ordered since we are in Python 3.7+) dictionary framework to test function, which places the framework
- we can guess from the repr first, then Numpy, then the others.
- """
- framework_to_test = {
- "pt": is_torch_tensor,
- "tf": is_tf_tensor,
- "jax": is_jax_tensor,
- "np": is_numpy_array,
- "mlx": is_mlx_array,
- }
- preferred_framework = infer_framework_from_repr(x)
- # We will test this one first, then numpy, then the others.
- frameworks = [] if preferred_framework is None else [preferred_framework]
- if preferred_framework != "np":
- frameworks.append("np")
- frameworks.extend([f for f in framework_to_test if f not in [preferred_framework, "np"]])
- return {f: framework_to_test[f] for f in frameworks}
- def is_tensor(x):
- """
- Tests if `x` is a `torch.Tensor`, `tf.Tensor`, `jaxlib.xla_extension.DeviceArray`, `np.ndarray` or `mlx.array`
- in the order defined by `infer_framework_from_repr`
- """
- # This gives us a smart order to test the frameworks with the corresponding tests.
- framework_to_test_func = _get_frameworks_and_test_func(x)
- for test_func in framework_to_test_func.values():
- if test_func(x):
- return True
- # Tracers
- if is_torch_fx_proxy(x):
- return True
- if is_flax_available():
- from jax.core import Tracer
- if isinstance(x, Tracer):
- return True
- return False
- def _is_numpy(x):
- return isinstance(x, np.ndarray)
- def is_numpy_array(x):
- """
- Tests if `x` is a numpy array or not.
- """
- return _is_numpy(x)
- def _is_torch(x):
- import torch
- return isinstance(x, torch.Tensor)
- def is_torch_tensor(x):
- """
- Tests if `x` is a torch tensor or not. Safe to call even if torch is not installed.
- """
- return False if not is_torch_available() else _is_torch(x)
- def _is_torch_device(x):
- import torch
- return isinstance(x, torch.device)
- def is_torch_device(x):
- """
- Tests if `x` is a torch device or not. Safe to call even if torch is not installed.
- """
- return False if not is_torch_available() else _is_torch_device(x)
- def _is_torch_dtype(x):
- import torch
- if isinstance(x, str):
- if hasattr(torch, x):
- x = getattr(torch, x)
- else:
- return False
- return isinstance(x, torch.dtype)
- def is_torch_dtype(x):
- """
- Tests if `x` is a torch dtype or not. Safe to call even if torch is not installed.
- """
- return False if not is_torch_available() else _is_torch_dtype(x)
- def _is_tensorflow(x):
- import tensorflow as tf
- return isinstance(x, tf.Tensor)
- def is_tf_tensor(x):
- """
- Tests if `x` is a tensorflow tensor or not. Safe to call even if tensorflow is not installed.
- """
- return False if not is_tf_available() else _is_tensorflow(x)
- def _is_tf_symbolic_tensor(x):
- import tensorflow as tf
- # the `is_symbolic_tensor` predicate is only available starting with TF 2.14
- if hasattr(tf, "is_symbolic_tensor"):
- return tf.is_symbolic_tensor(x)
- return isinstance(x, tf.Tensor)
- def is_tf_symbolic_tensor(x):
- """
- Tests if `x` is a tensorflow symbolic tensor or not (ie. not eager). Safe to call even if tensorflow is not
- installed.
- """
- return False if not is_tf_available() else _is_tf_symbolic_tensor(x)
- def _is_jax(x):
- import jax.numpy as jnp # noqa: F811
- return isinstance(x, jnp.ndarray)
- def is_jax_tensor(x):
- """
- Tests if `x` is a Jax tensor or not. Safe to call even if jax is not installed.
- """
- return False if not is_flax_available() else _is_jax(x)
- def _is_mlx(x):
- import mlx.core as mx
- return isinstance(x, mx.array)
- def is_mlx_array(x):
- """
- Tests if `x` is a mlx array or not. Safe to call even when mlx is not installed.
- """
- return False if not is_mlx_available() else _is_mlx(x)
- def to_py_obj(obj):
- """
- Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
- """
- framework_to_py_obj = {
- "pt": lambda obj: obj.detach().cpu().tolist(),
- "tf": lambda obj: obj.numpy().tolist(),
- "jax": lambda obj: np.asarray(obj).tolist(),
- "np": lambda obj: obj.tolist(),
- }
- if isinstance(obj, (dict, UserDict)):
- return {k: to_py_obj(v) for k, v in obj.items()}
- elif isinstance(obj, (list, tuple)):
- return [to_py_obj(o) for o in obj]
- # This gives us a smart order to test the frameworks with the corresponding tests.
- framework_to_test_func = _get_frameworks_and_test_func(obj)
- for framework, test_func in framework_to_test_func.items():
- if test_func(obj):
- return framework_to_py_obj[framework](obj)
- # tolist also works on 0d np arrays
- if isinstance(obj, np.number):
- return obj.tolist()
- else:
- return obj
- def to_numpy(obj):
- """
- Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a Numpy array.
- """
- framework_to_numpy = {
- "pt": lambda obj: obj.detach().cpu().numpy(),
- "tf": lambda obj: obj.numpy(),
- "jax": lambda obj: np.asarray(obj),
- "np": lambda obj: obj,
- }
- if isinstance(obj, (dict, UserDict)):
- return {k: to_numpy(v) for k, v in obj.items()}
- elif isinstance(obj, (list, tuple)):
- return np.array(obj)
- # This gives us a smart order to test the frameworks with the corresponding tests.
- framework_to_test_func = _get_frameworks_and_test_func(obj)
- for framework, test_func in framework_to_test_func.items():
- if test_func(obj):
- return framework_to_numpy[framework](obj)
- return obj
- class ModelOutput(OrderedDict):
- """
- Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
- tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
- python dictionary.
- <Tip warning={true}>
- You can't unpack a `ModelOutput` directly. Use the [`~utils.ModelOutput.to_tuple`] method to convert it to a tuple
- before.
- </Tip>
- """
- def __init_subclass__(cls) -> None:
- """Register subclasses as pytree nodes.
- This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with
- `static_graph=True` with modules that output `ModelOutput` subclasses.
- """
- if is_torch_available():
- if version.parse(get_torch_version()) >= version.parse("2.2"):
- _torch_pytree.register_pytree_node(
- cls,
- _model_output_flatten,
- partial(_model_output_unflatten, output_type=cls),
- serialized_type_name=f"{cls.__module__}.{cls.__name__}",
- )
- else:
- _torch_pytree._register_pytree_node(
- cls,
- _model_output_flatten,
- partial(_model_output_unflatten, output_type=cls),
- )
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- # Subclasses of ModelOutput must use the @dataclass decorator
- # This check is done in __init__ because the @dataclass decorator operates after __init_subclass__
- # issubclass() would return True for issubclass(ModelOutput, ModelOutput) when False is needed
- # Just need to check that the current class is not ModelOutput
- is_modeloutput_subclass = self.__class__ != ModelOutput
- if is_modeloutput_subclass and not is_dataclass(self):
- raise TypeError(
- f"{self.__module__}.{self.__class__.__name__} is not a dataclasss."
- " This is a subclass of ModelOutput and so must use the @dataclass decorator."
- )
- def __post_init__(self):
- """Check the ModelOutput dataclass.
- Only occurs if @dataclass decorator has been used.
- """
- class_fields = fields(self)
- # Safety and consistency checks
- if not len(class_fields):
- raise ValueError(f"{self.__class__.__name__} has no fields.")
- if not all(field.default is None for field in class_fields[1:]):
- raise ValueError(f"{self.__class__.__name__} should not have more than one required field.")
- first_field = getattr(self, class_fields[0].name)
- other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
- if other_fields_are_none and not is_tensor(first_field):
- if isinstance(first_field, dict):
- iterator = first_field.items()
- first_field_iterator = True
- else:
- try:
- iterator = iter(first_field)
- first_field_iterator = True
- except TypeError:
- first_field_iterator = False
- # if we provided an iterator as first field and the iterator is a (key, value) iterator
- # set the associated fields
- if first_field_iterator:
- for idx, element in enumerate(iterator):
- if (
- not isinstance(element, (list, tuple))
- or not len(element) == 2
- or not isinstance(element[0], str)
- ):
- if idx == 0:
- # If we do not have an iterator of key/values, set it as attribute
- self[class_fields[0].name] = first_field
- else:
- # If we have a mixed iterator, raise an error
- raise ValueError(
- f"Cannot set key/value for {element}. It needs to be a tuple (key, value)."
- )
- break
- setattr(self, element[0], element[1])
- if element[1] is not None:
- self[element[0]] = element[1]
- elif first_field is not None:
- self[class_fields[0].name] = first_field
- else:
- for field in class_fields:
- v = getattr(self, field.name)
- if v is not None:
- self[field.name] = v
- def __delitem__(self, *args, **kwargs):
- raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
- def setdefault(self, *args, **kwargs):
- raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
- def pop(self, *args, **kwargs):
- raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
- def update(self, *args, **kwargs):
- raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
- def __getitem__(self, k):
- if isinstance(k, str):
- inner_dict = dict(self.items())
- return inner_dict[k]
- else:
- return self.to_tuple()[k]
- def __setattr__(self, name, value):
- if name in self.keys() and value is not None:
- # Don't call self.__setitem__ to avoid recursion errors
- super().__setitem__(name, value)
- super().__setattr__(name, value)
- def __setitem__(self, key, value):
- # Will raise a KeyException if needed
- super().__setitem__(key, value)
- # Don't call self.__setattr__ to avoid recursion errors
- super().__setattr__(key, value)
- def __reduce__(self):
- if not is_dataclass(self):
- return super().__reduce__()
- callable, _args, *remaining = super().__reduce__()
- args = tuple(getattr(self, field.name) for field in fields(self))
- return callable, args, *remaining
- def to_tuple(self) -> Tuple[Any]:
- """
- Convert self to a tuple containing all the attributes/keys that are not `None`.
- """
- return tuple(self[k] for k in self.keys())
- if is_torch_available():
- import torch.utils._pytree as _torch_pytree
- def _model_output_flatten(output: ModelOutput) -> Tuple[List[Any], "_torch_pytree.Context"]:
- return list(output.values()), list(output.keys())
- def _model_output_unflatten(
- values: Iterable[Any],
- context: "_torch_pytree.Context",
- output_type=None,
- ) -> ModelOutput:
- return output_type(**dict(zip(context, values)))
- if version.parse(get_torch_version()) >= version.parse("2.2"):
- _torch_pytree.register_pytree_node(
- ModelOutput,
- _model_output_flatten,
- partial(_model_output_unflatten, output_type=ModelOutput),
- serialized_type_name=f"{ModelOutput.__module__}.{ModelOutput.__name__}",
- )
- else:
- _torch_pytree._register_pytree_node(
- ModelOutput,
- _model_output_flatten,
- partial(_model_output_unflatten, output_type=ModelOutput),
- )
- class ExplicitEnum(str, Enum):
- """
- Enum with more explicit error message for missing values.
- """
- @classmethod
- def _missing_(cls, value):
- raise ValueError(
- f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
- )
- class PaddingStrategy(ExplicitEnum):
- """
- Possible values for the `padding` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in an
- IDE.
- """
- LONGEST = "longest"
- MAX_LENGTH = "max_length"
- DO_NOT_PAD = "do_not_pad"
- class TensorType(ExplicitEnum):
- """
- Possible values for the `return_tensors` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for
- tab-completion in an IDE.
- """
- PYTORCH = "pt"
- TENSORFLOW = "tf"
- NUMPY = "np"
- JAX = "jax"
- MLX = "mlx"
- class ContextManagers:
- """
- Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers`
- in the `fastcore` library.
- """
- def __init__(self, context_managers: List[ContextManager]):
- self.context_managers = context_managers
- self.stack = ExitStack()
- def __enter__(self):
- for context_manager in self.context_managers:
- self.stack.enter_context(context_manager)
- def __exit__(self, *args, **kwargs):
- self.stack.__exit__(*args, **kwargs)
- def can_return_loss(model_class):
- """
- Check if a given model can return loss.
- Args:
- model_class (`type`): The class of the model.
- """
- framework = infer_framework(model_class)
- if framework == "tf":
- signature = inspect.signature(model_class.call) # TensorFlow models
- elif framework == "pt":
- signature = inspect.signature(model_class.forward) # PyTorch models
- else:
- signature = inspect.signature(model_class.__call__) # Flax models
- for p in signature.parameters:
- if p == "return_loss" and signature.parameters[p].default is True:
- return True
- return False
- def find_labels(model_class):
- """
- Find the labels used by a given model.
- Args:
- model_class (`type`): The class of the model.
- """
- model_name = model_class.__name__
- framework = infer_framework(model_class)
- if framework == "tf":
- signature = inspect.signature(model_class.call) # TensorFlow models
- elif framework == "pt":
- signature = inspect.signature(model_class.forward) # PyTorch models
- else:
- signature = inspect.signature(model_class.__call__) # Flax models
- if "QuestionAnswering" in model_name:
- return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")]
- else:
- return [p for p in signature.parameters if "label" in p]
- def flatten_dict(d: MutableMapping, parent_key: str = "", delimiter: str = "."):
- """Flatten a nested dict into a single level dict."""
- def _flatten_dict(d, parent_key="", delimiter="."):
- for k, v in d.items():
- key = str(parent_key) + delimiter + str(k) if parent_key else k
- if v and isinstance(v, MutableMapping):
- yield from flatten_dict(v, key, delimiter=delimiter).items()
- else:
- yield key, v
- return dict(_flatten_dict(d, parent_key, delimiter))
- @contextmanager
- def working_or_temp_dir(working_dir, use_temp_dir: bool = False):
- if use_temp_dir:
- with tempfile.TemporaryDirectory() as tmp_dir:
- yield tmp_dir
- else:
- yield working_dir
- def transpose(array, axes=None):
- """
- Framework-agnostic version of `numpy.transpose` that will work on torch/TensorFlow/Jax tensors as well as NumPy
- arrays.
- """
- if is_numpy_array(array):
- return np.transpose(array, axes=axes)
- elif is_torch_tensor(array):
- return array.T if axes is None else array.permute(*axes)
- elif is_tf_tensor(array):
- import tensorflow as tf
- return tf.transpose(array, perm=axes)
- elif is_jax_tensor(array):
- import jax.numpy as jnp
- return jnp.transpose(array, axes=axes)
- else:
- raise ValueError(f"Type not supported for transpose: {type(array)}.")
- def reshape(array, newshape):
- """
- Framework-agnostic version of `numpy.reshape` that will work on torch/TensorFlow/Jax tensors as well as NumPy
- arrays.
- """
- if is_numpy_array(array):
- return np.reshape(array, newshape)
- elif is_torch_tensor(array):
- return array.reshape(*newshape)
- elif is_tf_tensor(array):
- import tensorflow as tf
- return tf.reshape(array, newshape)
- elif is_jax_tensor(array):
- import jax.numpy as jnp
- return jnp.reshape(array, newshape)
- else:
- raise ValueError(f"Type not supported for reshape: {type(array)}.")
- def squeeze(array, axis=None):
- """
- Framework-agnostic version of `numpy.squeeze` that will work on torch/TensorFlow/Jax tensors as well as NumPy
- arrays.
- """
- if is_numpy_array(array):
- return np.squeeze(array, axis=axis)
- elif is_torch_tensor(array):
- return array.squeeze() if axis is None else array.squeeze(dim=axis)
- elif is_tf_tensor(array):
- import tensorflow as tf
- return tf.squeeze(array, axis=axis)
- elif is_jax_tensor(array):
- import jax.numpy as jnp
- return jnp.squeeze(array, axis=axis)
- else:
- raise ValueError(f"Type not supported for squeeze: {type(array)}.")
- def expand_dims(array, axis):
- """
- Framework-agnostic version of `numpy.expand_dims` that will work on torch/TensorFlow/Jax tensors as well as NumPy
- arrays.
- """
- if is_numpy_array(array):
- return np.expand_dims(array, axis)
- elif is_torch_tensor(array):
- return array.unsqueeze(dim=axis)
- elif is_tf_tensor(array):
- import tensorflow as tf
- return tf.expand_dims(array, axis=axis)
- elif is_jax_tensor(array):
- import jax.numpy as jnp
- return jnp.expand_dims(array, axis=axis)
- else:
- raise ValueError(f"Type not supported for expand_dims: {type(array)}.")
- def tensor_size(array):
- """
- Framework-agnostic version of `numpy.size` that will work on torch/TensorFlow/Jax tensors as well as NumPy arrays.
- """
- if is_numpy_array(array):
- return np.size(array)
- elif is_torch_tensor(array):
- return array.numel()
- elif is_tf_tensor(array):
- import tensorflow as tf
- return tf.size(array)
- elif is_jax_tensor(array):
- return array.size
- else:
- raise ValueError(f"Type not supported for tensor_size: {type(array)}.")
- def add_model_info_to_auto_map(auto_map, repo_id):
- """
- Adds the information of the repo_id to a given auto map.
- """
- for key, value in auto_map.items():
- if isinstance(value, (tuple, list)):
- auto_map[key] = [f"{repo_id}--{v}" if (v is not None and "--" not in v) else v for v in value]
- elif value is not None and "--" not in value:
- auto_map[key] = f"{repo_id}--{value}"
- return auto_map
- def add_model_info_to_custom_pipelines(custom_pipeline, repo_id):
- """
- Adds the information of the repo_id to a given custom pipeline.
- """
- # {custom_pipelines : {task: {"impl": "path.to.task"},...} }
- for task in custom_pipeline.keys():
- if "impl" in custom_pipeline[task]:
- module = custom_pipeline[task]["impl"]
- if "--" not in module:
- custom_pipeline[task]["impl"] = f"{repo_id}--{module}"
- return custom_pipeline
- def infer_framework(model_class):
- """
- Infers the framework of a given model without using isinstance(), because we cannot guarantee that the relevant
- classes are imported or available.
- """
- for base_class in inspect.getmro(model_class):
- module = base_class.__module__
- name = base_class.__name__
- if module.startswith("tensorflow") or module.startswith("keras") or name == "TFPreTrainedModel":
- return "tf"
- elif module.startswith("torch") or name == "PreTrainedModel":
- return "pt"
- elif module.startswith("flax") or module.startswith("jax") or name == "FlaxPreTrainedModel":
- return "flax"
- else:
- raise TypeError(f"Could not infer framework from class {model_class}.")
- def torch_int(x):
- """
- Casts an input to a torch int64 tensor if we are in a tracing context, otherwise to a Python int.
- """
- if not is_torch_available():
- return int(x)
- import torch
- return x.to(torch.int64) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
- def torch_float(x):
- """
- Casts an input to a torch float32 tensor if we are in a tracing context, otherwise to a Python float.
- """
- if not is_torch_available():
- return int(x)
- import torch
- return x.to(torch.float32) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
- def filter_out_non_signature_kwargs(extra: Optional[list] = None):
- """
- Decorator to filter out named arguments that are not in the function signature.
- This decorator ensures that only the keyword arguments that match the function's signature, or are specified in the
- `extra` list, are passed to the function. Any additional keyword arguments are filtered out and a warning is issued.
- Parameters:
- extra (`Optional[list]`, *optional*):
- A list of extra keyword argument names that are allowed even if they are not in the function's signature.
- Returns:
- Callable:
- A decorator that wraps the function and filters out invalid keyword arguments.
- Example usage:
- ```python
- @filter_out_non_signature_kwargs(extra=["allowed_extra_arg"])
- def my_function(arg1, arg2, **kwargs):
- print(arg1, arg2, kwargs)
- my_function(arg1=1, arg2=2, allowed_extra_arg=3, invalid_arg=4)
- # This will print: 1 2 {"allowed_extra_arg": 3}
- # And issue a warning: "The following named arguments are not valid for `my_function` and were ignored: 'invalid_arg'"
- ```
- """
- extra = extra or []
- extra_params_to_pass = set(extra)
- def decorator(func):
- sig = inspect.signature(func)
- function_named_args = set(sig.parameters.keys())
- valid_kwargs_to_pass = function_named_args.union(extra_params_to_pass)
- # Required for better warning message
- is_instance_method = "self" in function_named_args
- is_class_method = "cls" in function_named_args
- # Mark function as decorated
- func._filter_out_non_signature_kwargs = True
- @wraps(func)
- def wrapper(*args, **kwargs):
- valid_kwargs = {}
- invalid_kwargs = {}
- for k, v in kwargs.items():
- if k in valid_kwargs_to_pass:
- valid_kwargs[k] = v
- else:
- invalid_kwargs[k] = v
- if invalid_kwargs:
- invalid_kwargs_names = [f"'{k}'" for k in invalid_kwargs.keys()]
- invalid_kwargs_names = ", ".join(invalid_kwargs_names)
- # Get the class name for better warning message
- if is_instance_method:
- cls_prefix = args[0].__class__.__name__ + "."
- elif is_class_method:
- cls_prefix = args[0].__name__ + "."
- else:
- cls_prefix = ""
- warnings.warn(
- f"The following named arguments are not valid for `{cls_prefix}{func.__name__}`"
- f" and were ignored: {invalid_kwargs_names}",
- UserWarning,
- stacklevel=2,
- )
- return func(*args, **valid_kwargs)
- return wrapper
- return decorator
|