reference.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. # reference python implementations for C ops
  7. import torch
  8. from functorch._C import dim as _C
  9. from . import op_properties
  10. from .batch_tensor import _enable_layers
  11. from .tree_map import tree_flatten, tree_map
  12. DimList = _C.DimList
  13. import operator
  14. from functools import reduce
  15. # use dict to avoid writing C++ bindings for set
  16. pointwise = set(op_properties.pointwise)
  17. def prod(x):
  18. return reduce(operator.mul, x, 1)
  19. def _wrap_dim(d, N, keepdim):
  20. from . import Dim
  21. if isinstance(d, Dim):
  22. assert not keepdim, "cannot preserve first-class dimensions with keepdim=True"
  23. return d
  24. elif d >= 0:
  25. return d - N
  26. else:
  27. return d
  28. def _dims(d, N, keepdim, single_dim):
  29. from . import Dim
  30. if isinstance(d, (Dim, int)):
  31. return ltuple((_wrap_dim(d, N, keepdim),))
  32. assert not single_dim, f"expected a single dimension or int but found: {d}"
  33. return ltuple(_wrap_dim(x, N, keepdim) for x in d)
  34. def _bind_dims_to_size(lhs_size, rhs, lhs_debug):
  35. from . import DimensionMismatchError
  36. not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound)
  37. if len(not_bound) == 1:
  38. idx, d = not_bound[0]
  39. rhs_so_far = prod(r.size for r in rhs if r.is_bound)
  40. if lhs_size % rhs_so_far != 0:
  41. rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs)
  42. raise DimensionMismatchError(
  43. f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}"
  44. )
  45. new_size = lhs_size // rhs_so_far
  46. d.size = new_size
  47. elif len(not_bound) > 1:
  48. rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs)
  49. raise DimensionMismatchError(
  50. f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}"
  51. )
  52. else:
  53. rhs_size = prod(r.size for r in rhs)
  54. if lhs_size != rhs_size:
  55. raise DimensionMismatchError(
  56. f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}"
  57. )
  58. def _tensor_levels(inp):
  59. from . import _Tensor
  60. if isinstance(inp, _Tensor):
  61. return inp._tensor, llist(inp._levels), inp._has_device
  62. else:
  63. return inp, llist(range(-inp.ndim, 0)), True
  64. def _match_levels(v, from_levels, to_levels):
  65. view = []
  66. permute = []
  67. requires_view = False
  68. size = v.size()
  69. for t in to_levels:
  70. try:
  71. idx = from_levels.index(t)
  72. permute.append(idx)
  73. view.append(size[idx])
  74. except ValueError:
  75. view.append(1)
  76. requires_view = True
  77. if permute != list(range(len(permute))):
  78. v = v.permute(*permute)
  79. if requires_view:
  80. v = v.view(*view)
  81. return v
  82. # make a single dimension positional but do not permute it,
  83. # used to do multi-tensor operators where the dim being acted on
  84. # should not physically move if possible
  85. def _positional_no_permute(self, dim, expand_dim=False):
  86. from . import Tensor
  87. ptensor, levels = self._tensor, llist(self._levels)
  88. try:
  89. idx = levels.index(dim)
  90. except ValueError:
  91. if not expand_dim:
  92. raise
  93. idx = 0
  94. ptensor = ptensor.expand(dim.size, *ptensor.size())
  95. levels.insert(0, 0)
  96. idx_batched = 0
  97. for i in range(idx):
  98. if isinstance(levels[i], int):
  99. levels[i] -= 1
  100. idx_batched += 1
  101. levels[idx] = -idx_batched - 1
  102. return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched
  103. def seq(a, b):
  104. from . import Dim
  105. if isinstance(a, Dim) != isinstance(b, Dim):
  106. return False
  107. if isinstance(a, Dim):
  108. return a is b
  109. else:
  110. return a == b
  111. class isin:
  112. def __contains__(self, item):
  113. for x in self:
  114. if seq(item, x):
  115. return True
  116. return False
  117. def index(self, item):
  118. for i, x in enumerate(self):
  119. if seq(item, x):
  120. return i
  121. raise ValueError
  122. class llist(isin, list):
  123. pass
  124. class ltuple(isin, tuple):
  125. pass
  126. empty_dict = {}
  127. @classmethod
  128. def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
  129. from . import _Tensor, Tensor, TensorLike
  130. from .delayed_mul_tensor import DelayedMulTensor
  131. if orig is torch.Tensor.__mul__:
  132. lhs, rhs = args
  133. if (
  134. isinstance(lhs, _Tensor)
  135. and isinstance(rhs, _Tensor)
  136. and lhs.ndim == 0
  137. and rhs.ndim == 0
  138. ):
  139. return DelayedMulTensor(lhs, rhs)
  140. all_dims = llist()
  141. flat_args, unflatten = tree_flatten((args, kwargs))
  142. device_holding_tensor = None
  143. for f in flat_args:
  144. if isinstance(f, _Tensor):
  145. if f._has_device:
  146. device_holding_tensor = f._batchtensor
  147. for d in f.dims:
  148. if d not in all_dims:
  149. all_dims.append(d)
  150. def unwrap(t):
  151. if isinstance(t, _Tensor):
  152. r = t._batchtensor
  153. if device_holding_tensor is not None and not t._has_device:
  154. r = r.to(device=device_holding_tensor.device)
  155. return r
  156. return t
  157. if orig in pointwise:
  158. result_levels = llist()
  159. arg_levels = llist()
  160. to_expand = []
  161. for i, f in enumerate(flat_args):
  162. if isinstance(f, TensorLike):
  163. ptensor, levels, _ = _tensor_levels(f)
  164. if (
  165. isinstance(f, _Tensor)
  166. and not f._has_device
  167. and device_holding_tensor is not None
  168. ):
  169. ptensor = ptensor.to(device=device_holding_tensor.device)
  170. flat_args[i] = ptensor
  171. for l in levels:
  172. if l not in result_levels:
  173. result_levels.append(l)
  174. to_expand.append((i, levels))
  175. for i, levels in to_expand:
  176. flat_args[i] = _match_levels(flat_args[i], levels, result_levels)
  177. args, kwargs = unflatten(flat_args)
  178. result = orig(*args, **kwargs)
  179. def wrap(t):
  180. if isinstance(t, TensorLike):
  181. return Tensor.from_positional(
  182. t, result_levels, device_holding_tensor is not None
  183. )
  184. return t
  185. return tree_map(wrap, result)
  186. else:
  187. def wrap(t):
  188. if isinstance(t, TensorLike):
  189. return Tensor.from_batched(t, device_holding_tensor is not None)
  190. return t
  191. with _enable_layers(all_dims):
  192. print(f"batch_tensor for {orig}")
  193. args, kwargs = unflatten(unwrap(f) for f in flat_args)
  194. result = orig(*args, **kwargs)
  195. # print("END", orig)
  196. return tree_map(wrap, result)
  197. def positional(self, *dims):
  198. from . import Dim, DimensionBindError, Tensor
  199. ptensor, levels = self._tensor, llist(self._levels)
  200. flat_dims = llist()
  201. view = []
  202. needs_view = False
  203. ndim = self.ndim
  204. for d in dims:
  205. if isinstance(d, DimList):
  206. flat_dims.extend(d)
  207. view.extend(e.size for e in d)
  208. elif isinstance(d, Dim):
  209. flat_dims.append(d)
  210. view.append(d.size)
  211. elif isinstance(d, int):
  212. d = _wrap_dim(d, ndim, False)
  213. flat_dims.append(d)
  214. view.append(ptensor.size(d))
  215. else:
  216. flat_dims.extend(d)
  217. view.append(prod(e.size for e in d))
  218. needs_view = True
  219. permute = list(range(len(levels)))
  220. nflat = len(flat_dims)
  221. for i, d in enumerate(flat_dims):
  222. try:
  223. idx = levels.index(d)
  224. except ValueError as e:
  225. raise DimensionBindError(
  226. f"tensor of dimensions {self.dims} does not contain dim {d}"
  227. ) from e
  228. p = permute[idx]
  229. del levels[idx]
  230. del permute[idx]
  231. levels.insert(i, 0)
  232. permute.insert(i, p)
  233. ptensor = ptensor.permute(*permute)
  234. seen = 0
  235. for i in range(len(levels) - 1, -1, -1):
  236. if isinstance(levels[i], int):
  237. seen += 1
  238. levels[i] = -seen
  239. result = Tensor.from_positional(ptensor, levels, self._has_device)
  240. if needs_view:
  241. result = result.reshape(*view, *result.size()[len(flat_dims) :])
  242. return result
  243. def _contains_dim(input):
  244. from . import Dim
  245. for i in input:
  246. if isinstance(i, Dim):
  247. return True
  248. def expand(self, *sizes):
  249. if not _contains_dim(sizes):
  250. return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes))
  251. dims = sizes
  252. sizes = [d.size for d in dims] + [-1] * self.ndim
  253. self = self.expand(*sizes)
  254. return self[dims]
  255. _not_present = object()
  256. def _getarg(name, offset, args, kwargs, default):
  257. if len(args) > offset:
  258. return args[offset]
  259. return kwargs.get(name, default)
  260. def _patcharg(name, offset, args, kwargs, value):
  261. if len(args) > offset:
  262. args[offset] = value
  263. else:
  264. kwargs[name] = value
  265. def _wrap(
  266. orig, dim_offset=0, keepdim_offset=1, dim_name="dim", single_dim=False, reduce=True
  267. ):
  268. from . import Dim, Tensor, TensorLike
  269. def fn(self, *args, **kwargs):
  270. dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present)
  271. if dim is _not_present or (single_dim and not isinstance(dim, Dim)):
  272. with _enable_layers(self.dims):
  273. print(f"dim fallback batch_tensor for {orig}")
  274. return Tensor.from_batched(
  275. orig(self._batchtensor, *args, **kwargs), self._has_device
  276. )
  277. keepdim = (
  278. _getarg("keepdim", keepdim_offset, args, kwargs, False) if reduce else False
  279. )
  280. t, levels = self._tensor, llist(self._levels)
  281. dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim)
  282. dim_indices = tuple(levels.index(d) for d in dims)
  283. if reduce and not keepdim:
  284. new_levels = [l for i, l in enumerate(levels) if i not in dim_indices]
  285. else:
  286. new_levels = levels
  287. if len(dim_indices) == 1:
  288. dim_indices = dim_indices[
  289. 0
  290. ] # so that dims that really only take a single argument work...
  291. args = list(args)
  292. _patcharg(dim_name, dim_offset, args, kwargs, dim_indices)
  293. def wrap(t):
  294. if isinstance(t, TensorLike):
  295. return Tensor.from_positional(t, new_levels, self._has_device)
  296. return t
  297. with _enable_layers(new_levels):
  298. print(f"dim used batch_tensor for {orig}")
  299. r = orig(t, *args, **kwargs)
  300. return tree_map(wrap, r)
  301. return fn
  302. def _def(name, *args, **kwargs):
  303. from . import _Tensor
  304. orig = getattr(torch.Tensor, name)
  305. setattr(_Tensor, name, _wrap(orig, *args, **kwargs))
  306. no_slice = slice(None)
  307. _orig_getitem = torch.Tensor.__getitem__
  308. class dim_tracker:
  309. def __init__(self):
  310. self.dims = llist()
  311. self.count = []
  312. def record(self, d):
  313. if d not in self.dims:
  314. self.dims.append(d)
  315. self.count.append(1)
  316. def __getitem__(self, d):
  317. return self.count[self.dims.index(d)]
  318. def t__getitem__(self, input):
  319. from . import _Tensor, Dim, DimensionBindError, DimList, Tensor, TensorLike
  320. # * bail to original example if we have a single non-Dim tensor, or a non-tensor
  321. # * locate ... or an unbound tensor list, and determine its size, bind dim list
  322. # (remember that None does not count to the total dim count)
  323. # * bind simple dims and dim-packs to their sizes, count the number of uses of each dim,
  324. # produce the re-view if needed
  325. # * for each single-use dim index, replace with no_slice and mark that it will be added
  326. # (keep track of whether we have to call super)
  327. # * call super if needed
  328. # * if we have dims to bind, bind them (it will help if we eliminated ... and None before)
  329. # this handles bool indexing handling, as well as some other simple cases.
  330. is_simple = (
  331. not isinstance(input, Dim)
  332. and not isinstance(input, (tuple, list))
  333. and
  334. # WAR for functorch bug where zero time tensors in getitem are not handled correctly.
  335. not (isinstance(input, TensorLike) and input.ndim == 0)
  336. )
  337. if is_simple:
  338. if isinstance(self, _Tensor):
  339. return _Tensor.__torch_function__(_orig_getitem, None, (self, input))
  340. else:
  341. return _orig_getitem(self, input)
  342. # can further optimize this case
  343. if not isinstance(input, tuple):
  344. input = [input]
  345. else:
  346. input = list(input)
  347. dims_indexed = 0
  348. expanding_object = None
  349. dimlists = []
  350. for i, s in enumerate(input):
  351. if s is ... or isinstance(s, DimList) and not s.is_bound:
  352. if expanding_object is not None:
  353. msg = (
  354. "at most one ... or unbound dimension list can exist in indexing list but"
  355. f" found 2 at offsets {i} and {expanding_object}"
  356. )
  357. raise DimensionBindError(msg)
  358. expanding_object = i
  359. if isinstance(s, DimList):
  360. dims_indexed += len(s) if s.is_bound else 0
  361. dimlists.append(i)
  362. elif s is not None and s is not ...:
  363. dims_indexed += 1
  364. ndim = self.ndim
  365. if dims_indexed > ndim:
  366. raise IndexError(
  367. f"at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions."
  368. )
  369. if expanding_object is not None:
  370. expanding_ndims = ndim - dims_indexed
  371. obj = input[expanding_object]
  372. if obj is ...:
  373. input[expanding_object : expanding_object + 1] = [
  374. no_slice
  375. ] * expanding_ndims
  376. else:
  377. obj.bind_len(expanding_ndims)
  378. # flatten the dimslists into the indexing
  379. for i in reversed(dimlists):
  380. input[i : i + 1] = input[i]
  381. dims_indexed = 0
  382. requires_view = False
  383. size = self.size()
  384. view_sizes = []
  385. dims_seen = dim_tracker()
  386. def add_dims(t):
  387. if not isinstance(t, _Tensor):
  388. return
  389. for d in t.dims:
  390. dims_seen.record(d)
  391. add_dims(self)
  392. dim_packs = []
  393. for i, idx in enumerate(input):
  394. if idx is None:
  395. input[i] = no_slice
  396. view_sizes.append(1)
  397. requires_view = True
  398. else:
  399. sz = size[dims_indexed]
  400. if isinstance(idx, Dim):
  401. idx.size = sz
  402. dims_seen.record(idx)
  403. view_sizes.append(sz)
  404. elif isinstance(idx, (tuple, list)) and idx and isinstance(idx[0], Dim):
  405. for d in idx:
  406. dims_seen.record(idx)
  407. _bind_dims_to_size(sz, idx, f"offset {i}")
  408. view_sizes.extend(d.size for d in idx)
  409. requires_view = True
  410. dim_packs.append(i)
  411. else:
  412. add_dims(idx)
  413. view_sizes.append(sz)
  414. dims_indexed += 1
  415. if requires_view:
  416. self = self.view(*view_sizes)
  417. for i in reversed(dim_packs):
  418. input[i : i + 1] = input[i]
  419. # currenty:
  420. # input is flat, containing either Dim, or Tensor, or something valid for standard indexing
  421. # self may have first-class dims as well.
  422. # to index:
  423. # drop the first class dims from self, they just become direct indices of their positions
  424. # figure out the dimensions of the indexing tensors: union of all the dims in the tensors in the index.
  425. # these dimensions will appear and need to be bound at the first place tensor occures
  426. if isinstance(self, _Tensor):
  427. ptensor_self, levels = self._tensor, list(self._levels)
  428. # indices to ptensor rather than self which has first-class dimensions
  429. input_it = iter(input)
  430. flat_inputs = [next(input_it) if isinstance(l, int) else l for l in levels]
  431. has_device = self._has_device
  432. to_pad = 0
  433. else:
  434. ptensor_self, flat_inputs = self, input
  435. to_pad = ptensor_self.ndim - len(flat_inputs)
  436. has_device = True
  437. result_levels = []
  438. index_levels = []
  439. tensor_insert_point = None
  440. to_expand = {}
  441. requires_getindex = False
  442. for i, inp in enumerate(flat_inputs):
  443. if isinstance(inp, Dim) and dims_seen[inp] == 1:
  444. flat_inputs[i] = no_slice
  445. result_levels.append(inp)
  446. elif isinstance(inp, TensorLike):
  447. requires_getindex = True
  448. if tensor_insert_point is None:
  449. tensor_insert_point = len(result_levels)
  450. ptensor, levels, _ = _tensor_levels(inp)
  451. to_expand[i] = levels
  452. flat_inputs[i] = ptensor
  453. for l in levels:
  454. if l not in index_levels:
  455. index_levels.append(l)
  456. else:
  457. requires_getindex = True
  458. result_levels.append(0)
  459. if tensor_insert_point is not None:
  460. result_levels[tensor_insert_point:tensor_insert_point] = index_levels
  461. for i, levels in to_expand.items():
  462. flat_inputs[i] = _match_levels(flat_inputs[i], levels, index_levels)
  463. if requires_getindex:
  464. result = _orig_getitem(ptensor_self, flat_inputs)
  465. else:
  466. result = ptensor_self
  467. next_positional = -1
  468. if to_pad > 0:
  469. result_levels.extend([0] * to_pad)
  470. for i, r in enumerate(reversed(result_levels)):
  471. if isinstance(r, int):
  472. result_levels[-1 - i] = next_positional
  473. next_positional -= 1
  474. return Tensor.from_positional(result, result_levels, has_device)
  475. # XXX - dim is optional and can be the outer-most dimension...
  476. def stack(tensors, new_dim, dim=0, out=None):
  477. if isinstance(dim, int):
  478. return torch.stack(tensors, dim, out).index(dim, new_dim)
  479. index = None
  480. if out is not None:
  481. out, index = _positional_no_permute(out, dim, expand_dim=True)
  482. ptensors = []
  483. for t in tensors:
  484. pt, pi = _positional_no_permute(t, dim, expand_dim=True)
  485. if index is not None and pi != index:
  486. pt = pt.move_dim(pi, index)
  487. else:
  488. index = pi
  489. ptensors.append(pt)
  490. pr = torch.stack(ptensors, index, out=out)
  491. return pr.index((index, index + 1), (new_dim, dim))
  492. _orig_split = torch.Tensor.split
  493. def split(self, split_size_or_sections, dim=0):
  494. from . import _Tensor, Dim
  495. if isinstance(split_size_or_sections, int) or any(
  496. isinstance(t, int) for t in split_size_or_sections
  497. ):
  498. if isinstance(dim, Dim):
  499. raise ValueError(
  500. "when dim is specified as a Dim object, split sizes must also be dimensions."
  501. )
  502. return _orig_split(self, split_size_or_sections, dim=dim)
  503. if isinstance(dim, Dim):
  504. assert isinstance(self, _Tensor), f"Tensor does not have dimension {dim}"
  505. self, dim = _positional_no_permute(self, dim)
  506. size = self.size(dim)
  507. total_bound_size = 0
  508. unbound = []
  509. sizes = []
  510. for i, d in enumerate(split_size_or_sections):
  511. if d.is_bound:
  512. sizes.append(d.size)
  513. total_bound_size += d.size
  514. else:
  515. sizes.append(0)
  516. unbound.append(i)
  517. if unbound:
  518. assert (
  519. total_bound_size <= size
  520. ), f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})"
  521. remaining_size = size - total_bound_size
  522. chunk_size = -(-remaining_size // len(unbound))
  523. for u in unbound:
  524. sz = min(chunk_size, remaining_size)
  525. split_size_or_sections[u].size = sz
  526. sizes[u] = sz
  527. remaining_size -= sz
  528. else:
  529. assert (
  530. total_bound_size == size
  531. ), f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})"
  532. return tuple(
  533. t.index(dim, d)
  534. for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim))
  535. )