_digraph.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # mypy: allow-untyped-defs
  2. from collections import deque
  3. from typing import List, Set
  4. class DiGraph:
  5. """Really simple unweighted directed graph data structure to track dependencies.
  6. The API is pretty much the same as networkx so if you add something just
  7. copy their API.
  8. """
  9. def __init__(self):
  10. # Dict of node -> dict of arbitrary attributes
  11. self._node = {}
  12. # Nested dict of node -> successor node -> nothing.
  13. # (didn't implement edge data)
  14. self._succ = {}
  15. # Nested dict of node -> predecessor node -> nothing.
  16. self._pred = {}
  17. # Keep track of the order in which nodes are added to
  18. # the graph.
  19. self._node_order = {}
  20. self._insertion_idx = 0
  21. def add_node(self, n, **kwargs):
  22. """Add a node to the graph.
  23. Args:
  24. n: the node. Can we any object that is a valid dict key.
  25. **kwargs: any attributes you want to attach to the node.
  26. """
  27. if n not in self._node:
  28. self._node[n] = kwargs
  29. self._succ[n] = {}
  30. self._pred[n] = {}
  31. self._node_order[n] = self._insertion_idx
  32. self._insertion_idx += 1
  33. else:
  34. self._node[n].update(kwargs)
  35. def add_edge(self, u, v):
  36. """Add an edge to graph between nodes ``u`` and ``v``
  37. ``u`` and ``v`` will be created if they do not already exist.
  38. """
  39. # add nodes
  40. self.add_node(u)
  41. self.add_node(v)
  42. # add the edge
  43. self._succ[u][v] = True
  44. self._pred[v][u] = True
  45. def successors(self, n):
  46. """Returns an iterator over successor nodes of n."""
  47. try:
  48. return iter(self._succ[n])
  49. except KeyError as e:
  50. raise ValueError(f"The node {n} is not in the digraph.") from e
  51. def predecessors(self, n):
  52. """Returns an iterator over predecessors nodes of n."""
  53. try:
  54. return iter(self._pred[n])
  55. except KeyError as e:
  56. raise ValueError(f"The node {n} is not in the digraph.") from e
  57. @property
  58. def edges(self):
  59. """Returns an iterator over all edges (u, v) in the graph"""
  60. for n, successors in self._succ.items():
  61. for succ in successors:
  62. yield n, succ
  63. @property
  64. def nodes(self):
  65. """Returns a dictionary of all nodes to their attributes."""
  66. return self._node
  67. def __iter__(self):
  68. """Iterate over the nodes."""
  69. return iter(self._node)
  70. def __contains__(self, n):
  71. """Returns True if ``n`` is a node in the graph, False otherwise."""
  72. try:
  73. return n in self._node
  74. except TypeError:
  75. return False
  76. def forward_transitive_closure(self, src: str) -> Set[str]:
  77. """Returns a set of nodes that are reachable from src"""
  78. result = set(src)
  79. working_set = deque(src)
  80. while len(working_set) > 0:
  81. cur = working_set.popleft()
  82. for n in self.successors(cur):
  83. if n not in result:
  84. result.add(n)
  85. working_set.append(n)
  86. return result
  87. def backward_transitive_closure(self, src: str) -> Set[str]:
  88. """Returns a set of nodes that are reachable from src in reverse direction"""
  89. result = set(src)
  90. working_set = deque(src)
  91. while len(working_set) > 0:
  92. cur = working_set.popleft()
  93. for n in self.predecessors(cur):
  94. if n not in result:
  95. result.add(n)
  96. working_set.append(n)
  97. return result
  98. def all_paths(self, src: str, dst: str):
  99. """Returns a subgraph rooted at src that shows all the paths to dst."""
  100. result_graph = DiGraph()
  101. # First compute forward transitive closure of src (all things reachable from src).
  102. forward_reachable_from_src = self.forward_transitive_closure(src)
  103. if dst not in forward_reachable_from_src:
  104. return result_graph
  105. # Second walk the reverse dependencies of dst, adding each node to
  106. # the output graph iff it is also present in forward_reachable_from_src.
  107. # we don't use backward_transitive_closures for optimization purposes
  108. working_set = deque(dst)
  109. while len(working_set) > 0:
  110. cur = working_set.popleft()
  111. for n in self.predecessors(cur):
  112. if n in forward_reachable_from_src:
  113. result_graph.add_edge(n, cur)
  114. # only explore further if its reachable from src
  115. working_set.append(n)
  116. return result_graph.to_dot()
  117. def first_path(self, dst: str) -> List[str]:
  118. """Returns a list of nodes that show the first path that resulted in dst being added to the graph."""
  119. path = []
  120. while dst:
  121. path.append(dst)
  122. candidates = self._pred[dst].keys()
  123. dst, min_idx = "", None
  124. for candidate in candidates:
  125. idx = self._node_order.get(candidate, None)
  126. if idx is None:
  127. break
  128. if min_idx is None or idx < min_idx:
  129. min_idx = idx
  130. dst = candidate
  131. return list(reversed(path))
  132. def to_dot(self) -> str:
  133. """Returns the dot representation of the graph.
  134. Returns:
  135. A dot representation of the graph.
  136. """
  137. edges = "\n".join(f'"{f}" -> "{t}";' for f, t in self.edges)
  138. return f"""\
  139. digraph G {{
  140. rankdir = LR;
  141. node [shape=box];
  142. {edges}
  143. }}
  144. """