| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- # mypy: allow-untyped-defs
- import math
- import sympy
- import torch
- from torch.utils._sympy.value_ranges import ValueRanges
- from .ir import LoopBody
- from .utils import dominated_nodes
- def val_expressable_in_32_bits(val):
- if getattr(val, "is_Boolean", False):
- return True
- if isinstance(val, sympy.Expr):
- assert val.is_number
- if val.is_Integer or val.is_Boolean:
- val = int(val)
- else:
- val = float(val)
- # bound within mantissa
- if isinstance(val, float):
- return val <= (2**24) and val >= -(2**24)
- if isinstance(val, int):
- iinfo = torch.iinfo(torch.int32)
- return val <= iinfo.max and val >= iinfo.min
- raise TypeError(f"Unexpected value {val}")
- def range_expressable_in_32_bits(range):
- return val_expressable_in_32_bits(range.lower) and val_expressable_in_32_bits(
- range.upper
- )
- def try_to_reduce_precision(node, bounds, indirect_vars, indices, replacement_vals):
- # if a downstream use of a node explicitly converts to int32, or float16/float32/float64,
- # then it's precision is set for that chain of uses, and we don't need to consider those
- # dominated values
- def skip_filter(node):
- return node.target == "to_dtype" and node.args[2] in (
- torch.int32,
- torch.float32,
- torch.float64,
- )
- # TODO - there are dominated uses whose dtype does not depend on whether
- # we reduce the precision here, e.g. add(int64, int64) one of the args can be reduced to
- # int32 without changing the output precision of the node. this case hasn't shown up
- for dominated in dominated_nodes([node], skip_filter):
- if dominated.target in ["store", "output"]:
- continue
- if isinstance(dominated.target, str) and "set_indirect" in dominated.target:
- idx = int(dominated.target[len("set_indirect") :])
- indirect_var = indirect_vars[idx]
- # We check that we can compute all the indices it's involved in with int32
- for index, expr in indices.items():
- if indirect_var in expr.free_symbols:
- index_val = replacement_vals[index]
- if math.isinf(index_val.lower) or math.isinf(index_val.upper):
- return
- # all indices are integers, so make sure that we
- # use the bounds of integers instead of floats.
- # TODO - not sure if we should be doing int/float casts while tracing,
- # might interfere with sympy.
- index_val_int = ValueRanges[sympy.Expr](
- int(index_val.lower), int(index_val.upper)
- )
- if not range_expressable_in_32_bits(index_val_int):
- return
- if not range_expressable_in_32_bits(bounds[dominated]):
- return
- args = list(node.args)
- args[2] = torch.int32
- node.args = tuple(args)
- def indexing_dtype_strength_reduction(loop_body: LoopBody):
- """
- Performs Value Range Analysis on LoopBody's fx graph to reduce precision of
- intermediaries from int64 to int32
- """
- bv = loop_body.bounds()
- int64_dtype_nodes = [
- node
- for node in loop_body.get_nodes()
- if (
- node.target == "to_dtype"
- and node.args[2] == torch.int64
- and node not in bv.unbounded_vars
- )
- ]
- if not int64_dtype_nodes:
- return
- bounds = bv.get_bounds()
- # TODO - if dominated node of one to_dtype is not expressible in int32,
- # we should short circuit another to_dtype node if that node also dominates
- for node in int64_dtype_nodes:
- try_to_reduce_precision(
- node,
- bounds,
- loop_body.indirect_vars,
- loop_body.indexing_exprs,
- bv.replacement_vals,
- )
|