plot_subgraphs.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. """
  2. =========
  3. Subgraphs
  4. =========
  5. Example of partitioning a directed graph with nodes labeled as
  6. supported and unsupported nodes into a list of subgraphs
  7. that contain only entirely supported or entirely unsupported nodes.
  8. Adopted from
  9. https://github.com/lobpcg/python_examples/blob/master/networkx_example.py
  10. """
  11. import networkx as nx
  12. import matplotlib.pyplot as plt
  13. def graph_partitioning(G, plotting=True):
  14. """Partition a directed graph into a list of subgraphs that contain
  15. only entirely supported or entirely unsupported nodes.
  16. """
  17. # Categorize nodes by their node_type attribute
  18. supported_nodes = {n for n, d in G.nodes(data="node_type") if d == "supported"}
  19. unsupported_nodes = {n for n, d in G.nodes(data="node_type") if d == "unsupported"}
  20. # Make a copy of the graph.
  21. H = G.copy()
  22. # Remove all edges connecting supported and unsupported nodes.
  23. H.remove_edges_from(
  24. (n, nbr, d)
  25. for n, nbrs in G.adj.items()
  26. if n in supported_nodes
  27. for nbr, d in nbrs.items()
  28. if nbr in unsupported_nodes
  29. )
  30. H.remove_edges_from(
  31. (n, nbr, d)
  32. for n, nbrs in G.adj.items()
  33. if n in unsupported_nodes
  34. for nbr, d in nbrs.items()
  35. if nbr in supported_nodes
  36. )
  37. # Collect all removed edges for reconstruction.
  38. G_minus_H = nx.DiGraph()
  39. G_minus_H.add_edges_from(set(G.edges) - set(H.edges))
  40. if plotting:
  41. # Plot the stripped graph with the edges removed.
  42. _node_colors = [c for _, c in H.nodes(data="node_color")]
  43. _pos = nx.spring_layout(H)
  44. plt.figure(figsize=(8, 8))
  45. nx.draw_networkx_edges(H, _pos, alpha=0.3, edge_color="k")
  46. nx.draw_networkx_nodes(H, _pos, node_color=_node_colors)
  47. nx.draw_networkx_labels(H, _pos, font_size=14)
  48. plt.axis("off")
  49. plt.title("The stripped graph with the edges removed.")
  50. plt.show()
  51. # Plot the the edges removed.
  52. _pos = nx.spring_layout(G_minus_H)
  53. plt.figure(figsize=(8, 8))
  54. ncl = [G.nodes[n]["node_color"] for n in G_minus_H.nodes]
  55. nx.draw_networkx_edges(G_minus_H, _pos, alpha=0.3, edge_color="k")
  56. nx.draw_networkx_nodes(G_minus_H, _pos, node_color=ncl)
  57. nx.draw_networkx_labels(G_minus_H, _pos, font_size=14)
  58. plt.axis("off")
  59. plt.title("The removed edges.")
  60. plt.show()
  61. # Find the connected components in the stripped undirected graph.
  62. # And use the sets, specifying the components, to partition
  63. # the original directed graph into a list of directed subgraphs
  64. # that contain only entirely supported or entirely unsupported nodes.
  65. subgraphs = [
  66. H.subgraph(c).copy() for c in nx.connected_components(H.to_undirected())
  67. ]
  68. return subgraphs, G_minus_H
  69. ###############################################################################
  70. # Create an example directed graph.
  71. # ---------------------------------
  72. #
  73. # This directed graph has one input node labeled `in` and plotted in blue color
  74. # and one output node labeled `out` and plotted in magenta color.
  75. # The other six nodes are classified as four `supported` plotted in green color
  76. # and two `unsupported` plotted in red color. The goal is computing a list
  77. # of subgraphs that contain only entirely `supported` or `unsupported` nodes.
  78. G_ex = nx.DiGraph()
  79. G_ex.add_nodes_from(["In"], node_type="input", node_color="b")
  80. G_ex.add_nodes_from(["A", "C", "E", "F"], node_type="supported", node_color="g")
  81. G_ex.add_nodes_from(["B", "D"], node_type="unsupported", node_color="r")
  82. G_ex.add_nodes_from(["Out"], node_type="output", node_color="m")
  83. G_ex.add_edges_from(
  84. [
  85. ("In", "A"),
  86. ("A", "B"),
  87. ("B", "C"),
  88. ("B", "D"),
  89. ("D", "E"),
  90. ("C", "F"),
  91. ("E", "F"),
  92. ("F", "Out"),
  93. ]
  94. )
  95. ###############################################################################
  96. # Plot the original graph.
  97. # ------------------------
  98. #
  99. node_color_list = [nc for _, nc in G_ex.nodes(data="node_color")]
  100. pos = nx.spectral_layout(G_ex)
  101. plt.figure(figsize=(8, 8))
  102. nx.draw_networkx_edges(G_ex, pos, alpha=0.3, edge_color="k")
  103. nx.draw_networkx_nodes(G_ex, pos, alpha=0.8, node_color=node_color_list)
  104. nx.draw_networkx_labels(G_ex, pos, font_size=14)
  105. plt.axis("off")
  106. plt.title("The original graph.")
  107. plt.show()
  108. ###############################################################################
  109. # Calculate the subgraphs with plotting all results of intemediate steps.
  110. # -----------------------------------------------------------------------
  111. #
  112. subgraphs_of_G_ex, removed_edges = graph_partitioning(G_ex, plotting=True)
  113. ###############################################################################
  114. # Plot the results: every subgraph in the list.
  115. # ---------------------------------------------
  116. #
  117. for subgraph in subgraphs_of_G_ex:
  118. _pos = nx.spring_layout(subgraph)
  119. plt.figure(figsize=(8, 8))
  120. nx.draw_networkx_edges(subgraph, _pos, alpha=0.3, edge_color="k")
  121. node_color_list_c = [nc for _, nc in subgraph.nodes(data="node_color")]
  122. nx.draw_networkx_nodes(subgraph, _pos, node_color=node_color_list_c)
  123. nx.draw_networkx_labels(subgraph, _pos, font_size=14)
  124. plt.axis("off")
  125. plt.title("One of the subgraphs.")
  126. plt.show()
  127. ###############################################################################
  128. # Put the graph back from the list of subgraphs
  129. # ---------------------------------------------
  130. #
  131. G_ex_r = nx.DiGraph()
  132. # Composing all subgraphs.
  133. for subgraph in subgraphs_of_G_ex:
  134. G_ex_r = nx.compose(G_ex_r, subgraph)
  135. # Adding the previously stored edges.
  136. G_ex_r.add_edges_from(removed_edges.edges())
  137. ###############################################################################
  138. # Check that the original graph and the reconstructed graphs are isomorphic.
  139. # --------------------------------------------------------------------------
  140. #
  141. assert nx.is_isomorphic(G_ex, G_ex_r)
  142. ###############################################################################
  143. # Plot the reconstructed graph.
  144. # -----------------------------
  145. #
  146. node_color_list = [nc for _, nc in G_ex_r.nodes(data="node_color")]
  147. pos = nx.spectral_layout(G_ex_r)
  148. plt.figure(figsize=(8, 8))
  149. nx.draw_networkx_edges(G_ex_r, pos, alpha=0.3, edge_color="k")
  150. nx.draw_networkx_nodes(G_ex_r, pos, alpha=0.8, node_color=node_color_list)
  151. nx.draw_networkx_labels(G_ex_r, pos, font_size=14)
  152. plt.axis("off")
  153. plt.title("The reconstructed graph.")
  154. plt.show()