package com.smtaiserver.smtaiserver.javaai.qwen;
|
|
import java.util.ArrayList;
|
import java.util.HashSet;
|
import java.util.LinkedHashMap;
|
import java.util.List;
|
import java.util.Map;
|
import java.util.Map.Entry;
|
|
import org.apache.commons.text.similarity.JaccardSimilarity;
|
|
import java.util.Set;
|
|
import com.smtaiserver.smtaiserver.core.SMTAIServerApp;
|
import com.smtaiserver.smtaiserver.core.SMTAIServerRequest;
|
import com.smtaiserver.smtaiserver.database.SMTDatabase;
|
import com.smtaiserver.smtaiserver.database.SMTDatabase.DBRecord;
|
import com.smtaiserver.smtaiserver.database.SMTDatabase.DBRecords;
|
import com.smtaiserver.smtaiserver.javaai.SMTJavaAIError;
|
import com.smtaiserver.smtaiserver.javaai.llm.core.SMTLLMConnect;
|
import com.smtaiserver.smtaiserver.javaai.qwen.agent.SMTQwenAgent;
|
import com.smtaiserver.smtaiserver.javaai.qwen.agent.SMTQwenAgentUnknowQuestion;
|
import com.smtservlet.util.Json;
|
import com.smtservlet.util.SMTJsonWriter;
|
import com.smtservlet.util.SMTStatic;
|
|
public class SMTQwenAgentManager
|
{
|
private LinkedHashMap<String, SMTQwenAgent> _mapId2Agent = new LinkedHashMap<>();
|
private SMTQwenAgentUnknowQuestion _agentUnknown;
|
|
public SMTQwenAgentManager()
|
{
|
// 添加缺省的unknow的agent
|
_agentUnknown = new SMTQwenAgentUnknowQuestion();
|
}
|
|
public SMTQwenAgent getAgentById(String agentId)
|
{
|
return getAgentFromMap(_mapId2Agent, agentId);
|
}
|
|
private SMTQwenAgent getAgentFromMap(Map<String, SMTQwenAgent> mapId2Agent, String agentId)
|
{
|
return "agent_unknow".equals(agentId) ? _agentUnknown : mapId2Agent.get(agentId);
|
}
|
|
public void addAgent(SMTQwenAgent agent)
|
{
|
_mapId2Agent.put(agent.getAgentId(), agent);
|
}
|
|
private List<SMTQwenAgent> getUsefulAgentList(Set<String> setAgentGroup, Map<String, SMTQwenAgent> mapId2Agent, String groupType, String question)
|
{
|
List<SMTQwenAgent> listAgent = new ArrayList<>();
|
|
for(SMTQwenAgent agent : mapId2Agent.values())
|
{
|
if(question != null && !agent.isQuestionMatched(question))
|
continue;
|
|
if(!agent.isMatchGroupType(groupType))
|
continue;
|
|
if(setAgentGroup != null && !agent.isAgentGroupMatched(setAgentGroup))
|
continue;
|
|
listAgent.add(agent);
|
}
|
|
return listAgent;
|
}
|
|
public String getSupervisorPrompt(Set<String> setAgentGroup, Map<String, SMTQwenAgent> mapId2Agent, String groupType, String question, boolean checkInnerCall)
|
{
|
if(mapId2Agent == null)
|
mapId2Agent = _mapId2Agent;
|
|
List<SMTQwenAgent> listAgent = getUsefulAgentList(setAgentGroup, mapId2Agent, groupType, question);
|
|
StringBuilder sbToolsPrompt = new StringBuilder();
|
for(SMTQwenAgent agent : listAgent)
|
{
|
if(checkInnerCall && !agent.inSupervisor())
|
continue;
|
|
sbToolsPrompt.append(agent.getAgentPrompt() + "\n");
|
}
|
|
if(sbToolsPrompt.length() == 0)
|
return "";
|
|
String prompt = SMTQwenApp.getApp().getAgentToolPromptTemplate().replace("{{{AGENT_TOOL_DEFINE_LIST}}}", sbToolsPrompt.toString());
|
|
return prompt;
|
}
|
|
public SMTJavaAIError executeUnknowQuestionAgent(String rawKeyword, SMTAIServerRequest tranReq) throws Exception
|
{
|
int maxQuestion = 5;
|
List<String> listKeyword = new ArrayList<>();
|
|
// 如果关键字太长,意味着需要做二次切分
|
if(rawKeyword.length() > 1)
|
{
|
tranReq.sendChunkedBlock("begin", "无法找到匹配的执行器,对问题中关键字进行切分");
|
SMTLLMConnect llm = SMTAIServerApp.getApp().allocLLMConnect(null);
|
String sJsonKeywords = llm.callWithMessage(new String[] {"请将输入的内容切分成独立的单词,并以json数组形式表示。例如:流量和压力,返回[\"流量\",\"和\",\"压力\"]"}, rawKeyword, tranReq);
|
Json jsonKeywors = SMTStatic.convLLMAnswerToJson(sJsonKeywords, false);
|
if(jsonKeywors != null && jsonKeywors.isArray())
|
{
|
for(Json jsonKeyword : jsonKeywors.asJsonList())
|
{
|
listKeyword.add(jsonKeyword.asString());
|
}
|
}
|
}
|
|
// 如果关键字存在,则直接使用
|
if(listKeyword.size() == 0)
|
{
|
tranReq.sendChunkedBlock("begin", "关键字切分失败, 直接用原始问题进行匹配");
|
listKeyword.add(rawKeyword);
|
}
|
|
// 首先尝试从agent中找到相关问题
|
List<SMTQwenAgent> listAgent = getUsefulAgentList(tranReq.getAgentGroupSet(), _mapId2Agent, tranReq.getCurGroupType(), null);
|
for(String keywrod : listKeyword)
|
{
|
tranReq.sendChunkedBlock("begin", "开始匹配关键字:" + keywrod);
|
for(SMTQwenAgent agent : listAgent)
|
{
|
agent.queryUnknowQuestionList(keywrod, tranReq);
|
}
|
|
if(tranReq.getContentSampleQuestionCount() >= maxQuestion)
|
break;
|
}
|
|
// 如果无法从agent中找到足够多问题,则从例子中寻找
|
if(tranReq.getContentSampleQuestionCount() < maxQuestion)
|
{
|
tranReq.sendChunkedBlock("begin", "适配器中匹配的相关问题不足,从例子中寻找匹配");
|
Set<String> matchQuestion = new HashSet<>();
|
Set<String> randQuestion = new HashSet<>();
|
SMTDatabase db = SMTAIServerApp.getApp().allocDatabase();
|
JaccardSimilarity jaccardSimilarity = new JaccardSimilarity();
|
try
|
{
|
String curGroupType = tranReq.getCurGroupType();
|
Set<String> setAgentGroup = tranReq.getAgentGroupSet();
|
DBRecords recs = db.querySQL("SELECT A.sample_question, A.sample_match, A.group_id, G.group_type FROM ai_scene_sample A LEFT JOIN ai_scene_group G ON A.group_id=G.group_id", null);
|
for(DBRecord rec : recs.getRecords())
|
{
|
String groupId = rec.getString("group_id");
|
|
// 如果例子不在有权限分组,则忽略
|
if(setAgentGroup != null && !setAgentGroup.contains(groupId))
|
continue;
|
|
String groupType = rec.getString("group_type");
|
if(curGroupType != null && !curGroupType.equals(groupType))
|
continue;
|
|
// 将例子加入随机问题列表
|
String question = rec.getString("sample_question");
|
if(randQuestion.size() < maxQuestion)
|
randQuestion.add(question);
|
|
// 如果例子中包含关键字,则直接加入
|
for(String keywrod : listKeyword)
|
{
|
// 如果关键字包含在问题中,则直接加入并退出
|
if(question.indexOf(keywrod) >= 0)
|
{
|
tranReq.sendChunkedBlock("begin", "关键字[" + keywrod + "]包含在里子[" + question + "]中");
|
matchQuestion.add(question);
|
break;
|
}
|
|
// 如果存在匹配字段,则按照匹配字段和关键字匹配相似度
|
String sJsonMatch = rec.getString("sample_match");
|
if(!SMTStatic.isNullOrEmpty(sJsonMatch))
|
{
|
for(Json jsonMatch : Json.read(sJsonMatch).asJsonList())
|
{
|
String sMatch = jsonMatch.asString();
|
double match = jaccardSimilarity.apply(keywrod, sMatch);
|
if(match > 0.3 || sMatch.indexOf(keywrod) >= 0 || keywrod.indexOf(sMatch) >= 0)
|
{
|
tranReq.sendChunkedBlock("begin", "关键字[" + keywrod + "]和例子[" + sMatch + "]的匹配度为:" + match);
|
matchQuestion.add(question);
|
break;
|
}
|
}
|
}
|
|
}
|
}
|
|
// 将找到的数据加入匹配列表
|
if(matchQuestion.size() < maxQuestion)
|
{
|
tranReq.sendChunkedBlock("begin", "匹配到的例子个数不足,加入随机匹配的例子");
|
for(String randQ : randQuestion)
|
{
|
if(matchQuestion.size() >= maxQuestion)
|
break;
|
|
tranReq.sendChunkedBlock("begin", "加入随机匹配例子:" + randQ);
|
matchQuestion.add(randQ);
|
}
|
}
|
|
// 将匹配列表数据加入返回
|
for(String matchQ : matchQuestion)
|
{
|
tranReq.addContentSampleQuestion(matchQ);
|
}
|
|
// 只保留限定问题
|
tranReq.sendChunkedBlock("begin", "限定匹配后的问题条目为" + maxQuestion + "条");
|
tranReq.limitContentSampleQuestion(maxQuestion);
|
}
|
finally
|
{
|
db.close();
|
}
|
}
|
|
{
|
SMTLLMConnect llm = SMTAIServerApp.getApp().allocLLMConnect(null);
|
String result = llm.callWithMessage(null, tranReq.getAIQuestion(), tranReq);
|
|
SMTJsonWriter jsonWrResult = tranReq.getResultJsonWr();
|
jsonWrResult.addKeyValue("answer_type", "knowledge");
|
jsonWrResult.beginArray("knowledge");
|
{
|
jsonWrResult.beginMap(null);
|
{
|
jsonWrResult.addKeyValue("answer", result);
|
}
|
jsonWrResult.endMap();
|
}
|
jsonWrResult.endArray();
|
tranReq.sendChunkedResultBlock();
|
}
|
|
return new SMTJavaAIError(" ");
|
}
|
|
public SMTJavaAIError callUnknowQuestionAgent(Json jsonArgs, SMTLLMConnect llm, String question, SMTAIServerRequest tranReq) throws Exception
|
{
|
SMTJavaAIError error = _agentUnknown.callAgents("", jsonArgs, llm, question, tranReq);
|
return error;
|
}
|
|
public SMTJavaAIError callExtJson(SMTLLMConnect llm, String question, Json jsonCallExtList, SMTJsonWriter jsonWr, SMTAIServerRequest tranReq) throws Exception
|
{
|
for(Json jsonCallExt : jsonCallExtList.asJsonList())
|
{
|
String agentId = jsonCallExt.getJson("call").asString();
|
SMTQwenAgent agent = getAgentFromMap(_mapId2Agent, agentId);
|
if(agent == null)
|
return new SMTJavaAIError("当前问题未在AI的知识范围内,请重新提问");
|
tranReq.traceLLMDebug("匹配到aget :" + agentId);
|
|
SMTJavaAIError error = agent.callExtJson(jsonCallExt.getJson("args"), llm, question, tranReq);
|
if(error != null)
|
return error;
|
}
|
return null;
|
}
|
|
public SMTJavaAIError callSupervisorJson(String agentId, String jsonPath, Json callFunc, SMTJsonWriter jsonWr, SMTAIServerRequest tranReq) throws Exception
|
{
|
try
|
{
|
Map<String, SMTQwenAgent> mapId2Agent = _mapId2Agent;
|
|
SMTQwenAgent agent = getAgentFromMap(mapId2Agent, agentId);
|
if(agent == null)
|
return new SMTJavaAIError("未发现可使用的agent:" + agentId);
|
|
SMTJavaAIError error = agent.callSupervisorJson(agentId, jsonPath, callFunc, tranReq);
|
if(error != null)
|
return error;
|
}
|
finally
|
{
|
tranReq.closeQuestionResource();
|
}
|
|
return null;
|
}
|
|
|
public SMTJavaAIError callAgents(SMTLLMConnect llm, Set<String> setAgentGroup, String groupType, String question, Map<String, SMTQwenAgent> mapId2Agent, SMTAIServerRequest tranReq, boolean checkInnerCall) throws Exception
|
{
|
SMTJsonWriter jsonWrResult = tranReq.getResultJsonWr();
|
|
if(mapId2Agent == null)
|
mapId2Agent = _mapId2Agent;
|
|
long tick = System.currentTimeMillis();
|
String sysPrompt = getSupervisorPrompt(setAgentGroup, mapId2Agent, groupType, question, checkInnerCall);
|
List<Json> jsonASTList;
|
String answer;
|
|
if(SMTStatic.isNullOrEmpty(sysPrompt))
|
{
|
jsonASTList = new ArrayList<>();
|
}
|
else
|
{
|
tranReq.sendChunkedBlock("begin", "意图分析中...");
|
answer = llm.callWithMessage(new String[] {sysPrompt}, question, tranReq).replace("\r", "");
|
Json jsonASTListO = SMTStatic.convLLMAnswerToJson(answer, false);
|
|
tranReq.traceLLMDebug("question : " + question);
|
tranReq.traceLLMDebug("callAgents:[" + ((double)(System.currentTimeMillis() - tick) / 1000) + "秒] [\n" + SMTStatic.formatJson(jsonASTListO) + "\n]");
|
|
|
// 判断解析的意图参数格式是否正确
|
if(jsonASTListO.isArray())
|
jsonASTList = jsonASTListO.asJsonList();
|
else if(jsonASTListO.isObject())
|
jsonASTList = Json.array(jsonASTListO).asJsonList();
|
else
|
return new SMTJavaAIError("解析问题失败");
|
|
// 将语法树按照agent自动分类
|
LinkedHashMap<SMTQwenAgent, List<Json>> mapAgent2ASTList = new LinkedHashMap<>();
|
for(int jsonIdx = 0; jsonIdx < jsonASTList.size(); jsonIdx ++)
|
{
|
Json jsonAST = jsonASTList.get(jsonIdx);
|
if(!jsonAST.has("call"))
|
{
|
String error = jsonAST.safeGetStr("error", null);
|
if(!SMTStatic.isNullOrEmpty(error))
|
return new SMTJavaAIError(error);
|
}
|
String agentId = jsonAST.getJson("call").asString();
|
SMTQwenAgent agent = getAgentFromMap(mapId2Agent, agentId);
|
if(agent == null)
|
{
|
String error = jsonAST.safeGetStr("error", null);
|
if(!SMTStatic.isNullOrEmpty(error))
|
return new SMTJavaAIError(error);
|
else
|
return new SMTJavaAIError("未发现可使用的agent:" + agentId);
|
}
|
|
List<Json> listAST = mapAgent2ASTList.get(agent);
|
if(listAST == null)
|
{
|
listAST = new ArrayList<>();
|
mapAgent2ASTList.put(agent, listAST);
|
}
|
listAST.add(jsonAST);
|
}
|
|
// 在agent中将其自动分类
|
jsonASTList = new ArrayList<>();
|
for(Entry<SMTQwenAgent, List<Json>> entry : mapAgent2ASTList.entrySet())
|
{
|
SMTJavaAIError error = entry.getKey().mergeASTList(entry.getValue(), jsonASTList);
|
if(error != null)
|
return error;
|
}
|
|
// 输出合并后结果
|
StringBuilder sbMergeJson = new StringBuilder();
|
sbMergeJson.append("merge AST :[\n");
|
for(Json jsonAST : jsonASTList)
|
{
|
sbMergeJson.append(SMTStatic.formatJson(jsonAST) + "\n");
|
}
|
sbMergeJson.append("]\n");
|
tranReq.traceLLMDebug(sbMergeJson.toString());
|
|
// 将解析后的作为返回值
|
tranReq.appendSupervisorJson(jsonASTList);
|
}
|
|
// 如果无结果解析,则调用知识库agent
|
if(jsonASTList.size() == 0)
|
{
|
tranReq.sendChunkedBlock("begin", "由大语音模型回答,正在处理。。。");
|
|
answer = llm.callWithMessage(null, question, tranReq);
|
jsonWrResult.addKeyValue("json_ok", true);
|
jsonWrResult.addKeyValue("answer_type", "knowledge");
|
jsonWrResult.beginArray("knowledge");
|
{
|
jsonWrResult.beginMap(null);
|
{
|
jsonWrResult.addKeyValue("answer", answer);
|
}
|
jsonWrResult.endMap();
|
}
|
jsonWrResult.endArray();
|
tranReq.sendChunkedBlock("end", "大模型回答了您的问题");
|
return null;
|
}
|
|
// 解析脚本
|
SMTJavaAIError aiError = null;
|
List<SMTJavaAIError> listAIError = new ArrayList<>();
|
StringBuilder sbQuestions = new StringBuilder();
|
|
tranReq.sendChunkedBlock("end", "意图分析完成");
|
|
// 执行脚本
|
for(int jsonIdx = 0; jsonIdx < jsonASTList.size(); jsonIdx ++)
|
{
|
Json jsonAST = jsonASTList.get(jsonIdx);
|
|
// 前面已经判断过,此处不用再判断
|
String agentId = jsonAST.getJson("call").asString();
|
SMTQwenAgent agent = getAgentFromMap(mapId2Agent, agentId);
|
|
Json jsonArgs = jsonAST.safeGetJson("args");
|
if(jsonArgs == null)
|
jsonArgs = Json.object();
|
|
String subQuestion;
|
|
// 如果只有一个问题,则将用户的提问作为问题原型,如果包含多个问题则只能以解析出来的问题作为问题
|
if(jsonASTList.size() == 1)
|
{
|
subQuestion = question;
|
}
|
else
|
{
|
subQuestion = jsonArgs.safeGetStr("question", question);
|
}
|
if((aiError = agent.checkArgsValid(jsonArgs)) != null)
|
{
|
listAIError.add(aiError);
|
}
|
else
|
{
|
tranReq.sendChunkedBlock("begin", "思考如何执行:" + agent.getAgentTitle());
|
|
aiError = agent.callAgents("#" + SMTStatic.toString(jsonIdx) + "/", jsonArgs, llm, subQuestion, tranReq);
|
|
//tranReq.sendChunkedBlock("end", agent.getAgentTitle() + "完成");
|
|
if(aiError != null)
|
{
|
String perfix = "__call__";
|
if(aiError.getErrorMsg().startsWith(perfix))
|
{
|
agentId = aiError.getErrorMsg().substring(perfix.length());
|
agent = getAgentFromMap(mapId2Agent, agentId);
|
if(agent != null)
|
{
|
aiError = agent.callAgents("#" + SMTStatic.toString(jsonIdx) + "/", jsonArgs, llm, subQuestion, tranReq);
|
if(aiError != null)
|
listAIError.add(aiError);
|
}
|
else
|
{
|
listAIError.add(new SMTJavaAIError("未发现" + agentId + " "));
|
}
|
}
|
else
|
{
|
listAIError.add(aiError);
|
}
|
}
|
}
|
|
sbQuestions.append(question + " ");
|
}
|
|
if(listAIError.size() == 0)
|
return null;
|
|
return listAIError.get(0);
|
}
|
}
|
|
|