__init__.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605
  1. # mypy: allow-untyped-defs
  2. # The Tensor classes are added to this module by python_tensor.cpp
  3. from typing import Optional, Tuple, List, Union, Any
  4. import torch
  5. from torch._C import _add_docstr, _sparse # type: ignore[attr-defined]
  6. from torch import Tensor
  7. # Semi structured sparsity support
  8. from .semi_structured import (
  9. SparseSemiStructuredTensor,
  10. SparseSemiStructuredTensorCUSPARSELT,
  11. SparseSemiStructuredTensorCUTLASS,
  12. to_sparse_semi_structured
  13. )
  14. # A workaround to support both TorchScript and MyPy:
  15. from typing import TYPE_CHECKING
  16. if TYPE_CHECKING:
  17. from torch.types import _dtype as DType
  18. DimOrDims = Optional[Union[int, Tuple[int, ...], List[int]]]
  19. else:
  20. # The JIT doesn't understand Union, nor torch.dtype here
  21. DType = int
  22. DimOrDims = Optional[Tuple[int]]
  23. __all__ = [
  24. 'addmm',
  25. 'check_sparse_tensor_invariants',
  26. 'mm',
  27. 'sum',
  28. 'softmax',
  29. 'log_softmax',
  30. 'SparseSemiStructuredTensor',
  31. 'SparseSemiStructuredTensorCUTLASS',
  32. 'SparseSemiStructuredTensorCUSPARSELT',
  33. 'to_sparse_semi_structured',
  34. 'as_sparse_gradcheck',
  35. ]
  36. addmm = _add_docstr(_sparse._sparse_addmm, r"""
  37. sparse.addmm(mat, mat1, mat2, *, beta=1., alpha=1.) -> Tensor
  38. This function does exact same thing as :func:`torch.addmm` in the forward,
  39. except that it supports backward for sparse COO matrix :attr:`mat1`.
  40. When :attr:`mat1` is a COO tensor it must have `sparse_dim = 2`.
  41. When inputs are COO tensors, this function also supports backward for both inputs.
  42. Supports both CSR and COO storage formats.
  43. .. note::
  44. This function doesn't support computing derivaties with respect to CSR matrices.
  45. Args:
  46. mat (Tensor): a dense matrix to be added
  47. mat1 (Tensor): a sparse matrix to be multiplied
  48. mat2 (Tensor): a dense matrix to be multiplied
  49. beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`)
  50. alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
  51. """)
  52. mm = _add_docstr(_sparse._sparse_mm, r"""
  53. Performs a matrix multiplication of the sparse matrix :attr:`mat1`
  54. and the (sparse or strided) matrix :attr:`mat2`. Similar to :func:`torch.mm`, if :attr:`mat1` is a
  55. :math:`(n \times m)` tensor, :attr:`mat2` is a :math:`(m \times p)` tensor, out will be a
  56. :math:`(n \times p)` tensor.
  57. When :attr:`mat1` is a COO tensor it must have `sparse_dim = 2`.
  58. When inputs are COO tensors, this function also supports backward for both inputs.
  59. Supports both CSR and COO storage formats.
  60. .. note::
  61. This function doesn't support computing derivaties with respect to CSR matrices.
  62. This function also additionally accepts an optional :attr:`reduce` argument that allows
  63. specification of an optional reduction operation, mathematically performs the following operation:
  64. .. math::
  65. z_{ij} = \bigoplus_{k = 0}^{K - 1} x_{ik} y_{kj}
  66. where :math:`\bigoplus` defines the reduce operator. :attr:`reduce` is implemented only for
  67. CSR storage format on CPU device.
  68. Args:
  69. mat1 (Tensor): the first sparse matrix to be multiplied
  70. mat2 (Tensor): the second matrix to be multiplied, which could be sparse or dense
  71. reduce (str, optional): the reduction operation to apply for non-unique indices
  72. (:obj:`"sum"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`). Default :obj:`"sum"`.
  73. Shape:
  74. The format of the output tensor of this function follows:
  75. - sparse x sparse -> sparse
  76. - sparse x dense -> dense
  77. Example::
  78. >>> a = torch.tensor([[1., 0, 2], [0, 3, 0]]).to_sparse().requires_grad_()
  79. >>> a
  80. tensor(indices=tensor([[0, 0, 1],
  81. [0, 2, 1]]),
  82. values=tensor([1., 2., 3.]),
  83. size=(2, 3), nnz=3, layout=torch.sparse_coo, requires_grad=True)
  84. >>> b = torch.tensor([[0, 1.], [2, 0], [0, 0]], requires_grad=True)
  85. >>> b
  86. tensor([[0., 1.],
  87. [2., 0.],
  88. [0., 0.]], requires_grad=True)
  89. >>> y = torch.sparse.mm(a, b)
  90. >>> y
  91. tensor([[0., 1.],
  92. [6., 0.]], grad_fn=<SparseAddmmBackward0>)
  93. >>> y.sum().backward()
  94. >>> a.grad
  95. tensor(indices=tensor([[0, 0, 1],
  96. [0, 2, 1]]),
  97. values=tensor([1., 0., 2.]),
  98. size=(2, 3), nnz=3, layout=torch.sparse_coo)
  99. >>> c = a.detach().to_sparse_csr()
  100. >>> c
  101. tensor(crow_indices=tensor([0, 2, 3]),
  102. col_indices=tensor([0, 2, 1]),
  103. values=tensor([1., 2., 3.]), size=(2, 3), nnz=3,
  104. layout=torch.sparse_csr)
  105. >>> y1 = torch.sparse.mm(c, b, 'sum')
  106. >>> y1
  107. tensor([[0., 1.],
  108. [6., 0.]], grad_fn=<SparseMmReduceImplBackward0>)
  109. >>> y2 = torch.sparse.mm(c, b, 'max')
  110. >>> y2
  111. tensor([[0., 1.],
  112. [6., 0.]], grad_fn=<SparseMmReduceImplBackward0>)
  113. """)
  114. sampled_addmm = _add_docstr(_sparse.sparse_sampled_addmm, r"""
  115. sparse.sampled_addmm(input, mat1, mat2, *, beta=1., alpha=1., out=None) -> Tensor
  116. Performs a matrix multiplication of the dense matrices :attr:`mat1` and :attr:`mat2` at the locations
  117. specified by the sparsity pattern of :attr:`input`. The matrix :attr:`input` is added to the final result.
  118. Mathematically this performs the following operation:
  119. .. math::
  120. \text{out} = \alpha\ (\text{mat1} \mathbin{@} \text{mat2})*\text{spy}(\text{input}) + \beta\ \text{input}
  121. where :math:`\text{spy}(\text{input})` is the sparsity pattern matrix of :attr:`input`, :attr:`alpha`
  122. and :attr:`beta` are the scaling factors.
  123. :math:`\text{spy}(\text{input})` has value 1 at the positions where :attr:`input` has non-zero values, and 0 elsewhere.
  124. .. note::
  125. :attr:`input` must be a sparse CSR tensor. :attr:`mat1` and :attr:`mat2` must be dense tensors.
  126. Args:
  127. input (Tensor): a sparse CSR matrix of shape `(m, n)` to be added and used to compute
  128. the sampled matrix multiplication
  129. mat1 (Tensor): a dense matrix of shape `(m, k)` to be multiplied
  130. mat2 (Tensor): a dense matrix of shape `(k, n)` to be multiplied
  131. Keyword args:
  132. beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`)
  133. alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
  134. out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`.
  135. Examples::
  136. >>> input = torch.eye(3, device='cuda').to_sparse_csr()
  137. >>> mat1 = torch.randn(3, 5, device='cuda')
  138. >>> mat2 = torch.randn(5, 3, device='cuda')
  139. >>> torch.sparse.sampled_addmm(input, mat1, mat2)
  140. tensor(crow_indices=tensor([0, 1, 2, 3]),
  141. col_indices=tensor([0, 1, 2]),
  142. values=tensor([ 0.2847, -0.7805, -0.1900]), device='cuda:0',
  143. size=(3, 3), nnz=3, layout=torch.sparse_csr)
  144. >>> torch.sparse.sampled_addmm(input, mat1, mat2).to_dense()
  145. tensor([[ 0.2847, 0.0000, 0.0000],
  146. [ 0.0000, -0.7805, 0.0000],
  147. [ 0.0000, 0.0000, -0.1900]], device='cuda:0')
  148. >>> torch.sparse.sampled_addmm(input, mat1, mat2, beta=0.5, alpha=0.5)
  149. tensor(crow_indices=tensor([0, 1, 2, 3]),
  150. col_indices=tensor([0, 1, 2]),
  151. values=tensor([ 0.1423, -0.3903, -0.0950]), device='cuda:0',
  152. size=(3, 3), nnz=3, layout=torch.sparse_csr)
  153. """)
  154. def sum(input: Tensor, dim: DimOrDims = None,
  155. dtype: Optional[DType] = None) -> Tensor:
  156. r"""Return the sum of each row of the given sparse tensor.
  157. Returns the sum of each row of the sparse tensor :attr:`input` in the given
  158. dimensions :attr:`dim`. If :attr:`dim` is a list of dimensions,
  159. reduce over all of them. When sum over all ``sparse_dim``, this method
  160. returns a dense tensor instead of a sparse tensor.
  161. All summed :attr:`dim` are squeezed (see :func:`torch.squeeze`), resulting an output
  162. tensor having :attr:`dim` fewer dimensions than :attr:`input`.
  163. During backward, only gradients at ``nnz`` locations of :attr:`input`
  164. will propagate back. Note that the gradients of :attr:`input` is coalesced.
  165. Args:
  166. input (Tensor): the input sparse tensor
  167. dim (int or tuple of ints): a dimension or a list of dimensions to reduce. Default: reduce
  168. over all dims.
  169. dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor.
  170. Default: dtype of :attr:`input`.
  171. Example::
  172. >>> nnz = 3
  173. >>> dims = [5, 5, 2, 3]
  174. >>> I = torch.cat([torch.randint(0, dims[0], size=(nnz,)),
  175. torch.randint(0, dims[1], size=(nnz,))], 0).reshape(2, nnz)
  176. >>> V = torch.randn(nnz, dims[2], dims[3])
  177. >>> size = torch.Size(dims)
  178. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  179. >>> S = torch.sparse_coo_tensor(I, V, size)
  180. >>> S
  181. tensor(indices=tensor([[2, 0, 3],
  182. [2, 4, 1]]),
  183. values=tensor([[[-0.6438, -1.6467, 1.4004],
  184. [ 0.3411, 0.0918, -0.2312]],
  185. [[ 0.5348, 0.0634, -2.0494],
  186. [-0.7125, -1.0646, 2.1844]],
  187. [[ 0.1276, 0.1874, -0.6334],
  188. [-1.9682, -0.5340, 0.7483]]]),
  189. size=(5, 5, 2, 3), nnz=3, layout=torch.sparse_coo)
  190. # when sum over only part of sparse_dims, return a sparse tensor
  191. >>> torch.sparse.sum(S, [1, 3])
  192. tensor(indices=tensor([[0, 2, 3]]),
  193. values=tensor([[-1.4512, 0.4073],
  194. [-0.8901, 0.2017],
  195. [-0.3183, -1.7539]]),
  196. size=(5, 2), nnz=3, layout=torch.sparse_coo)
  197. # when sum over all sparse dim, return a dense tensor
  198. # with summed dims squeezed
  199. >>> torch.sparse.sum(S, [0, 1, 3])
  200. tensor([-2.6596, -1.1450])
  201. """
  202. if dtype is None:
  203. if dim is not None:
  204. return torch._sparse_sum(input, dim)
  205. else:
  206. return torch._sparse_sum(input)
  207. else:
  208. if dim is not None:
  209. return torch._sparse_sum(input, dim, dtype=dtype)
  210. else:
  211. return torch._sparse_sum(input, dtype=dtype)
  212. softmax = _add_docstr(_sparse._sparse_softmax, r"""
  213. sparse.softmax(input, dim, *, dtype=None) -> Tensor
  214. Applies a softmax function.
  215. Softmax is defined as:
  216. :math:`\text{Softmax}(x_{i}) = \frac{exp(x_i)}{\sum_j exp(x_j)}`
  217. where :math:`i, j` run over sparse tensor indices and unspecified
  218. entries are ignores. This is equivalent to defining unspecified
  219. entries as negative infinity so that :math:`exp(x_k) = 0` when the
  220. entry with index :math:`k` has not specified.
  221. It is applied to all slices along `dim`, and will re-scale them so
  222. that the elements lie in the range `[0, 1]` and sum to 1.
  223. Args:
  224. input (Tensor): input
  225. dim (int): A dimension along which softmax will be computed.
  226. dtype (:class:`torch.dtype`, optional): the desired data type
  227. of returned tensor. If specified, the input tensor is
  228. casted to :attr:`dtype` before the operation is
  229. performed. This is useful for preventing data type
  230. overflows. Default: None
  231. """)
  232. log_softmax = _add_docstr(_sparse._sparse_log_softmax, r"""
  233. sparse.log_softmax(input, dim, *, dtype=None) -> Tensor
  234. Applies a softmax function followed by logarithm.
  235. See :class:`~torch.sparse.softmax` for more details.
  236. Args:
  237. input (Tensor): input
  238. dim (int): A dimension along which softmax will be computed.
  239. dtype (:class:`torch.dtype`, optional): the desired data type
  240. of returned tensor. If specified, the input tensor is
  241. casted to :attr:`dtype` before the operation is
  242. performed. This is useful for preventing data type
  243. overflows. Default: None
  244. """)
  245. spdiags = _add_docstr(
  246. _sparse._spdiags,
  247. r"""
  248. sparse.spdiags(diagonals, offsets, shape, layout=None) -> Tensor
  249. Creates a sparse 2D tensor by placing the values from rows of
  250. :attr:`diagonals` along specified diagonals of the output
  251. The :attr:`offsets` tensor controls which diagonals are set.
  252. - If :attr:`offsets[i]` = 0, it is the main diagonal
  253. - If :attr:`offsets[i]` < 0, it is below the main diagonal
  254. - If :attr:`offsets[i]` > 0, it is above the main diagonal
  255. The number of rows in :attr:`diagonals` must match the length of :attr:`offsets`,
  256. and an offset may not be repeated.
  257. Args:
  258. diagonals (Tensor): Matrix storing diagonals row-wise
  259. offsets (Tensor): The diagonals to be set, stored as a vector
  260. shape (2-tuple of ints): The desired shape of the result
  261. Keyword args:
  262. layout (:class:`torch.layout`, optional): The desired layout of the
  263. returned tensor. ``torch.sparse_coo``, ``torch.sparse_csc`` and ``torch.sparse_csr``
  264. are supported. Default: ``torch.sparse_coo``
  265. Examples:
  266. Set the main and first two lower diagonals of a matrix::
  267. >>> diags = torch.arange(9).reshape(3, 3)
  268. >>> diags
  269. tensor([[0, 1, 2],
  270. [3, 4, 5],
  271. [6, 7, 8]])
  272. >>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3))
  273. >>> s
  274. tensor(indices=tensor([[0, 1, 2, 1, 2, 2],
  275. [0, 1, 2, 0, 1, 0]]),
  276. values=tensor([0, 1, 2, 3, 4, 6]),
  277. size=(3, 3), nnz=6, layout=torch.sparse_coo)
  278. >>> s.to_dense()
  279. tensor([[0, 0, 0],
  280. [3, 1, 0],
  281. [6, 4, 2]])
  282. Change the output layout::
  283. >>> diags = torch.arange(9).reshape(3, 3)
  284. >>> diags
  285. tensor([[0, 1, 2],[3, 4, 5], [6, 7, 8])
  286. >>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3), layout=torch.sparse_csr)
  287. >>> s
  288. tensor(crow_indices=tensor([0, 1, 3, 6]),
  289. col_indices=tensor([0, 0, 1, 0, 1, 2]),
  290. values=tensor([0, 3, 1, 6, 4, 2]), size=(3, 3), nnz=6,
  291. layout=torch.sparse_csr)
  292. >>> s.to_dense()
  293. tensor([[0, 0, 0],
  294. [3, 1, 0],
  295. [6, 4, 2]])
  296. Set partial diagonals of a large output::
  297. >>> diags = torch.tensor([[1, 2], [3, 4]])
  298. >>> offsets = torch.tensor([0, -1])
  299. >>> torch.sparse.spdiags(diags, offsets, (5, 5)).to_dense()
  300. tensor([[1, 0, 0, 0, 0],
  301. [3, 2, 0, 0, 0],
  302. [0, 4, 0, 0, 0],
  303. [0, 0, 0, 0, 0],
  304. [0, 0, 0, 0, 0]])
  305. .. note::
  306. When setting the values along a given diagonal the index into the diagonal
  307. and the index into the row of :attr:`diagonals` is taken as the
  308. column index in the output. This has the effect that when setting a diagonal
  309. with a positive offset `k` the first value along that diagonal will be
  310. the value in position `k` of the row of :attr:`diagonals`
  311. Specifying a positive offset::
  312. >>> diags = torch.tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]])
  313. >>> torch.sparse.spdiags(diags, torch.tensor([0, 1, 2]), (5, 5)).to_dense()
  314. tensor([[1, 2, 3, 0, 0],
  315. [0, 2, 3, 0, 0],
  316. [0, 0, 3, 0, 0],
  317. [0, 0, 0, 0, 0],
  318. [0, 0, 0, 0, 0]])
  319. """)
  320. class check_sparse_tensor_invariants:
  321. """A tool to control checking sparse tensor invariants.
  322. The following options exists to manage sparsr tensor invariants
  323. checking in sparse tensor construction:
  324. 1. Using a context manager:
  325. .. code:: python
  326. with torch.sparse.check_sparse_tensor_invariants():
  327. run_my_model()
  328. 2. Using a procedural approach:
  329. .. code:: python
  330. prev_checks_enabled = torch.sparse.check_sparse_tensor_invariants.is_enabled()
  331. torch.sparse.check_sparse_tensor_invariants.enable()
  332. run_my_model()
  333. if not prev_checks_enabled:
  334. torch.sparse.check_sparse_tensor_invariants.disable()
  335. 3. Using function decoration:
  336. .. code:: python
  337. @torch.sparse.check_sparse_tensor_invariants()
  338. def run_my_model():
  339. ...
  340. run_my_model()
  341. 4. Using ``check_invariants`` keyword argument in sparse tensor constructor call.
  342. For example:
  343. >>> torch.sparse_csr_tensor([0, 1, 3], [0, 1], [1, 2], check_invariants=True)
  344. Traceback (most recent call last):
  345. File "<stdin>", line 1, in <module>
  346. RuntimeError: `crow_indices[..., -1] == nnz` is not satisfied.
  347. """
  348. @staticmethod
  349. def is_enabled():
  350. r"""Return True if the sparse tensor invariants checking is enabled.
  351. .. note::
  352. Use :func:`torch.sparse.check_sparse_tensor_invariants.enable` or
  353. :func:`torch.sparse.check_sparse_tensor_invariants.disable` to
  354. manage the state of the sparse tensor invariants checks.
  355. """
  356. return torch._C._check_sparse_tensor_invariants()
  357. @staticmethod
  358. def enable():
  359. r"""Enable sparse tensor invariants checking in sparse tensor constructors.
  360. .. note::
  361. By default, the sparse tensor invariants checks are disabled. Use
  362. :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled` to
  363. retrieve the current state of sparse tensor invariants checking.
  364. .. note::
  365. The sparse tensor invariants check flag is effective to all sparse
  366. tensor constructors, both in Python and ATen.
  367. The flag can be locally overridden by the ``check_invariants``
  368. optional argument of the sparse tensor constructor functions.
  369. """
  370. torch._C._set_check_sparse_tensor_invariants(True)
  371. @staticmethod
  372. def disable():
  373. r"""Disable sparse tensor invariants checking in sparse tensor constructors.
  374. See :func:`torch.sparse.check_sparse_tensor_invariants.enable` for more information.
  375. """
  376. torch._C._set_check_sparse_tensor_invariants(False)
  377. # context manager support
  378. def __init__(self, enable=True):
  379. self.state = enable
  380. self.saved_state : Optional[bool] = None
  381. def __enter__(self):
  382. if self.saved_state is not None:
  383. raise RuntimeError('This context manager instance is already activated.'
  384. ' Use a different context manager instance for context nesting.')
  385. self.saved_state = self.is_enabled()
  386. torch._C._set_check_sparse_tensor_invariants(self.state)
  387. def __exit__(self, type, value, traceback):
  388. assert self.saved_state is not None
  389. torch._C._set_check_sparse_tensor_invariants(self.saved_state)
  390. self.saved_state = None
  391. # decorator support
  392. def __call__(self, mth):
  393. def test_mth(*args, **kwargs):
  394. with type(self)(self.state):
  395. return mth(*args, **kwargs)
  396. return test_mth
  397. def as_sparse_gradcheck(gradcheck):
  398. """Decorate function, to extend gradcheck for sparse tensors.
  399. Decorator for torch.autograd.gradcheck or its functools.partial
  400. variants that extends the gradcheck function with support to input
  401. functions that operate on or/and return sparse tensors.
  402. The specified gradcheck function itself is guaranteed to operate
  403. on strided tensors only.
  404. For example:
  405. >>> gradcheck = torch.sparse.as_sparse_gradcheck(torch.autograd.gradcheck)
  406. >>> x = torch.tensor([[0, 1], [2, 3]], dtype=torch.float64).to_sparse_coo().requires_grad_(True)
  407. >>> gradcheck(lambda x: x.to_sparse_csr(), x)
  408. True
  409. """
  410. def gradcheck_with_sparse_support(func, inputs, **kwargs):
  411. """
  412. Create gradcheck with support for sparse tensors.
  413. Same as :func:`torch.autograd.gradcheck` but with sparse tensors inputs and outputs support.
  414. """
  415. masked = kwargs.pop('masked', False)
  416. sparse_layouts = {torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}
  417. sparse_compressed_layouts = {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}
  418. sparse_block_layouts = {torch.sparse_bsr, torch.sparse_bsc}
  419. STRIDED_REPRESENTATION = '__STRIDED_REPRESENTATION__'
  420. def convert_to_strided_representation(args):
  421. """Convert differentiable non-strided tensors to a representation containing differentiable strided tensors."""
  422. if not isinstance(args, (list, tuple)):
  423. args = args,
  424. new_args: List[Any] = []
  425. for obj in args:
  426. if isinstance(obj, torch.Tensor) and obj.requires_grad and obj.layout in sparse_layouts:
  427. d = dict(layout=obj.layout, shape=obj.shape)
  428. if not masked:
  429. # Materialize unspecified elements with zero values
  430. batch_dim = obj.ndim - obj.dense_dim() - obj.sparse_dim()
  431. blocksize = obj.values().shape[batch_dim + 1:batch_dim + 3] if obj.layout in sparse_block_layouts else None
  432. full_mask = torch.ones(obj.shape, device=obj.device, dtype=torch.bool).to_sparse(
  433. layout=obj.layout, blocksize=blocksize, dense_dim=obj.dense_dim())
  434. obj = obj.to_dense().sparse_mask(full_mask)
  435. if obj.layout is torch.sparse_coo:
  436. d.update(indices=obj._indices(), is_coalesced=obj.is_coalesced())
  437. values = obj._values()
  438. elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}:
  439. d.update(compressed_indices=obj.crow_indices(), plain_indices=obj.col_indices())
  440. values = obj.values()
  441. else:
  442. d.update(compressed_indices=obj.ccol_indices(), plain_indices=obj.row_indices())
  443. values = obj.values()
  444. new_args.extend((STRIDED_REPRESENTATION, d, values.requires_grad_(True)))
  445. else:
  446. new_args.append(obj)
  447. return tuple(new_args)
  448. def restore_from_strided_representation(args):
  449. """Restore non-strided differentiable tensosr from their strided representations."""
  450. new_args = []
  451. args = list(args)
  452. while args:
  453. a = args.pop(0)
  454. if a == STRIDED_REPRESENTATION:
  455. d, values = args.pop(0), args.pop(0)
  456. if d['layout'] is torch.sparse_coo:
  457. a = torch.sparse_coo_tensor(d['indices'], values, size=d['shape'], is_coalesced=d['is_coalesced'])
  458. elif d['layout'] in sparse_compressed_layouts:
  459. a = torch.sparse_compressed_tensor(d['compressed_indices'], d['plain_indices'], values,
  460. size=d['shape'], layout=d['layout'])
  461. else:
  462. raise NotImplementedError(f'conversion of {d["layout"]} strided representation to tensor')
  463. new_args.append(a)
  464. return tuple(new_args)
  465. def func_wrapper(*args, **kwargs):
  466. restored_args = restore_from_strided_representation(args)
  467. # convert differentiable output sparse tensors to strided
  468. # tensors:
  469. outputs = func(*restored_args, **kwargs)
  470. strided_outputs = tuple(outputs) if isinstance(outputs, (list, tuple)) else (outputs,)
  471. strided_outputs = tuple((o.to_dense(masked_grad=masked)
  472. if isinstance(o, torch.Tensor) and o.requires_grad and o.layout in sparse_layouts else o)
  473. for o in strided_outputs)
  474. return strided_outputs if isinstance(outputs, (list, tuple)) else strided_outputs[0]
  475. args = (func_wrapper, convert_to_strided_representation(inputs))
  476. return gradcheck(*args, **kwargs)
  477. return gradcheck_with_sparse_support