""" 文献聚类模块 对文献进行向量化和聚类分析 """ from typing import Dict, List, Any, Optional import numpy as np from sklearn.cluster import KMeans from sklearn.manifold import TSNE import logging from sentence_transformers import SentenceTransformer from backend.config import DEFAULT_EMBEDDING_MODEL, DEFAULT_NUM_CLUSTERS logger = logging.getLogger(__name__) class PaperClusterer: """文献聚类引擎""" def __init__(self, embedding_model: str = DEFAULT_EMBEDDING_MODEL): """ 初始化聚类引擎 Args: embedding_model: 使用的嵌入模型名称 """ self.model_name = embedding_model self._model = None # 延迟加载 @property def model(self): """延迟加载嵌入模型""" if self._model is None: logger.info(f"Loading embedding model: {self.model_name}") try: self._model = SentenceTransformer(self.model_name) except Exception as e: logger.error(f"Failed to load embedding model: {str(e)}") raise return self._model def _extract_texts(self, papers: List[Dict[str, Any]]) -> List[str]: """从论文中提取文本用于嵌入""" texts = [] for paper in papers: # 合并标题和摘要以获得更好的嵌入 title = paper.get("title", "") summary = paper.get("summary", "") text = f"{title} {summary}" texts.append(text) return texts def cluster_papers( self, papers: List[Dict[str, Any]], num_clusters: int = DEFAULT_NUM_CLUSTERS ) -> Dict[str, Any]: """ 对论文进行聚类分析 Args: papers: 论文列表 num_clusters: 聚类数量,0表示自动确定 Returns: 聚类结果字典 """ if len(papers) < 3: logger.warning("Too few papers for clustering") return { "clusters": [], "cluster_info": {}, "visualization_data": None, "error": "Too few papers for clustering" } try: # 1. 提取文本 texts = self._extract_texts(papers) # 2. 计算嵌入 logger.info(f"Computing embeddings for {len(texts)} papers") embeddings = self.model.encode(texts, show_progress_bar=True) # 3. 确定聚类数量(如果未指定) if num_clusters <= 0: # 使用论文数量的平方根作为默认聚类数 num_clusters = min(max(2, int(np.sqrt(len(papers)))), 8) # 4. 聚类分析 logger.info(f"Clustering into {num_clusters} clusters") kmeans = KMeans(n_clusters=num_clusters, random_state=42) cluster_labels = kmeans.fit_predict(embeddings) # 5. 可视化数据准备(使用t-SNE降维) logger.info("Generating 2D visualization data with t-SNE") tsne = TSNE(n_components=2, random_state=42) vis_data = tsne.fit_transform(embeddings) # 6. 整理聚类结果 clusters = [[] for _ in range(num_clusters)] for i, (paper, label) in enumerate(zip(papers, cluster_labels)): paper_with_cluster = paper.copy() paper_with_cluster["cluster"] = int(label) paper_with_cluster["vis_x"] = float(vis_data[i][0]) paper_with_cluster["vis_y"] = float(vis_data[i][1]) clusters[label].append(paper_with_cluster) # 7. 提取聚类信息 cluster_info = {} for i in range(num_clusters): cluster_info[i] = { "size": len(clusters[i]), "papers": clusters[i] } # 8. 构造返回结果 result = { "clusters": clusters, "cluster_info": cluster_info, "visualization_data": { "x": vis_data[:, 0].tolist(), "y": vis_data[:, 1].tolist(), "labels": cluster_labels.tolist() } } return result except Exception as e: logger.error(f"Clustering error: {str(e)}", exc_info=True) return { "clusters": [], "cluster_info": {}, "visualization_data": None, "error": str(e) }