polyfill.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # mypy: ignore-errors
  2. """
  3. Python polyfills for common builtins.
  4. """
  5. import math
  6. from typing import Any, Callable, Sequence
  7. import torch
  8. def all(iterator):
  9. for elem in iterator:
  10. if not elem:
  11. return False
  12. return True
  13. def any(iterator):
  14. for elem in iterator:
  15. if elem:
  16. return True
  17. return False
  18. def index(iterator, item, start=0, end=None):
  19. for i, elem in enumerate(list(iterator))[start:end]:
  20. if item == elem:
  21. return i
  22. # This will not run in dynamo
  23. raise ValueError(f"{item} is not in {type(iterator)}")
  24. def repeat(item, count):
  25. for i in range(count):
  26. yield item
  27. def radians(x):
  28. return math.pi / 180.0 * x
  29. def accumulate_grad(x, new_grad):
  30. new_grad = torch.clone(new_grad)
  31. if x.grad is None:
  32. x.grad = new_grad
  33. else:
  34. x.grad.add_(new_grad)
  35. def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequence[Any]):
  36. """emulate `(1,2,3) > (1,2)` etc"""
  37. for a, b in zip(left, right):
  38. if a != b:
  39. return op(a, b)
  40. return op(len(left), len(right))
  41. def set_isdisjoint(set1, set2):
  42. for x in set1:
  43. if x in set2:
  44. return False
  45. return True
  46. def dropwhile(predicate, iterable):
  47. # dropwhile(lambda x: x<5, [1,4,6,4,1]) -> 6 4 1
  48. iterable = iter(iterable)
  49. for x in iterable:
  50. if not predicate(x):
  51. yield x
  52. break
  53. yield from iterable