_funcs.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # mypy: ignore-errors
  2. import inspect
  3. import itertools
  4. from . import _funcs_impl, _reductions_impl
  5. from ._normalizations import normalizer
  6. # _funcs_impl.py contains functions which mimic NumPy's eponymous equivalents,
  7. # and consume/return PyTorch tensors/dtypes.
  8. # They are also type annotated.
  9. # Pull these functions from _funcs_impl and decorate them with @normalizer, which
  10. # - Converts any input `np.ndarray`, `torch._numpy.ndarray`, list of lists, Python scalars, etc into a `torch.Tensor`.
  11. # - Maps NumPy dtypes to PyTorch dtypes
  12. # - If the input to the `axis` kwarg is an ndarray, it maps it into a tuple
  13. # - Implements the semantics for the `out=` arg
  14. # - Wraps back the outputs into `torch._numpy.ndarrays`
  15. def _public_functions(mod):
  16. def is_public_function(f):
  17. return inspect.isfunction(f) and not f.__name__.startswith("_")
  18. return inspect.getmembers(mod, is_public_function)
  19. # We fill in __all__ in the loop below
  20. __all__ = []
  21. # decorate implementer functions with argument normalizers and export to the top namespace
  22. for name, func in itertools.chain(
  23. _public_functions(_funcs_impl), _public_functions(_reductions_impl)
  24. ):
  25. if name in ["percentile", "quantile", "median"]:
  26. decorated = normalizer(func, promote_scalar_result=True)
  27. elif name == "einsum":
  28. # normalized manually
  29. decorated = func
  30. else:
  31. decorated = normalizer(func)
  32. decorated.__qualname__ = name
  33. decorated.__name__ = name
  34. vars()[name] = decorated
  35. __all__.append(name)
  36. """
  37. Vendored objects from numpy.lib.index_tricks
  38. """
  39. class IndexExpression:
  40. """
  41. Written by Konrad Hinsen <hinsen@cnrs-orleans.fr>
  42. last revision: 1999-7-23
  43. Cosmetic changes by T. Oliphant 2001
  44. """
  45. def __init__(self, maketuple):
  46. self.maketuple = maketuple
  47. def __getitem__(self, item):
  48. if self.maketuple and not isinstance(item, tuple):
  49. return (item,)
  50. else:
  51. return item
  52. index_exp = IndexExpression(maketuple=True)
  53. s_ = IndexExpression(maketuple=False)
  54. __all__ += ["index_exp", "s_"]