| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170 |
- """
- =========
- Subgraphs
- =========
- Example of partitioning a directed graph with nodes labeled as
- supported and unsupported nodes into a list of subgraphs
- that contain only entirely supported or entirely unsupported nodes.
- Adopted from
- https://github.com/lobpcg/python_examples/blob/master/networkx_example.py
- """
- import networkx as nx
- import matplotlib.pyplot as plt
- def graph_partitioning(G, plotting=True):
- """Partition a directed graph into a list of subgraphs that contain
- only entirely supported or entirely unsupported nodes.
- """
- # Categorize nodes by their node_type attribute
- supported_nodes = {n for n, d in G.nodes(data="node_type") if d == "supported"}
- unsupported_nodes = {n for n, d in G.nodes(data="node_type") if d == "unsupported"}
- # Make a copy of the graph.
- H = G.copy()
- # Remove all edges connecting supported and unsupported nodes.
- H.remove_edges_from(
- (n, nbr, d)
- for n, nbrs in G.adj.items()
- if n in supported_nodes
- for nbr, d in nbrs.items()
- if nbr in unsupported_nodes
- )
- H.remove_edges_from(
- (n, nbr, d)
- for n, nbrs in G.adj.items()
- if n in unsupported_nodes
- for nbr, d in nbrs.items()
- if nbr in supported_nodes
- )
- # Collect all removed edges for reconstruction.
- G_minus_H = nx.DiGraph()
- G_minus_H.add_edges_from(set(G.edges) - set(H.edges))
- if plotting:
- # Plot the stripped graph with the edges removed.
- _node_colors = [c for _, c in H.nodes(data="node_color")]
- _pos = nx.spring_layout(H)
- plt.figure(figsize=(8, 8))
- nx.draw_networkx_edges(H, _pos, alpha=0.3, edge_color="k")
- nx.draw_networkx_nodes(H, _pos, node_color=_node_colors)
- nx.draw_networkx_labels(H, _pos, font_size=14)
- plt.axis("off")
- plt.title("The stripped graph with the edges removed.")
- plt.show()
- # Plot the the edges removed.
- _pos = nx.spring_layout(G_minus_H)
- plt.figure(figsize=(8, 8))
- ncl = [G.nodes[n]["node_color"] for n in G_minus_H.nodes]
- nx.draw_networkx_edges(G_minus_H, _pos, alpha=0.3, edge_color="k")
- nx.draw_networkx_nodes(G_minus_H, _pos, node_color=ncl)
- nx.draw_networkx_labels(G_minus_H, _pos, font_size=14)
- plt.axis("off")
- plt.title("The removed edges.")
- plt.show()
- # Find the connected components in the stripped undirected graph.
- # And use the sets, specifying the components, to partition
- # the original directed graph into a list of directed subgraphs
- # that contain only entirely supported or entirely unsupported nodes.
- subgraphs = [
- H.subgraph(c).copy() for c in nx.connected_components(H.to_undirected())
- ]
- return subgraphs, G_minus_H
- ###############################################################################
- # Create an example directed graph.
- # ---------------------------------
- #
- # This directed graph has one input node labeled `in` and plotted in blue color
- # and one output node labeled `out` and plotted in magenta color.
- # The other six nodes are classified as four `supported` plotted in green color
- # and two `unsupported` plotted in red color. The goal is computing a list
- # of subgraphs that contain only entirely `supported` or `unsupported` nodes.
- G_ex = nx.DiGraph()
- G_ex.add_nodes_from(["In"], node_type="input", node_color="b")
- G_ex.add_nodes_from(["A", "C", "E", "F"], node_type="supported", node_color="g")
- G_ex.add_nodes_from(["B", "D"], node_type="unsupported", node_color="r")
- G_ex.add_nodes_from(["Out"], node_type="output", node_color="m")
- G_ex.add_edges_from(
- [
- ("In", "A"),
- ("A", "B"),
- ("B", "C"),
- ("B", "D"),
- ("D", "E"),
- ("C", "F"),
- ("E", "F"),
- ("F", "Out"),
- ]
- )
- ###############################################################################
- # Plot the original graph.
- # ------------------------
- #
- node_color_list = [nc for _, nc in G_ex.nodes(data="node_color")]
- pos = nx.spectral_layout(G_ex)
- plt.figure(figsize=(8, 8))
- nx.draw_networkx_edges(G_ex, pos, alpha=0.3, edge_color="k")
- nx.draw_networkx_nodes(G_ex, pos, alpha=0.8, node_color=node_color_list)
- nx.draw_networkx_labels(G_ex, pos, font_size=14)
- plt.axis("off")
- plt.title("The original graph.")
- plt.show()
- ###############################################################################
- # Calculate the subgraphs with plotting all results of intemediate steps.
- # -----------------------------------------------------------------------
- #
- subgraphs_of_G_ex, removed_edges = graph_partitioning(G_ex, plotting=True)
- ###############################################################################
- # Plot the results: every subgraph in the list.
- # ---------------------------------------------
- #
- for subgraph in subgraphs_of_G_ex:
- _pos = nx.spring_layout(subgraph)
- plt.figure(figsize=(8, 8))
- nx.draw_networkx_edges(subgraph, _pos, alpha=0.3, edge_color="k")
- node_color_list_c = [nc for _, nc in subgraph.nodes(data="node_color")]
- nx.draw_networkx_nodes(subgraph, _pos, node_color=node_color_list_c)
- nx.draw_networkx_labels(subgraph, _pos, font_size=14)
- plt.axis("off")
- plt.title("One of the subgraphs.")
- plt.show()
- ###############################################################################
- # Put the graph back from the list of subgraphs
- # ---------------------------------------------
- #
- G_ex_r = nx.DiGraph()
- # Composing all subgraphs.
- for subgraph in subgraphs_of_G_ex:
- G_ex_r = nx.compose(G_ex_r, subgraph)
- # Adding the previously stored edges.
- G_ex_r.add_edges_from(removed_edges.edges())
- ###############################################################################
- # Check that the original graph and the reconstructed graphs are isomorphic.
- # --------------------------------------------------------------------------
- #
- assert nx.is_isomorphic(G_ex, G_ex_r)
- ###############################################################################
- # Plot the reconstructed graph.
- # -----------------------------
- #
- node_color_list = [nc for _, nc in G_ex_r.nodes(data="node_color")]
- pos = nx.spectral_layout(G_ex_r)
- plt.figure(figsize=(8, 8))
- nx.draw_networkx_edges(G_ex_r, pos, alpha=0.3, edge_color="k")
- nx.draw_networkx_nodes(G_ex_r, pos, alpha=0.8, node_color=node_color_list)
- nx.draw_networkx_labels(G_ex_r, pos, font_size=14)
- plt.axis("off")
- plt.title("The reconstructed graph.")
- plt.show()
|