clustering.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. """
  2. 文献聚类模块
  3. 对文献进行向量化和聚类分析
  4. """
  5. from typing import Dict, List, Any, Optional
  6. import numpy as np
  7. from sklearn.cluster import KMeans
  8. from sklearn.manifold import TSNE
  9. import logging
  10. from sentence_transformers import SentenceTransformer
  11. from backend.config import DEFAULT_EMBEDDING_MODEL, DEFAULT_NUM_CLUSTERS
  12. logger = logging.getLogger(__name__)
  13. class PaperClusterer:
  14. """文献聚类引擎"""
  15. def __init__(self, embedding_model: str = DEFAULT_EMBEDDING_MODEL):
  16. """
  17. 初始化聚类引擎
  18. Args:
  19. embedding_model: 使用的嵌入模型名称
  20. """
  21. self.model_name = embedding_model
  22. self._model = None # 延迟加载
  23. @property
  24. def model(self):
  25. """延迟加载嵌入模型"""
  26. if self._model is None:
  27. logger.info(f"Loading embedding model: {self.model_name}")
  28. try:
  29. self._model = SentenceTransformer(self.model_name)
  30. except Exception as e:
  31. logger.error(f"Failed to load embedding model: {str(e)}")
  32. raise
  33. return self._model
  34. def _extract_texts(self, papers: List[Dict[str, Any]]) -> List[str]:
  35. """从论文中提取文本用于嵌入"""
  36. texts = []
  37. for paper in papers:
  38. # 合并标题和摘要以获得更好的嵌入
  39. title = paper.get("title", "")
  40. summary = paper.get("summary", "")
  41. text = f"{title} {summary}"
  42. texts.append(text)
  43. return texts
  44. def cluster_papers(
  45. self,
  46. papers: List[Dict[str, Any]],
  47. num_clusters: int = DEFAULT_NUM_CLUSTERS
  48. ) -> Dict[str, Any]:
  49. """
  50. 对论文进行聚类分析
  51. Args:
  52. papers: 论文列表
  53. num_clusters: 聚类数量,0表示自动确定
  54. Returns:
  55. 聚类结果字典
  56. """
  57. if len(papers) < 3:
  58. logger.warning("Too few papers for clustering")
  59. return {
  60. "clusters": [],
  61. "cluster_info": {},
  62. "visualization_data": None,
  63. "error": "Too few papers for clustering"
  64. }
  65. try:
  66. # 1. 提取文本
  67. texts = self._extract_texts(papers)
  68. # 2. 计算嵌入
  69. logger.info(f"Computing embeddings for {len(texts)} papers")
  70. embeddings = self.model.encode(texts, show_progress_bar=True)
  71. # 3. 确定聚类数量(如果未指定)
  72. if num_clusters <= 0:
  73. # 使用论文数量的平方根作为默认聚类数
  74. num_clusters = min(max(2, int(np.sqrt(len(papers)))), 8)
  75. # 4. 聚类分析
  76. logger.info(f"Clustering into {num_clusters} clusters")
  77. kmeans = KMeans(n_clusters=num_clusters, random_state=42)
  78. cluster_labels = kmeans.fit_predict(embeddings)
  79. # 5. 可视化数据准备(使用t-SNE降维)
  80. logger.info("Generating 2D visualization data with t-SNE")
  81. tsne = TSNE(n_components=2, random_state=42)
  82. vis_data = tsne.fit_transform(embeddings)
  83. # 6. 整理聚类结果
  84. clusters = [[] for _ in range(num_clusters)]
  85. for i, (paper, label) in enumerate(zip(papers, cluster_labels)):
  86. paper_with_cluster = paper.copy()
  87. paper_with_cluster["cluster"] = int(label)
  88. paper_with_cluster["vis_x"] = float(vis_data[i][0])
  89. paper_with_cluster["vis_y"] = float(vis_data[i][1])
  90. clusters[label].append(paper_with_cluster)
  91. # 7. 提取聚类信息
  92. cluster_info = {}
  93. for i in range(num_clusters):
  94. cluster_info[i] = {
  95. "size": len(clusters[i]),
  96. "papers": clusters[i]
  97. }
  98. # 8. 构造返回结果
  99. result = {
  100. "clusters": clusters,
  101. "cluster_info": cluster_info,
  102. "visualization_data": {
  103. "x": vis_data[:, 0].tolist(),
  104. "y": vis_data[:, 1].tolist(),
  105. "labels": cluster_labels.tolist()
  106. }
  107. }
  108. return result
  109. except Exception as e:
  110. logger.error(f"Clustering error: {str(e)}", exc_info=True)
  111. return {
  112. "clusters": [],
  113. "cluster_info": {},
  114. "visualization_data": None,
  115. "error": str(e)
  116. }