script_object.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # mypy: allow-untyped-defs
  2. import functools
  3. from typing import Dict
  4. import torch
  5. from ..exc import unimplemented, UnsafeScriptObjectError, Unsupported
  6. from .base import VariableTracker
  7. from .user_defined import UserDefinedObjectVariable
  8. def _raise_hard_error_if_graph_break(reason):
  9. def deco(fn):
  10. @functools.wraps(fn)
  11. def graph_break_as_hard_error(*args, **kwargs):
  12. try:
  13. return fn(*args, **kwargs)
  14. except Unsupported as e:
  15. raise UnsafeScriptObjectError(e.msg) from e
  16. return graph_break_as_hard_error
  17. return deco
  18. class TorchScriptObjectVariable(UserDefinedObjectVariable):
  19. _fake_script_object_cache: Dict[int, "TorchScriptObjectVariable"] = {}
  20. @classmethod
  21. def is_matching_cls(cls, user_cls: type):
  22. return issubclass(user_cls, torch.ScriptObject)
  23. @staticmethod
  24. def create(proxy, value, **options):
  25. return TorchScriptObjectVariable(proxy, value, **options)
  26. def __init__(self, proxy, value, source, **kwargs):
  27. super().__init__(value, **kwargs)
  28. self.proxy = proxy
  29. self.proxy.node.meta["example_value"] = value
  30. self.source = source
  31. def as_proxy(self):
  32. return self.proxy
  33. @_raise_hard_error_if_graph_break(
  34. "Dynamo cannot safely trace script object due to graph break."
  35. )
  36. def var_getattr(self, tx, name: str) -> VariableTracker:
  37. from torch._higher_order_ops.torchbind import call_torchbind
  38. from ..source import AttrSource
  39. from .higher_order_ops import TorchHigherOrderOperatorVariable
  40. method = getattr(self.value, name, None)
  41. if method is None:
  42. unimplemented(
  43. f"FakeScriptObject doesn't define method {name}. Did you forget to implement it in the fake class?"
  44. )
  45. if not callable(method):
  46. unimplemented(
  47. "Only method calls on TorchScript objects can be supported safely."
  48. " Please use method calls instead of attribute access."
  49. )
  50. return TorchHigherOrderOperatorVariable.make(
  51. call_torchbind,
  52. source=AttrSource(self.source, name),
  53. script_obj_var=self,
  54. method_name=name,
  55. )
  56. # We only support method calls on script objects. Interpreting the bytecodes
  57. # should go through var_getattr then call_function instead of call_method.
  58. #
  59. # However, it's possible for call_method to be used directly e.g. for __setattr__.
  60. @_raise_hard_error_if_graph_break(
  61. "Dynamo cannot safely trace script object due to graph break."
  62. )
  63. def call_method(self, tx, name, args, kwargs):
  64. unimplemented(f"call method {name} on script object is not safe.")