_agglomerative.py 49 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358
  1. """Hierarchical Agglomerative Clustering
  2. These routines perform some hierarchical agglomerative clustering of some
  3. input data.
  4. Authors : Vincent Michel, Bertrand Thirion, Alexandre Gramfort,
  5. Gael Varoquaux
  6. License: BSD 3 clause
  7. """
  8. import warnings
  9. from heapq import heapify, heappop, heappush, heappushpop
  10. from numbers import Integral, Real
  11. import numpy as np
  12. from scipy import sparse
  13. from scipy.sparse.csgraph import connected_components
  14. from ..base import (
  15. BaseEstimator,
  16. ClassNamePrefixFeaturesOutMixin,
  17. ClusterMixin,
  18. _fit_context,
  19. )
  20. from ..metrics import DistanceMetric
  21. from ..metrics._dist_metrics import METRIC_MAPPING64
  22. from ..metrics.pairwise import _VALID_METRICS, paired_distances
  23. from ..utils import check_array
  24. from ..utils._fast_dict import IntFloatDict
  25. from ..utils._param_validation import (
  26. HasMethods,
  27. Hidden,
  28. Interval,
  29. StrOptions,
  30. validate_params,
  31. )
  32. from ..utils.graph import _fix_connected_components
  33. from ..utils.validation import check_memory
  34. # mypy error: Module 'sklearn.cluster' has no attribute '_hierarchical_fast'
  35. from . import _hierarchical_fast as _hierarchical # type: ignore
  36. from ._feature_agglomeration import AgglomerationTransform
  37. ###############################################################################
  38. # For non fully-connected graphs
  39. def _fix_connectivity(X, connectivity, affinity):
  40. """
  41. Fixes the connectivity matrix.
  42. The different steps are:
  43. - copies it
  44. - makes it symmetric
  45. - converts it to LIL if necessary
  46. - completes it if necessary.
  47. Parameters
  48. ----------
  49. X : array-like of shape (n_samples, n_features)
  50. Feature matrix representing `n_samples` samples to be clustered.
  51. connectivity : sparse matrix, default=None
  52. Connectivity matrix. Defines for each sample the neighboring samples
  53. following a given structure of the data. The matrix is assumed to
  54. be symmetric and only the upper triangular half is used.
  55. Default is `None`, i.e, the Ward algorithm is unstructured.
  56. affinity : {"euclidean", "precomputed"}, default="euclidean"
  57. Which affinity to use. At the moment `precomputed` and
  58. ``euclidean`` are supported. `euclidean` uses the
  59. negative squared Euclidean distance between points.
  60. Returns
  61. -------
  62. connectivity : sparse matrix
  63. The fixed connectivity matrix.
  64. n_connected_components : int
  65. The number of connected components in the graph.
  66. """
  67. n_samples = X.shape[0]
  68. if connectivity.shape[0] != n_samples or connectivity.shape[1] != n_samples:
  69. raise ValueError(
  70. "Wrong shape for connectivity matrix: %s when X is %s"
  71. % (connectivity.shape, X.shape)
  72. )
  73. # Make the connectivity matrix symmetric:
  74. connectivity = connectivity + connectivity.T
  75. # Convert connectivity matrix to LIL
  76. if not sparse.issparse(connectivity):
  77. connectivity = sparse.lil_matrix(connectivity)
  78. # `connectivity` is a sparse matrix at this point
  79. if connectivity.format != "lil":
  80. connectivity = connectivity.tolil()
  81. # Compute the number of nodes
  82. n_connected_components, labels = connected_components(connectivity)
  83. if n_connected_components > 1:
  84. warnings.warn(
  85. "the number of connected components of the "
  86. "connectivity matrix is %d > 1. Completing it to avoid "
  87. "stopping the tree early." % n_connected_components,
  88. stacklevel=2,
  89. )
  90. # XXX: Can we do without completing the matrix?
  91. connectivity = _fix_connected_components(
  92. X=X,
  93. graph=connectivity,
  94. n_connected_components=n_connected_components,
  95. component_labels=labels,
  96. metric=affinity,
  97. mode="connectivity",
  98. )
  99. return connectivity, n_connected_components
  100. def _single_linkage_tree(
  101. connectivity,
  102. n_samples,
  103. n_nodes,
  104. n_clusters,
  105. n_connected_components,
  106. return_distance,
  107. ):
  108. """
  109. Perform single linkage clustering on sparse data via the minimum
  110. spanning tree from scipy.sparse.csgraph, then using union-find to label.
  111. The parent array is then generated by walking through the tree.
  112. """
  113. from scipy.sparse.csgraph import minimum_spanning_tree
  114. # explicitly cast connectivity to ensure safety
  115. connectivity = connectivity.astype(np.float64, copy=False)
  116. # Ensure zero distances aren't ignored by setting them to "epsilon"
  117. epsilon_value = np.finfo(dtype=connectivity.data.dtype).eps
  118. connectivity.data[connectivity.data == 0] = epsilon_value
  119. # Use scipy.sparse.csgraph to generate a minimum spanning tree
  120. mst = minimum_spanning_tree(connectivity.tocsr())
  121. # Convert the graph to scipy.cluster.hierarchy array format
  122. mst = mst.tocoo()
  123. # Undo the epsilon values
  124. mst.data[mst.data == epsilon_value] = 0
  125. mst_array = np.vstack([mst.row, mst.col, mst.data]).T
  126. # Sort edges of the min_spanning_tree by weight
  127. mst_array = mst_array[np.argsort(mst_array.T[2], kind="mergesort"), :]
  128. # Convert edge list into standard hierarchical clustering format
  129. single_linkage_tree = _hierarchical._single_linkage_label(mst_array)
  130. children_ = single_linkage_tree[:, :2].astype(int)
  131. # Compute parents
  132. parent = np.arange(n_nodes, dtype=np.intp)
  133. for i, (left, right) in enumerate(children_, n_samples):
  134. if n_clusters is not None and i >= n_nodes:
  135. break
  136. if left < n_nodes:
  137. parent[left] = i
  138. if right < n_nodes:
  139. parent[right] = i
  140. if return_distance:
  141. distances = single_linkage_tree[:, 2]
  142. return children_, n_connected_components, n_samples, parent, distances
  143. return children_, n_connected_components, n_samples, parent
  144. ###############################################################################
  145. # Hierarchical tree building functions
  146. @validate_params(
  147. {
  148. "X": ["array-like"],
  149. "connectivity": ["array-like", "sparse matrix", None],
  150. "n_clusters": [Interval(Integral, 1, None, closed="left"), None],
  151. "return_distance": ["boolean"],
  152. },
  153. prefer_skip_nested_validation=True,
  154. )
  155. def ward_tree(X, *, connectivity=None, n_clusters=None, return_distance=False):
  156. """Ward clustering based on a Feature matrix.
  157. Recursively merges the pair of clusters that minimally increases
  158. within-cluster variance.
  159. The inertia matrix uses a Heapq-based representation.
  160. This is the structured version, that takes into account some topological
  161. structure between samples.
  162. Read more in the :ref:`User Guide <hierarchical_clustering>`.
  163. Parameters
  164. ----------
  165. X : array-like of shape (n_samples, n_features)
  166. Feature matrix representing `n_samples` samples to be clustered.
  167. connectivity : {array-like, sparse matrix}, default=None
  168. Connectivity matrix. Defines for each sample the neighboring samples
  169. following a given structure of the data. The matrix is assumed to
  170. be symmetric and only the upper triangular half is used.
  171. Default is None, i.e, the Ward algorithm is unstructured.
  172. n_clusters : int, default=None
  173. `n_clusters` should be less than `n_samples`. Stop early the
  174. construction of the tree at `n_clusters.` This is useful to decrease
  175. computation time if the number of clusters is not small compared to the
  176. number of samples. In this case, the complete tree is not computed, thus
  177. the 'children' output is of limited use, and the 'parents' output should
  178. rather be used. This option is valid only when specifying a connectivity
  179. matrix.
  180. return_distance : bool, default=False
  181. If `True`, return the distance between the clusters.
  182. Returns
  183. -------
  184. children : ndarray of shape (n_nodes-1, 2)
  185. The children of each non-leaf node. Values less than `n_samples`
  186. correspond to leaves of the tree which are the original samples.
  187. A node `i` greater than or equal to `n_samples` is a non-leaf
  188. node and has children `children_[i - n_samples]`. Alternatively
  189. at the i-th iteration, children[i][0] and children[i][1]
  190. are merged to form node `n_samples + i`.
  191. n_connected_components : int
  192. The number of connected components in the graph.
  193. n_leaves : int
  194. The number of leaves in the tree.
  195. parents : ndarray of shape (n_nodes,) or None
  196. The parent of each node. Only returned when a connectivity matrix
  197. is specified, elsewhere 'None' is returned.
  198. distances : ndarray of shape (n_nodes-1,)
  199. Only returned if `return_distance` is set to `True` (for compatibility).
  200. The distances between the centers of the nodes. `distances[i]`
  201. corresponds to a weighted Euclidean distance between
  202. the nodes `children[i, 1]` and `children[i, 2]`. If the nodes refer to
  203. leaves of the tree, then `distances[i]` is their unweighted Euclidean
  204. distance. Distances are updated in the following way
  205. (from scipy.hierarchy.linkage):
  206. The new entry :math:`d(u,v)` is computed as follows,
  207. .. math::
  208. d(u,v) = \\sqrt{\\frac{|v|+|s|}
  209. {T}d(v,s)^2
  210. + \\frac{|v|+|t|}
  211. {T}d(v,t)^2
  212. - \\frac{|v|}
  213. {T}d(s,t)^2}
  214. where :math:`u` is the newly joined cluster consisting of
  215. clusters :math:`s` and :math:`t`, :math:`v` is an unused
  216. cluster in the forest, :math:`T=|v|+|s|+|t|`, and
  217. :math:`|*|` is the cardinality of its argument. This is also
  218. known as the incremental algorithm.
  219. """
  220. X = np.asarray(X)
  221. if X.ndim == 1:
  222. X = np.reshape(X, (-1, 1))
  223. n_samples, n_features = X.shape
  224. if connectivity is None:
  225. from scipy.cluster import hierarchy # imports PIL
  226. if n_clusters is not None:
  227. warnings.warn(
  228. (
  229. "Partial build of the tree is implemented "
  230. "only for structured clustering (i.e. with "
  231. "explicit connectivity). The algorithm "
  232. "will build the full tree and only "
  233. "retain the lower branches required "
  234. "for the specified number of clusters"
  235. ),
  236. stacklevel=2,
  237. )
  238. X = np.require(X, requirements="W")
  239. out = hierarchy.ward(X)
  240. children_ = out[:, :2].astype(np.intp)
  241. if return_distance:
  242. distances = out[:, 2]
  243. return children_, 1, n_samples, None, distances
  244. else:
  245. return children_, 1, n_samples, None
  246. connectivity, n_connected_components = _fix_connectivity(
  247. X, connectivity, affinity="euclidean"
  248. )
  249. if n_clusters is None:
  250. n_nodes = 2 * n_samples - 1
  251. else:
  252. if n_clusters > n_samples:
  253. raise ValueError(
  254. "Cannot provide more clusters than samples. "
  255. "%i n_clusters was asked, and there are %i "
  256. "samples." % (n_clusters, n_samples)
  257. )
  258. n_nodes = 2 * n_samples - n_clusters
  259. # create inertia matrix
  260. coord_row = []
  261. coord_col = []
  262. A = []
  263. for ind, row in enumerate(connectivity.rows):
  264. A.append(row)
  265. # We keep only the upper triangular for the moments
  266. # Generator expressions are faster than arrays on the following
  267. row = [i for i in row if i < ind]
  268. coord_row.extend(
  269. len(row)
  270. * [
  271. ind,
  272. ]
  273. )
  274. coord_col.extend(row)
  275. coord_row = np.array(coord_row, dtype=np.intp, order="C")
  276. coord_col = np.array(coord_col, dtype=np.intp, order="C")
  277. # build moments as a list
  278. moments_1 = np.zeros(n_nodes, order="C")
  279. moments_1[:n_samples] = 1
  280. moments_2 = np.zeros((n_nodes, n_features), order="C")
  281. moments_2[:n_samples] = X
  282. inertia = np.empty(len(coord_row), dtype=np.float64, order="C")
  283. _hierarchical.compute_ward_dist(moments_1, moments_2, coord_row, coord_col, inertia)
  284. inertia = list(zip(inertia, coord_row, coord_col))
  285. heapify(inertia)
  286. # prepare the main fields
  287. parent = np.arange(n_nodes, dtype=np.intp)
  288. used_node = np.ones(n_nodes, dtype=bool)
  289. children = []
  290. if return_distance:
  291. distances = np.empty(n_nodes - n_samples)
  292. not_visited = np.empty(n_nodes, dtype=bool, order="C")
  293. # recursive merge loop
  294. for k in range(n_samples, n_nodes):
  295. # identify the merge
  296. while True:
  297. inert, i, j = heappop(inertia)
  298. if used_node[i] and used_node[j]:
  299. break
  300. parent[i], parent[j] = k, k
  301. children.append((i, j))
  302. used_node[i] = used_node[j] = False
  303. if return_distance: # store inertia value
  304. distances[k - n_samples] = inert
  305. # update the moments
  306. moments_1[k] = moments_1[i] + moments_1[j]
  307. moments_2[k] = moments_2[i] + moments_2[j]
  308. # update the structure matrix A and the inertia matrix
  309. coord_col = []
  310. not_visited.fill(1)
  311. not_visited[k] = 0
  312. _hierarchical._get_parents(A[i], coord_col, parent, not_visited)
  313. _hierarchical._get_parents(A[j], coord_col, parent, not_visited)
  314. # List comprehension is faster than a for loop
  315. [A[col].append(k) for col in coord_col]
  316. A.append(coord_col)
  317. coord_col = np.array(coord_col, dtype=np.intp, order="C")
  318. coord_row = np.empty(coord_col.shape, dtype=np.intp, order="C")
  319. coord_row.fill(k)
  320. n_additions = len(coord_row)
  321. ini = np.empty(n_additions, dtype=np.float64, order="C")
  322. _hierarchical.compute_ward_dist(moments_1, moments_2, coord_row, coord_col, ini)
  323. # List comprehension is faster than a for loop
  324. [heappush(inertia, (ini[idx], k, coord_col[idx])) for idx in range(n_additions)]
  325. # Separate leaves in children (empty lists up to now)
  326. n_leaves = n_samples
  327. # sort children to get consistent output with unstructured version
  328. children = [c[::-1] for c in children]
  329. children = np.array(children) # return numpy array for efficient caching
  330. if return_distance:
  331. # 2 is scaling factor to compare w/ unstructured version
  332. distances = np.sqrt(2.0 * distances)
  333. return children, n_connected_components, n_leaves, parent, distances
  334. else:
  335. return children, n_connected_components, n_leaves, parent
  336. # single average and complete linkage
  337. def linkage_tree(
  338. X,
  339. connectivity=None,
  340. n_clusters=None,
  341. linkage="complete",
  342. affinity="euclidean",
  343. return_distance=False,
  344. ):
  345. """Linkage agglomerative clustering based on a Feature matrix.
  346. The inertia matrix uses a Heapq-based representation.
  347. This is the structured version, that takes into account some topological
  348. structure between samples.
  349. Read more in the :ref:`User Guide <hierarchical_clustering>`.
  350. Parameters
  351. ----------
  352. X : array-like of shape (n_samples, n_features)
  353. Feature matrix representing `n_samples` samples to be clustered.
  354. connectivity : sparse matrix, default=None
  355. Connectivity matrix. Defines for each sample the neighboring samples
  356. following a given structure of the data. The matrix is assumed to
  357. be symmetric and only the upper triangular half is used.
  358. Default is `None`, i.e, the Ward algorithm is unstructured.
  359. n_clusters : int, default=None
  360. Stop early the construction of the tree at `n_clusters`. This is
  361. useful to decrease computation time if the number of clusters is
  362. not small compared to the number of samples. In this case, the
  363. complete tree is not computed, thus the 'children' output is of
  364. limited use, and the 'parents' output should rather be used.
  365. This option is valid only when specifying a connectivity matrix.
  366. linkage : {"average", "complete", "single"}, default="complete"
  367. Which linkage criteria to use. The linkage criterion determines which
  368. distance to use between sets of observation.
  369. - "average" uses the average of the distances of each observation of
  370. the two sets.
  371. - "complete" or maximum linkage uses the maximum distances between
  372. all observations of the two sets.
  373. - "single" uses the minimum of the distances between all
  374. observations of the two sets.
  375. affinity : str or callable, default='euclidean'
  376. Which metric to use. Can be 'euclidean', 'manhattan', or any
  377. distance known to paired distance (see metric.pairwise).
  378. return_distance : bool, default=False
  379. Whether or not to return the distances between the clusters.
  380. Returns
  381. -------
  382. children : ndarray of shape (n_nodes-1, 2)
  383. The children of each non-leaf node. Values less than `n_samples`
  384. correspond to leaves of the tree which are the original samples.
  385. A node `i` greater than or equal to `n_samples` is a non-leaf
  386. node and has children `children_[i - n_samples]`. Alternatively
  387. at the i-th iteration, children[i][0] and children[i][1]
  388. are merged to form node `n_samples + i`.
  389. n_connected_components : int
  390. The number of connected components in the graph.
  391. n_leaves : int
  392. The number of leaves in the tree.
  393. parents : ndarray of shape (n_nodes, ) or None
  394. The parent of each node. Only returned when a connectivity matrix
  395. is specified, elsewhere 'None' is returned.
  396. distances : ndarray of shape (n_nodes-1,)
  397. Returned when `return_distance` is set to `True`.
  398. distances[i] refers to the distance between children[i][0] and
  399. children[i][1] when they are merged.
  400. See Also
  401. --------
  402. ward_tree : Hierarchical clustering with ward linkage.
  403. """
  404. X = np.asarray(X)
  405. if X.ndim == 1:
  406. X = np.reshape(X, (-1, 1))
  407. n_samples, n_features = X.shape
  408. linkage_choices = {
  409. "complete": _hierarchical.max_merge,
  410. "average": _hierarchical.average_merge,
  411. "single": None,
  412. } # Single linkage is handled differently
  413. try:
  414. join_func = linkage_choices[linkage]
  415. except KeyError as e:
  416. raise ValueError(
  417. "Unknown linkage option, linkage should be one of %s, but %s was given"
  418. % (linkage_choices.keys(), linkage)
  419. ) from e
  420. if affinity == "cosine" and np.any(~np.any(X, axis=1)):
  421. raise ValueError("Cosine affinity cannot be used when X contains zero vectors")
  422. if connectivity is None:
  423. from scipy.cluster import hierarchy # imports PIL
  424. if n_clusters is not None:
  425. warnings.warn(
  426. (
  427. "Partial build of the tree is implemented "
  428. "only for structured clustering (i.e. with "
  429. "explicit connectivity). The algorithm "
  430. "will build the full tree and only "
  431. "retain the lower branches required "
  432. "for the specified number of clusters"
  433. ),
  434. stacklevel=2,
  435. )
  436. if affinity == "precomputed":
  437. # for the linkage function of hierarchy to work on precomputed
  438. # data, provide as first argument an ndarray of the shape returned
  439. # by sklearn.metrics.pairwise_distances.
  440. if X.shape[0] != X.shape[1]:
  441. raise ValueError(
  442. f"Distance matrix should be square, got matrix of shape {X.shape}"
  443. )
  444. i, j = np.triu_indices(X.shape[0], k=1)
  445. X = X[i, j]
  446. elif affinity == "l2":
  447. # Translate to something understood by scipy
  448. affinity = "euclidean"
  449. elif affinity in ("l1", "manhattan"):
  450. affinity = "cityblock"
  451. elif callable(affinity):
  452. X = affinity(X)
  453. i, j = np.triu_indices(X.shape[0], k=1)
  454. X = X[i, j]
  455. if (
  456. linkage == "single"
  457. and affinity != "precomputed"
  458. and not callable(affinity)
  459. and affinity in METRIC_MAPPING64
  460. ):
  461. # We need the fast cythonized metric from neighbors
  462. dist_metric = DistanceMetric.get_metric(affinity)
  463. # The Cython routines used require contiguous arrays
  464. X = np.ascontiguousarray(X, dtype=np.double)
  465. mst = _hierarchical.mst_linkage_core(X, dist_metric)
  466. # Sort edges of the min_spanning_tree by weight
  467. mst = mst[np.argsort(mst.T[2], kind="mergesort"), :]
  468. # Convert edge list into standard hierarchical clustering format
  469. out = _hierarchical.single_linkage_label(mst)
  470. else:
  471. out = hierarchy.linkage(X, method=linkage, metric=affinity)
  472. children_ = out[:, :2].astype(int, copy=False)
  473. if return_distance:
  474. distances = out[:, 2]
  475. return children_, 1, n_samples, None, distances
  476. return children_, 1, n_samples, None
  477. connectivity, n_connected_components = _fix_connectivity(
  478. X, connectivity, affinity=affinity
  479. )
  480. connectivity = connectivity.tocoo()
  481. # Put the diagonal to zero
  482. diag_mask = connectivity.row != connectivity.col
  483. connectivity.row = connectivity.row[diag_mask]
  484. connectivity.col = connectivity.col[diag_mask]
  485. connectivity.data = connectivity.data[diag_mask]
  486. del diag_mask
  487. if affinity == "precomputed":
  488. distances = X[connectivity.row, connectivity.col].astype(np.float64, copy=False)
  489. else:
  490. # FIXME We compute all the distances, while we could have only computed
  491. # the "interesting" distances
  492. distances = paired_distances(
  493. X[connectivity.row], X[connectivity.col], metric=affinity
  494. )
  495. connectivity.data = distances
  496. if n_clusters is None:
  497. n_nodes = 2 * n_samples - 1
  498. else:
  499. assert n_clusters <= n_samples
  500. n_nodes = 2 * n_samples - n_clusters
  501. if linkage == "single":
  502. return _single_linkage_tree(
  503. connectivity,
  504. n_samples,
  505. n_nodes,
  506. n_clusters,
  507. n_connected_components,
  508. return_distance,
  509. )
  510. if return_distance:
  511. distances = np.empty(n_nodes - n_samples)
  512. # create inertia heap and connection matrix
  513. A = np.empty(n_nodes, dtype=object)
  514. inertia = list()
  515. # LIL seems to the best format to access the rows quickly,
  516. # without the numpy overhead of slicing CSR indices and data.
  517. connectivity = connectivity.tolil()
  518. # We are storing the graph in a list of IntFloatDict
  519. for ind, (data, row) in enumerate(zip(connectivity.data, connectivity.rows)):
  520. A[ind] = IntFloatDict(
  521. np.asarray(row, dtype=np.intp), np.asarray(data, dtype=np.float64)
  522. )
  523. # We keep only the upper triangular for the heap
  524. # Generator expressions are faster than arrays on the following
  525. inertia.extend(
  526. _hierarchical.WeightedEdge(d, ind, r) for r, d in zip(row, data) if r < ind
  527. )
  528. del connectivity
  529. heapify(inertia)
  530. # prepare the main fields
  531. parent = np.arange(n_nodes, dtype=np.intp)
  532. used_node = np.ones(n_nodes, dtype=np.intp)
  533. children = []
  534. # recursive merge loop
  535. for k in range(n_samples, n_nodes):
  536. # identify the merge
  537. while True:
  538. edge = heappop(inertia)
  539. if used_node[edge.a] and used_node[edge.b]:
  540. break
  541. i = edge.a
  542. j = edge.b
  543. if return_distance:
  544. # store distances
  545. distances[k - n_samples] = edge.weight
  546. parent[i] = parent[j] = k
  547. children.append((i, j))
  548. # Keep track of the number of elements per cluster
  549. n_i = used_node[i]
  550. n_j = used_node[j]
  551. used_node[k] = n_i + n_j
  552. used_node[i] = used_node[j] = False
  553. # update the structure matrix A and the inertia matrix
  554. # a clever 'min', or 'max' operation between A[i] and A[j]
  555. coord_col = join_func(A[i], A[j], used_node, n_i, n_j)
  556. for col, d in coord_col:
  557. A[col].append(k, d)
  558. # Here we use the information from coord_col (containing the
  559. # distances) to update the heap
  560. heappush(inertia, _hierarchical.WeightedEdge(d, k, col))
  561. A[k] = coord_col
  562. # Clear A[i] and A[j] to save memory
  563. A[i] = A[j] = 0
  564. # Separate leaves in children (empty lists up to now)
  565. n_leaves = n_samples
  566. # # return numpy array for efficient caching
  567. children = np.array(children)[:, ::-1]
  568. if return_distance:
  569. return children, n_connected_components, n_leaves, parent, distances
  570. return children, n_connected_components, n_leaves, parent
  571. # Matching names to tree-building strategies
  572. def _complete_linkage(*args, **kwargs):
  573. kwargs["linkage"] = "complete"
  574. return linkage_tree(*args, **kwargs)
  575. def _average_linkage(*args, **kwargs):
  576. kwargs["linkage"] = "average"
  577. return linkage_tree(*args, **kwargs)
  578. def _single_linkage(*args, **kwargs):
  579. kwargs["linkage"] = "single"
  580. return linkage_tree(*args, **kwargs)
  581. _TREE_BUILDERS = dict(
  582. ward=ward_tree,
  583. complete=_complete_linkage,
  584. average=_average_linkage,
  585. single=_single_linkage,
  586. )
  587. ###############################################################################
  588. # Functions for cutting hierarchical clustering tree
  589. def _hc_cut(n_clusters, children, n_leaves):
  590. """Function cutting the ward tree for a given number of clusters.
  591. Parameters
  592. ----------
  593. n_clusters : int or ndarray
  594. The number of clusters to form.
  595. children : ndarray of shape (n_nodes-1, 2)
  596. The children of each non-leaf node. Values less than `n_samples`
  597. correspond to leaves of the tree which are the original samples.
  598. A node `i` greater than or equal to `n_samples` is a non-leaf
  599. node and has children `children_[i - n_samples]`. Alternatively
  600. at the i-th iteration, children[i][0] and children[i][1]
  601. are merged to form node `n_samples + i`.
  602. n_leaves : int
  603. Number of leaves of the tree.
  604. Returns
  605. -------
  606. labels : array [n_samples]
  607. Cluster labels for each point.
  608. """
  609. if n_clusters > n_leaves:
  610. raise ValueError(
  611. "Cannot extract more clusters than samples: "
  612. "%s clusters where given for a tree with %s leaves."
  613. % (n_clusters, n_leaves)
  614. )
  615. # In this function, we store nodes as a heap to avoid recomputing
  616. # the max of the nodes: the first element is always the smallest
  617. # We use negated indices as heaps work on smallest elements, and we
  618. # are interested in largest elements
  619. # children[-1] is the root of the tree
  620. nodes = [-(max(children[-1]) + 1)]
  621. for _ in range(n_clusters - 1):
  622. # As we have a heap, nodes[0] is the smallest element
  623. these_children = children[-nodes[0] - n_leaves]
  624. # Insert the 2 children and remove the largest node
  625. heappush(nodes, -these_children[0])
  626. heappushpop(nodes, -these_children[1])
  627. label = np.zeros(n_leaves, dtype=np.intp)
  628. for i, node in enumerate(nodes):
  629. label[_hierarchical._hc_get_descendent(-node, children, n_leaves)] = i
  630. return label
  631. ###############################################################################
  632. class AgglomerativeClustering(ClusterMixin, BaseEstimator):
  633. """
  634. Agglomerative Clustering.
  635. Recursively merges pair of clusters of sample data; uses linkage distance.
  636. Read more in the :ref:`User Guide <hierarchical_clustering>`.
  637. Parameters
  638. ----------
  639. n_clusters : int or None, default=2
  640. The number of clusters to find. It must be ``None`` if
  641. ``distance_threshold`` is not ``None``.
  642. affinity : str or callable, default='euclidean'
  643. The metric to use when calculating distance between instances in a
  644. feature array. If metric is a string or callable, it must be one of
  645. the options allowed by :func:`sklearn.metrics.pairwise_distances` for
  646. its metric parameter.
  647. If linkage is "ward", only "euclidean" is accepted.
  648. If "precomputed", a distance matrix (instead of a similarity matrix)
  649. is needed as input for the fit method.
  650. .. deprecated:: 1.2
  651. `affinity` was deprecated in version 1.2 and will be renamed to
  652. `metric` in 1.4.
  653. metric : str or callable, default=None
  654. Metric used to compute the linkage. Can be "euclidean", "l1", "l2",
  655. "manhattan", "cosine", or "precomputed". If set to `None` then
  656. "euclidean" is used. If linkage is "ward", only "euclidean" is
  657. accepted. If "precomputed", a distance matrix is needed as input for
  658. the fit method.
  659. .. versionadded:: 1.2
  660. memory : str or object with the joblib.Memory interface, default=None
  661. Used to cache the output of the computation of the tree.
  662. By default, no caching is done. If a string is given, it is the
  663. path to the caching directory.
  664. connectivity : array-like or callable, default=None
  665. Connectivity matrix. Defines for each sample the neighboring
  666. samples following a given structure of the data.
  667. This can be a connectivity matrix itself or a callable that transforms
  668. the data into a connectivity matrix, such as derived from
  669. `kneighbors_graph`. Default is ``None``, i.e, the
  670. hierarchical clustering algorithm is unstructured.
  671. compute_full_tree : 'auto' or bool, default='auto'
  672. Stop early the construction of the tree at ``n_clusters``. This is
  673. useful to decrease computation time if the number of clusters is not
  674. small compared to the number of samples. This option is useful only
  675. when specifying a connectivity matrix. Note also that when varying the
  676. number of clusters and using caching, it may be advantageous to compute
  677. the full tree. It must be ``True`` if ``distance_threshold`` is not
  678. ``None``. By default `compute_full_tree` is "auto", which is equivalent
  679. to `True` when `distance_threshold` is not `None` or that `n_clusters`
  680. is inferior to the maximum between 100 or `0.02 * n_samples`.
  681. Otherwise, "auto" is equivalent to `False`.
  682. linkage : {'ward', 'complete', 'average', 'single'}, default='ward'
  683. Which linkage criterion to use. The linkage criterion determines which
  684. distance to use between sets of observation. The algorithm will merge
  685. the pairs of cluster that minimize this criterion.
  686. - 'ward' minimizes the variance of the clusters being merged.
  687. - 'average' uses the average of the distances of each observation of
  688. the two sets.
  689. - 'complete' or 'maximum' linkage uses the maximum distances between
  690. all observations of the two sets.
  691. - 'single' uses the minimum of the distances between all observations
  692. of the two sets.
  693. .. versionadded:: 0.20
  694. Added the 'single' option
  695. distance_threshold : float, default=None
  696. The linkage distance threshold at or above which clusters will not be
  697. merged. If not ``None``, ``n_clusters`` must be ``None`` and
  698. ``compute_full_tree`` must be ``True``.
  699. .. versionadded:: 0.21
  700. compute_distances : bool, default=False
  701. Computes distances between clusters even if `distance_threshold` is not
  702. used. This can be used to make dendrogram visualization, but introduces
  703. a computational and memory overhead.
  704. .. versionadded:: 0.24
  705. Attributes
  706. ----------
  707. n_clusters_ : int
  708. The number of clusters found by the algorithm. If
  709. ``distance_threshold=None``, it will be equal to the given
  710. ``n_clusters``.
  711. labels_ : ndarray of shape (n_samples)
  712. Cluster labels for each point.
  713. n_leaves_ : int
  714. Number of leaves in the hierarchical tree.
  715. n_connected_components_ : int
  716. The estimated number of connected components in the graph.
  717. .. versionadded:: 0.21
  718. ``n_connected_components_`` was added to replace ``n_components_``.
  719. n_features_in_ : int
  720. Number of features seen during :term:`fit`.
  721. .. versionadded:: 0.24
  722. feature_names_in_ : ndarray of shape (`n_features_in_`,)
  723. Names of features seen during :term:`fit`. Defined only when `X`
  724. has feature names that are all strings.
  725. .. versionadded:: 1.0
  726. children_ : array-like of shape (n_samples-1, 2)
  727. The children of each non-leaf node. Values less than `n_samples`
  728. correspond to leaves of the tree which are the original samples.
  729. A node `i` greater than or equal to `n_samples` is a non-leaf
  730. node and has children `children_[i - n_samples]`. Alternatively
  731. at the i-th iteration, children[i][0] and children[i][1]
  732. are merged to form node `n_samples + i`.
  733. distances_ : array-like of shape (n_nodes-1,)
  734. Distances between nodes in the corresponding place in `children_`.
  735. Only computed if `distance_threshold` is used or `compute_distances`
  736. is set to `True`.
  737. See Also
  738. --------
  739. FeatureAgglomeration : Agglomerative clustering but for features instead of
  740. samples.
  741. ward_tree : Hierarchical clustering with ward linkage.
  742. Examples
  743. --------
  744. >>> from sklearn.cluster import AgglomerativeClustering
  745. >>> import numpy as np
  746. >>> X = np.array([[1, 2], [1, 4], [1, 0],
  747. ... [4, 2], [4, 4], [4, 0]])
  748. >>> clustering = AgglomerativeClustering().fit(X)
  749. >>> clustering
  750. AgglomerativeClustering()
  751. >>> clustering.labels_
  752. array([1, 1, 1, 0, 0, 0])
  753. """
  754. _parameter_constraints: dict = {
  755. "n_clusters": [Interval(Integral, 1, None, closed="left"), None],
  756. "affinity": [
  757. Hidden(StrOptions({"deprecated"})),
  758. StrOptions(set(_VALID_METRICS) | {"precomputed"}),
  759. callable,
  760. ],
  761. "metric": [
  762. StrOptions(set(_VALID_METRICS) | {"precomputed"}),
  763. callable,
  764. None,
  765. ],
  766. "memory": [str, HasMethods("cache"), None],
  767. "connectivity": ["array-like", callable, None],
  768. "compute_full_tree": [StrOptions({"auto"}), "boolean"],
  769. "linkage": [StrOptions(set(_TREE_BUILDERS.keys()))],
  770. "distance_threshold": [Interval(Real, 0, None, closed="left"), None],
  771. "compute_distances": ["boolean"],
  772. }
  773. def __init__(
  774. self,
  775. n_clusters=2,
  776. *,
  777. affinity="deprecated", # TODO(1.4): Remove
  778. metric=None, # TODO(1.4): Set to "euclidean"
  779. memory=None,
  780. connectivity=None,
  781. compute_full_tree="auto",
  782. linkage="ward",
  783. distance_threshold=None,
  784. compute_distances=False,
  785. ):
  786. self.n_clusters = n_clusters
  787. self.distance_threshold = distance_threshold
  788. self.memory = memory
  789. self.connectivity = connectivity
  790. self.compute_full_tree = compute_full_tree
  791. self.linkage = linkage
  792. self.affinity = affinity
  793. self.metric = metric
  794. self.compute_distances = compute_distances
  795. @_fit_context(prefer_skip_nested_validation=True)
  796. def fit(self, X, y=None):
  797. """Fit the hierarchical clustering from features, or distance matrix.
  798. Parameters
  799. ----------
  800. X : array-like, shape (n_samples, n_features) or \
  801. (n_samples, n_samples)
  802. Training instances to cluster, or distances between instances if
  803. ``metric='precomputed'``.
  804. y : Ignored
  805. Not used, present here for API consistency by convention.
  806. Returns
  807. -------
  808. self : object
  809. Returns the fitted instance.
  810. """
  811. X = self._validate_data(X, ensure_min_samples=2)
  812. return self._fit(X)
  813. def _fit(self, X):
  814. """Fit without validation
  815. Parameters
  816. ----------
  817. X : ndarray of shape (n_samples, n_features) or (n_samples, n_samples)
  818. Training instances to cluster, or distances between instances if
  819. ``affinity='precomputed'``.
  820. Returns
  821. -------
  822. self : object
  823. Returns the fitted instance.
  824. """
  825. memory = check_memory(self.memory)
  826. self._metric = self.metric
  827. # TODO(1.4): Remove
  828. if self.affinity != "deprecated":
  829. if self.metric is not None:
  830. raise ValueError(
  831. "Both `affinity` and `metric` attributes were set. Attribute"
  832. " `affinity` was deprecated in version 1.2 and will be removed in"
  833. " 1.4. To avoid this error, only set the `metric` attribute."
  834. )
  835. warnings.warn(
  836. (
  837. "Attribute `affinity` was deprecated in version 1.2 and will be"
  838. " removed in 1.4. Use `metric` instead"
  839. ),
  840. FutureWarning,
  841. )
  842. self._metric = self.affinity
  843. elif self.metric is None:
  844. self._metric = "euclidean"
  845. if not ((self.n_clusters is None) ^ (self.distance_threshold is None)):
  846. raise ValueError(
  847. "Exactly one of n_clusters and "
  848. "distance_threshold has to be set, and the other "
  849. "needs to be None."
  850. )
  851. if self.distance_threshold is not None and not self.compute_full_tree:
  852. raise ValueError(
  853. "compute_full_tree must be True if distance_threshold is set."
  854. )
  855. if self.linkage == "ward" and self._metric != "euclidean":
  856. raise ValueError(
  857. f"{self._metric} was provided as metric. Ward can only "
  858. "work with euclidean distances."
  859. )
  860. tree_builder = _TREE_BUILDERS[self.linkage]
  861. connectivity = self.connectivity
  862. if self.connectivity is not None:
  863. if callable(self.connectivity):
  864. connectivity = self.connectivity(X)
  865. connectivity = check_array(
  866. connectivity, accept_sparse=["csr", "coo", "lil"]
  867. )
  868. n_samples = len(X)
  869. compute_full_tree = self.compute_full_tree
  870. if self.connectivity is None:
  871. compute_full_tree = True
  872. if compute_full_tree == "auto":
  873. if self.distance_threshold is not None:
  874. compute_full_tree = True
  875. else:
  876. # Early stopping is likely to give a speed up only for
  877. # a large number of clusters. The actual threshold
  878. # implemented here is heuristic
  879. compute_full_tree = self.n_clusters < max(100, 0.02 * n_samples)
  880. n_clusters = self.n_clusters
  881. if compute_full_tree:
  882. n_clusters = None
  883. # Construct the tree
  884. kwargs = {}
  885. if self.linkage != "ward":
  886. kwargs["linkage"] = self.linkage
  887. kwargs["affinity"] = self._metric
  888. distance_threshold = self.distance_threshold
  889. return_distance = (distance_threshold is not None) or self.compute_distances
  890. out = memory.cache(tree_builder)(
  891. X,
  892. connectivity=connectivity,
  893. n_clusters=n_clusters,
  894. return_distance=return_distance,
  895. **kwargs,
  896. )
  897. (self.children_, self.n_connected_components_, self.n_leaves_, parents) = out[
  898. :4
  899. ]
  900. if return_distance:
  901. self.distances_ = out[-1]
  902. if self.distance_threshold is not None: # distance_threshold is used
  903. self.n_clusters_ = (
  904. np.count_nonzero(self.distances_ >= distance_threshold) + 1
  905. )
  906. else: # n_clusters is used
  907. self.n_clusters_ = self.n_clusters
  908. # Cut the tree
  909. if compute_full_tree:
  910. self.labels_ = _hc_cut(self.n_clusters_, self.children_, self.n_leaves_)
  911. else:
  912. labels = _hierarchical.hc_get_heads(parents, copy=False)
  913. # copy to avoid holding a reference on the original array
  914. labels = np.copy(labels[:n_samples])
  915. # Reassign cluster numbers
  916. self.labels_ = np.searchsorted(np.unique(labels), labels)
  917. return self
  918. def fit_predict(self, X, y=None):
  919. """Fit and return the result of each sample's clustering assignment.
  920. In addition to fitting, this method also return the result of the
  921. clustering assignment for each sample in the training set.
  922. Parameters
  923. ----------
  924. X : array-like of shape (n_samples, n_features) or \
  925. (n_samples, n_samples)
  926. Training instances to cluster, or distances between instances if
  927. ``affinity='precomputed'``.
  928. y : Ignored
  929. Not used, present here for API consistency by convention.
  930. Returns
  931. -------
  932. labels : ndarray of shape (n_samples,)
  933. Cluster labels.
  934. """
  935. return super().fit_predict(X, y)
  936. class FeatureAgglomeration(
  937. ClassNamePrefixFeaturesOutMixin, AgglomerativeClustering, AgglomerationTransform
  938. ):
  939. """Agglomerate features.
  940. Recursively merges pair of clusters of features.
  941. Read more in the :ref:`User Guide <hierarchical_clustering>`.
  942. Parameters
  943. ----------
  944. n_clusters : int or None, default=2
  945. The number of clusters to find. It must be ``None`` if
  946. ``distance_threshold`` is not ``None``.
  947. affinity : str or callable, default='euclidean'
  948. The metric to use when calculating distance between instances in a
  949. feature array. If metric is a string or callable, it must be one of
  950. the options allowed by :func:`sklearn.metrics.pairwise_distances` for
  951. its metric parameter.
  952. If linkage is "ward", only "euclidean" is accepted.
  953. If "precomputed", a distance matrix (instead of a similarity matrix)
  954. is needed as input for the fit method.
  955. .. deprecated:: 1.2
  956. `affinity` was deprecated in version 1.2 and will be renamed to
  957. `metric` in 1.4.
  958. metric : str or callable, default=None
  959. Metric used to compute the linkage. Can be "euclidean", "l1", "l2",
  960. "manhattan", "cosine", or "precomputed". If set to `None` then
  961. "euclidean" is used. If linkage is "ward", only "euclidean" is
  962. accepted. If "precomputed", a distance matrix is needed as input for
  963. the fit method.
  964. .. versionadded:: 1.2
  965. memory : str or object with the joblib.Memory interface, default=None
  966. Used to cache the output of the computation of the tree.
  967. By default, no caching is done. If a string is given, it is the
  968. path to the caching directory.
  969. connectivity : array-like or callable, default=None
  970. Connectivity matrix. Defines for each feature the neighboring
  971. features following a given structure of the data.
  972. This can be a connectivity matrix itself or a callable that transforms
  973. the data into a connectivity matrix, such as derived from
  974. `kneighbors_graph`. Default is `None`, i.e, the
  975. hierarchical clustering algorithm is unstructured.
  976. compute_full_tree : 'auto' or bool, default='auto'
  977. Stop early the construction of the tree at `n_clusters`. This is useful
  978. to decrease computation time if the number of clusters is not small
  979. compared to the number of features. This option is useful only when
  980. specifying a connectivity matrix. Note also that when varying the
  981. number of clusters and using caching, it may be advantageous to compute
  982. the full tree. It must be ``True`` if ``distance_threshold`` is not
  983. ``None``. By default `compute_full_tree` is "auto", which is equivalent
  984. to `True` when `distance_threshold` is not `None` or that `n_clusters`
  985. is inferior to the maximum between 100 or `0.02 * n_samples`.
  986. Otherwise, "auto" is equivalent to `False`.
  987. linkage : {"ward", "complete", "average", "single"}, default="ward"
  988. Which linkage criterion to use. The linkage criterion determines which
  989. distance to use between sets of features. The algorithm will merge
  990. the pairs of cluster that minimize this criterion.
  991. - "ward" minimizes the variance of the clusters being merged.
  992. - "complete" or maximum linkage uses the maximum distances between
  993. all features of the two sets.
  994. - "average" uses the average of the distances of each feature of
  995. the two sets.
  996. - "single" uses the minimum of the distances between all features
  997. of the two sets.
  998. pooling_func : callable, default=np.mean
  999. This combines the values of agglomerated features into a single
  1000. value, and should accept an array of shape [M, N] and the keyword
  1001. argument `axis=1`, and reduce it to an array of size [M].
  1002. distance_threshold : float, default=None
  1003. The linkage distance threshold at or above which clusters will not be
  1004. merged. If not ``None``, ``n_clusters`` must be ``None`` and
  1005. ``compute_full_tree`` must be ``True``.
  1006. .. versionadded:: 0.21
  1007. compute_distances : bool, default=False
  1008. Computes distances between clusters even if `distance_threshold` is not
  1009. used. This can be used to make dendrogram visualization, but introduces
  1010. a computational and memory overhead.
  1011. .. versionadded:: 0.24
  1012. Attributes
  1013. ----------
  1014. n_clusters_ : int
  1015. The number of clusters found by the algorithm. If
  1016. ``distance_threshold=None``, it will be equal to the given
  1017. ``n_clusters``.
  1018. labels_ : array-like of (n_features,)
  1019. Cluster labels for each feature.
  1020. n_leaves_ : int
  1021. Number of leaves in the hierarchical tree.
  1022. n_connected_components_ : int
  1023. The estimated number of connected components in the graph.
  1024. .. versionadded:: 0.21
  1025. ``n_connected_components_`` was added to replace ``n_components_``.
  1026. n_features_in_ : int
  1027. Number of features seen during :term:`fit`.
  1028. .. versionadded:: 0.24
  1029. feature_names_in_ : ndarray of shape (`n_features_in_`,)
  1030. Names of features seen during :term:`fit`. Defined only when `X`
  1031. has feature names that are all strings.
  1032. .. versionadded:: 1.0
  1033. children_ : array-like of shape (n_nodes-1, 2)
  1034. The children of each non-leaf node. Values less than `n_features`
  1035. correspond to leaves of the tree which are the original samples.
  1036. A node `i` greater than or equal to `n_features` is a non-leaf
  1037. node and has children `children_[i - n_features]`. Alternatively
  1038. at the i-th iteration, children[i][0] and children[i][1]
  1039. are merged to form node `n_features + i`.
  1040. distances_ : array-like of shape (n_nodes-1,)
  1041. Distances between nodes in the corresponding place in `children_`.
  1042. Only computed if `distance_threshold` is used or `compute_distances`
  1043. is set to `True`.
  1044. See Also
  1045. --------
  1046. AgglomerativeClustering : Agglomerative clustering samples instead of
  1047. features.
  1048. ward_tree : Hierarchical clustering with ward linkage.
  1049. Examples
  1050. --------
  1051. >>> import numpy as np
  1052. >>> from sklearn import datasets, cluster
  1053. >>> digits = datasets.load_digits()
  1054. >>> images = digits.images
  1055. >>> X = np.reshape(images, (len(images), -1))
  1056. >>> agglo = cluster.FeatureAgglomeration(n_clusters=32)
  1057. >>> agglo.fit(X)
  1058. FeatureAgglomeration(n_clusters=32)
  1059. >>> X_reduced = agglo.transform(X)
  1060. >>> X_reduced.shape
  1061. (1797, 32)
  1062. """
  1063. _parameter_constraints: dict = {
  1064. "n_clusters": [Interval(Integral, 1, None, closed="left"), None],
  1065. "affinity": [
  1066. Hidden(StrOptions({"deprecated"})),
  1067. StrOptions(set(_VALID_METRICS) | {"precomputed"}),
  1068. callable,
  1069. ],
  1070. "metric": [
  1071. StrOptions(set(_VALID_METRICS) | {"precomputed"}),
  1072. callable,
  1073. None,
  1074. ],
  1075. "memory": [str, HasMethods("cache"), None],
  1076. "connectivity": ["array-like", callable, None],
  1077. "compute_full_tree": [StrOptions({"auto"}), "boolean"],
  1078. "linkage": [StrOptions(set(_TREE_BUILDERS.keys()))],
  1079. "pooling_func": [callable],
  1080. "distance_threshold": [Interval(Real, 0, None, closed="left"), None],
  1081. "compute_distances": ["boolean"],
  1082. }
  1083. def __init__(
  1084. self,
  1085. n_clusters=2,
  1086. *,
  1087. affinity="deprecated", # TODO(1.4): Remove
  1088. metric=None, # TODO(1.4): Set to "euclidean"
  1089. memory=None,
  1090. connectivity=None,
  1091. compute_full_tree="auto",
  1092. linkage="ward",
  1093. pooling_func=np.mean,
  1094. distance_threshold=None,
  1095. compute_distances=False,
  1096. ):
  1097. super().__init__(
  1098. n_clusters=n_clusters,
  1099. memory=memory,
  1100. connectivity=connectivity,
  1101. compute_full_tree=compute_full_tree,
  1102. linkage=linkage,
  1103. affinity=affinity,
  1104. metric=metric,
  1105. distance_threshold=distance_threshold,
  1106. compute_distances=compute_distances,
  1107. )
  1108. self.pooling_func = pooling_func
  1109. @_fit_context(prefer_skip_nested_validation=True)
  1110. def fit(self, X, y=None):
  1111. """Fit the hierarchical clustering on the data.
  1112. Parameters
  1113. ----------
  1114. X : array-like of shape (n_samples, n_features)
  1115. The data.
  1116. y : Ignored
  1117. Not used, present here for API consistency by convention.
  1118. Returns
  1119. -------
  1120. self : object
  1121. Returns the transformer.
  1122. """
  1123. X = self._validate_data(X, ensure_min_features=2)
  1124. super()._fit(X.T)
  1125. self._n_features_out = self.n_clusters_
  1126. return self
  1127. @property
  1128. def fit_predict(self):
  1129. """Fit and return the result of each sample's clustering assignment."""
  1130. raise AttributeError