vectorizer.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. from transformers import AutoModel, AutoTokenizer
  2. import os
  3. import torch
  4. import faiss
  5. import numpy as np
  6. import pickle
  7. import re
  8. from sklearn.preprocessing import normalize
  9. # 指定本地模型路径和资源路径
  10. model_path = r"D:/STUDY/Project/jizhouyao/RAG/models/all-MiniLM-L6-v2"# 本地存储的预训练模型路径
  11. resources_path = r"D:/STUDY/Project/jizhouyao/RAG/resources"# 文本文件存储路径
  12. faiss_index_path = r"D:/STUDY/Project/jizhouyao/RAG/index/vector_index.faiss"# FAISS 索引的保存路径
  13. texts_pickle_path = r"D:/STUDY/Project/jizhouyao/RAG/index/texts.pkl"# 文本数据保存路径
  14. # 检查路径是否存在
  15. if not os.path.exists(model_path):
  16. raise FileNotFoundError(f"模型路径不存在: {model_path}")
  17. if not os.path.exists(resources_path):
  18. raise FileNotFoundError(f"资源文件夹不存在: {resources_path}")
  19. # 加载模型和分词器
  20. tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
  21. model = AutoModel.from_pretrained(model_path, local_files_only=True)
  22. # 文本存储
  23. all_texts = [] # 存储所有文本内容
  24. # 分句函数:使用正则表达式分割中文句子
  25. def split_sentences(text):
  26. sentences = re.split(r'[' '。!?;::\n]', text) # 基于标点符号分割
  27. return [sentence.strip() for sentence in sentences if sentence.strip()]
  28. # 向量化文本并动态选择索引类型
  29. def vectorize_and_store(texts):
  30. """
  31. 将文本分句并进行向量化,存入 FAISS 索引(句子级别)
  32. """
  33. print("开始向量化文本句子...")
  34. embeddings = [] # 存储向量
  35. global all_sentences # 存储所有句子及其元数据
  36. all_sentences = [] # 重置存储
  37. # 遍历所有文本文件
  38. for idx, text_data in enumerate(texts):
  39. file_name = text_data["name"] # 文件名
  40. content = text_data["content"] # 文件内容
  41. # 分句
  42. sentences = split_sentences(content) # 使用分句函数拆分文本
  43. for sentence in sentences:
  44. if not sentence.strip(): # 跳过空句子
  45. continue
  46. # 向量化每个句子
  47. inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=512)
  48. with torch.no_grad():
  49. outputs = model(**inputs)
  50. cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
  51. # 归一化句子向量
  52. cls_embedding = normalize(cls_embedding.reshape(1, -1)).flatten()
  53. embeddings.append(cls_embedding)
  54. # 存储句子及其来源元数据
  55. all_sentences.append({
  56. "file": file_name,
  57. "sentence": sentence
  58. })
  59. # 将所有向量转换为 numpy 数组
  60. embeddings = np.array(embeddings)
  61. # 动态选择 FAISS 索引类型
  62. if len(embeddings) < 10:
  63. print("数据量较小,使用 IndexFlatL2 进行索引...")
  64. index = faiss.IndexFlatL2(embeddings.shape[1]) # 无需训练的索引
  65. else:
  66. num_clusters = max(1, min(10, len(embeddings) // 10)) # 根据数据量动态设置聚类数
  67. print(f"数据量较大,使用 IVFFlat 索引,聚类数: {num_clusters}")
  68. quantizer = faiss.IndexFlatL2(embeddings.shape[1])
  69. index = faiss.IndexIVFFlat(quantizer, embeddings.shape[1], num_clusters, faiss.METRIC_L2)
  70. print("训练 IVFFlat 索引...")
  71. index.train(embeddings) # 训练索引
  72. # 添加向量到索引
  73. index.add(embeddings)
  74. print(f"已完成 {len(embeddings)} 条句子的向量化与索引!")
  75. return index
  76. # 保存索引和文本内容
  77. def save_index_and_texts(index):
  78. os.makedirs(os.path.dirname(faiss_index_path), exist_ok=True)
  79. faiss.write_index(index, faiss_index_path)
  80. with open(texts_pickle_path, 'wb') as f:
  81. pickle.dump(all_texts, f)
  82. print("FAISS 索引和文本数据已保存!")
  83. # 读取资源文件
  84. def process_txt_files(resources_path):
  85. for filename in os.listdir(resources_path):
  86. if filename.endswith('.txt'):
  87. file_path = os.path.join(resources_path, filename)
  88. print(f"正在处理文件: {file_path}")
  89. with open(file_path, 'r', encoding='utf-8') as file:
  90. all_texts.append({"name": filename, "content": file.read()})
  91. # 主函数
  92. if __name__ == '__main__':
  93. process_txt_files(resources_path) # 读取文本数据
  94. index = vectorize_and_store(all_texts) # 向量化并构建索引
  95. save_index_and_texts(index) # 保存索引和文本数据
  96. print("索引构建完成!")