| 123456789101112131415161718192021222324252627282930313233343536373839 |
- # mypy: allow-untyped-defs
- import torch
- doc = """
- This is used when dynamo traces torch.nn.Parameter, which normally would not trace properly
- with AOTAutograd. We instead create a placeholder torch.nn.Parameter before the graph, which
- becomes a graph arg and has no storage backing it. At the point in the graph where the parameter
- actually should be created we mutate this sacrificial placeholder into it. This allows gradients
- to flow into the parameter as if it were an input to the graph (which is the only thing we are
- allowed to compute gradients on).
- """.strip()
- class TracableCreateParameter(torch.autograd.Function):
- @staticmethod
- def forward(ctx, tensor, placeholder):
- assert not tensor.requires_grad
- return placeholder.set_(tensor)
- @staticmethod
- def backward(ctx, grad):
- return None, grad # grad flows to placeholder
- def tracable_create_parameter(tensor, placeholder):
- with torch.set_grad_enabled(placeholder.requires_grad):
- out = TracableCreateParameter.apply(tensor, placeholder)
- return out
- def new_parameter_placeholder(size, dtype, device, requires_grad):
- """Create a placeholder to be passed to the above functions"""
- result = torch.nn.Parameter(
- torch.empty(size, dtype=dtype, device=device), requires_grad=requires_grad
- )
- # TODO(jansel): alloc followed by free is inefficient, need a way to allocate an unbacked tensor.
- # Allocating a zero tensor would causes assert failures in autograd.
- result.untyped_storage().resize_(0)
- return result
|