optimize_indexing.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # mypy: allow-untyped-defs
  2. import math
  3. import sympy
  4. import torch
  5. from torch.utils._sympy.value_ranges import ValueRanges
  6. from .ir import LoopBody
  7. from .utils import dominated_nodes
  8. def val_expressable_in_32_bits(val):
  9. if getattr(val, "is_Boolean", False):
  10. return True
  11. if isinstance(val, sympy.Expr):
  12. assert val.is_number
  13. if val.is_Integer or val.is_Boolean:
  14. val = int(val)
  15. else:
  16. val = float(val)
  17. # bound within mantissa
  18. if isinstance(val, float):
  19. return val <= (2**24) and val >= -(2**24)
  20. if isinstance(val, int):
  21. iinfo = torch.iinfo(torch.int32)
  22. return val <= iinfo.max and val >= iinfo.min
  23. raise TypeError(f"Unexpected value {val}")
  24. def range_expressable_in_32_bits(range):
  25. return val_expressable_in_32_bits(range.lower) and val_expressable_in_32_bits(
  26. range.upper
  27. )
  28. def try_to_reduce_precision(node, bounds, indirect_vars, indices, replacement_vals):
  29. # if a downstream use of a node explicitly converts to int32, or float16/float32/float64,
  30. # then it's precision is set for that chain of uses, and we don't need to consider those
  31. # dominated values
  32. def skip_filter(node):
  33. return node.target == "to_dtype" and node.args[2] in (
  34. torch.int32,
  35. torch.float32,
  36. torch.float64,
  37. )
  38. # TODO - there are dominated uses whose dtype does not depend on whether
  39. # we reduce the precision here, e.g. add(int64, int64) one of the args can be reduced to
  40. # int32 without changing the output precision of the node. this case hasn't shown up
  41. for dominated in dominated_nodes([node], skip_filter):
  42. if dominated.target in ["store", "output"]:
  43. continue
  44. if isinstance(dominated.target, str) and "set_indirect" in dominated.target:
  45. idx = int(dominated.target[len("set_indirect") :])
  46. indirect_var = indirect_vars[idx]
  47. # We check that we can compute all the indices it's involved in with int32
  48. for index, expr in indices.items():
  49. if indirect_var in expr.free_symbols:
  50. index_val = replacement_vals[index]
  51. if math.isinf(index_val.lower) or math.isinf(index_val.upper):
  52. return
  53. # all indices are integers, so make sure that we
  54. # use the bounds of integers instead of floats.
  55. # TODO - not sure if we should be doing int/float casts while tracing,
  56. # might interfere with sympy.
  57. index_val_int = ValueRanges[sympy.Expr](
  58. int(index_val.lower), int(index_val.upper)
  59. )
  60. if not range_expressable_in_32_bits(index_val_int):
  61. return
  62. if not range_expressable_in_32_bits(bounds[dominated]):
  63. return
  64. args = list(node.args)
  65. args[2] = torch.int32
  66. node.args = tuple(args)
  67. def indexing_dtype_strength_reduction(loop_body: LoopBody):
  68. """
  69. Performs Value Range Analysis on LoopBody's fx graph to reduce precision of
  70. intermediaries from int64 to int32
  71. """
  72. bv = loop_body.bounds()
  73. int64_dtype_nodes = [
  74. node
  75. for node in loop_body.get_nodes()
  76. if (
  77. node.target == "to_dtype"
  78. and node.args[2] == torch.int64
  79. and node not in bv.unbounded_vars
  80. )
  81. ]
  82. if not int64_dtype_nodes:
  83. return
  84. bounds = bv.get_bounds()
  85. # TODO - if dominated node of one to_dtype is not expressible in int32,
  86. # we should short circuit another to_dtype node if that node also dominates
  87. for node in int64_dtype_nodes:
  88. try_to_reduce_precision(
  89. node,
  90. bounds,
  91. loop_body.indirect_vars,
  92. loop_body.indexing_exprs,
  93. bv.replacement_vals,
  94. )