package com.smtaiserver.smtaiserver.javaai.llm.core; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; import com.smtaiserver.smtaiserver.core.SMTAIServerApp; import com.smtaiserver.smtaiserver.core.SMTAIServerRequest; import com.smtaiserver.smtaiserver.database.SMTDatabase; import com.smtaiserver.smtaiserver.database.SMTDatabase.DBRecord; import com.smtaiserver.smtaiserver.database.SMTDatabase.DBRecords; import com.smtservlet.util.SMTStatic; public abstract class SMTLLMConnect { private static Pattern _patScope = Pattern.compile("^分数:(\\d+)"); public abstract void close(); public abstract String getVector(String text) throws Exception; public abstract String callWithMessage(Object listSysMsg, Object userMsg, SMTAIServerRequest tranReq) throws Exception; public String callWithVector(String question, String rawQuestion, String vectorSQL, int minScope, SMTAIServerRequest tranReq, List r_contexts) throws Exception { if(vectorSQL == null) { vectorSQL = " SELECT * FROM (" + " SELECT vector_title, vector_message, ?::vector <=> vector_value AS ratio FROM vector_doc WHERE group_id='test1') T" + " ORDER BY ratio LIMIT 4" ; } List sysMesssage = new ArrayList<>(); sysMesssage.add("判定问题和关联信息的匹配度,匹配分数从0分到100分。严格按照输出格式输出内容,不要添加任何不属于分数的内容。\n输出格式:\n分数:80\n"); long tick = System.currentTimeMillis(); Set setContext = new HashSet<>(); String vector = getVector(question); tranReq.traceLLMDebug("callWithVector回答获取向量:[" + ((double)(System.currentTimeMillis() - tick) / 1000) + "秒][" + vector + "]"); SMTDatabase db = SMTAIServerApp.getApp().allocDatabase(); try { tick = System.currentTimeMillis(); DBRecords recs = db.querySQL(vectorSQL, new Object[] { "[" + vector + "]" }); tranReq.traceLLMDebug("callWithVector回答查询向量:[" + recs.getRowCount() + "条][" + ((double)(System.currentTimeMillis() - tick) / 1000) + "秒]"); for(DBRecord rec : recs.getRecords()) { int scope; tick = System.currentTimeMillis(); String ratio = callWithMessage(sysMesssage, "问题:" + question + "\n关联信息:" + rec.getString("vector_title") + "\n", tranReq); tranReq.traceLLMDebug("callWithVector回答对比匹配度:[" + ((double)(System.currentTimeMillis() - tick) / 1000) + "秒][" + ratio + "][" + rec.getString("vector_title") + "]"); try { if(ratio == null) continue; Matcher m = _patScope.matcher(ratio); if(!m.find()) continue; scope = Integer.parseInt(m.group(1)); if(scope < minScope) continue; } catch(Exception ex) { continue; } r_contexts.add(new String[] { "匹配" + SMTStatic.toInt(scope) + "分 : " + rec.getString("vector_title"), rec.getString("vector_message") }); setContext.add(rec.getString("vector_message")); } StringBuilder sbContext = new StringBuilder(); for(String context : setContext) { sbContext.append(context + "\n"); } tick = System.currentTimeMillis(); List askMesssage = new ArrayList<>(); askMesssage.add("以获取的上下文作为提纲,组织成一份结构完整的回答报告\n提供的上下文如下:\n" + sbContext.toString()); String answer = callWithMessage(askMesssage, rawQuestion != null ? new String[] {question, rawQuestion} : question, tranReq); tranReq.traceLLMDebug("callWithVector回答向量结果:[" + question + "] 时间: " + ((double)(System.currentTimeMillis() - tick) / 1000)); return answer; } finally { db.close(); } } }