python_interpreter.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914
  1. #!/usr/bin/env python
  2. # coding=utf-8
  3. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import ast
  17. import builtins
  18. import difflib
  19. from collections.abc import Mapping
  20. from importlib import import_module
  21. from typing import Any, Callable, Dict, List, Optional
  22. import numpy as np
  23. from ..utils import is_pandas_available
  24. if is_pandas_available():
  25. import pandas as pd
  26. class InterpreterError(ValueError):
  27. """
  28. An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported
  29. operations.
  30. """
  31. pass
  32. ERRORS = {
  33. name: getattr(builtins, name)
  34. for name in dir(builtins)
  35. if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException)
  36. }
  37. LIST_SAFE_MODULES = [
  38. "random",
  39. "collections",
  40. "math",
  41. "time",
  42. "queue",
  43. "itertools",
  44. "re",
  45. "stat",
  46. "statistics",
  47. "unicodedata",
  48. ]
  49. PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000
  50. OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000
  51. class BreakException(Exception):
  52. pass
  53. class ContinueException(Exception):
  54. pass
  55. class ReturnException(Exception):
  56. def __init__(self, value):
  57. self.value = value
  58. def get_iterable(obj):
  59. if isinstance(obj, list):
  60. return obj
  61. elif hasattr(obj, "__iter__"):
  62. return list(obj)
  63. else:
  64. raise InterpreterError("Object is not iterable")
  65. def evaluate_unaryop(expression, state, static_tools, custom_tools):
  66. operand = evaluate_ast(expression.operand, state, static_tools, custom_tools)
  67. if isinstance(expression.op, ast.USub):
  68. return -operand
  69. elif isinstance(expression.op, ast.UAdd):
  70. return operand
  71. elif isinstance(expression.op, ast.Not):
  72. return not operand
  73. elif isinstance(expression.op, ast.Invert):
  74. return ~operand
  75. else:
  76. raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
  77. def evaluate_lambda(lambda_expression, state, static_tools, custom_tools):
  78. args = [arg.arg for arg in lambda_expression.args.args]
  79. def lambda_func(*values):
  80. new_state = state.copy()
  81. for arg, value in zip(args, values):
  82. new_state[arg] = value
  83. return evaluate_ast(lambda_expression.body, new_state, static_tools, custom_tools)
  84. return lambda_func
  85. def evaluate_while(while_loop, state, static_tools, custom_tools):
  86. max_iterations = 1000
  87. iterations = 0
  88. while evaluate_ast(while_loop.test, state, static_tools, custom_tools):
  89. for node in while_loop.body:
  90. try:
  91. evaluate_ast(node, state, static_tools, custom_tools)
  92. except BreakException:
  93. return None
  94. except ContinueException:
  95. break
  96. iterations += 1
  97. if iterations > max_iterations:
  98. raise InterpreterError(f"Maximum number of {max_iterations} iterations in While loop exceeded")
  99. return None
  100. def create_function(func_def, state, static_tools, custom_tools):
  101. def new_func(*args, **kwargs):
  102. func_state = state.copy()
  103. arg_names = [arg.arg for arg in func_def.args.args]
  104. default_values = [evaluate_ast(d, state, static_tools, custom_tools) for d in func_def.args.defaults]
  105. # Apply default values
  106. defaults = dict(zip(arg_names[-len(default_values) :], default_values))
  107. # Set positional arguments
  108. for name, value in zip(arg_names, args):
  109. func_state[name] = value
  110. # # Set keyword arguments
  111. for name, value in kwargs.items():
  112. func_state[name] = value
  113. # Handle variable arguments
  114. if func_def.args.vararg:
  115. vararg_name = func_def.args.vararg.arg
  116. func_state[vararg_name] = args
  117. if func_def.args.kwarg:
  118. kwarg_name = func_def.args.kwarg.arg
  119. func_state[kwarg_name] = kwargs
  120. # Set default values for arguments that were not provided
  121. for name, value in defaults.items():
  122. if name not in func_state:
  123. func_state[name] = value
  124. # Update function state with self and __class__
  125. if func_def.args.args and func_def.args.args[0].arg == "self":
  126. if args:
  127. func_state["self"] = args[0]
  128. func_state["__class__"] = args[0].__class__
  129. result = None
  130. try:
  131. for stmt in func_def.body:
  132. result = evaluate_ast(stmt, func_state, static_tools, custom_tools)
  133. except ReturnException as e:
  134. result = e.value
  135. return result
  136. return new_func
  137. def create_class(class_name, class_bases, class_body):
  138. class_dict = {}
  139. for key, value in class_body.items():
  140. class_dict[key] = value
  141. return type(class_name, tuple(class_bases), class_dict)
  142. def evaluate_function_def(func_def, state, static_tools, custom_tools):
  143. custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools)
  144. return custom_tools[func_def.name]
  145. def evaluate_class_def(class_def, state, static_tools, custom_tools):
  146. class_name = class_def.name
  147. bases = [evaluate_ast(base, state, static_tools, custom_tools) for base in class_def.bases]
  148. class_dict = {}
  149. for stmt in class_def.body:
  150. if isinstance(stmt, ast.FunctionDef):
  151. class_dict[stmt.name] = evaluate_function_def(stmt, state, static_tools, custom_tools)
  152. elif isinstance(stmt, ast.Assign):
  153. for target in stmt.targets:
  154. if isinstance(target, ast.Name):
  155. class_dict[target.id] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
  156. elif isinstance(target, ast.Attribute):
  157. class_dict[target.attr] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
  158. else:
  159. raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
  160. new_class = type(class_name, tuple(bases), class_dict)
  161. state[class_name] = new_class
  162. return new_class
  163. def evaluate_augassign(expression, state, static_tools, custom_tools):
  164. # Helper function to get current value and set new value based on the target type
  165. def get_current_value(target):
  166. if isinstance(target, ast.Name):
  167. return state.get(target.id, 0)
  168. elif isinstance(target, ast.Subscript):
  169. obj = evaluate_ast(target.value, state, static_tools, custom_tools)
  170. key = evaluate_ast(target.slice, state, static_tools, custom_tools)
  171. return obj[key]
  172. elif isinstance(target, ast.Attribute):
  173. obj = evaluate_ast(target.value, state, static_tools, custom_tools)
  174. return getattr(obj, target.attr)
  175. elif isinstance(target, ast.Tuple):
  176. return tuple(get_current_value(elt) for elt in target.elts)
  177. elif isinstance(target, ast.List):
  178. return [get_current_value(elt) for elt in target.elts]
  179. else:
  180. raise InterpreterError("AugAssign not supported for {type(target)} targets.")
  181. current_value = get_current_value(expression.target)
  182. value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools)
  183. # Determine the operation and apply it
  184. if isinstance(expression.op, ast.Add):
  185. if isinstance(current_value, list):
  186. if not isinstance(value_to_add, list):
  187. raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.")
  188. updated_value = current_value + value_to_add
  189. else:
  190. updated_value = current_value + value_to_add
  191. elif isinstance(expression.op, ast.Sub):
  192. updated_value = current_value - value_to_add
  193. elif isinstance(expression.op, ast.Mult):
  194. updated_value = current_value * value_to_add
  195. elif isinstance(expression.op, ast.Div):
  196. updated_value = current_value / value_to_add
  197. elif isinstance(expression.op, ast.Mod):
  198. updated_value = current_value % value_to_add
  199. elif isinstance(expression.op, ast.Pow):
  200. updated_value = current_value**value_to_add
  201. elif isinstance(expression.op, ast.FloorDiv):
  202. updated_value = current_value // value_to_add
  203. elif isinstance(expression.op, ast.BitAnd):
  204. updated_value = current_value & value_to_add
  205. elif isinstance(expression.op, ast.BitOr):
  206. updated_value = current_value | value_to_add
  207. elif isinstance(expression.op, ast.BitXor):
  208. updated_value = current_value ^ value_to_add
  209. elif isinstance(expression.op, ast.LShift):
  210. updated_value = current_value << value_to_add
  211. elif isinstance(expression.op, ast.RShift):
  212. updated_value = current_value >> value_to_add
  213. else:
  214. raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
  215. # Update the state
  216. set_value(expression.target, updated_value, state, static_tools, custom_tools)
  217. return updated_value
  218. def evaluate_boolop(node, state, static_tools, custom_tools):
  219. if isinstance(node.op, ast.And):
  220. for value in node.values:
  221. if not evaluate_ast(value, state, static_tools, custom_tools):
  222. return False
  223. return True
  224. elif isinstance(node.op, ast.Or):
  225. for value in node.values:
  226. if evaluate_ast(value, state, static_tools, custom_tools):
  227. return True
  228. return False
  229. def evaluate_binop(binop, state, static_tools, custom_tools):
  230. # Recursively evaluate the left and right operands
  231. left_val = evaluate_ast(binop.left, state, static_tools, custom_tools)
  232. right_val = evaluate_ast(binop.right, state, static_tools, custom_tools)
  233. # Determine the operation based on the type of the operator in the BinOp
  234. if isinstance(binop.op, ast.Add):
  235. return left_val + right_val
  236. elif isinstance(binop.op, ast.Sub):
  237. return left_val - right_val
  238. elif isinstance(binop.op, ast.Mult):
  239. return left_val * right_val
  240. elif isinstance(binop.op, ast.Div):
  241. return left_val / right_val
  242. elif isinstance(binop.op, ast.Mod):
  243. return left_val % right_val
  244. elif isinstance(binop.op, ast.Pow):
  245. return left_val**right_val
  246. elif isinstance(binop.op, ast.FloorDiv):
  247. return left_val // right_val
  248. elif isinstance(binop.op, ast.BitAnd):
  249. return left_val & right_val
  250. elif isinstance(binop.op, ast.BitOr):
  251. return left_val | right_val
  252. elif isinstance(binop.op, ast.BitXor):
  253. return left_val ^ right_val
  254. elif isinstance(binop.op, ast.LShift):
  255. return left_val << right_val
  256. elif isinstance(binop.op, ast.RShift):
  257. return left_val >> right_val
  258. else:
  259. raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.")
  260. def evaluate_assign(assign, state, static_tools, custom_tools):
  261. result = evaluate_ast(assign.value, state, static_tools, custom_tools)
  262. if len(assign.targets) == 1:
  263. target = assign.targets[0]
  264. set_value(target, result, state, static_tools, custom_tools)
  265. else:
  266. if len(assign.targets) != len(result):
  267. raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.")
  268. expanded_values = []
  269. for tgt in assign.targets:
  270. if isinstance(tgt, ast.Starred):
  271. expanded_values.extend(result)
  272. else:
  273. expanded_values.append(result)
  274. for tgt, val in zip(assign.targets, expanded_values):
  275. set_value(tgt, val, state, static_tools, custom_tools)
  276. return result
  277. def set_value(target, value, state, static_tools, custom_tools):
  278. if isinstance(target, ast.Name):
  279. if target.id in static_tools:
  280. raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!")
  281. state[target.id] = value
  282. elif isinstance(target, ast.Tuple):
  283. if not isinstance(value, tuple):
  284. if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
  285. value = tuple(value)
  286. else:
  287. raise InterpreterError("Cannot unpack non-tuple value")
  288. if len(target.elts) != len(value):
  289. raise InterpreterError("Cannot unpack tuple of wrong size")
  290. for i, elem in enumerate(target.elts):
  291. set_value(elem, value[i], state, static_tools, custom_tools)
  292. elif isinstance(target, ast.Subscript):
  293. obj = evaluate_ast(target.value, state, static_tools, custom_tools)
  294. key = evaluate_ast(target.slice, state, static_tools, custom_tools)
  295. obj[key] = value
  296. elif isinstance(target, ast.Attribute):
  297. obj = evaluate_ast(target.value, state, static_tools, custom_tools)
  298. setattr(obj, target.attr, value)
  299. def evaluate_call(call, state, static_tools, custom_tools):
  300. if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)):
  301. raise InterpreterError(f"This is not a correct function: {call.func}).")
  302. if isinstance(call.func, ast.Attribute):
  303. obj = evaluate_ast(call.func.value, state, static_tools, custom_tools)
  304. func_name = call.func.attr
  305. if not hasattr(obj, func_name):
  306. raise InterpreterError(f"Object {obj} has no attribute {func_name}")
  307. func = getattr(obj, func_name)
  308. elif isinstance(call.func, ast.Name):
  309. func_name = call.func.id
  310. if func_name in state:
  311. func = state[func_name]
  312. elif func_name in static_tools:
  313. func = static_tools[func_name]
  314. elif func_name in custom_tools:
  315. func = custom_tools[func_name]
  316. elif func_name in ERRORS:
  317. func = ERRORS[func_name]
  318. else:
  319. raise InterpreterError(
  320. f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})."
  321. )
  322. args = []
  323. for arg in call.args:
  324. if isinstance(arg, ast.Starred):
  325. args.extend(evaluate_ast(arg.value, state, static_tools, custom_tools))
  326. else:
  327. args.append(evaluate_ast(arg, state, static_tools, custom_tools))
  328. args = []
  329. for arg in call.args:
  330. if isinstance(arg, ast.Starred):
  331. unpacked = evaluate_ast(arg.value, state, static_tools, custom_tools)
  332. if not hasattr(unpacked, "__iter__") or isinstance(unpacked, (str, bytes)):
  333. raise InterpreterError(f"Cannot unpack non-iterable value {unpacked}")
  334. args.extend(unpacked)
  335. else:
  336. args.append(evaluate_ast(arg, state, static_tools, custom_tools))
  337. kwargs = {keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools) for keyword in call.keywords}
  338. if isinstance(func, type) and len(func.__module__.split(".")) > 1: # Check for user-defined classes
  339. # Instantiate the class using its constructor
  340. obj = func.__new__(func) # Create a new instance of the class
  341. if hasattr(obj, "__init__"): # Check if the class has an __init__ method
  342. obj.__init__(*args, **kwargs) # Call the __init__ method correctly
  343. return obj
  344. else:
  345. if func_name == "super":
  346. if not args:
  347. if "__class__" in state and "self" in state:
  348. return super(state["__class__"], state["self"])
  349. else:
  350. raise InterpreterError("super() needs at least one argument")
  351. cls = args[0]
  352. if not isinstance(cls, type):
  353. raise InterpreterError("super() argument 1 must be type")
  354. if len(args) == 1:
  355. return super(cls)
  356. elif len(args) == 2:
  357. instance = args[1]
  358. return super(cls, instance)
  359. else:
  360. raise InterpreterError("super() takes at most 2 arguments")
  361. else:
  362. if func_name == "print":
  363. output = " ".join(map(str, args))
  364. global PRINT_OUTPUTS
  365. PRINT_OUTPUTS += output + "\n"
  366. # cap the number of lines
  367. return None
  368. else: # Assume it's a callable object
  369. output = func(*args, **kwargs)
  370. return output
  371. def evaluate_subscript(subscript, state, static_tools, custom_tools):
  372. index = evaluate_ast(subscript.slice, state, static_tools, custom_tools)
  373. value = evaluate_ast(subscript.value, state, static_tools, custom_tools)
  374. if isinstance(value, str) and isinstance(index, str):
  375. raise InterpreterError("You're trying to subscript a string with a string index, which is impossible")
  376. if isinstance(value, pd.core.indexing._LocIndexer):
  377. parent_object = value.obj
  378. return parent_object.loc[index]
  379. if isinstance(value, (pd.DataFrame, pd.Series, np.ndarray)):
  380. return value[index]
  381. elif isinstance(value, pd.core.groupby.generic.DataFrameGroupBy):
  382. return value[index]
  383. elif isinstance(index, slice):
  384. return value[index]
  385. elif isinstance(value, (list, tuple)):
  386. if not (-len(value) <= index < len(value)):
  387. raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}")
  388. return value[int(index)]
  389. elif isinstance(value, str):
  390. if not (-len(value) <= index < len(value)):
  391. raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}")
  392. return value[index]
  393. elif index in value:
  394. return value[index]
  395. elif isinstance(index, str) and isinstance(value, Mapping):
  396. close_matches = difflib.get_close_matches(index, list(value.keys()))
  397. if len(close_matches) > 0:
  398. return value[close_matches[0]]
  399. raise InterpreterError(f"Could not index {value} with '{index}'.")
  400. def evaluate_name(name, state, static_tools, custom_tools):
  401. if name.id in state:
  402. return state[name.id]
  403. elif name.id in static_tools:
  404. return static_tools[name.id]
  405. elif name.id in ERRORS:
  406. return ERRORS[name.id]
  407. close_matches = difflib.get_close_matches(name.id, list(state.keys()))
  408. if len(close_matches) > 0:
  409. return state[close_matches[0]]
  410. raise InterpreterError(f"The variable `{name.id}` is not defined.")
  411. def evaluate_condition(condition, state, static_tools, custom_tools):
  412. left = evaluate_ast(condition.left, state, static_tools, custom_tools)
  413. comparators = [evaluate_ast(c, state, static_tools, custom_tools) for c in condition.comparators]
  414. ops = [type(op) for op in condition.ops]
  415. result = True
  416. current_left = left
  417. for op, comparator in zip(ops, comparators):
  418. if op == ast.Eq:
  419. current_result = current_left == comparator
  420. elif op == ast.NotEq:
  421. current_result = current_left != comparator
  422. elif op == ast.Lt:
  423. current_result = current_left < comparator
  424. elif op == ast.LtE:
  425. current_result = current_left <= comparator
  426. elif op == ast.Gt:
  427. current_result = current_left > comparator
  428. elif op == ast.GtE:
  429. current_result = current_left >= comparator
  430. elif op == ast.Is:
  431. current_result = current_left is comparator
  432. elif op == ast.IsNot:
  433. current_result = current_left is not comparator
  434. elif op == ast.In:
  435. current_result = current_left in comparator
  436. elif op == ast.NotIn:
  437. current_result = current_left not in comparator
  438. else:
  439. raise InterpreterError(f"Operator not supported: {op}")
  440. result = result & current_result
  441. current_left = comparator
  442. if isinstance(result, bool) and not result:
  443. break
  444. return result if isinstance(result, (bool, pd.Series)) else result.all()
  445. def evaluate_if(if_statement, state, static_tools, custom_tools):
  446. result = None
  447. test_result = evaluate_ast(if_statement.test, state, static_tools, custom_tools)
  448. if test_result:
  449. for line in if_statement.body:
  450. line_result = evaluate_ast(line, state, static_tools, custom_tools)
  451. if line_result is not None:
  452. result = line_result
  453. else:
  454. for line in if_statement.orelse:
  455. line_result = evaluate_ast(line, state, static_tools, custom_tools)
  456. if line_result is not None:
  457. result = line_result
  458. return result
  459. def evaluate_for(for_loop, state, static_tools, custom_tools):
  460. result = None
  461. iterator = evaluate_ast(for_loop.iter, state, static_tools, custom_tools)
  462. for counter in iterator:
  463. set_value(for_loop.target, counter, state, static_tools, custom_tools)
  464. for node in for_loop.body:
  465. try:
  466. line_result = evaluate_ast(node, state, static_tools, custom_tools)
  467. if line_result is not None:
  468. result = line_result
  469. except BreakException:
  470. break
  471. except ContinueException:
  472. continue
  473. else:
  474. continue
  475. break
  476. return result
  477. def evaluate_listcomp(listcomp, state, static_tools, custom_tools):
  478. def inner_evaluate(generators, index, current_state):
  479. if index >= len(generators):
  480. return [evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools)]
  481. generator = generators[index]
  482. iter_value = evaluate_ast(generator.iter, current_state, static_tools, custom_tools)
  483. result = []
  484. for value in iter_value:
  485. new_state = current_state.copy()
  486. if isinstance(generator.target, ast.Tuple):
  487. for idx, elem in enumerate(generator.target.elts):
  488. new_state[elem.id] = value[idx]
  489. else:
  490. new_state[generator.target.id] = value
  491. if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in generator.ifs):
  492. result.extend(inner_evaluate(generators, index + 1, new_state))
  493. return result
  494. return inner_evaluate(listcomp.generators, 0, state)
  495. def evaluate_try(try_node, state, static_tools, custom_tools):
  496. try:
  497. for stmt in try_node.body:
  498. evaluate_ast(stmt, state, static_tools, custom_tools)
  499. except Exception as e:
  500. matched = False
  501. for handler in try_node.handlers:
  502. if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, static_tools, custom_tools)):
  503. matched = True
  504. if handler.name:
  505. state[handler.name] = e
  506. for stmt in handler.body:
  507. evaluate_ast(stmt, state, static_tools, custom_tools)
  508. break
  509. if not matched:
  510. raise e
  511. else:
  512. if try_node.orelse:
  513. for stmt in try_node.orelse:
  514. evaluate_ast(stmt, state, static_tools, custom_tools)
  515. finally:
  516. if try_node.finalbody:
  517. for stmt in try_node.finalbody:
  518. evaluate_ast(stmt, state, static_tools, custom_tools)
  519. def evaluate_raise(raise_node, state, static_tools, custom_tools):
  520. if raise_node.exc is not None:
  521. exc = evaluate_ast(raise_node.exc, state, static_tools, custom_tools)
  522. else:
  523. exc = None
  524. if raise_node.cause is not None:
  525. cause = evaluate_ast(raise_node.cause, state, static_tools, custom_tools)
  526. else:
  527. cause = None
  528. if exc is not None:
  529. if cause is not None:
  530. raise exc from cause
  531. else:
  532. raise exc
  533. else:
  534. raise InterpreterError("Re-raise is not supported without an active exception")
  535. def evaluate_assert(assert_node, state, static_tools, custom_tools):
  536. test_result = evaluate_ast(assert_node.test, state, static_tools, custom_tools)
  537. if not test_result:
  538. if assert_node.msg:
  539. msg = evaluate_ast(assert_node.msg, state, static_tools, custom_tools)
  540. raise AssertionError(msg)
  541. else:
  542. # Include the failing condition in the assertion message
  543. test_code = ast.unparse(assert_node.test)
  544. raise AssertionError(f"Assertion failed: {test_code}")
  545. def evaluate_with(with_node, state, static_tools, custom_tools):
  546. contexts = []
  547. for item in with_node.items:
  548. context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools)
  549. if item.optional_vars:
  550. state[item.optional_vars.id] = context_expr.__enter__()
  551. contexts.append(state[item.optional_vars.id])
  552. else:
  553. context_var = context_expr.__enter__()
  554. contexts.append(context_var)
  555. try:
  556. for stmt in with_node.body:
  557. evaluate_ast(stmt, state, static_tools, custom_tools)
  558. except Exception as e:
  559. for context in reversed(contexts):
  560. context.__exit__(type(e), e, e.__traceback__)
  561. raise
  562. else:
  563. for context in reversed(contexts):
  564. context.__exit__(None, None, None)
  565. def import_modules(expression, state, authorized_imports):
  566. def check_module_authorized(module_name):
  567. module_path = module_name.split(".")
  568. module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
  569. return any(subpath in authorized_imports for subpath in module_subpaths)
  570. if isinstance(expression, ast.Import):
  571. for alias in expression.names:
  572. if check_module_authorized(alias.name):
  573. module = import_module(alias.name)
  574. state[alias.asname or alias.name] = module
  575. else:
  576. raise InterpreterError(
  577. f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
  578. )
  579. return None
  580. elif isinstance(expression, ast.ImportFrom):
  581. if check_module_authorized(expression.module):
  582. module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
  583. for alias in expression.names:
  584. state[alias.asname or alias.name] = getattr(module, alias.name)
  585. else:
  586. raise InterpreterError(f"Import from {expression.module} is not allowed.")
  587. return None
  588. def evaluate_dictcomp(dictcomp, state, static_tools, custom_tools):
  589. result = {}
  590. for gen in dictcomp.generators:
  591. iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools)
  592. for value in iter_value:
  593. new_state = state.copy()
  594. set_value(gen.target, value, new_state, static_tools, custom_tools)
  595. if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in gen.ifs):
  596. key = evaluate_ast(dictcomp.key, new_state, static_tools, custom_tools)
  597. val = evaluate_ast(dictcomp.value, new_state, static_tools, custom_tools)
  598. result[key] = val
  599. return result
  600. def evaluate_ast(
  601. expression: ast.AST,
  602. state: Dict[str, Any],
  603. static_tools: Dict[str, Callable],
  604. custom_tools: Dict[str, Callable],
  605. authorized_imports: List[str] = LIST_SAFE_MODULES,
  606. ):
  607. """
  608. Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given
  609. set of functions.
  610. This function will recurse trough the nodes of the tree provided.
  611. Args:
  612. expression (`ast.AST`):
  613. The code to evaluate, as an abstract syntax tree.
  614. state (`Dict[str, Any]`):
  615. A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
  616. encounters assignements.
  617. static_tools (`Dict[str, Callable]`):
  618. Functions that may be called during the evaluation. Trying to change one of these static_tools will raise an error.
  619. custom_tools (`Dict[str, Callable]`):
  620. Functions that may be called during the evaluation. These static_tools can be overwritten.
  621. authorized_imports (`List[str]`):
  622. The list of modules that can be imported by the code. By default, only a few safe modules are allowed.
  623. Add more at your own risk!
  624. """
  625. global OPERATIONS_COUNT
  626. if OPERATIONS_COUNT >= MAX_OPERATIONS:
  627. raise InterpreterError(
  628. f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations."
  629. )
  630. OPERATIONS_COUNT += 1
  631. if isinstance(expression, ast.Assign):
  632. # Assignement -> we evaluate the assignment which should update the state
  633. # We return the variable assigned as it may be used to determine the final result.
  634. return evaluate_assign(expression, state, static_tools, custom_tools)
  635. elif isinstance(expression, ast.AugAssign):
  636. return evaluate_augassign(expression, state, static_tools, custom_tools)
  637. elif isinstance(expression, ast.Call):
  638. # Function call -> we return the value of the function call
  639. return evaluate_call(expression, state, static_tools, custom_tools)
  640. elif isinstance(expression, ast.Constant):
  641. # Constant -> just return the value
  642. return expression.value
  643. elif isinstance(expression, ast.Tuple):
  644. return tuple(evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts)
  645. elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
  646. return evaluate_listcomp(expression, state, static_tools, custom_tools)
  647. elif isinstance(expression, ast.UnaryOp):
  648. return evaluate_unaryop(expression, state, static_tools, custom_tools)
  649. elif isinstance(expression, ast.Starred):
  650. return evaluate_ast(expression.value, state, static_tools, custom_tools)
  651. elif isinstance(expression, ast.BoolOp):
  652. # Boolean operation -> evaluate the operation
  653. return evaluate_boolop(expression, state, static_tools, custom_tools)
  654. elif isinstance(expression, ast.Break):
  655. raise BreakException()
  656. elif isinstance(expression, ast.Continue):
  657. raise ContinueException()
  658. elif isinstance(expression, ast.BinOp):
  659. # Binary operation -> execute operation
  660. return evaluate_binop(expression, state, static_tools, custom_tools)
  661. elif isinstance(expression, ast.Compare):
  662. # Comparison -> evaluate the comparison
  663. return evaluate_condition(expression, state, static_tools, custom_tools)
  664. elif isinstance(expression, ast.Lambda):
  665. return evaluate_lambda(expression, state, static_tools, custom_tools)
  666. elif isinstance(expression, ast.FunctionDef):
  667. return evaluate_function_def(expression, state, static_tools, custom_tools)
  668. elif isinstance(expression, ast.Dict):
  669. # Dict -> evaluate all keys and values
  670. keys = [evaluate_ast(k, state, static_tools, custom_tools) for k in expression.keys]
  671. values = [evaluate_ast(v, state, static_tools, custom_tools) for v in expression.values]
  672. return dict(zip(keys, values))
  673. elif isinstance(expression, ast.Expr):
  674. # Expression -> evaluate the content
  675. return evaluate_ast(expression.value, state, static_tools, custom_tools)
  676. elif isinstance(expression, ast.For):
  677. # For loop -> execute the loop
  678. return evaluate_for(expression, state, static_tools, custom_tools)
  679. elif isinstance(expression, ast.FormattedValue):
  680. # Formatted value (part of f-string) -> evaluate the content and return
  681. return evaluate_ast(expression.value, state, static_tools, custom_tools)
  682. elif isinstance(expression, ast.If):
  683. # If -> execute the right branch
  684. return evaluate_if(expression, state, static_tools, custom_tools)
  685. elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
  686. return evaluate_ast(expression.value, state, static_tools, custom_tools)
  687. elif isinstance(expression, ast.JoinedStr):
  688. return "".join([str(evaluate_ast(v, state, static_tools, custom_tools)) for v in expression.values])
  689. elif isinstance(expression, ast.List):
  690. # List -> evaluate all elements
  691. return [evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts]
  692. elif isinstance(expression, ast.Name):
  693. # Name -> pick up the value in the state
  694. return evaluate_name(expression, state, static_tools, custom_tools)
  695. elif isinstance(expression, ast.Subscript):
  696. # Subscript -> return the value of the indexing
  697. return evaluate_subscript(expression, state, static_tools, custom_tools)
  698. elif isinstance(expression, ast.IfExp):
  699. test_val = evaluate_ast(expression.test, state, static_tools, custom_tools)
  700. if test_val:
  701. return evaluate_ast(expression.body, state, static_tools, custom_tools)
  702. else:
  703. return evaluate_ast(expression.orelse, state, static_tools, custom_tools)
  704. elif isinstance(expression, ast.Attribute):
  705. value = evaluate_ast(expression.value, state, static_tools, custom_tools)
  706. return getattr(value, expression.attr)
  707. elif isinstance(expression, ast.Slice):
  708. return slice(
  709. evaluate_ast(expression.lower, state, static_tools, custom_tools)
  710. if expression.lower is not None
  711. else None,
  712. evaluate_ast(expression.upper, state, static_tools, custom_tools)
  713. if expression.upper is not None
  714. else None,
  715. evaluate_ast(expression.step, state, static_tools, custom_tools) if expression.step is not None else None,
  716. )
  717. elif isinstance(expression, ast.DictComp):
  718. return evaluate_dictcomp(expression, state, static_tools, custom_tools)
  719. elif isinstance(expression, ast.While):
  720. return evaluate_while(expression, state, static_tools, custom_tools)
  721. elif isinstance(expression, (ast.Import, ast.ImportFrom)):
  722. return import_modules(expression, state, authorized_imports)
  723. elif isinstance(expression, ast.ClassDef):
  724. return evaluate_class_def(expression, state, static_tools, custom_tools)
  725. elif isinstance(expression, ast.Try):
  726. return evaluate_try(expression, state, static_tools, custom_tools)
  727. elif isinstance(expression, ast.Raise):
  728. return evaluate_raise(expression, state, static_tools, custom_tools)
  729. elif isinstance(expression, ast.Assert):
  730. return evaluate_assert(expression, state, static_tools, custom_tools)
  731. elif isinstance(expression, ast.With):
  732. return evaluate_with(expression, state, static_tools, custom_tools)
  733. elif isinstance(expression, ast.Set):
  734. return {evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts}
  735. elif isinstance(expression, ast.Return):
  736. raise ReturnException(
  737. evaluate_ast(expression.value, state, static_tools, custom_tools) if expression.value else None
  738. )
  739. else:
  740. # For now we refuse anything else. Let's add things as we need them.
  741. raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
  742. def evaluate_python_code(
  743. code: str,
  744. static_tools: Optional[Dict[str, Callable]] = None,
  745. custom_tools: Optional[Dict[str, Callable]] = None,
  746. state: Optional[Dict[str, Any]] = None,
  747. authorized_imports: List[str] = LIST_SAFE_MODULES,
  748. ):
  749. """
  750. Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
  751. of functions.
  752. This function will recurse through the nodes of the tree provided.
  753. Args:
  754. code (`str`):
  755. The code to evaluate.
  756. static_tools (`Dict[str, Callable]`):
  757. The functions that may be called during the evaluation.
  758. These tools cannot be overwritten in the code: any assignment to their name will raise an error.
  759. custom_tools (`Dict[str, Callable]`):
  760. The functions that may be called during the evaluation.
  761. These tools can be overwritten in the code: any assignment to their name will overwrite them.
  762. state (`Dict[str, Any]`):
  763. A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
  764. updated by this function to contain all variables as they are evaluated.
  765. The print outputs will be stored in the state under the key 'print_outputs'.
  766. """
  767. try:
  768. expression = ast.parse(code)
  769. except SyntaxError as e:
  770. raise SyntaxError(f"The code generated by the agent is not valid.\n{e}")
  771. if state is None:
  772. state = {}
  773. if static_tools is None:
  774. static_tools = {}
  775. if custom_tools is None:
  776. custom_tools = {}
  777. result = None
  778. global PRINT_OUTPUTS
  779. PRINT_OUTPUTS = ""
  780. global OPERATIONS_COUNT
  781. OPERATIONS_COUNT = 0
  782. for node in expression.body:
  783. try:
  784. result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
  785. except InterpreterError as e:
  786. msg = ""
  787. if len(PRINT_OUTPUTS) > 0:
  788. if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT:
  789. msg += f"Print outputs:\n{PRINT_OUTPUTS}\n====\n"
  790. else:
  791. msg += f"Print outputs:\n{PRINT_OUTPUTS[:MAX_LEN_OUTPUT]}\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._\n====\n"
  792. msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
  793. raise InterpreterError(msg)
  794. finally:
  795. if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT:
  796. state["print_outputs"] = PRINT_OUTPUTS
  797. else:
  798. state["print_outputs"] = (
  799. PRINT_OUTPUTS[:MAX_LEN_OUTPUT]
  800. + f"\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._"
  801. )
  802. return result