| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- # mypy: ignore-errors
- """
- Python polyfills for common builtins.
- """
- import math
- from typing import Any, Callable, Sequence
- import torch
- def all(iterator):
- for elem in iterator:
- if not elem:
- return False
- return True
- def any(iterator):
- for elem in iterator:
- if elem:
- return True
- return False
- def index(iterator, item, start=0, end=None):
- for i, elem in enumerate(list(iterator))[start:end]:
- if item == elem:
- return i
- # This will not run in dynamo
- raise ValueError(f"{item} is not in {type(iterator)}")
- def repeat(item, count):
- for i in range(count):
- yield item
- def radians(x):
- return math.pi / 180.0 * x
- def accumulate_grad(x, new_grad):
- new_grad = torch.clone(new_grad)
- if x.grad is None:
- x.grad = new_grad
- else:
- x.grad.add_(new_grad)
- def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequence[Any]):
- """emulate `(1,2,3) > (1,2)` etc"""
- for a, b in zip(left, right):
- if a != b:
- return op(a, b)
- return op(len(left), len(right))
- def set_isdisjoint(set1, set2):
- for x in set1:
- if x in set2:
- return False
- return True
- def dropwhile(predicate, iterable):
- # dropwhile(lambda x: x<5, [1,4,6,4,1]) -> 6 4 1
- iterable = iter(iterable)
- for x in iterable:
- if not predicate(x):
- yield x
- break
- yield from iterable
|