package com.smtaiserver.smtaiserver.javaai.qwen.agent; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.regex.Pattern; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.dom4j.Document; import org.dom4j.Element; import org.dom4j.Node; import com.smtaiserver.smtaiserver.core.SMTAIServerApp; import com.smtaiserver.smtaiserver.core.SMTAIServerRequest; import com.smtaiserver.smtaiserver.database.SMTDatabase.DBRecord; import com.smtaiserver.smtaiserver.javaai.SMTJavaAIError; import com.smtaiserver.smtaiserver.javaai.llm.core.SMTLLMConnect; import com.smtaiserver.smtaiserver.javaai.qwen.SMTQwenApp; import com.smtservlet.util.Json; import com.smtservlet.util.SMTStatic; public abstract class SMTQwenAgent { protected interface QuestionMatch { public void initInstance(Element xmlRegex); public boolean isMatch(String question); } protected static class QuestionMatchRegex implements QuestionMatch { private Pattern _patMatch; @Override public void initInstance(Element xmlRegex) { _patMatch = Pattern.compile(SMTStatic.trimStrLines(xmlRegex.getText())); } @Override public boolean isMatch(String question) { if(!_patMatch.matcher(question).find()) return false; return true; } } public static class AgentArgument { public String _name; public String _title; public char _type; public boolean _required; public String _question; public String _prompt; public AgentArgument(Element xmlArg) throws Exception { _required = "true".equals(SMTStatic.getXmlAttr(xmlArg, "required", "false")); _name = SMTStatic.getXmlAttr(xmlArg, "name"); _type = SMTStatic.getXmlAttr(xmlArg, "type", "S").charAt(0); _question = SMTStatic.getXmlAttr(xmlArg, "question", _name); _prompt = xmlArg.getText(); } public SMTJavaAIError checkArgumentValid(Json jsonArgs) { Json jsonArg = jsonArgs.safeGetJson(_name); if(jsonArg == null) { if(_required) return new SMTJavaAIError("参数" + _title + "未设置"); return null; } String sValue = jsonArg.asString(); switch(_type) { case 'S': return null; case 'D': try { SMTStatic.toDouble(sValue); } catch(Exception ex) { return new SMTJavaAIError("参数" + _title + "的值" + sValue + "不是浮点数"); } return null; case 'I': try { SMTStatic.toInt(sValue); } catch(Exception ex) { return new SMTJavaAIError("参数" + _title + "的值" + sValue + "不是整数"); } return null; case 'T': try { SMTStatic.toDate(sValue); } catch(Exception ex) { return new SMTJavaAIError("参数" + _title + "的值" + sValue + "不是时间"); } return null; } return null; } } protected boolean _inSupervisor = false; protected String _agentGroup; protected String _agentTitle; protected String _agentGroupType; protected String _agentId; protected String _agentType; protected String _agentPrompt; protected Map _mapName2Argument = new HashMap<>(); protected List _listQuestionMatch = new ArrayList<>(); private static Logger _logger = LogManager.getLogger(SMTQwenApp.class); public abstract SMTJavaAIError callAgents(String jsonPath, Json jsonArgs, SMTLLMConnect llm, String question, SMTAIServerRequest tranReq) throws Exception; public void queryUnknowQuestionList(String keyword, SMTAIServerRequest tranReq) throws Exception { } public boolean inSupervisor() { return _inSupervisor; } public SMTJavaAIError mergeASTList(List listASTSrc, List r_listASTTag) { for(Json jsonAST : listASTSrc) { r_listASTTag.add(jsonAST); } return null; } public boolean isMatchGroupType(String groupType) { if(SMTStatic.isNullOrEmpty(groupType)) return true; if(groupType.equals(_agentGroupType)) return true; return false; } public SMTJavaAIError callSupervisorJson(String agentId, String jsonPath, Json callFunc, SMTAIServerRequest tranReq) throws Exception { return new SMTJavaAIError("agent不支持二次操作"); } public SMTJavaAIError checkArgsValid(Json jsonArgs) { for(AgentArgument agentArgument : _mapName2Argument.values()) { SMTJavaAIError error = agentArgument.checkArgumentValid(jsonArgs); if(error != null) return error; } return null; } public SMTJavaAIError callExtJson(Json jsonArgs, SMTLLMConnect llm, String question, SMTAIServerRequest tranReq) throws Exception { return this.callAgents("", jsonArgs, llm, question, tranReq); } public AgentArgument getAgentArgument(String name) { return _mapName2Argument.get(name); } //agent_id,agent_type,tool_desc,tool_arguments,tool_return,clz_name,clz_arguments public void initInstance(DBRecord rec) throws Exception { try { _inSupervisor = !("Y".equals(rec.getString("inner_call"))); _agentTitle = rec.getString("agent_title"); if(_agentTitle == null) _agentTitle = "业务查询"; _agentType = rec.getString("agent_type"); _agentId = rec.getString("agent_id"); _agentGroup = rec.getString("agent_group"); if(rec.getFieldMap().containsKey("GROUP_TYPE")) _agentGroupType = rec.getString("group_type"); else _agentGroupType = ""; _agentPrompt = getAgentPrompt(rec); _logger.info("workflow : " + _agentId + ":\n" + _agentPrompt); } catch(Exception ex) { throw new Exception("init supervisor agent error : " + _agentId, ex); } } public boolean isAgentGroupMatched(Set setAgentGroup) { if(setAgentGroup == null) return true; if(setAgentGroup.contains(_agentGroup)) return true; return false; } public String getAgentTitle() { return _agentTitle == null ? "" : _agentTitle; } public String getAgentId() { return _agentId; } public String getAgentType() { return _agentType; } public String getAgentPrompt() { return _agentPrompt; } public boolean isQuestionMatched(String question) { for(QuestionMatch match : _listQuestionMatch) { if(!match.isMatch(question)) return false; } return true; } protected String getToolDesc(Element xmlTitle) throws Exception { String toolDesc = SMTStatic.trimStrLines(xmlTitle.getText()); return toolDesc; } protected String getAgentPrompt(DBRecord rec) throws Exception { String agentId = rec.getString("agent_id"); try { Document doc = SMTStatic.convStrToXmlDoc("" + rec.getString("agent_xml") + ""); Element xmlTitle = (Element) doc.selectSingleNode("ROOT/TITLE"); String toolDesc = getToolDesc(xmlTitle); StringBuilder sbPrompt = new StringBuilder(); sbPrompt.append( agentId + "\n" + " 功能:\n" + " " + SMTStatic.indentStrLines(toolDesc, " ") + "\n" + " 参数:\n" + " question:提出的问题\n" ); for(Node nodeArg : doc.selectNodes("ROOT/ARGS/ARG")) { Element xmlArg = (Element)nodeArg; AgentArgument agentArg = new AgentArgument(xmlArg); _mapName2Argument.put(agentArg._name, agentArg); String argName = SMTStatic.getXmlAttr(xmlArg, "name"); String argDesc = SMTStatic.trimStrLines(xmlArg.getText()); sbPrompt.append(" " + argName + ":" + argDesc + "\n"); } sbPrompt.append( " 返回:\n" + " 返回查询结果" ); for(Node nodeMatchRegex : doc.selectNodes("ROOT/MATCHES/REGEX")) { QuestionMatchRegex match = new QuestionMatchRegex(); match.initInstance((Element)nodeMatchRegex); _listQuestionMatch.add(match); } return sbPrompt.toString(); } catch(Exception ex) { throw new Exception("get agent exception : " + agentId, ex); } } protected SMTJavaAIError queryUnknowQuestion(String errorMsg, SMTLLMConnect llm, String question, SMTAIServerRequest tranReq) throws Exception { SMTJavaAIError orgError = new SMTJavaAIError(errorMsg); tranReq.sendChunkedBlock("begin", "无法精确匹配已有的模式,进行模糊匹配"); String prompt = (String) SMTAIServerApp.getApp().getGlobalConfig("prompt.metric_unknow"); String answer = llm.callWithMessage(new String[] {prompt}, question, tranReq).replace("\r", ""); Json jsonAST = SMTStatic.convLLMAnswerToJson(answer, true); tranReq.traceLLMDebug("unknow question Agents:[\n" + SMTStatic.formatJson(jsonAST) + "\n]"); if(!jsonAST.isArray()) return orgError; List jsonCalls = jsonAST.asJsonList(); if(jsonCalls.size() == 0) return orgError; Json jsonCall = jsonCalls.get(0); if(!"agent_unknow".equals(jsonCall.safeGetStr("call", ""))) return orgError; Json jsonArgs = jsonCall.safeGetJson("args"); if(jsonArgs == null) return orgError; if(!jsonArgs.has("keyword")) return orgError; SMTJavaAIError error = SMTAIServerApp.getApp().getQwenAgentManager().callUnknowQuestionAgent(jsonArgs, llm, question, tranReq); return error; } }