| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914 |
- #!/usr/bin/env python
- # coding=utf-8
- # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import ast
- import builtins
- import difflib
- from collections.abc import Mapping
- from importlib import import_module
- from typing import Any, Callable, Dict, List, Optional
- import numpy as np
- from ..utils import is_pandas_available
- if is_pandas_available():
- import pandas as pd
- class InterpreterError(ValueError):
- """
- An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported
- operations.
- """
- pass
- ERRORS = {
- name: getattr(builtins, name)
- for name in dir(builtins)
- if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException)
- }
- LIST_SAFE_MODULES = [
- "random",
- "collections",
- "math",
- "time",
- "queue",
- "itertools",
- "re",
- "stat",
- "statistics",
- "unicodedata",
- ]
- PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000
- OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000
- class BreakException(Exception):
- pass
- class ContinueException(Exception):
- pass
- class ReturnException(Exception):
- def __init__(self, value):
- self.value = value
- def get_iterable(obj):
- if isinstance(obj, list):
- return obj
- elif hasattr(obj, "__iter__"):
- return list(obj)
- else:
- raise InterpreterError("Object is not iterable")
- def evaluate_unaryop(expression, state, static_tools, custom_tools):
- operand = evaluate_ast(expression.operand, state, static_tools, custom_tools)
- if isinstance(expression.op, ast.USub):
- return -operand
- elif isinstance(expression.op, ast.UAdd):
- return operand
- elif isinstance(expression.op, ast.Not):
- return not operand
- elif isinstance(expression.op, ast.Invert):
- return ~operand
- else:
- raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
- def evaluate_lambda(lambda_expression, state, static_tools, custom_tools):
- args = [arg.arg for arg in lambda_expression.args.args]
- def lambda_func(*values):
- new_state = state.copy()
- for arg, value in zip(args, values):
- new_state[arg] = value
- return evaluate_ast(lambda_expression.body, new_state, static_tools, custom_tools)
- return lambda_func
- def evaluate_while(while_loop, state, static_tools, custom_tools):
- max_iterations = 1000
- iterations = 0
- while evaluate_ast(while_loop.test, state, static_tools, custom_tools):
- for node in while_loop.body:
- try:
- evaluate_ast(node, state, static_tools, custom_tools)
- except BreakException:
- return None
- except ContinueException:
- break
- iterations += 1
- if iterations > max_iterations:
- raise InterpreterError(f"Maximum number of {max_iterations} iterations in While loop exceeded")
- return None
- def create_function(func_def, state, static_tools, custom_tools):
- def new_func(*args, **kwargs):
- func_state = state.copy()
- arg_names = [arg.arg for arg in func_def.args.args]
- default_values = [evaluate_ast(d, state, static_tools, custom_tools) for d in func_def.args.defaults]
- # Apply default values
- defaults = dict(zip(arg_names[-len(default_values) :], default_values))
- # Set positional arguments
- for name, value in zip(arg_names, args):
- func_state[name] = value
- # # Set keyword arguments
- for name, value in kwargs.items():
- func_state[name] = value
- # Handle variable arguments
- if func_def.args.vararg:
- vararg_name = func_def.args.vararg.arg
- func_state[vararg_name] = args
- if func_def.args.kwarg:
- kwarg_name = func_def.args.kwarg.arg
- func_state[kwarg_name] = kwargs
- # Set default values for arguments that were not provided
- for name, value in defaults.items():
- if name not in func_state:
- func_state[name] = value
- # Update function state with self and __class__
- if func_def.args.args and func_def.args.args[0].arg == "self":
- if args:
- func_state["self"] = args[0]
- func_state["__class__"] = args[0].__class__
- result = None
- try:
- for stmt in func_def.body:
- result = evaluate_ast(stmt, func_state, static_tools, custom_tools)
- except ReturnException as e:
- result = e.value
- return result
- return new_func
- def create_class(class_name, class_bases, class_body):
- class_dict = {}
- for key, value in class_body.items():
- class_dict[key] = value
- return type(class_name, tuple(class_bases), class_dict)
- def evaluate_function_def(func_def, state, static_tools, custom_tools):
- custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools)
- return custom_tools[func_def.name]
- def evaluate_class_def(class_def, state, static_tools, custom_tools):
- class_name = class_def.name
- bases = [evaluate_ast(base, state, static_tools, custom_tools) for base in class_def.bases]
- class_dict = {}
- for stmt in class_def.body:
- if isinstance(stmt, ast.FunctionDef):
- class_dict[stmt.name] = evaluate_function_def(stmt, state, static_tools, custom_tools)
- elif isinstance(stmt, ast.Assign):
- for target in stmt.targets:
- if isinstance(target, ast.Name):
- class_dict[target.id] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
- elif isinstance(target, ast.Attribute):
- class_dict[target.attr] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
- else:
- raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
- new_class = type(class_name, tuple(bases), class_dict)
- state[class_name] = new_class
- return new_class
- def evaluate_augassign(expression, state, static_tools, custom_tools):
- # Helper function to get current value and set new value based on the target type
- def get_current_value(target):
- if isinstance(target, ast.Name):
- return state.get(target.id, 0)
- elif isinstance(target, ast.Subscript):
- obj = evaluate_ast(target.value, state, static_tools, custom_tools)
- key = evaluate_ast(target.slice, state, static_tools, custom_tools)
- return obj[key]
- elif isinstance(target, ast.Attribute):
- obj = evaluate_ast(target.value, state, static_tools, custom_tools)
- return getattr(obj, target.attr)
- elif isinstance(target, ast.Tuple):
- return tuple(get_current_value(elt) for elt in target.elts)
- elif isinstance(target, ast.List):
- return [get_current_value(elt) for elt in target.elts]
- else:
- raise InterpreterError("AugAssign not supported for {type(target)} targets.")
- current_value = get_current_value(expression.target)
- value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools)
- # Determine the operation and apply it
- if isinstance(expression.op, ast.Add):
- if isinstance(current_value, list):
- if not isinstance(value_to_add, list):
- raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.")
- updated_value = current_value + value_to_add
- else:
- updated_value = current_value + value_to_add
- elif isinstance(expression.op, ast.Sub):
- updated_value = current_value - value_to_add
- elif isinstance(expression.op, ast.Mult):
- updated_value = current_value * value_to_add
- elif isinstance(expression.op, ast.Div):
- updated_value = current_value / value_to_add
- elif isinstance(expression.op, ast.Mod):
- updated_value = current_value % value_to_add
- elif isinstance(expression.op, ast.Pow):
- updated_value = current_value**value_to_add
- elif isinstance(expression.op, ast.FloorDiv):
- updated_value = current_value // value_to_add
- elif isinstance(expression.op, ast.BitAnd):
- updated_value = current_value & value_to_add
- elif isinstance(expression.op, ast.BitOr):
- updated_value = current_value | value_to_add
- elif isinstance(expression.op, ast.BitXor):
- updated_value = current_value ^ value_to_add
- elif isinstance(expression.op, ast.LShift):
- updated_value = current_value << value_to_add
- elif isinstance(expression.op, ast.RShift):
- updated_value = current_value >> value_to_add
- else:
- raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
- # Update the state
- set_value(expression.target, updated_value, state, static_tools, custom_tools)
- return updated_value
- def evaluate_boolop(node, state, static_tools, custom_tools):
- if isinstance(node.op, ast.And):
- for value in node.values:
- if not evaluate_ast(value, state, static_tools, custom_tools):
- return False
- return True
- elif isinstance(node.op, ast.Or):
- for value in node.values:
- if evaluate_ast(value, state, static_tools, custom_tools):
- return True
- return False
- def evaluate_binop(binop, state, static_tools, custom_tools):
- # Recursively evaluate the left and right operands
- left_val = evaluate_ast(binop.left, state, static_tools, custom_tools)
- right_val = evaluate_ast(binop.right, state, static_tools, custom_tools)
- # Determine the operation based on the type of the operator in the BinOp
- if isinstance(binop.op, ast.Add):
- return left_val + right_val
- elif isinstance(binop.op, ast.Sub):
- return left_val - right_val
- elif isinstance(binop.op, ast.Mult):
- return left_val * right_val
- elif isinstance(binop.op, ast.Div):
- return left_val / right_val
- elif isinstance(binop.op, ast.Mod):
- return left_val % right_val
- elif isinstance(binop.op, ast.Pow):
- return left_val**right_val
- elif isinstance(binop.op, ast.FloorDiv):
- return left_val // right_val
- elif isinstance(binop.op, ast.BitAnd):
- return left_val & right_val
- elif isinstance(binop.op, ast.BitOr):
- return left_val | right_val
- elif isinstance(binop.op, ast.BitXor):
- return left_val ^ right_val
- elif isinstance(binop.op, ast.LShift):
- return left_val << right_val
- elif isinstance(binop.op, ast.RShift):
- return left_val >> right_val
- else:
- raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.")
- def evaluate_assign(assign, state, static_tools, custom_tools):
- result = evaluate_ast(assign.value, state, static_tools, custom_tools)
- if len(assign.targets) == 1:
- target = assign.targets[0]
- set_value(target, result, state, static_tools, custom_tools)
- else:
- if len(assign.targets) != len(result):
- raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.")
- expanded_values = []
- for tgt in assign.targets:
- if isinstance(tgt, ast.Starred):
- expanded_values.extend(result)
- else:
- expanded_values.append(result)
- for tgt, val in zip(assign.targets, expanded_values):
- set_value(tgt, val, state, static_tools, custom_tools)
- return result
- def set_value(target, value, state, static_tools, custom_tools):
- if isinstance(target, ast.Name):
- if target.id in static_tools:
- raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!")
- state[target.id] = value
- elif isinstance(target, ast.Tuple):
- if not isinstance(value, tuple):
- if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
- value = tuple(value)
- else:
- raise InterpreterError("Cannot unpack non-tuple value")
- if len(target.elts) != len(value):
- raise InterpreterError("Cannot unpack tuple of wrong size")
- for i, elem in enumerate(target.elts):
- set_value(elem, value[i], state, static_tools, custom_tools)
- elif isinstance(target, ast.Subscript):
- obj = evaluate_ast(target.value, state, static_tools, custom_tools)
- key = evaluate_ast(target.slice, state, static_tools, custom_tools)
- obj[key] = value
- elif isinstance(target, ast.Attribute):
- obj = evaluate_ast(target.value, state, static_tools, custom_tools)
- setattr(obj, target.attr, value)
- def evaluate_call(call, state, static_tools, custom_tools):
- if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)):
- raise InterpreterError(f"This is not a correct function: {call.func}).")
- if isinstance(call.func, ast.Attribute):
- obj = evaluate_ast(call.func.value, state, static_tools, custom_tools)
- func_name = call.func.attr
- if not hasattr(obj, func_name):
- raise InterpreterError(f"Object {obj} has no attribute {func_name}")
- func = getattr(obj, func_name)
- elif isinstance(call.func, ast.Name):
- func_name = call.func.id
- if func_name in state:
- func = state[func_name]
- elif func_name in static_tools:
- func = static_tools[func_name]
- elif func_name in custom_tools:
- func = custom_tools[func_name]
- elif func_name in ERRORS:
- func = ERRORS[func_name]
- else:
- raise InterpreterError(
- f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})."
- )
- args = []
- for arg in call.args:
- if isinstance(arg, ast.Starred):
- args.extend(evaluate_ast(arg.value, state, static_tools, custom_tools))
- else:
- args.append(evaluate_ast(arg, state, static_tools, custom_tools))
- args = []
- for arg in call.args:
- if isinstance(arg, ast.Starred):
- unpacked = evaluate_ast(arg.value, state, static_tools, custom_tools)
- if not hasattr(unpacked, "__iter__") or isinstance(unpacked, (str, bytes)):
- raise InterpreterError(f"Cannot unpack non-iterable value {unpacked}")
- args.extend(unpacked)
- else:
- args.append(evaluate_ast(arg, state, static_tools, custom_tools))
- kwargs = {keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools) for keyword in call.keywords}
- if isinstance(func, type) and len(func.__module__.split(".")) > 1: # Check for user-defined classes
- # Instantiate the class using its constructor
- obj = func.__new__(func) # Create a new instance of the class
- if hasattr(obj, "__init__"): # Check if the class has an __init__ method
- obj.__init__(*args, **kwargs) # Call the __init__ method correctly
- return obj
- else:
- if func_name == "super":
- if not args:
- if "__class__" in state and "self" in state:
- return super(state["__class__"], state["self"])
- else:
- raise InterpreterError("super() needs at least one argument")
- cls = args[0]
- if not isinstance(cls, type):
- raise InterpreterError("super() argument 1 must be type")
- if len(args) == 1:
- return super(cls)
- elif len(args) == 2:
- instance = args[1]
- return super(cls, instance)
- else:
- raise InterpreterError("super() takes at most 2 arguments")
- else:
- if func_name == "print":
- output = " ".join(map(str, args))
- global PRINT_OUTPUTS
- PRINT_OUTPUTS += output + "\n"
- # cap the number of lines
- return None
- else: # Assume it's a callable object
- output = func(*args, **kwargs)
- return output
- def evaluate_subscript(subscript, state, static_tools, custom_tools):
- index = evaluate_ast(subscript.slice, state, static_tools, custom_tools)
- value = evaluate_ast(subscript.value, state, static_tools, custom_tools)
- if isinstance(value, str) and isinstance(index, str):
- raise InterpreterError("You're trying to subscript a string with a string index, which is impossible")
- if isinstance(value, pd.core.indexing._LocIndexer):
- parent_object = value.obj
- return parent_object.loc[index]
- if isinstance(value, (pd.DataFrame, pd.Series, np.ndarray)):
- return value[index]
- elif isinstance(value, pd.core.groupby.generic.DataFrameGroupBy):
- return value[index]
- elif isinstance(index, slice):
- return value[index]
- elif isinstance(value, (list, tuple)):
- if not (-len(value) <= index < len(value)):
- raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}")
- return value[int(index)]
- elif isinstance(value, str):
- if not (-len(value) <= index < len(value)):
- raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}")
- return value[index]
- elif index in value:
- return value[index]
- elif isinstance(index, str) and isinstance(value, Mapping):
- close_matches = difflib.get_close_matches(index, list(value.keys()))
- if len(close_matches) > 0:
- return value[close_matches[0]]
- raise InterpreterError(f"Could not index {value} with '{index}'.")
- def evaluate_name(name, state, static_tools, custom_tools):
- if name.id in state:
- return state[name.id]
- elif name.id in static_tools:
- return static_tools[name.id]
- elif name.id in ERRORS:
- return ERRORS[name.id]
- close_matches = difflib.get_close_matches(name.id, list(state.keys()))
- if len(close_matches) > 0:
- return state[close_matches[0]]
- raise InterpreterError(f"The variable `{name.id}` is not defined.")
- def evaluate_condition(condition, state, static_tools, custom_tools):
- left = evaluate_ast(condition.left, state, static_tools, custom_tools)
- comparators = [evaluate_ast(c, state, static_tools, custom_tools) for c in condition.comparators]
- ops = [type(op) for op in condition.ops]
- result = True
- current_left = left
- for op, comparator in zip(ops, comparators):
- if op == ast.Eq:
- current_result = current_left == comparator
- elif op == ast.NotEq:
- current_result = current_left != comparator
- elif op == ast.Lt:
- current_result = current_left < comparator
- elif op == ast.LtE:
- current_result = current_left <= comparator
- elif op == ast.Gt:
- current_result = current_left > comparator
- elif op == ast.GtE:
- current_result = current_left >= comparator
- elif op == ast.Is:
- current_result = current_left is comparator
- elif op == ast.IsNot:
- current_result = current_left is not comparator
- elif op == ast.In:
- current_result = current_left in comparator
- elif op == ast.NotIn:
- current_result = current_left not in comparator
- else:
- raise InterpreterError(f"Operator not supported: {op}")
- result = result & current_result
- current_left = comparator
- if isinstance(result, bool) and not result:
- break
- return result if isinstance(result, (bool, pd.Series)) else result.all()
- def evaluate_if(if_statement, state, static_tools, custom_tools):
- result = None
- test_result = evaluate_ast(if_statement.test, state, static_tools, custom_tools)
- if test_result:
- for line in if_statement.body:
- line_result = evaluate_ast(line, state, static_tools, custom_tools)
- if line_result is not None:
- result = line_result
- else:
- for line in if_statement.orelse:
- line_result = evaluate_ast(line, state, static_tools, custom_tools)
- if line_result is not None:
- result = line_result
- return result
- def evaluate_for(for_loop, state, static_tools, custom_tools):
- result = None
- iterator = evaluate_ast(for_loop.iter, state, static_tools, custom_tools)
- for counter in iterator:
- set_value(for_loop.target, counter, state, static_tools, custom_tools)
- for node in for_loop.body:
- try:
- line_result = evaluate_ast(node, state, static_tools, custom_tools)
- if line_result is not None:
- result = line_result
- except BreakException:
- break
- except ContinueException:
- continue
- else:
- continue
- break
- return result
- def evaluate_listcomp(listcomp, state, static_tools, custom_tools):
- def inner_evaluate(generators, index, current_state):
- if index >= len(generators):
- return [evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools)]
- generator = generators[index]
- iter_value = evaluate_ast(generator.iter, current_state, static_tools, custom_tools)
- result = []
- for value in iter_value:
- new_state = current_state.copy()
- if isinstance(generator.target, ast.Tuple):
- for idx, elem in enumerate(generator.target.elts):
- new_state[elem.id] = value[idx]
- else:
- new_state[generator.target.id] = value
- if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in generator.ifs):
- result.extend(inner_evaluate(generators, index + 1, new_state))
- return result
- return inner_evaluate(listcomp.generators, 0, state)
- def evaluate_try(try_node, state, static_tools, custom_tools):
- try:
- for stmt in try_node.body:
- evaluate_ast(stmt, state, static_tools, custom_tools)
- except Exception as e:
- matched = False
- for handler in try_node.handlers:
- if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, static_tools, custom_tools)):
- matched = True
- if handler.name:
- state[handler.name] = e
- for stmt in handler.body:
- evaluate_ast(stmt, state, static_tools, custom_tools)
- break
- if not matched:
- raise e
- else:
- if try_node.orelse:
- for stmt in try_node.orelse:
- evaluate_ast(stmt, state, static_tools, custom_tools)
- finally:
- if try_node.finalbody:
- for stmt in try_node.finalbody:
- evaluate_ast(stmt, state, static_tools, custom_tools)
- def evaluate_raise(raise_node, state, static_tools, custom_tools):
- if raise_node.exc is not None:
- exc = evaluate_ast(raise_node.exc, state, static_tools, custom_tools)
- else:
- exc = None
- if raise_node.cause is not None:
- cause = evaluate_ast(raise_node.cause, state, static_tools, custom_tools)
- else:
- cause = None
- if exc is not None:
- if cause is not None:
- raise exc from cause
- else:
- raise exc
- else:
- raise InterpreterError("Re-raise is not supported without an active exception")
- def evaluate_assert(assert_node, state, static_tools, custom_tools):
- test_result = evaluate_ast(assert_node.test, state, static_tools, custom_tools)
- if not test_result:
- if assert_node.msg:
- msg = evaluate_ast(assert_node.msg, state, static_tools, custom_tools)
- raise AssertionError(msg)
- else:
- # Include the failing condition in the assertion message
- test_code = ast.unparse(assert_node.test)
- raise AssertionError(f"Assertion failed: {test_code}")
- def evaluate_with(with_node, state, static_tools, custom_tools):
- contexts = []
- for item in with_node.items:
- context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools)
- if item.optional_vars:
- state[item.optional_vars.id] = context_expr.__enter__()
- contexts.append(state[item.optional_vars.id])
- else:
- context_var = context_expr.__enter__()
- contexts.append(context_var)
- try:
- for stmt in with_node.body:
- evaluate_ast(stmt, state, static_tools, custom_tools)
- except Exception as e:
- for context in reversed(contexts):
- context.__exit__(type(e), e, e.__traceback__)
- raise
- else:
- for context in reversed(contexts):
- context.__exit__(None, None, None)
- def import_modules(expression, state, authorized_imports):
- def check_module_authorized(module_name):
- module_path = module_name.split(".")
- module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
- return any(subpath in authorized_imports for subpath in module_subpaths)
- if isinstance(expression, ast.Import):
- for alias in expression.names:
- if check_module_authorized(alias.name):
- module = import_module(alias.name)
- state[alias.asname or alias.name] = module
- else:
- raise InterpreterError(
- f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
- )
- return None
- elif isinstance(expression, ast.ImportFrom):
- if check_module_authorized(expression.module):
- module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
- for alias in expression.names:
- state[alias.asname or alias.name] = getattr(module, alias.name)
- else:
- raise InterpreterError(f"Import from {expression.module} is not allowed.")
- return None
- def evaluate_dictcomp(dictcomp, state, static_tools, custom_tools):
- result = {}
- for gen in dictcomp.generators:
- iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools)
- for value in iter_value:
- new_state = state.copy()
- set_value(gen.target, value, new_state, static_tools, custom_tools)
- if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in gen.ifs):
- key = evaluate_ast(dictcomp.key, new_state, static_tools, custom_tools)
- val = evaluate_ast(dictcomp.value, new_state, static_tools, custom_tools)
- result[key] = val
- return result
- def evaluate_ast(
- expression: ast.AST,
- state: Dict[str, Any],
- static_tools: Dict[str, Callable],
- custom_tools: Dict[str, Callable],
- authorized_imports: List[str] = LIST_SAFE_MODULES,
- ):
- """
- Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given
- set of functions.
- This function will recurse trough the nodes of the tree provided.
- Args:
- expression (`ast.AST`):
- The code to evaluate, as an abstract syntax tree.
- state (`Dict[str, Any]`):
- A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
- encounters assignements.
- static_tools (`Dict[str, Callable]`):
- Functions that may be called during the evaluation. Trying to change one of these static_tools will raise an error.
- custom_tools (`Dict[str, Callable]`):
- Functions that may be called during the evaluation. These static_tools can be overwritten.
- authorized_imports (`List[str]`):
- The list of modules that can be imported by the code. By default, only a few safe modules are allowed.
- Add more at your own risk!
- """
- global OPERATIONS_COUNT
- if OPERATIONS_COUNT >= MAX_OPERATIONS:
- raise InterpreterError(
- f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations."
- )
- OPERATIONS_COUNT += 1
- if isinstance(expression, ast.Assign):
- # Assignement -> we evaluate the assignment which should update the state
- # We return the variable assigned as it may be used to determine the final result.
- return evaluate_assign(expression, state, static_tools, custom_tools)
- elif isinstance(expression, ast.AugAssign):
- return evaluate_augassign(expression, state, static_tools, custom_tools)
- elif isinstance(expression, ast.Call):
- # Function call -> we return the value of the function call
- return evaluate_call(expression, state, static_tools, custom_tools)
- elif isinstance(expression, ast.Constant):
- # Constant -> just return the value
- return expression.value
- elif isinstance(expression, ast.Tuple):
- return tuple(evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts)
- elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
- return evaluate_listcomp(expression, state, static_tools, custom_tools)
- elif isinstance(expression, ast.UnaryOp):
- return evaluate_unaryop(expression, state, static_tools, custom_tools)
- elif isinstance(expression, ast.Starred):
- return evaluate_ast(expression.value, state, static_tools, custom_tools)
- elif isinstance(expression, ast.BoolOp):
- # Boolean operation -> evaluate the operation
- return evaluate_boolop(expression, state, static_tools, custom_tools)
- elif isinstance(expression, ast.Break):
- raise BreakException()
- elif isinstance(expression, ast.Continue):
- raise ContinueException()
- elif isinstance(expression, ast.BinOp):
- # Binary operation -> execute operation
- return evaluate_binop(expression, state, static_tools, custom_tools)
- elif isinstance(expression, ast.Compare):
- # Comparison -> evaluate the comparison
- return evaluate_condition(expression, state, static_tools, custom_tools)
- elif isinstance(expression, ast.Lambda):
- return evaluate_lambda(expression, state, static_tools, custom_tools)
- elif isinstance(expression, ast.FunctionDef):
- return evaluate_function_def(expression, state, static_tools, custom_tools)
- elif isinstance(expression, ast.Dict):
- # Dict -> evaluate all keys and values
- keys = [evaluate_ast(k, state, static_tools, custom_tools) for k in expression.keys]
- values = [evaluate_ast(v, state, static_tools, custom_tools) for v in expression.values]
- return dict(zip(keys, values))
- elif isinstance(expression, ast.Expr):
- # Expression -> evaluate the content
- return evaluate_ast(expression.value, state, static_tools, custom_tools)
- elif isinstance(expression, ast.For):
- # For loop -> execute the loop
- return evaluate_for(expression, state, static_tools, custom_tools)
- elif isinstance(expression, ast.FormattedValue):
- # Formatted value (part of f-string) -> evaluate the content and return
- return evaluate_ast(expression.value, state, static_tools, custom_tools)
- elif isinstance(expression, ast.If):
- # If -> execute the right branch
- return evaluate_if(expression, state, static_tools, custom_tools)
- elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
- return evaluate_ast(expression.value, state, static_tools, custom_tools)
- elif isinstance(expression, ast.JoinedStr):
- return "".join([str(evaluate_ast(v, state, static_tools, custom_tools)) for v in expression.values])
- elif isinstance(expression, ast.List):
- # List -> evaluate all elements
- return [evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts]
- elif isinstance(expression, ast.Name):
- # Name -> pick up the value in the state
- return evaluate_name(expression, state, static_tools, custom_tools)
- elif isinstance(expression, ast.Subscript):
- # Subscript -> return the value of the indexing
- return evaluate_subscript(expression, state, static_tools, custom_tools)
- elif isinstance(expression, ast.IfExp):
- test_val = evaluate_ast(expression.test, state, static_tools, custom_tools)
- if test_val:
- return evaluate_ast(expression.body, state, static_tools, custom_tools)
- else:
- return evaluate_ast(expression.orelse, state, static_tools, custom_tools)
- elif isinstance(expression, ast.Attribute):
- value = evaluate_ast(expression.value, state, static_tools, custom_tools)
- return getattr(value, expression.attr)
- elif isinstance(expression, ast.Slice):
- return slice(
- evaluate_ast(expression.lower, state, static_tools, custom_tools)
- if expression.lower is not None
- else None,
- evaluate_ast(expression.upper, state, static_tools, custom_tools)
- if expression.upper is not None
- else None,
- evaluate_ast(expression.step, state, static_tools, custom_tools) if expression.step is not None else None,
- )
- elif isinstance(expression, ast.DictComp):
- return evaluate_dictcomp(expression, state, static_tools, custom_tools)
- elif isinstance(expression, ast.While):
- return evaluate_while(expression, state, static_tools, custom_tools)
- elif isinstance(expression, (ast.Import, ast.ImportFrom)):
- return import_modules(expression, state, authorized_imports)
- elif isinstance(expression, ast.ClassDef):
- return evaluate_class_def(expression, state, static_tools, custom_tools)
- elif isinstance(expression, ast.Try):
- return evaluate_try(expression, state, static_tools, custom_tools)
- elif isinstance(expression, ast.Raise):
- return evaluate_raise(expression, state, static_tools, custom_tools)
- elif isinstance(expression, ast.Assert):
- return evaluate_assert(expression, state, static_tools, custom_tools)
- elif isinstance(expression, ast.With):
- return evaluate_with(expression, state, static_tools, custom_tools)
- elif isinstance(expression, ast.Set):
- return {evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts}
- elif isinstance(expression, ast.Return):
- raise ReturnException(
- evaluate_ast(expression.value, state, static_tools, custom_tools) if expression.value else None
- )
- else:
- # For now we refuse anything else. Let's add things as we need them.
- raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
- def evaluate_python_code(
- code: str,
- static_tools: Optional[Dict[str, Callable]] = None,
- custom_tools: Optional[Dict[str, Callable]] = None,
- state: Optional[Dict[str, Any]] = None,
- authorized_imports: List[str] = LIST_SAFE_MODULES,
- ):
- """
- Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
- of functions.
- This function will recurse through the nodes of the tree provided.
- Args:
- code (`str`):
- The code to evaluate.
- static_tools (`Dict[str, Callable]`):
- The functions that may be called during the evaluation.
- These tools cannot be overwritten in the code: any assignment to their name will raise an error.
- custom_tools (`Dict[str, Callable]`):
- The functions that may be called during the evaluation.
- These tools can be overwritten in the code: any assignment to their name will overwrite them.
- state (`Dict[str, Any]`):
- A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
- updated by this function to contain all variables as they are evaluated.
- The print outputs will be stored in the state under the key 'print_outputs'.
- """
- try:
- expression = ast.parse(code)
- except SyntaxError as e:
- raise SyntaxError(f"The code generated by the agent is not valid.\n{e}")
- if state is None:
- state = {}
- if static_tools is None:
- static_tools = {}
- if custom_tools is None:
- custom_tools = {}
- result = None
- global PRINT_OUTPUTS
- PRINT_OUTPUTS = ""
- global OPERATIONS_COUNT
- OPERATIONS_COUNT = 0
- for node in expression.body:
- try:
- result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
- except InterpreterError as e:
- msg = ""
- if len(PRINT_OUTPUTS) > 0:
- if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT:
- msg += f"Print outputs:\n{PRINT_OUTPUTS}\n====\n"
- else:
- msg += f"Print outputs:\n{PRINT_OUTPUTS[:MAX_LEN_OUTPUT]}\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._\n====\n"
- msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
- raise InterpreterError(msg)
- finally:
- if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT:
- state["print_outputs"] = PRINT_OUTPUTS
- else:
- state["print_outputs"] = (
- PRINT_OUTPUTS[:MAX_LEN_OUTPUT]
- + f"\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._"
- )
- return result
|