dataframes.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. # mypy: allow-untyped-defs
  2. from typing import Any, Dict, List, Optional
  3. from torch.utils.data.datapipes._decorator import functional_datapipe
  4. from torch.utils.data.datapipes.datapipe import DFIterDataPipe, IterDataPipe
  5. from torch.utils.data.datapipes.dataframe.structures import DataChunkDF
  6. # TODO(VitalyFedyunin): Add error when two different traces get combined
  7. __all__ = [
  8. "Capture",
  9. "CaptureA",
  10. "CaptureAdd",
  11. "CaptureCall",
  12. "CaptureControl",
  13. "CaptureDataFrame",
  14. "CaptureDataFrameWithDataPipeOps",
  15. "CaptureF",
  16. "CaptureGetAttr",
  17. "CaptureGetItem",
  18. "CaptureInitial",
  19. "CaptureLikeMock",
  20. "CaptureMul",
  21. "CaptureSetItem",
  22. "CaptureSub",
  23. "CaptureVariable",
  24. "CaptureVariableAssign",
  25. "DataFrameTracer",
  26. "DataFrameTracedOps",
  27. "disable_capture",
  28. "get_val",
  29. ]
  30. def disable_capture():
  31. CaptureControl.disabled = True
  32. class CaptureControl:
  33. disabled = False
  34. class DataFrameTracedOps(DFIterDataPipe):
  35. def __init__(self, source_datapipe, output_var):
  36. self.source_datapipe = source_datapipe
  37. self.output_var = output_var
  38. def __iter__(self):
  39. for item in self.source_datapipe:
  40. yield self.output_var.apply_ops(item)
  41. # TODO(VitalyFedyunin): Extract this list from the DFIterDataPipe registred functions
  42. DATAPIPES_OPS = ['_dataframes_as_tuples', 'groupby', '_dataframes_filter', 'map', 'to_datapipe',
  43. 'shuffle', 'concat', 'batch', '_dataframes_per_row', '_dataframes_concat', '_dataframes_shuffle']
  44. UNIMPLEMENTED_ATTR = ['__deepcopy__', '__setstate__', 'is_shardable', 'apply_sharding']
  45. class Capture:
  46. # TODO: All operations are shared across entire InitialCapture, need to figure out what if we join two captures
  47. def __init__(self, schema_df=None):
  48. self.ctx = {'operations': [], 'variables': [], 'schema_df': schema_df}
  49. def __str__(self):
  50. return self._ops_str()
  51. def _ops_str(self):
  52. res = ""
  53. for op in self.ctx['operations']:
  54. if len(res) > 0:
  55. res += "\n"
  56. res += str(op)
  57. return res
  58. def __getstate__(self):
  59. # TODO(VitalyFedyunin): Currently can't pickle (why?)
  60. self.ctx['schema_df'] = None
  61. for var in self.ctx['variables']:
  62. var.calculated_value = None
  63. state = {}
  64. for item in self.__dict__:
  65. state[item] = getattr(self, item)
  66. return state
  67. def __setstate__(self, state):
  68. for k, v in state.items():
  69. setattr(self, k, v)
  70. def __getattr__(self, attrname):
  71. if attrname == 'kwarg' or attrname == 'kwargs':
  72. raise Exception('no kwargs!') # noqa: TRY002
  73. if attrname in ['__deepcopy__']:
  74. raise AttributeError
  75. result = CaptureGetAttr(self, attrname, ctx=self.ctx)
  76. return result
  77. def __getitem__(self, key):
  78. return CaptureGetItem(self, key, ctx=self.ctx)
  79. def __setitem__(self, key, value):
  80. self.ctx['operations'].append(
  81. CaptureSetItem(self, key, value, ctx=self.ctx))
  82. def __add__(self, add_val):
  83. res = CaptureAdd(self, add_val, ctx=self.ctx)
  84. var = CaptureVariable(res, ctx=self.ctx)
  85. self.ctx['operations'].append(
  86. CaptureVariableAssign(variable=var, value=res, ctx=self.ctx))
  87. return var
  88. def __sub__(self, add_val):
  89. res = CaptureSub(self, add_val, ctx=self.ctx)
  90. var = CaptureVariable(res, ctx=self.ctx)
  91. self.ctx['operations'].append(
  92. CaptureVariableAssign(variable=var, value=res, ctx=self.ctx))
  93. return var
  94. def __mul__(self, add_val):
  95. res = CaptureMul(self, add_val, ctx=self.ctx)
  96. var = CaptureVariable(res, ctx=self.ctx)
  97. t = CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
  98. self.ctx['operations'].append(t)
  99. return var
  100. def _is_context_empty(self):
  101. return len(self.ctx['operations']) == 0 and len(self.ctx['variables']) == 0
  102. def apply_ops_2(self, dataframe):
  103. # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
  104. self.ctx['variables'][0].calculated_value = dataframe
  105. for op in self.ctx['operations']:
  106. op.execute()
  107. @property
  108. def columns(self):
  109. self.apply_ops_2(self.ctx['schema_df'])
  110. value = self.execute()
  111. return value.columns
  112. # TODO(VitalyFedyunin): Add tests
  113. # TODO(VitalyFedyunin): Need to join context if one of them are empty because we used capture
  114. def __call__(self, *args, **kwargs):
  115. # TODO: Check if args or kwargs have more than one different context
  116. if self._is_context_empty():
  117. # TODO: Allow CaptureA to take context from mock
  118. for arg in args:
  119. if isinstance(arg, Capture) and not arg._is_context_empty():
  120. self.ctx = arg.ctx
  121. break
  122. if self._is_context_empty():
  123. for k, v in kwargs.items():
  124. if isinstance(k, Capture) and not k._is_context_empty():
  125. self.ctx = k.ctx
  126. break
  127. if isinstance(v, Capture) and not v._is_context_empty():
  128. self.ctx = v.ctx
  129. break
  130. res = CaptureCall(self, ctx=self.ctx, args=args, kwargs=kwargs)
  131. var = CaptureVariable(None, ctx=self.ctx)
  132. t = CaptureVariableAssign(ctx=self.ctx, variable=var, value=res)
  133. self.ctx['operations'].append(t)
  134. return var
  135. class CaptureF(Capture):
  136. def __init__(self, ctx=None, **kwargs):
  137. if ctx is None:
  138. self.ctx = {'operations': [], 'variables': []}
  139. else:
  140. self.ctx = ctx
  141. self.kwargs = kwargs
  142. class CaptureA(CaptureF):
  143. def __str__(self):
  144. return f"{self.kwargs['name']}"
  145. def execute(self):
  146. value = self.kwargs['real_attribute']
  147. return value
  148. class CaptureLikeMock:
  149. def __init__(self, name):
  150. import unittest.mock as mock
  151. # TODO(VitalyFedyunin): Do not use provate function here, copy own implementation instead.
  152. get_target, attribute = mock._get_target(name) # type: ignore[attr-defined]
  153. self.get_target = get_target
  154. self.attribute = attribute
  155. self.name = name
  156. def __enter__(self):
  157. self.save = getattr(self.get_target(), self.attribute)
  158. capt = CaptureA(name=self.name, real_attribute=self.save)
  159. setattr(self.get_target(), self.attribute, capt)
  160. def __exit__(self, *exc_info):
  161. setattr(self.get_target(), self.attribute, self.save)
  162. class CaptureCall(Capture):
  163. def __init__(self, callable, ctx=None, **kwargs):
  164. if ctx is None:
  165. self.ctx = {'operations': [], 'variables': []}
  166. else:
  167. self.ctx = ctx
  168. self.kwargs = kwargs
  169. self.callable = callable
  170. def __str__(self):
  171. return "{callable}({args},{kwargs})".format(callable=self.callable, **self.kwargs)
  172. def execute(self):
  173. # TODO: VitalyFedyunin execute kwargs and maybe nested structures
  174. executed_args = []
  175. for arg in self.kwargs['args']:
  176. if isinstance(arg, Capture):
  177. executed_args.append(arg.execute())
  178. else:
  179. executed_args.append(arg)
  180. left = get_val(self.callable)
  181. return left(*executed_args, **self.kwargs['kwargs'])
  182. class CaptureVariableAssign(CaptureF):
  183. def __str__(self):
  184. variable = self.kwargs['variable']
  185. value = self.kwargs['value']
  186. return f"{variable} = {value}"
  187. def execute(self):
  188. self.kwargs['variable'].calculated_value = self.kwargs['value'].execute()
  189. class CaptureVariable(Capture):
  190. # TODO(VitalyFedyunin): This should be atomic and thread safe
  191. names_idx = 0
  192. def __init__(self, value, ctx):
  193. if CaptureControl.disabled:
  194. raise Exception('Attempting to create capture variable with capture off') # noqa: TRY002
  195. self.ctx = ctx
  196. self.value = value
  197. self.name = f'var_{CaptureVariable.names_idx}'
  198. CaptureVariable.names_idx += 1
  199. self.ctx['variables'].append(self)
  200. def __str__(self):
  201. return self.name
  202. def execute(self):
  203. return self.calculated_value
  204. def apply_ops(self, dataframe):
  205. # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
  206. self.ctx['variables'][0].calculated_value = dataframe
  207. for op in self.ctx['operations']:
  208. op.execute()
  209. return self.calculated_value
  210. class CaptureGetItem(Capture):
  211. def __init__(self, left, key, ctx):
  212. self.ctx = ctx
  213. self.left = left
  214. self.key = key
  215. def __str__(self):
  216. return f"{self.left}[{get_val(self.key)}]"
  217. def execute(self):
  218. left = self.left.execute()
  219. return left[self.key]
  220. class CaptureSetItem(Capture):
  221. def __init__(self, left, key, value, ctx):
  222. self.ctx = ctx
  223. self.left = left
  224. self.key = key
  225. self.value = value
  226. def __str__(self):
  227. return f"{self.left}[{get_val(self.key)}] = {self.value}"
  228. def execute(self):
  229. left = self.left.execute()
  230. value = self.value.execute()
  231. left[self.key] = value
  232. class CaptureAdd(Capture):
  233. def __init__(self, left, right, ctx):
  234. self.ctx = ctx
  235. self.left = left
  236. self.right = right
  237. def __str__(self):
  238. return f"{self.left} + {self.right}"
  239. def execute(self):
  240. return get_val(self.left) + get_val(self.right)
  241. class CaptureMul(Capture):
  242. def __init__(self, left, right, ctx):
  243. self.ctx = ctx
  244. self.left = left
  245. self.right = right
  246. def __str__(self):
  247. return f"{self.left} * {self.right}"
  248. def execute(self):
  249. return get_val(self.left) * get_val(self.right)
  250. class CaptureSub(Capture):
  251. def __init__(self, left, right, ctx):
  252. self.ctx = ctx
  253. self.left = left
  254. self.right = right
  255. def __str__(self):
  256. return f"{self.left} - {self.right}"
  257. def execute(self):
  258. return get_val(self.left) - get_val(self.right)
  259. class CaptureGetAttr(Capture):
  260. def __init__(self, src, name, ctx):
  261. self.ctx = ctx
  262. self.src = src
  263. self.name = name
  264. def __str__(self):
  265. return f"{self.src}.{self.name}"
  266. def execute(self):
  267. val = get_val(self.src)
  268. return getattr(val, self.name)
  269. def get_val(capture):
  270. if isinstance(capture, Capture):
  271. return capture.execute()
  272. elif isinstance(capture, str):
  273. return f'"{capture}"'
  274. else:
  275. return capture
  276. class CaptureInitial(CaptureVariable):
  277. def __init__(self, schema_df=None):
  278. new_ctx: Dict[str, List[Any]] = {'operations': [], 'variables': [], 'schema_df': schema_df}
  279. super().__init__(None, new_ctx)
  280. self.name = f'input_{self.name}'
  281. class CaptureDataFrame(CaptureInitial):
  282. pass
  283. class CaptureDataFrameWithDataPipeOps(CaptureDataFrame):
  284. def as_datapipe(self):
  285. return DataFrameTracedOps(
  286. self.ctx['variables'][0].source_datapipe, self)
  287. def raw_iterator(self):
  288. return self.as_datapipe().__iter__()
  289. def __iter__(self):
  290. return iter(self._dataframes_as_tuples())
  291. def batch(self, batch_size=10, drop_last: bool = False, wrapper_class=DataChunkDF):
  292. dp = self._dataframes_per_row()._dataframes_concat(batch_size)
  293. dp = dp.as_datapipe().batch(1, drop_last=drop_last, wrapper_class=wrapper_class)
  294. dp._dp_contains_dataframe = True
  295. return dp
  296. def groupby(self,
  297. group_key_fn,
  298. *,
  299. buffer_size=10000,
  300. group_size=None,
  301. guaranteed_group_size=None,
  302. drop_remaining=False):
  303. dp = self._dataframes_per_row()
  304. dp = dp.as_datapipe().groupby(group_key_fn, buffer_size=buffer_size, group_size=group_size,
  305. guaranteed_group_size=guaranteed_group_size, drop_remaining=drop_remaining)
  306. return dp
  307. def shuffle(self, *args, **kwargs):
  308. return self._dataframes_shuffle(*args, **kwargs)
  309. def filter(self, *args, **kwargs):
  310. return self._dataframes_filter(*args, **kwargs)
  311. def collate(self, *args, **kwargs):
  312. raise Exception("Can't collate unbatched DataFrames stream") # noqa: TRY002
  313. def __getattr__(self, attrname): # ?
  314. if attrname in UNIMPLEMENTED_ATTR:
  315. raise AttributeError('Attempting to get ', attrname)
  316. if attrname in DATAPIPES_OPS:
  317. return (self.as_datapipe()).__getattr__(attrname)
  318. return super().__getattr__(attrname)
  319. @functional_datapipe('trace_as_dataframe')
  320. class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe): # type: ignore[misc]
  321. source_datapipe: Optional[Any] = None
  322. # TODO(VitalyFedyunin): Must implement all special functions of datapipes
  323. def set_shuffle_settings(self, *args, **kwargs):
  324. pass
  325. def is_shardable(self):
  326. return False
  327. def __init__(self, source_datapipe, schema_df=None):
  328. self.source_datapipe = source_datapipe
  329. if schema_df is None:
  330. schema_df = next(iter(self.source_datapipe))
  331. super().__init__(schema_df=schema_df)