| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266 |
- # mypy: ignore-errors
- MAX_CYCLE = 3000
- import itertools
- import operator
- from typing import Dict, List, Optional
- from .. import polyfill, variables
- from ..exc import unimplemented
- from .base import MutableLocal, VariableTracker
- from .constant import ConstantVariable
- class ItertoolsVariable(VariableTracker):
- def __init__(self, value, **kwargs):
- super().__init__(**kwargs)
- self.value = value
- def __repr__(self):
- return f"ItertoolsVariable({self.value})"
- def python_type(self):
- return type(self.value)
- def as_python_constant(self):
- return self.value
- def call_function(
- self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
- ) -> "VariableTracker":
- if (
- self.value is itertools.product
- and not kwargs
- and all(arg.has_unpack_var_sequence(tx) for arg in args)
- ):
- seqs = [arg.unpack_var_sequence(tx) for arg in args]
- items = []
- for item in itertools.product(*seqs):
- items.append(variables.TupleVariable(list(item)))
- return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
- elif (
- self.value is itertools.chain
- and not kwargs
- and all(arg.has_unpack_var_sequence(tx) for arg in args)
- ):
- seqs = [arg.unpack_var_sequence(tx) for arg in args]
- items = list(itertools.chain.from_iterable(seqs))
- return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
- elif self.value is itertools.accumulate:
- from .builtin import BuiltinVariable
- if any(key not in ["initial", "func"] for key in kwargs.keys()):
- unimplemented(
- "Unsupported kwargs for itertools.accumulate: "
- f"{','.join(set(kwargs.keys()) - {'initial', 'func'})}"
- )
- acc = kwargs.get("initial")
- if len(args) in [1, 2] and args[0].has_unpack_var_sequence(tx):
- seq = args[0].unpack_var_sequence(tx)
- if "func" in kwargs and len(args) == 1:
- func = kwargs["func"].call_function
- elif len(args) == 2:
- func = args[1].call_function
- elif len(args) == 1:
- # Default to operator.add
- func = BuiltinVariable(operator.add).call_function
- else:
- unimplemented(
- "itertools.accumulate can only accept one of: `func` kwarg, pos 2 arg"
- )
- else:
- unimplemented("Unsupported arguments for itertools.accumulate")
- items = []
- if acc is not None:
- items.append(acc)
- for item in seq:
- if acc is None:
- acc = item
- else:
- try:
- acc = func(tx, [acc, item], {})
- except Exception as e:
- unimplemented(
- f"Unexpected failure in invoking function during accumulate. Failed running func {func}({item}{acc})",
- from_exc=e,
- )
- items.append(acc)
- return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
- elif (
- self.value is itertools.combinations
- and not kwargs
- and len(args) == 2
- and args[0].has_unpack_var_sequence(tx)
- and args[1].is_python_constant()
- ):
- iterable = args[0].unpack_var_sequence(tx)
- r = args[1].as_python_constant()
- items = []
- for item in itertools.combinations(iterable, r):
- items.append(variables.TupleVariable(list(item)))
- return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
- elif self.value is itertools.groupby:
- if any(kw != "key" for kw in kwargs.keys()):
- unimplemented(
- "Unsupported kwargs for itertools.groupby: "
- f"{','.join(set(kwargs.keys()) - {'key'})}"
- )
- def retrieve_const_key(key):
- if isinstance(key, variables.SymNodeVariable):
- return key.evaluate_expr()
- elif isinstance(key, variables.ConstantVariable):
- return key.as_python_constant()
- else:
- unimplemented(
- "Unsupported key type for itertools.groupby: " + str(type(key))
- )
- if len(args) == 1 and args[0].has_unpack_var_sequence(tx):
- seq = args[0].unpack_var_sequence(tx)
- keyfunc = (
- (
- lambda x: (
- retrieve_const_key(
- kwargs.get("key").call_function(tx, [x], {})
- )
- )
- )
- if "key" in kwargs
- else None
- )
- else:
- unimplemented("Unsupported arguments for itertools.groupby")
- result = []
- try:
- for k, v in itertools.groupby(seq, key=keyfunc):
- result.append(
- variables.TupleVariable(
- [
- variables.ConstantVariable.create(k)
- if variables.ConstantVariable.is_literal(k)
- else k,
- variables.ListIteratorVariable(
- list(v), mutable_local=MutableLocal()
- ),
- ],
- mutable_local=MutableLocal(),
- )
- )
- except Exception as e:
- unimplemented(
- "Unexpected failure when calling itertools.groupby",
- from_exc=e,
- )
- return variables.ListIteratorVariable(result, mutable_local=MutableLocal())
- elif self.value is itertools.repeat:
- if len(args) < 2:
- return variables.RepeatIteratorVariable(
- *args, mutable_local=MutableLocal()
- )
- from .builder import SourcelessBuilder
- return tx.inline_user_function_return(
- SourcelessBuilder.create(tx, polyfill.repeat), args, kwargs
- )
- elif self.value is itertools.count:
- return variables.CountIteratorVariable(*args, mutable_local=MutableLocal())
- elif self.value is itertools.cycle:
- return variables.CycleIteratorVariable(*args, mutable_local=MutableLocal())
- elif self.value is itertools.dropwhile:
- return variables.UserFunctionVariable(polyfill.dropwhile).call_function(
- tx, args, kwargs
- )
- else:
- return super().call_function(tx, args, kwargs)
- class IteratorVariable(VariableTracker):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- def next_variable(self, tx):
- unimplemented("abstract method, must implement")
- class RepeatIteratorVariable(IteratorVariable):
- def __init__(self, item: VariableTracker, **kwargs):
- super().__init__(**kwargs)
- self.item = item
- # Repeat needs no mutation, clone self
- def next_variable(self, tx):
- return self.item
- class CountIteratorVariable(IteratorVariable):
- def __init__(self, item: int = 0, step: int = 1, **kwargs):
- super().__init__(**kwargs)
- if not isinstance(item, VariableTracker):
- item = ConstantVariable.create(item)
- if not isinstance(step, VariableTracker):
- step = ConstantVariable.create(step)
- self.item = item
- self.step = step
- def next_variable(self, tx):
- assert self.mutable_local
- tx.output.side_effects.mutation(self)
- next_item = self.item.call_method(tx, "__add__", [self.step], {})
- self.item = next_item
- return self.item
- class CycleIteratorVariable(IteratorVariable):
- def __init__(
- self,
- iterator: IteratorVariable,
- saved: List[VariableTracker] = None,
- saved_index: int = 0,
- item: Optional[VariableTracker] = None,
- **kwargs,
- ):
- if saved is None:
- saved = []
- super().__init__(**kwargs)
- self.iterator = iterator
- self.saved = saved
- self.saved_index = saved_index
- self.item = item
- def next_variable(self, tx):
- assert self.mutable_local
- if self.iterator is not None:
- try:
- new_item = self.iterator.next_variable(tx)
- if len(self.saved) > MAX_CYCLE:
- unimplemented(
- "input iterator to itertools.cycle has too many items"
- )
- tx.output.side_effects.mutation(self)
- self.saved.append(new_item)
- self.item = new_item
- if self.item is None:
- return self.next_variable(tx)
- return self.item
- except StopIteration:
- self.iterator = None
- return self.next_variable(tx)
- elif len(self.saved) > 0:
- tx.output.side_effects.mutation(self)
- self.saved_index = (self.saved_index + 1) % len(self.saved)
- return self.item
- else:
- raise StopIteration
|