plot_basic.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. """
  2. ================
  3. Basic matplotlib
  4. ================
  5. A basic example of 3D Graph visualization using `mpl_toolkits.mplot_3d`.
  6. """
  7. import networkx as nx
  8. import numpy as np
  9. import matplotlib.pyplot as plt
  10. from mpl_toolkits.mplot3d import Axes3D
  11. # The graph to visualize
  12. G = nx.cycle_graph(20)
  13. # 3d spring layout
  14. pos = nx.spring_layout(G, dim=3, seed=779)
  15. # Extract node and edge positions from the layout
  16. node_xyz = np.array([pos[v] for v in sorted(G)])
  17. edge_xyz = np.array([(pos[u], pos[v]) for u, v in G.edges()])
  18. # Create the 3D figure
  19. fig = plt.figure()
  20. ax = fig.add_subplot(111, projection="3d")
  21. # Plot the nodes - alpha is scaled by "depth" automatically
  22. ax.scatter(*node_xyz.T, s=100, ec="w")
  23. # Plot the edges
  24. for vizedge in edge_xyz:
  25. ax.plot(*vizedge.T, color="tab:gray")
  26. def _format_axes(ax):
  27. """Visualization options for the 3D axes."""
  28. # Turn gridlines off
  29. ax.grid(False)
  30. # Suppress tick labels
  31. for dim in (ax.xaxis, ax.yaxis, ax.zaxis):
  32. dim.set_ticks([])
  33. # Set axes labels
  34. ax.set_xlabel("x")
  35. ax.set_ylabel("y")
  36. ax.set_zlabel("z")
  37. _format_axes(ax)
  38. fig.tight_layout()
  39. plt.show()