bytecode_analysis.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. # mypy: allow-untyped-defs
  2. import bisect
  3. import dataclasses
  4. import dis
  5. import sys
  6. from typing import Any, Set, Union
  7. TERMINAL_OPCODES = {
  8. dis.opmap["RETURN_VALUE"],
  9. dis.opmap["JUMP_FORWARD"],
  10. dis.opmap["RAISE_VARARGS"],
  11. # TODO(jansel): double check exception handling
  12. }
  13. if sys.version_info >= (3, 9):
  14. TERMINAL_OPCODES.add(dis.opmap["RERAISE"])
  15. if sys.version_info >= (3, 11):
  16. TERMINAL_OPCODES.add(dis.opmap["JUMP_BACKWARD"])
  17. TERMINAL_OPCODES.add(dis.opmap["JUMP_FORWARD"])
  18. else:
  19. TERMINAL_OPCODES.add(dis.opmap["JUMP_ABSOLUTE"])
  20. if sys.version_info >= (3, 12):
  21. TERMINAL_OPCODES.add(dis.opmap["RETURN_CONST"])
  22. JUMP_OPCODES = set(dis.hasjrel + dis.hasjabs)
  23. JUMP_OPNAMES = {dis.opname[opcode] for opcode in JUMP_OPCODES}
  24. HASLOCAL = set(dis.haslocal)
  25. HASFREE = set(dis.hasfree)
  26. stack_effect = dis.stack_effect
  27. def get_indexof(insts):
  28. """
  29. Get a mapping from instruction memory address to index in instruction list.
  30. Additionally checks that each instruction only appears once in the list.
  31. """
  32. indexof = {}
  33. for i, inst in enumerate(insts):
  34. assert inst not in indexof
  35. indexof[inst] = i
  36. return indexof
  37. def remove_dead_code(instructions):
  38. """Dead code elimination"""
  39. indexof = get_indexof(instructions)
  40. live_code = set()
  41. def find_live_code(start):
  42. for i in range(start, len(instructions)):
  43. if i in live_code:
  44. return
  45. live_code.add(i)
  46. inst = instructions[i]
  47. if inst.exn_tab_entry:
  48. find_live_code(indexof[inst.exn_tab_entry.target])
  49. if inst.opcode in JUMP_OPCODES:
  50. find_live_code(indexof[inst.target])
  51. if inst.opcode in TERMINAL_OPCODES:
  52. return
  53. find_live_code(0)
  54. # change exception table entries if start/end instructions are dead
  55. # assumes that exception table entries have been propagated,
  56. # e.g. with bytecode_transformation.propagate_inst_exn_table_entries,
  57. # and that instructions with an exn_tab_entry lies within its start/end.
  58. if sys.version_info >= (3, 11):
  59. live_idx = sorted(live_code)
  60. for i, inst in enumerate(instructions):
  61. if i in live_code and inst.exn_tab_entry:
  62. # find leftmost live instruction >= start
  63. start_idx = bisect.bisect_left(
  64. live_idx, indexof[inst.exn_tab_entry.start]
  65. )
  66. assert start_idx < len(live_idx)
  67. # find rightmost live instruction <= end
  68. end_idx = (
  69. bisect.bisect_right(live_idx, indexof[inst.exn_tab_entry.end]) - 1
  70. )
  71. assert end_idx >= 0
  72. assert live_idx[start_idx] <= i <= live_idx[end_idx]
  73. inst.exn_tab_entry.start = instructions[live_idx[start_idx]]
  74. inst.exn_tab_entry.end = instructions[live_idx[end_idx]]
  75. return [inst for i, inst in enumerate(instructions) if i in live_code]
  76. def remove_pointless_jumps(instructions):
  77. """Eliminate jumps to the next instruction"""
  78. pointless_jumps = {
  79. id(a)
  80. for a, b in zip(instructions, instructions[1:])
  81. if a.opname == "JUMP_ABSOLUTE" and a.target is b
  82. }
  83. return [inst for inst in instructions if id(inst) not in pointless_jumps]
  84. def propagate_line_nums(instructions):
  85. """Ensure every instruction has line number set in case some are removed"""
  86. cur_line_no = None
  87. def populate_line_num(inst):
  88. nonlocal cur_line_no
  89. if inst.starts_line:
  90. cur_line_no = inst.starts_line
  91. inst.starts_line = cur_line_no
  92. for inst in instructions:
  93. populate_line_num(inst)
  94. def remove_extra_line_nums(instructions):
  95. """Remove extra starts line properties before packing bytecode"""
  96. cur_line_no = None
  97. def remove_line_num(inst):
  98. nonlocal cur_line_no
  99. if inst.starts_line is None:
  100. return
  101. elif inst.starts_line == cur_line_no:
  102. inst.starts_line = None
  103. else:
  104. cur_line_no = inst.starts_line
  105. for inst in instructions:
  106. remove_line_num(inst)
  107. @dataclasses.dataclass
  108. class ReadsWrites:
  109. reads: Set[Any]
  110. writes: Set[Any]
  111. visited: Set[Any]
  112. def livevars_analysis(instructions, instruction):
  113. indexof = get_indexof(instructions)
  114. must = ReadsWrites(set(), set(), set())
  115. may = ReadsWrites(set(), set(), set())
  116. def walk(state, start):
  117. if start in state.visited:
  118. return
  119. state.visited.add(start)
  120. for i in range(start, len(instructions)):
  121. inst = instructions[i]
  122. if inst.opcode in HASLOCAL or inst.opcode in HASFREE:
  123. if "LOAD" in inst.opname or "DELETE" in inst.opname:
  124. if inst.argval not in must.writes:
  125. state.reads.add(inst.argval)
  126. elif "STORE" in inst.opname:
  127. state.writes.add(inst.argval)
  128. elif inst.opname == "MAKE_CELL":
  129. pass
  130. else:
  131. raise NotImplementedError(f"unhandled {inst.opname}")
  132. if inst.exn_tab_entry:
  133. walk(may, indexof[inst.exn_tab_entry.target])
  134. if inst.opcode in JUMP_OPCODES:
  135. walk(may, indexof[inst.target])
  136. state = may
  137. if inst.opcode in TERMINAL_OPCODES:
  138. return
  139. walk(must, indexof[instruction])
  140. return must.reads | may.reads
  141. @dataclasses.dataclass
  142. class FixedPointBox:
  143. value: bool = True
  144. @dataclasses.dataclass
  145. class StackSize:
  146. low: Union[int, float]
  147. high: Union[int, float]
  148. fixed_point: FixedPointBox
  149. def zero(self):
  150. self.low = 0
  151. self.high = 0
  152. self.fixed_point.value = False
  153. def offset_of(self, other, n):
  154. prior = (self.low, self.high)
  155. self.low = min(self.low, other.low + n)
  156. self.high = max(self.high, other.high + n)
  157. if (self.low, self.high) != prior:
  158. self.fixed_point.value = False
  159. def exn_tab_jump(self, depth):
  160. prior = (self.low, self.high)
  161. self.low = min(self.low, depth)
  162. self.high = max(self.high, depth)
  163. if (self.low, self.high) != prior:
  164. self.fixed_point.value = False
  165. def stacksize_analysis(instructions) -> Union[int, float]:
  166. assert instructions
  167. fixed_point = FixedPointBox()
  168. stack_sizes = {
  169. inst: StackSize(float("inf"), float("-inf"), fixed_point)
  170. for inst in instructions
  171. }
  172. stack_sizes[instructions[0]].zero()
  173. for _ in range(100):
  174. if fixed_point.value:
  175. break
  176. fixed_point.value = True
  177. for inst, next_inst in zip(instructions, instructions[1:] + [None]):
  178. stack_size = stack_sizes[inst]
  179. # CALL_FINALLY in Python 3.8 is handled differently when determining stack depth.
  180. # See https://github.com/python/cpython/blob/3.8/Python/compile.c#L5450.
  181. # Essentially, the stack effect of CALL_FINALLY is computed with jump=True,
  182. # but the resulting stack depth is propagated to the next instruction, not the
  183. # jump target.
  184. is_call_finally = (
  185. sys.version_info < (3, 9) and inst.opcode == dis.opmap["CALL_FINALLY"]
  186. )
  187. if inst.opcode not in TERMINAL_OPCODES:
  188. assert next_inst is not None, f"missing next inst: {inst}"
  189. # total stack effect of CALL_FINALLY and END_FINALLY in 3.8 is 0
  190. eff = (
  191. 0
  192. if is_call_finally
  193. else stack_effect(inst.opcode, inst.arg, jump=False)
  194. )
  195. stack_sizes[next_inst].offset_of(stack_size, eff)
  196. if inst.opcode in JUMP_OPCODES and not is_call_finally:
  197. stack_sizes[inst.target].offset_of(
  198. stack_size, stack_effect(inst.opcode, inst.arg, jump=True)
  199. )
  200. if inst.exn_tab_entry:
  201. # see https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt
  202. # on why depth is computed this way.
  203. depth = inst.exn_tab_entry.depth + int(inst.exn_tab_entry.lasti) + 1
  204. stack_sizes[inst.exn_tab_entry.target].exn_tab_jump(depth)
  205. if False:
  206. for inst in instructions:
  207. stack_size = stack_sizes[inst]
  208. print(stack_size.low, stack_size.high, inst)
  209. low = min(x.low for x in stack_sizes.values())
  210. high = max(x.high for x in stack_sizes.values())
  211. assert fixed_point.value, "failed to reach fixed point"
  212. assert low >= 0
  213. return high