Browse Source

feat: /api/agent/retrive

haitao 3 months ago
parent
commit
c669457b42

+ 40 - 0
rag-server/api/agent/retrive/routes.js

@@ -0,0 +1,40 @@
+// routes.js
+const express = require('express');
+const router = express.Router();
+
+// 数据库连接配置
+const psqlDB = require("../../../psql.service");
+
+// 计算余弦相似度的 SQL 查询
+const cosineSimilarityQuery = `
+    SELECT "objectId","story","pageContent",(1 - (vector_array::vector <=> $1::vector)) AS similarity
+    FROM (
+        SELECT *, 
+            (SELECT array_agg(value::float) 
+                FROM jsonb_array_elements(vector512) AS value) AS vector_array
+        FROM "Document"
+        WHERE ($2 IS NULL OR story = $2)
+    ) AS subquery
+    ORDER BY similarity
+    LIMIT 20;
+`;
+
+// POST 路由处理
+router.post('/retrive', async (req, res) => {
+    const { search, vector512, story } = req.body;
+
+    if (!Array.isArray(vector512)) {
+        return res.status(400).json({ error: 'Invalid input' });
+    }
+
+    try {
+        const result = await psqlDB.any(cosineSimilarityQuery, [vector512, story]);
+        res.json(result);
+    } catch (error) {
+        console.error('Database query error:', error);
+        console.error(error)
+        res.status(500).json({ error: 'Database query error' });
+    }
+});
+
+module.exports = router;

+ 2 - 0
rag-server/dev-server.js

