package com.smtaiserver.smtaiserver.javaai.llm.core;
|
|
import java.util.ArrayList;
|
import java.util.HashSet;
|
import java.util.List;
|
import java.util.Set;
|
import java.util.regex.Matcher;
|
import java.util.regex.Pattern;
|
|
import com.smtaiserver.smtaiserver.core.SMTAIServerApp;
|
import com.smtaiserver.smtaiserver.core.SMTAIServerRequest;
|
import com.smtaiserver.smtaiserver.database.SMTDatabase;
|
import com.smtaiserver.smtaiserver.database.SMTDatabase.DBRecord;
|
import com.smtaiserver.smtaiserver.database.SMTDatabase.DBRecords;
|
import com.smtservlet.util.SMTStatic;
|
|
public abstract class SMTLLMConnect
|
{
|
private static Pattern _patScope = Pattern.compile("^分数:(\\d+)");
|
|
public abstract void close();
|
public abstract String getVector(String text) throws Exception;
|
public abstract String callWithMessage(Object listSysMsg, Object userMsg, SMTAIServerRequest tranReq) throws Exception;
|
|
|
public String callWithVector(String question, String rawQuestion, String vectorSQL, int minScope, SMTAIServerRequest tranReq, List<String[]> r_contexts) throws Exception
|
{
|
if(vectorSQL == null)
|
{
|
vectorSQL =
|
" SELECT * FROM ("
|
+ " SELECT vector_title, vector_message, ?::vector <=> vector_value AS ratio FROM vector_doc WHERE group_id='test1') T"
|
+ " ORDER BY ratio LIMIT 4"
|
;
|
}
|
|
List<String> sysMesssage = new ArrayList<>();
|
sysMesssage.add("判定问题和关联信息的匹配度,匹配分数从0分到100分。严格按照输出格式输出内容,不要添加任何不属于分数的内容。\n输出格式:\n分数:80\n");
|
|
long tick = System.currentTimeMillis();
|
Set<String> setContext = new HashSet<>();
|
String vector = getVector(question);
|
tranReq.traceLLMDebug("callWithVector回答获取向量:[" + ((double)(System.currentTimeMillis() - tick) / 1000) + "秒][" + vector + "]");
|
SMTDatabase db = SMTAIServerApp.getApp().allocDatabase();
|
try
|
{
|
tick = System.currentTimeMillis();
|
DBRecords recs = db.querySQL(vectorSQL, new Object[] {
|
"[" + vector + "]"
|
});
|
tranReq.traceLLMDebug("callWithVector回答查询向量:[" + recs.getRowCount() + "条][" + ((double)(System.currentTimeMillis() - tick) / 1000) + "秒]");
|
|
for(DBRecord rec : recs.getRecords())
|
{
|
int scope;
|
tick = System.currentTimeMillis();
|
String ratio = callWithMessage(sysMesssage, "问题:" + question + "\n关联信息:" + rec.getString("vector_title") + "\n", tranReq);
|
tranReq.traceLLMDebug("callWithVector回答对比匹配度:[" + ((double)(System.currentTimeMillis() - tick) / 1000) + "秒][" + ratio + "][" + rec.getString("vector_title") + "]");
|
try
|
{
|
if(ratio == null)
|
continue;
|
Matcher m = _patScope.matcher(ratio);
|
if(!m.find())
|
continue;
|
scope = Integer.parseInt(m.group(1));
|
|
if(scope < minScope)
|
continue;
|
}
|
catch(Exception ex)
|
{
|
continue;
|
}
|
r_contexts.add(new String[] {
|
"匹配" + SMTStatic.toInt(scope) + "分 : " + rec.getString("vector_title"),
|
rec.getString("vector_message")
|
});
|
setContext.add(rec.getString("vector_message"));
|
}
|
|
StringBuilder sbContext = new StringBuilder();
|
for(String context : setContext)
|
{
|
sbContext.append(context + "\n");
|
}
|
|
tick = System.currentTimeMillis();
|
List<String> askMesssage = new ArrayList<>();
|
askMesssage.add("以获取的上下文作为提纲,组织成一份结构完整的回答报告\n提供的上下文如下:\n" + sbContext.toString());
|
String answer = callWithMessage(askMesssage, rawQuestion != null ? new String[] {question, rawQuestion} : question, tranReq);
|
tranReq.traceLLMDebug("callWithVector回答向量结果:[" + question + "] 时间: " + ((double)(System.currentTimeMillis() - tick) / 1000));
|
return answer;
|
}
|
finally
|
{
|
db.close();
|
}
|
}
|
|
}
|