union.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # mypy: allow-untyped-defs
  2. import functools
  3. from dataclasses import fields
  4. from typing import Hashable, Set
  5. class _UnionTag(str):
  6. _cls: Hashable
  7. @staticmethod
  8. def create(t, cls):
  9. tag = _UnionTag(t)
  10. assert not hasattr(tag, "_cls")
  11. tag._cls = cls
  12. return tag
  13. def __eq__(self, cmp) -> bool:
  14. assert isinstance(cmp, str)
  15. other = str(cmp)
  16. assert other in _get_field_names(
  17. self._cls
  18. ), f"{other} is not a valid tag for {self._cls}. Available tags: {_get_field_names(self._cls)}"
  19. return str(self) == other
  20. def __hash__(self):
  21. return hash(str(self))
  22. @functools.lru_cache(maxsize=None)
  23. def _get_field_names(cls) -> Set[str]:
  24. return {f.name for f in fields(cls)}
  25. class _Union:
  26. _type: _UnionTag
  27. @classmethod
  28. def create(cls, **kwargs):
  29. assert len(kwargs) == 1
  30. obj = cls(**{**{f.name: None for f in fields(cls)}, **kwargs}) # type: ignore[arg-type]
  31. obj._type = _UnionTag.create(next(iter(kwargs.keys())), cls)
  32. return obj
  33. def __post_init__(self):
  34. assert not any(f.name in ("type", "_type", "create", "value") for f in fields(self)) # type: ignore[arg-type, misc]
  35. @property
  36. def type(self) -> str:
  37. try:
  38. return self._type
  39. except AttributeError as e:
  40. raise RuntimeError(
  41. f"Please use {type(self).__name__}.create to instantiate the union type."
  42. ) from e
  43. @property
  44. def value(self):
  45. return getattr(self, self.type)
  46. def __getattribute__(self, name):
  47. attr = super().__getattribute__(name)
  48. if attr is None and name in _get_field_names(type(self)) and name != self.type: # type: ignore[arg-type]
  49. raise AttributeError(f"Field {name} is not set.")
  50. return attr
  51. def __str__(self):
  52. return self.__repr__()
  53. def __repr__(self):
  54. return f"{type(self).__name__}({self.type}={getattr(self, self.type)})"