package com.smtaiserver.smtaiserver.javaai.qwen; import java.util.ArrayList; import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import org.apache.commons.text.similarity.JaccardSimilarity; import java.util.Set; import com.smtaiserver.smtaiserver.core.SMTAIServerApp; import com.smtaiserver.smtaiserver.core.SMTAIServerRequest; import com.smtaiserver.smtaiserver.database.SMTDatabase; import com.smtaiserver.smtaiserver.database.SMTDatabase.DBRecord; import com.smtaiserver.smtaiserver.database.SMTDatabase.DBRecords; import com.smtaiserver.smtaiserver.javaai.SMTJavaAIError; import com.smtaiserver.smtaiserver.javaai.llm.core.SMTLLMConnect; import com.smtaiserver.smtaiserver.javaai.qwen.agent.SMTQwenAgent; import com.smtaiserver.smtaiserver.javaai.qwen.agent.SMTQwenAgentUnknowQuestion; import com.smtservlet.util.Json; import com.smtservlet.util.SMTJsonWriter; import com.smtservlet.util.SMTStatic; public class SMTQwenAgentManager { private LinkedHashMap _mapId2Agent = new LinkedHashMap<>(); private SMTQwenAgentUnknowQuestion _agentUnknown; public SMTQwenAgentManager() { // 添加缺省的unknow的agent _agentUnknown = new SMTQwenAgentUnknowQuestion(); } public SMTQwenAgent getAgentById(String agentId) { return getAgentFromMap(_mapId2Agent, agentId); } private SMTQwenAgent getAgentFromMap(Map mapId2Agent, String agentId) { return "agent_unknow".equals(agentId) ? _agentUnknown : mapId2Agent.get(agentId); } public void addAgent(SMTQwenAgent agent) { _mapId2Agent.put(agent.getAgentId(), agent); } private List getUsefulAgentList(Set setAgentGroup, Map mapId2Agent, String groupType, String question) { List listAgent = new ArrayList<>(); for(SMTQwenAgent agent : mapId2Agent.values()) { if(question != null && !agent.isQuestionMatched(question)) continue; if(!agent.isMatchGroupType(groupType)) continue; if(setAgentGroup != null && !agent.isAgentGroupMatched(setAgentGroup)) continue; listAgent.add(agent); } return listAgent; } public String getSupervisorPrompt(Set setAgentGroup, Map mapId2Agent, String groupType, String question, boolean checkInnerCall) { if(mapId2Agent == null) mapId2Agent = _mapId2Agent; List listAgent = getUsefulAgentList(setAgentGroup, mapId2Agent, groupType, question); StringBuilder sbToolsPrompt = new StringBuilder(); for(SMTQwenAgent agent : listAgent) { if(checkInnerCall && !agent.inSupervisor()) continue; sbToolsPrompt.append(agent.getAgentPrompt() + "\n"); } if(sbToolsPrompt.length() == 0) return ""; String prompt = SMTQwenApp.getApp().getAgentToolPromptTemplate().replace("{{{AGENT_TOOL_DEFINE_LIST}}}", sbToolsPrompt.toString()); return prompt; } public SMTJavaAIError executeUnknowQuestionAgent(String rawKeyword, SMTAIServerRequest tranReq) throws Exception { int maxQuestion = 5; List listKeyword = new ArrayList<>(); // 如果关键字太长,意味着需要做二次切分 if(rawKeyword.length() > 1) { tranReq.sendChunkedBlock("begin", "无法找到匹配的执行器,对问题中关键字进行切分"); SMTLLMConnect llm = SMTAIServerApp.getApp().allocLLMConnect(null); String sJsonKeywords = llm.callWithMessage(new String[] {"请将输入的内容切分成独立的单词,并以json数组形式表示。例如:流量和压力,返回[\"流量\",\"和\",\"压力\"]"}, rawKeyword, tranReq); Json jsonKeywors = SMTStatic.convLLMAnswerToJson(sJsonKeywords, false); if(jsonKeywors != null && jsonKeywors.isArray()) { for(Json jsonKeyword : jsonKeywors.asJsonList()) { listKeyword.add(jsonKeyword.asString()); } } } // 如果关键字存在,则直接使用 if(listKeyword.size() == 0) { tranReq.sendChunkedBlock("begin", "关键字切分失败, 直接用原始问题进行匹配"); listKeyword.add(rawKeyword); } // 首先尝试从agent中找到相关问题 List listAgent = getUsefulAgentList(tranReq.getAgentGroupSet(), _mapId2Agent, tranReq.getCurGroupType(), null); for(String keywrod : listKeyword) { tranReq.sendChunkedBlock("begin", "开始匹配关键字:" + keywrod); for(SMTQwenAgent agent : listAgent) { agent.queryUnknowQuestionList(keywrod, tranReq); } if(tranReq.getContentSampleQuestionCount() >= maxQuestion) break; } // 如果无法从agent中找到足够多问题,则从例子中寻找 if(tranReq.getContentSampleQuestionCount() < maxQuestion) { tranReq.sendChunkedBlock("begin", "适配器中匹配的相关问题不足,从例子中寻找匹配"); Set matchQuestion = new HashSet<>(); Set randQuestion = new HashSet<>(); SMTDatabase db = SMTAIServerApp.getApp().allocDatabase(); JaccardSimilarity jaccardSimilarity = new JaccardSimilarity(); try { String curGroupType = tranReq.getCurGroupType(); Set setAgentGroup = tranReq.getAgentGroupSet(); DBRecords recs = db.querySQL("SELECT A.sample_question, A.sample_match, A.group_id, G.group_type FROM ai_scene_sample A LEFT JOIN ai_scene_group G ON A.group_id=G.group_id", null); for(DBRecord rec : recs.getRecords()) { String groupId = rec.getString("group_id"); // 如果例子不在有权限分组,则忽略 if(setAgentGroup != null && !setAgentGroup.contains(groupId)) continue; String groupType = rec.getString("group_type"); if(curGroupType != null && !curGroupType.equals(groupType)) continue; // 将例子加入随机问题列表 String question = rec.getString("sample_question"); if(randQuestion.size() < maxQuestion) randQuestion.add(question); // 如果例子中包含关键字,则直接加入 for(String keywrod : listKeyword) { // 如果关键字包含在问题中,则直接加入并退出 if(question.indexOf(keywrod) >= 0) { tranReq.sendChunkedBlock("begin", "关键字[" + keywrod + "]包含在里子[" + question + "]中"); matchQuestion.add(question); break; } // 如果存在匹配字段,则按照匹配字段和关键字匹配相似度 String sJsonMatch = rec.getString("sample_match"); if(!SMTStatic.isNullOrEmpty(sJsonMatch)) { for(Json jsonMatch : Json.read(sJsonMatch).asJsonList()) { String sMatch = jsonMatch.asString(); double match = jaccardSimilarity.apply(keywrod, sMatch); if(match > 0.3 || sMatch.indexOf(keywrod) >= 0 || keywrod.indexOf(sMatch) >= 0) { tranReq.sendChunkedBlock("begin", "关键字[" + keywrod + "]和例子[" + sMatch + "]的匹配度为:" + match); matchQuestion.add(question); break; } } } } } // 将找到的数据加入匹配列表 if(matchQuestion.size() < maxQuestion) { tranReq.sendChunkedBlock("begin", "匹配到的例子个数不足,加入随机匹配的例子"); for(String randQ : randQuestion) { if(matchQuestion.size() >= maxQuestion) break; tranReq.sendChunkedBlock("begin", "加入随机匹配例子:" + randQ); matchQuestion.add(randQ); } } // 将匹配列表数据加入返回 for(String matchQ : matchQuestion) { tranReq.addContentSampleQuestion(matchQ); } // 只保留限定问题 tranReq.sendChunkedBlock("begin", "限定匹配后的问题条目为" + maxQuestion + "条"); tranReq.limitContentSampleQuestion(maxQuestion); } finally { db.close(); } } { SMTLLMConnect llm = SMTAIServerApp.getApp().allocLLMConnect(null); String result = llm.callWithMessage(null, tranReq.getAIQuestion(), tranReq); SMTJsonWriter jsonWrResult = tranReq.getResultJsonWr(); jsonWrResult.addKeyValue("answer_type", "knowledge"); jsonWrResult.beginArray("knowledge"); { jsonWrResult.beginMap(null); { jsonWrResult.addKeyValue("answer", result); } jsonWrResult.endMap(); } jsonWrResult.endArray(); tranReq.sendChunkedResultBlock(); } return new SMTJavaAIError(" "); } public SMTJavaAIError callUnknowQuestionAgent(Json jsonArgs, SMTLLMConnect llm, String question, SMTAIServerRequest tranReq) throws Exception { SMTJavaAIError error = _agentUnknown.callAgents("", jsonArgs, llm, question, tranReq); return error; } public SMTJavaAIError callExtJson(SMTLLMConnect llm, String question, Json jsonCallExtList, SMTJsonWriter jsonWr, SMTAIServerRequest tranReq) throws Exception { for(Json jsonCallExt : jsonCallExtList.asJsonList()) { String agentId = jsonCallExt.getJson("call").asString(); SMTQwenAgent agent = getAgentFromMap(_mapId2Agent, agentId); if(agent == null) return new SMTJavaAIError("当前问题未在AI的知识范围内,请重新提问"); tranReq.traceLLMDebug("匹配到aget :" + agentId); SMTJavaAIError error = agent.callExtJson(jsonCallExt.getJson("args"), llm, question, tranReq); if(error != null) return error; } return null; } public SMTJavaAIError callSupervisorJson(String agentId, String jsonPath, Json callFunc, SMTJsonWriter jsonWr, SMTAIServerRequest tranReq) throws Exception { try { Map mapId2Agent = _mapId2Agent; SMTQwenAgent agent = getAgentFromMap(mapId2Agent, agentId); if(agent == null) return new SMTJavaAIError("未发现可使用的agent:" + agentId); SMTJavaAIError error = agent.callSupervisorJson(agentId, jsonPath, callFunc, tranReq); if(error != null) return error; } finally { tranReq.closeQuestionResource(); } return null; } public SMTJavaAIError callAgents(SMTLLMConnect llm, Set setAgentGroup, String groupType, String question, Map mapId2Agent, SMTAIServerRequest tranReq, boolean checkInnerCall) throws Exception { SMTJsonWriter jsonWrResult = tranReq.getResultJsonWr(); if(mapId2Agent == null) mapId2Agent = _mapId2Agent; long tick = System.currentTimeMillis(); String sysPrompt = getSupervisorPrompt(setAgentGroup, mapId2Agent, groupType, question, checkInnerCall); List jsonASTList; String answer; if(SMTStatic.isNullOrEmpty(sysPrompt)) { jsonASTList = new ArrayList<>(); } else { tranReq.sendChunkedBlock("begin", "意图分析中..."); answer = llm.callWithMessage(new String[] {sysPrompt}, question, tranReq).replace("\r", ""); Json jsonASTListO = SMTStatic.convLLMAnswerToJson(answer, false); tranReq.traceLLMDebug("question : " + question); tranReq.traceLLMDebug("callAgents:[" + ((double)(System.currentTimeMillis() - tick) / 1000) + "秒] [\n" + SMTStatic.formatJson(jsonASTListO) + "\n]"); // 判断解析的意图参数格式是否正确 if(jsonASTListO.isArray()) jsonASTList = jsonASTListO.asJsonList(); else if(jsonASTListO.isObject()) jsonASTList = Json.array(jsonASTListO).asJsonList(); else return new SMTJavaAIError("解析问题失败"); // 将语法树按照agent自动分类 LinkedHashMap> mapAgent2ASTList = new LinkedHashMap<>(); for(int jsonIdx = 0; jsonIdx < jsonASTList.size(); jsonIdx ++) { Json jsonAST = jsonASTList.get(jsonIdx); if(!jsonAST.has("call")) { String error = jsonAST.safeGetStr("error", null); if(!SMTStatic.isNullOrEmpty(error)) return new SMTJavaAIError(error); } String agentId = jsonAST.getJson("call").asString(); SMTQwenAgent agent = getAgentFromMap(mapId2Agent, agentId); if(agent == null) { String error = jsonAST.safeGetStr("error", null); if(!SMTStatic.isNullOrEmpty(error)) return new SMTJavaAIError(error); else return new SMTJavaAIError("未发现可使用的agent:" + agentId); } List listAST = mapAgent2ASTList.get(agent); if(listAST == null) { listAST = new ArrayList<>(); mapAgent2ASTList.put(agent, listAST); } listAST.add(jsonAST); } // 在agent中将其自动分类 jsonASTList = new ArrayList<>(); for(Entry> entry : mapAgent2ASTList.entrySet()) { SMTJavaAIError error = entry.getKey().mergeASTList(entry.getValue(), jsonASTList); if(error != null) return error; } // 输出合并后结果 StringBuilder sbMergeJson = new StringBuilder(); sbMergeJson.append("merge AST :[\n"); for(Json jsonAST : jsonASTList) { sbMergeJson.append(SMTStatic.formatJson(jsonAST) + "\n"); } sbMergeJson.append("]\n"); tranReq.traceLLMDebug(sbMergeJson.toString()); // 将解析后的作为返回值 tranReq.appendSupervisorJson(jsonASTList); } // 如果无结果解析,则调用知识库agent if(jsonASTList.size() == 0) { tranReq.sendChunkedBlock("begin", "由大语音模型回答,正在处理。。。"); answer = llm.callWithMessage(null, question, tranReq); jsonWrResult.addKeyValue("json_ok", true); jsonWrResult.addKeyValue("answer_type", "knowledge"); jsonWrResult.beginArray("knowledge"); { jsonWrResult.beginMap(null); { jsonWrResult.addKeyValue("answer", answer); } jsonWrResult.endMap(); } jsonWrResult.endArray(); tranReq.sendChunkedBlock("end", "大模型回答了您的问题"); return null; } // 解析脚本 SMTJavaAIError aiError = null; List listAIError = new ArrayList<>(); StringBuilder sbQuestions = new StringBuilder(); tranReq.sendChunkedBlock("end", "意图分析完成"); // 执行脚本 for(int jsonIdx = 0; jsonIdx < jsonASTList.size(); jsonIdx ++) { Json jsonAST = jsonASTList.get(jsonIdx); // 前面已经判断过,此处不用再判断 String agentId = jsonAST.getJson("call").asString(); SMTQwenAgent agent = getAgentFromMap(mapId2Agent, agentId); Json jsonArgs = jsonAST.safeGetJson("args"); if(jsonArgs == null) jsonArgs = Json.object(); String subQuestion; // 如果只有一个问题,则将用户的提问作为问题原型,如果包含多个问题则只能以解析出来的问题作为问题 if(jsonASTList.size() == 1) { subQuestion = question; } else { subQuestion = jsonArgs.safeGetStr("question", question); } if((aiError = agent.checkArgsValid(jsonArgs)) != null) { listAIError.add(aiError); } else { tranReq.sendChunkedBlock("begin", "思考如何执行:" + agent.getAgentTitle()); aiError = agent.callAgents("#" + SMTStatic.toString(jsonIdx) + "/", jsonArgs, llm, subQuestion, tranReq); //tranReq.sendChunkedBlock("end", agent.getAgentTitle() + "完成"); if(aiError != null) { String perfix = "__call__"; if(aiError.getErrorMsg().startsWith(perfix)) { agentId = aiError.getErrorMsg().substring(perfix.length()); agent = getAgentFromMap(mapId2Agent, agentId); if(agent != null) { aiError = agent.callAgents("#" + SMTStatic.toString(jsonIdx) + "/", jsonArgs, llm, subQuestion, tranReq); if(aiError != null) listAIError.add(aiError); } else { listAIError.add(new SMTJavaAIError("未发现" + agentId + " ")); } } else { listAIError.add(aiError); } } } sbQuestions.append(question + " "); } if(listAIError.size() == 0) return null; return listAIError.get(0); } }