quantized_lowerings.py 789 B

12345678910111213141516171819202122232425262728293031
  1. # mypy: allow-untyped-defs
  2. import torch
  3. from . import lowering
  4. quantized = torch.ops.quantized
  5. _quantized = torch.ops._quantized
  6. aten = torch.ops.aten
  7. def register_quantized_ops():
  8. lowering.add_needs_realized_inputs(
  9. [
  10. quantized.max_pool2d,
  11. _quantized.wrapped_fbgemm_pack_gemm_matrix_fp16,
  12. _quantized.wrapped_fbgemm_linear_fp16_weight,
  13. ]
  14. )
  15. lowering.make_fallback(quantized.max_pool2d)
  16. lowering.make_fallback(_quantized.wrapped_fbgemm_pack_gemm_matrix_fp16)
  17. lowering.make_fallback(_quantized.wrapped_fbgemm_linear_fp16_weight)
  18. def register_woq_mm_ops():
  19. lowering.add_needs_realized_inputs(
  20. [
  21. aten._weight_int8pack_mm,
  22. ]
  23. )
  24. lowering.make_fallback(aten._weight_int8pack_mm)