@@ -30,6 +30,8 @@ async function main(){
   // 加载Agent专用路由 
   const pdfRouter = require('./api/agent/loader/routes'); // 根据你的文件结构调整路径
   app.use('/api/agent', pdfRouter); // 使用路由
+  const retriveRouter = require('./api/agent/retrive/routes'); // 根据你的文件结构调整路径
+  app.use('/api/agent', retriveRouter); // 使用路由
 
   const port = 1337;
   app.listen(port, function() {

+ 11 - 0
rag-server/logs/parse-server.info.2024-12-17

@@ -2,3 +2,14 @@
 {"level":"warn","message":"DeprecationWarning: The Parse Server option 'encodeParseObjectInCloudFunction' default will change to 'true' in a future version.","timestamp":"2024-12-16T16:07:29.316Z"}
 {"level":"warn","message":"DeprecationWarning: The Parse Server option 'encodeParseObjectInCloudFunction' default will change to 'true' in a future version.","timestamp":"2024-12-16T16:07:39.449Z"}
 {"level":"warn","message":"DeprecationWarning: The Parse Server option 'encodeParseObjectInCloudFunction' default will change to 'true' in a future version.","timestamp":"2024-12-16T16:08:27.856Z"}
+{"level":"warn","message":"DeprecationWarning: The Parse Server option 'encodeParseObjectInCloudFunction' default will change to 'true' in a future version.","timestamp":"2024-12-17T01:45:32.549Z"}
+{"level":"warn","message":"DeprecationWarning: The Parse Server option 'encodeParseObjectInCloudFunction' default will change to 'true' in a future version.","timestamp":"2024-12-17T01:49:46.296Z"}
+{"level":"warn","message":"DeprecationWarning: The Parse Server option 'encodeParseObjectInCloudFunction' default will change to 'true' in a future version.","timestamp":"2024-12-17T01:50:19.234Z"}
+{"level":"warn","message":"DeprecationWarning: The Parse Server option 'encodeParseObjectInCloudFunction' default will change to 'true' in a future version.","timestamp":"2024-12-17T01:52:30.826Z"}
+{"level":"warn","message":"DeprecationWarning: The Parse Server option 'encodeParseObjectInCloudFunction' default will change to 'true' in a future version.","timestamp":"2024-12-17T01:54:34.880Z"}
+{"level":"warn","message":"DeprecationWarning: The Parse Server option 'encodeParseObjectInCloudFunction' default will change to 'true' in a future version.","timestamp":"2024-12-17T01:56:45.422Z"}
+{"level":"warn","message":"DeprecationWarning: The Parse Server option 'encodeParseObjectInCloudFunction' default will change to 'true' in a future version.","timestamp":"2024-12-17T01:57:43.188Z"}
+{"level":"warn","message":"DeprecationWarning: The Parse Server option 'encodeParseObjectInCloudFunction' default will change to 'true' in a future version.","timestamp":"2024-12-17T01:58:22.670Z"}
+{"level":"warn","message":"DeprecationWarning: The Parse Server option 'encodeParseObjectInCloudFunction' default will change to 'true' in a future version.","timestamp":"2024-12-17T01:58:42.552Z"}
+{"level":"warn","message":"DeprecationWarning: The Parse Server option 'encodeParseObjectInCloudFunction' default will change to 'true' in a future version.","timestamp":"2024-12-17T01:59:26.544Z"}
+{"level":"warn","message":"DeprecationWarning: The Parse Server option 'encodeParseObjectInCloudFunction' default will change to 'true' in a future version.","timestamp":"2024-12-17T02:00:13.328Z"}

+ 1 - 1
src/lib/ncloud.ts

@@ -134,7 +134,7 @@ export class CloudQuery {
         return json || {};
     }
 
-    async find() {
+    async find():Promise<Array<CloudObject>> {
         let url = `http://dev.fmode.cn:1337/parse/classes/${this.className}?`;
 
         let queryStr = ``

+ 17 - 1
src/modules/story/page-hangzhou/page-hangzhou.component.html

@@ -40,7 +40,23 @@
   }
 
   @if(tab=="retrive"){
-   
+   <ion-list>
+    <ion-item>
+      <ion-textarea [value]="userInput" (ionChange)="inputChange($event)" label="用户输入" placeholder="请输入您的问题"></ion-textarea>
+    </ion-item>
+    <ion-item>
+      <ion-button (click)="retriveFrontEnd()">前端检索</ion-button>
+      <ion-button (click)="retriveBackEnd()">后端检索</ion-button>
+    </ion-item>
+    @if(searchDocList?.length){
+      @for(doc of searchDocList;track doc){
+
+        <ion-item>
+          相似度:{{doc.similarity}} 内容:{{doc.pageContent}}
+        </ion-item>
+      }
+    }
+   </ion-list>
   }
 
   @if(tab=="story"){

+ 40 - 6
src/modules/story/page-hangzhou/page-hangzhou.component.ts

@@ -1,8 +1,9 @@
 import { Component, OnInit } from '@angular/core';
 import { CommonModule } from '@angular/common';
     
-import { IonContent,IonButton,IonSegment,IonSegmentButton,IonLabel,IonList,IonItem } from "@ionic/angular/standalone";
-import { AgentStory } from '../story-loader/story.loader';
+import { IonContent,IonButton,IonSegment,IonSegmentButton,IonLabel,IonList,IonItem,IonTextarea } from "@ionic/angular/standalone";
+import { AgentStory, EmbedQuery, RetriveAllDocument } from '../story-loader/story.loader';
+import { CloudApi } from 'src/lib/ncloud';
 
 @Component({
   selector: 'app-page-hangzhou',
@@ -13,7 +14,7 @@ import { AgentStory } from '../story-loader/story.loader';
       CommonModule,
       IonContent,IonButton,
       IonSegment,IonSegmentButton,IonLabel,
-      IonList,IonItem
+      IonList,IonItem,IonTextarea,
     ]
 })
 export class PageHangzhouComponent  implements OnInit {
@@ -31,10 +32,12 @@ export class PageHangzhouComponent  implements OnInit {
   fileList:Array<any> = [
     {
       title:`市委办公厅 市政府办公厅印发《关于服务保障“抓防控促发展”落实“人才生态37条”的补充意见》的通知`,
+      tags:["杭州","人才政策"],
       url:`https://app.fmode.cn/dev/jxnu/case/2020%E5%B9%B4%E6%9D%AD%E5%B7%9E%E5%B8%82%E4%BA%BA%E6%89%8D37%E6%9D%A1.docx`
     },
     {
       title:"杭州市余杭区服务保障高层次人才创新创业政策实施细则",
+      tags:["杭州","余杭","人才政策","创新创业","双创"],
       url:"https://app.fmode.cn/dev/jxnu/case/2022年杭州余杭.docx"
     }
   ]
@@ -43,17 +46,48 @@ export class PageHangzhouComponent  implements OnInit {
     window.open(file.url,"_blank")
   }
   async loader(file:any){
-    let story = new AgentStory();
+    let story = new AgentStory(file);
     await story.loader(file.url);
-    console.log(story.content);
+    console.log(story);
     this.storyMap[file?.url] = story;
   }
   async splitter(story:AgentStory){
-
     await story.splitter();
   }
   async embedings(story:AgentStory){
     await story.embedings()
   }
 
+  /**
+   * 文本向量检索
+   */
+  userInput:string = "";
+  searchDocList:Array<any> = []
+  inputChange(ev:any){
+    this.userInput = ev.detail.value
+  }
+  async retriveFrontEnd(){
+    let storyList = Object.values(this.storyMap);
+    let docList:any = []
+    storyList.forEach((story:any)=>{
+      docList = docList.concat(story.docList)
+    })
+    let vector512 = await EmbedQuery(this.userInput)
+    console.log(this.userInput,vector512)
+    console.log(docList)
+    this.searchDocList = RetriveAllDocument(vector512,docList)
+    console.log("searchDocList",this.searchDocList)
+  }
+
+  async retriveBackEnd(){
+    let vector512 = await EmbedQuery(this.userInput)
+    console.log(this.userInput,vector512)
+
+    let api = new CloudApi()
+    let result = await api.fetch("agent/retrive",{
+      vector512:vector512
+    })
+    console.log(result)
+  }
+
 }

+ 135 - 8
src/modules/story/story-loader/story.loader.ts

@@ -1,6 +1,6 @@
 // import pdf from 'pdf-parse';
 // import fs from 'fs';
-import { CloudApi } from 'src/lib/ncloud';
+import { CloudApi, CloudObject, CloudQuery } from 'src/lib/ncloud';
 import mammoth from "mammoth";
 import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
 import { Document } from '@langchain/core/documents';
@@ -14,12 +14,47 @@ import { TensorFlowEmbeddings } from "@langchain/community/embeddings/tensorflow
 
 export class AgentStory{
 
-    url:string = ""
-    content:string = ""
-    docList:Array<Document> = []
-    constructor(){
+    story:CloudObject|undefined
+    // 文件标题
+    title:string|undefined = ""
+    // 文档标签
+    tags:Array<string>|undefined
+    // 文件源地址
+    url:string|undefined = ""
+    // 文档完整纯文本内容
+    content:string|undefined = ""
+    // 文档hash唯一值
+    hash:string|undefined = ""
+    // 文档分割后的列表
+    docList:Array<Document|any> = []
+
+    constructor(metadata:{
+        url:string,
+        title?:string,
+        tags?:Array<string>
+    }){
+        this.url = metadata.url
+        this.title = metadata.title
+        this.tags = metadata.tags
         setBackend()
     }
+    async save(){
+        if(!this.hash){ return }
+        let query = new CloudQuery("Story");
+        query.equalTo("hash",this.hash);
+        let story = await query.first();
+        if(!story?.id){
+            story = new CloudObject("Story");
+        }
+        story.set({
+            title: this.title,
+            url: this.url,
+            content: this.content,
+            hash: this.hash,
+            tags:this.tags
+        })
+        this.story = await story.save();
+    }
     async loader(url:string){
         let api = new CloudApi();
 
@@ -34,6 +69,7 @@ export class AgentStory{
         if(this.content){
             this.url = url
         }
+        this.save();
         return this.content
     }
 
@@ -49,8 +85,12 @@ export class AgentStory{
         } catch (err) {
             console.error(err);
         }
+
+        this.hash = await arrayBufferToHASH(arrayBuffer)
+
         // let html = mammoth.convertToHtml(buffer)
         data = text?.value || "";
+        // 正则匹配所有 多个\n换行的字符 替换成一次换行
         data = data.replaceAll(/\n+/g,"\n") // 剔除多余换行
         return {data}
     }
@@ -62,7 +102,7 @@ export class AgentStory{
         // 默认:递归字符文本分割器
         let splitter = new RecursiveCharacterTextSplitter({
             chunkSize: options?.chunkSize || 500,
-            chunkOverlap: options?.chunkOverlap || 100,
+            chunkOverlap: options?.chunkOverlap || 150,
         });
           
         let docOutput = await splitter.splitDocuments([
@@ -79,14 +119,51 @@ export class AgentStory{
      * https://js.langchain.com/docs/integrations/text_embedding/tensorflow/
      * @returns 
      */
-    //  embedding vector(1536) NOT NULL -- NOTE: 1536 for ChatGPT
+    //  TensorFlow embedding vector(512) NOT NULL -- NOTE: 512 for Tensorflow
+    //  OpenAI embedding vector(1536) NOT NULL -- NOTE: 1536 for ChatGPT
     async embedings(){
         if(!this.docList?.length){return}
         const embeddings = new TensorFlowEmbeddings();
         let documentRes = await embeddings.embedDocuments(this.docList?.map(item=>item.pageContent));
         console.log(documentRes);
+
+        // 向量持久化
+        documentRes.forEach(async (vector512:any,index)=>{
+            /**
+             * metadata
+             * pageContent
+             */
+            let document = this.docList[index]
+            this.docList[index].vector512 = vector512
+            let hash = await arrayBufferToHASH(stringToArrayBuffer(document?.pageContent))
+            let query = new CloudQuery("Document");
+            query.equalTo("hash",hash);
+            let docObj = await query.first()
+            if(!docObj?.id){
+                docObj = new CloudObject("Document");
+            }
+            docObj.set({
+                metadata:document?.metadata,
+                pageContent:document?.pageContent,
+                vector512:vector512,
+                hash:hash,
+                story:this.story?.toPointer(),
+            })
+            docObj.save();
+        })
         return documentRes;
     }
+    async destoryAllDocument(){
+        if(this.story?.id){
+            let query = new CloudQuery("Document");
+            query.equalTo("story",this.story?.id);
+            let docList = await query.find();
+            docList.forEach(doc=>{
+                doc.destroy();
+            })
+        }
+        
+    }
 }
 
 export async function fetchFileBuffer(url: string): Promise<Buffer> {
@@ -132,4 +209,54 @@ async function setBackend(){
         backend&&await tf.setBackend(backend);
         await tf.ready();
         return
-  }
+  }
+
+  export async function arrayBufferToHASH(arrayBuffer:any) {
+    // 使用 SubtleCrypto API 计算哈希
+    const hashBuffer = await crypto.subtle.digest('SHA-256', arrayBuffer); // 使用 SHA-256 代替 MD5
+    const hashArray = Array.from(new Uint8Array(hashBuffer)); // 将缓冲区转换为字节数组
+    const hashHex = hashArray.map(b => ('00' + b.toString(16)).slice(-2)).join(''); // 转换为十六进制字符串
+    return hashHex;
+}
+export function stringToArrayBuffer(str:string) {
+    // 创建一个与字符串长度相同的Uint8Array
+    const encoder = new TextEncoder();
+    return encoder.encode(str).buffer;
+}
+export async function EmbedQuery(str:any):Promise<Array<number>>{
+    const embeddings = new TensorFlowEmbeddings();
+    let documentRes = await embeddings.embedQuery(str);
+    return documentRes
+}
+
+/** 向量余弦相似度计算 */
+export function RetriveAllDocument(vector1: Array<number>, docList: Array<any>): Array<any> {
+    docList.forEach(doc => {
+        const vector512 = doc.vector512;
+        doc.similarity = cosineSimilarity(vector1, vector512); // 计算余弦相似度并存储
+    });
+
+    // 按照相似度排序,降序排列
+    docList.sort((a, b) => b.similarity - a.similarity);
+
+    return docList; // 返回排序后的docList
+}
+function dotProduct(vectorA: number[], vectorB: number[]): number {
+    return vectorA.reduce((sum, value, index) => sum + value * vectorB[index], 0);
+}
+
+function magnitude(vector: number[]): number {
+    return Math.sqrt(vector.reduce((sum, value) => sum + value * value, 0));
+}
+
+function cosineSimilarity(vectorA: number[], vectorB: number[]): number {
+    const dotProd = dotProduct(vectorA, vectorB);
+    const magnitudeA = magnitude(vectorA);
+    const magnitudeB = magnitude(vectorB);
+
+    if (magnitudeA === 0 || magnitudeB === 0) {
+        throw new Error("One or both vectors are zero vectors, cannot compute cosine similarity.");
+    }
+
+    return dotProd / (magnitudeA * magnitudeB);
+}