TangCheng
2025-02-28 d787e447e95c7b897c2cc9c0e832f8d2e5084934
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
package com.smtaiserver.smtaiserver.javaai.llm.core;
 
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
 
import com.smtaiserver.smtaiserver.core.SMTAIServerApp;
import com.smtaiserver.smtaiserver.core.SMTAIServerRequest;
import com.smtaiserver.smtaiserver.database.SMTDatabase;
import com.smtaiserver.smtaiserver.database.SMTDatabase.DBRecord;
import com.smtaiserver.smtaiserver.database.SMTDatabase.DBRecords;
import com.smtservlet.util.SMTStatic;
 
public abstract class SMTLLMConnect 
{
    private static Pattern         _patScope = Pattern.compile("^分数:(\\d+)");
    
    public abstract void close();
    public abstract String getVector(String text) throws Exception;
    public abstract String callWithMessage(Object listSysMsg, Object userMsg, SMTAIServerRequest tranReq)   throws Exception;
 
    
    public String callWithVector(String question, String rawQuestion, String vectorSQL, int minScope, SMTAIServerRequest tranReq, List<String[]> r_contexts) throws Exception
    {
        if(vectorSQL == null)
        {
            vectorSQL = 
                  " SELECT * FROM ("
                + "        SELECT vector_title, vector_message, ?::vector <=> vector_value AS ratio FROM vector_doc WHERE group_id='test1') T"
                + " ORDER BY ratio LIMIT 4"
                ;
        }
        
        List<String> sysMesssage = new ArrayList<>();
        sysMesssage.add("判定问题和关联信息的匹配度,匹配分数从0分到100分。严格按照输出格式输出内容,不要添加任何不属于分数的内容。\n输出格式:\n分数:80\n");
        
        long tick = System.currentTimeMillis();
        Set<String> setContext = new HashSet<>();
        String vector = getVector(question);
        tranReq.traceLLMDebug("callWithVector回答获取向量:[" + ((double)(System.currentTimeMillis() - tick) / 1000) + "秒][" + vector + "]");
        SMTDatabase db = SMTAIServerApp.getApp().allocDatabase();
        try
        {
            tick = System.currentTimeMillis();
            DBRecords recs = db.querySQL(vectorSQL, new Object[] {
                "[" + vector + "]"    
            });
            tranReq.traceLLMDebug("callWithVector回答查询向量:[" + recs.getRowCount() + "条][" + ((double)(System.currentTimeMillis() - tick) / 1000) + "秒]");
            
            for(DBRecord rec : recs.getRecords())
            {
                int scope;
                tick = System.currentTimeMillis();
                String ratio = callWithMessage(sysMesssage, "问题:" + question + "\n关联信息:" + rec.getString("vector_title") + "\n", tranReq);
                tranReq.traceLLMDebug("callWithVector回答对比匹配度:[" + ((double)(System.currentTimeMillis() - tick) / 1000) + "秒][" + ratio + "][" + rec.getString("vector_title") + "]");
                try
                {
                    if(ratio == null)
                        continue;
                    Matcher m = _patScope.matcher(ratio);
                    if(!m.find())
                        continue;
                    scope = Integer.parseInt(m.group(1));
                    
                    if(scope < minScope)
                        continue;
                }
                catch(Exception ex)
                {
                    continue;
                }
                r_contexts.add(new String[] {
                    "匹配" + SMTStatic.toInt(scope) + "分 : " + rec.getString("vector_title"),
                    rec.getString("vector_message")
                });
                setContext.add(rec.getString("vector_message"));
            }
            
            StringBuilder sbContext = new StringBuilder();
            for(String context : setContext)
            {
                sbContext.append(context + "\n");
            }
            
            tick = System.currentTimeMillis();
            List<String> askMesssage = new ArrayList<>();
            askMesssage.add("以获取的上下文作为提纲,组织成一份结构完整的回答报告\n提供的上下文如下:\n" + sbContext.toString());
            String answer = callWithMessage(askMesssage, rawQuestion != null ? new String[] {question, rawQuestion} : question, tranReq);
            tranReq.traceLLMDebug("callWithVector回答向量结果:[" + question + "] 时间: " + ((double)(System.currentTimeMillis() - tick) / 1000));
            return answer;
        }
        finally
        {
            db.close();
        }
    }
 
}