test_graph.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import numpy as np
  2. import pytest
  3. from scipy.sparse.csgraph import connected_components
  4. from sklearn.metrics.pairwise import pairwise_distances
  5. from sklearn.neighbors import kneighbors_graph
  6. from sklearn.utils.graph import _fix_connected_components
  7. def test_fix_connected_components():
  8. # Test that _fix_connected_components reduces the number of component to 1.
  9. X = np.array([0, 1, 2, 5, 6, 7])[:, None]
  10. graph = kneighbors_graph(X, n_neighbors=2, mode="distance")
  11. n_connected_components, labels = connected_components(graph)
  12. assert n_connected_components > 1
  13. graph = _fix_connected_components(X, graph, n_connected_components, labels)
  14. n_connected_components, labels = connected_components(graph)
  15. assert n_connected_components == 1
  16. def test_fix_connected_components_precomputed():
  17. # Test that _fix_connected_components accepts precomputed distance matrix.
  18. X = np.array([0, 1, 2, 5, 6, 7])[:, None]
  19. graph = kneighbors_graph(X, n_neighbors=2, mode="distance")
  20. n_connected_components, labels = connected_components(graph)
  21. assert n_connected_components > 1
  22. distances = pairwise_distances(X)
  23. graph = _fix_connected_components(
  24. distances, graph, n_connected_components, labels, metric="precomputed"
  25. )
  26. n_connected_components, labels = connected_components(graph)
  27. assert n_connected_components == 1
  28. # but it does not work with precomputed neighbors graph
  29. with pytest.raises(RuntimeError, match="does not work with a sparse"):
  30. _fix_connected_components(
  31. graph, graph, n_connected_components, labels, metric="precomputed"
  32. )
  33. def test_fix_connected_components_wrong_mode():
  34. # Test that the an error is raised if the mode string is incorrect.
  35. X = np.array([0, 1, 2, 5, 6, 7])[:, None]
  36. graph = kneighbors_graph(X, n_neighbors=2, mode="distance")
  37. n_connected_components, labels = connected_components(graph)
  38. with pytest.raises(ValueError, match="Unknown mode"):
  39. graph = _fix_connected_components(
  40. X, graph, n_connected_components, labels, mode="foo"
  41. )
  42. def test_fix_connected_components_connectivity_mode():
  43. # Test that the connectivity mode fill new connections with ones.
  44. X = np.array([0, 1, 6, 7])[:, None]
  45. graph = kneighbors_graph(X, n_neighbors=1, mode="connectivity")
  46. n_connected_components, labels = connected_components(graph)
  47. graph = _fix_connected_components(
  48. X, graph, n_connected_components, labels, mode="connectivity"
  49. )
  50. assert np.all(graph.data == 1)
  51. def test_fix_connected_components_distance_mode():
  52. # Test that the distance mode does not fill new connections with ones.
  53. X = np.array([0, 1, 6, 7])[:, None]
  54. graph = kneighbors_graph(X, n_neighbors=1, mode="distance")
  55. assert np.all(graph.data == 1)
  56. n_connected_components, labels = connected_components(graph)
  57. graph = _fix_connected_components(
  58. X, graph, n_connected_components, labels, mode="distance"
  59. )
  60. assert not np.all(graph.data == 1)