create_parameter_op.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. # mypy: allow-untyped-defs
  2. import torch
  3. doc = """
  4. This is used when dynamo traces torch.nn.Parameter, which normally would not trace properly
  5. with AOTAutograd. We instead create a placeholder torch.nn.Parameter before the graph, which
  6. becomes a graph arg and has no storage backing it. At the point in the graph where the parameter
  7. actually should be created we mutate this sacrificial placeholder into it. This allows gradients
  8. to flow into the parameter as if it were an input to the graph (which is the only thing we are
  9. allowed to compute gradients on).
  10. """.strip()
  11. class TracableCreateParameter(torch.autograd.Function):
  12. @staticmethod
  13. def forward(ctx, tensor, placeholder):
  14. assert not tensor.requires_grad
  15. return placeholder.set_(tensor)
  16. @staticmethod
  17. def backward(ctx, grad):
  18. return None, grad # grad flows to placeholder
  19. def tracable_create_parameter(tensor, placeholder):
  20. with torch.set_grad_enabled(placeholder.requires_grad):
  21. out = TracableCreateParameter.apply(tensor, placeholder)
  22. return out
  23. def new_parameter_placeholder(size, dtype, device, requires_grad):
  24. """Create a placeholder to be passed to the above functions"""
  25. result = torch.nn.Parameter(
  26. torch.empty(size, dtype=dtype, device=device), requires_grad=requires_grad
  27. )
  28. # TODO(jansel): alloc followed by free is inefficient, need a way to allocate an unbacked tensor.
  29. # Allocating a zero tensor would causes assert failures in autograd.
  30. result.untyped_storage().resize_(0)
  31. return result