123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- """
- 文献聚类模块
- 对文献进行向量化和聚类分析
- """
- from typing import Dict, List, Any, Optional
- import numpy as np
- from sklearn.cluster import KMeans
- from sklearn.manifold import TSNE
- import logging
- from sentence_transformers import SentenceTransformer
- from backend.config import DEFAULT_EMBEDDING_MODEL, DEFAULT_NUM_CLUSTERS
- logger = logging.getLogger(__name__)
- class PaperClusterer:
- """文献聚类引擎"""
-
- def __init__(self, embedding_model: str = DEFAULT_EMBEDDING_MODEL):
- """
- 初始化聚类引擎
-
- Args:
- embedding_model: 使用的嵌入模型名称
- """
- self.model_name = embedding_model
- self._model = None # 延迟加载
-
- @property
- def model(self):
- """延迟加载嵌入模型"""
- if self._model is None:
- logger.info(f"Loading embedding model: {self.model_name}")
- try:
- self._model = SentenceTransformer(self.model_name)
- except Exception as e:
- logger.error(f"Failed to load embedding model: {str(e)}")
- raise
- return self._model
-
- def _extract_texts(self, papers: List[Dict[str, Any]]) -> List[str]:
- """从论文中提取文本用于嵌入"""
- texts = []
- for paper in papers:
- # 合并标题和摘要以获得更好的嵌入
- title = paper.get("title", "")
- summary = paper.get("summary", "")
- text = f"{title} {summary}"
- texts.append(text)
- return texts
-
- def cluster_papers(
- self,
- papers: List[Dict[str, Any]],
- num_clusters: int = DEFAULT_NUM_CLUSTERS
- ) -> Dict[str, Any]:
- """
- 对论文进行聚类分析
-
- Args:
- papers: 论文列表
- num_clusters: 聚类数量,0表示自动确定
-
- Returns:
- 聚类结果字典
- """
- if len(papers) < 3:
- logger.warning("Too few papers for clustering")
- return {
- "clusters": [],
- "cluster_info": {},
- "visualization_data": None,
- "error": "Too few papers for clustering"
- }
-
- try:
- # 1. 提取文本
- texts = self._extract_texts(papers)
-
- # 2. 计算嵌入
- logger.info(f"Computing embeddings for {len(texts)} papers")
- embeddings = self.model.encode(texts, show_progress_bar=True)
-
- # 3. 确定聚类数量(如果未指定)
- if num_clusters <= 0:
- # 使用论文数量的平方根作为默认聚类数
- num_clusters = min(max(2, int(np.sqrt(len(papers)))), 8)
-
- # 4. 聚类分析
- logger.info(f"Clustering into {num_clusters} clusters")
- kmeans = KMeans(n_clusters=num_clusters, random_state=42)
- cluster_labels = kmeans.fit_predict(embeddings)
-
- # 5. 可视化数据准备(使用t-SNE降维)
- logger.info("Generating 2D visualization data with t-SNE")
- tsne = TSNE(n_components=2, random_state=42)
- vis_data = tsne.fit_transform(embeddings)
-
- # 6. 整理聚类结果
- clusters = [[] for _ in range(num_clusters)]
- for i, (paper, label) in enumerate(zip(papers, cluster_labels)):
- paper_with_cluster = paper.copy()
- paper_with_cluster["cluster"] = int(label)
- paper_with_cluster["vis_x"] = float(vis_data[i][0])
- paper_with_cluster["vis_y"] = float(vis_data[i][1])
- clusters[label].append(paper_with_cluster)
-
- # 7. 提取聚类信息
- cluster_info = {}
- for i in range(num_clusters):
- cluster_info[i] = {
- "size": len(clusters[i]),
- "papers": clusters[i]
- }
-
- # 8. 构造返回结果
- result = {
- "clusters": clusters,
- "cluster_info": cluster_info,
- "visualization_data": {
- "x": vis_data[:, 0].tolist(),
- "y": vis_data[:, 1].tolist(),
- "labels": cluster_labels.tolist()
- }
- }
-
- return result
-
- except Exception as e:
- logger.error(f"Clustering error: {str(e)}", exc_info=True)
- return {
- "clusters": [],
- "cluster_info": {},
- "visualization_data": None,
- "error": str(e)
- }
|