| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- import numpy as np
- import pytest
- from scipy.sparse.csgraph import connected_components
- from sklearn.metrics.pairwise import pairwise_distances
- from sklearn.neighbors import kneighbors_graph
- from sklearn.utils.graph import _fix_connected_components
- def test_fix_connected_components():
- # Test that _fix_connected_components reduces the number of component to 1.
- X = np.array([0, 1, 2, 5, 6, 7])[:, None]
- graph = kneighbors_graph(X, n_neighbors=2, mode="distance")
- n_connected_components, labels = connected_components(graph)
- assert n_connected_components > 1
- graph = _fix_connected_components(X, graph, n_connected_components, labels)
- n_connected_components, labels = connected_components(graph)
- assert n_connected_components == 1
- def test_fix_connected_components_precomputed():
- # Test that _fix_connected_components accepts precomputed distance matrix.
- X = np.array([0, 1, 2, 5, 6, 7])[:, None]
- graph = kneighbors_graph(X, n_neighbors=2, mode="distance")
- n_connected_components, labels = connected_components(graph)
- assert n_connected_components > 1
- distances = pairwise_distances(X)
- graph = _fix_connected_components(
- distances, graph, n_connected_components, labels, metric="precomputed"
- )
- n_connected_components, labels = connected_components(graph)
- assert n_connected_components == 1
- # but it does not work with precomputed neighbors graph
- with pytest.raises(RuntimeError, match="does not work with a sparse"):
- _fix_connected_components(
- graph, graph, n_connected_components, labels, metric="precomputed"
- )
- def test_fix_connected_components_wrong_mode():
- # Test that the an error is raised if the mode string is incorrect.
- X = np.array([0, 1, 2, 5, 6, 7])[:, None]
- graph = kneighbors_graph(X, n_neighbors=2, mode="distance")
- n_connected_components, labels = connected_components(graph)
- with pytest.raises(ValueError, match="Unknown mode"):
- graph = _fix_connected_components(
- X, graph, n_connected_components, labels, mode="foo"
- )
- def test_fix_connected_components_connectivity_mode():
- # Test that the connectivity mode fill new connections with ones.
- X = np.array([0, 1, 6, 7])[:, None]
- graph = kneighbors_graph(X, n_neighbors=1, mode="connectivity")
- n_connected_components, labels = connected_components(graph)
- graph = _fix_connected_components(
- X, graph, n_connected_components, labels, mode="connectivity"
- )
- assert np.all(graph.data == 1)
- def test_fix_connected_components_distance_mode():
- # Test that the distance mode does not fill new connections with ones.
- X = np.array([0, 1, 6, 7])[:, None]
- graph = kneighbors_graph(X, n_neighbors=1, mode="distance")
- assert np.all(graph.data == 1)
- n_connected_components, labels = connected_components(graph)
- graph = _fix_connected_components(
- X, graph, n_connected_components, labels, mode="distance"
- )
- assert not np.all(graph.data == 1)
|