package com.smtaiserver.smtaiserver.javaai.qwen.agent;
|
|
import java.util.ArrayList;
|
import java.util.HashMap;
|
import java.util.List;
|
import java.util.Map;
|
import java.util.Set;
|
import java.util.regex.Pattern;
|
|
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.Logger;
|
import org.dom4j.Document;
|
import org.dom4j.Element;
|
import org.dom4j.Node;
|
|
import com.smtaiserver.smtaiserver.core.SMTAIServerApp;
|
import com.smtaiserver.smtaiserver.core.SMTAIServerRequest;
|
import com.smtaiserver.smtaiserver.database.SMTDatabase.DBRecord;
|
import com.smtaiserver.smtaiserver.javaai.SMTJavaAIError;
|
import com.smtaiserver.smtaiserver.javaai.llm.core.SMTLLMConnect;
|
import com.smtaiserver.smtaiserver.javaai.qwen.SMTQwenApp;
|
import com.smtservlet.util.Json;
|
import com.smtservlet.util.SMTStatic;
|
|
public abstract class SMTQwenAgent
|
{
|
protected interface QuestionMatch
|
{
|
public void initInstance(Element xmlRegex);
|
public boolean isMatch(String question);
|
}
|
|
protected static class QuestionMatchRegex implements QuestionMatch
|
{
|
private Pattern _patMatch;
|
|
@Override
|
public void initInstance(Element xmlRegex)
|
{
|
_patMatch = Pattern.compile(SMTStatic.trimStrLines(xmlRegex.getText()));
|
}
|
|
@Override
|
public boolean isMatch(String question)
|
{
|
if(!_patMatch.matcher(question).find())
|
return false;
|
return true;
|
}
|
|
}
|
|
public static class AgentArgument
|
{
|
public String _name;
|
public String _title;
|
public char _type;
|
public boolean _required;
|
public String _question;
|
public String _prompt;
|
|
public AgentArgument(Element xmlArg) throws Exception
|
{
|
_required = "true".equals(SMTStatic.getXmlAttr(xmlArg, "required", "false"));
|
_name = SMTStatic.getXmlAttr(xmlArg, "name");
|
_type = SMTStatic.getXmlAttr(xmlArg, "type", "S").charAt(0);
|
_question = SMTStatic.getXmlAttr(xmlArg, "question", _name);
|
_prompt = xmlArg.getText();
|
}
|
|
public SMTJavaAIError checkArgumentValid(Json jsonArgs)
|
{
|
Json jsonArg = jsonArgs.safeGetJson(_name);
|
if(jsonArg == null)
|
{
|
if(_required)
|
return new SMTJavaAIError("参数" + _title + "未设置");
|
|
return null;
|
}
|
|
String sValue = jsonArg.asString();
|
switch(_type)
|
{
|
case 'S':
|
return null;
|
|
case 'D':
|
try
|
{
|
SMTStatic.toDouble(sValue);
|
}
|
catch(Exception ex)
|
{
|
return new SMTJavaAIError("参数" + _title + "的值" + sValue + "不是浮点数");
|
}
|
return null;
|
|
case 'I':
|
try
|
{
|
SMTStatic.toInt(sValue);
|
}
|
catch(Exception ex)
|
{
|
return new SMTJavaAIError("参数" + _title + "的值" + sValue + "不是整数");
|
}
|
return null;
|
|
case 'T':
|
try
|
{
|
SMTStatic.toDate(sValue);
|
}
|
catch(Exception ex)
|
{
|
return new SMTJavaAIError("参数" + _title + "的值" + sValue + "不是时间");
|
}
|
return null;
|
}
|
|
return null;
|
}
|
}
|
|
protected boolean _inSupervisor = false;
|
protected String _agentGroup;
|
protected String _agentTitle;
|
protected String _agentGroupType;
|
protected String _agentId;
|
protected String _agentType;
|
protected String _agentPrompt;
|
protected Map<String, AgentArgument> _mapName2Argument = new HashMap<>();
|
protected List<QuestionMatch> _listQuestionMatch = new ArrayList<>();
|
private static Logger _logger = LogManager.getLogger(SMTQwenApp.class);
|
|
public abstract SMTJavaAIError callAgents(String jsonPath, Json jsonArgs, SMTLLMConnect llm, String question, SMTAIServerRequest tranReq) throws Exception;
|
|
|
public void queryUnknowQuestionList(String keyword, SMTAIServerRequest tranReq) throws Exception
|
{
|
}
|
|
public boolean inSupervisor()
|
{
|
return _inSupervisor;
|
}
|
|
public SMTJavaAIError mergeASTList(List<Json> listASTSrc, List<Json> r_listASTTag)
|
{
|
for(Json jsonAST : listASTSrc)
|
{
|
r_listASTTag.add(jsonAST);
|
}
|
|
return null;
|
}
|
|
public boolean isMatchGroupType(String groupType)
|
{
|
if(SMTStatic.isNullOrEmpty(groupType))
|
return true;
|
if(groupType.equals(_agentGroupType))
|
return true;
|
|
return false;
|
}
|
|
public SMTJavaAIError callSupervisorJson(String agentId, String jsonPath, Json callFunc, SMTAIServerRequest tranReq) throws Exception
|
{
|
return new SMTJavaAIError("agent不支持二次操作");
|
}
|
|
public SMTJavaAIError checkArgsValid(Json jsonArgs)
|
{
|
for(AgentArgument agentArgument : _mapName2Argument.values())
|
{
|
SMTJavaAIError error = agentArgument.checkArgumentValid(jsonArgs);
|
if(error != null)
|
return error;
|
}
|
return null;
|
}
|
|
public SMTJavaAIError callExtJson(Json jsonArgs, SMTLLMConnect llm, String question, SMTAIServerRequest tranReq) throws Exception
|
{
|
return this.callAgents("", jsonArgs, llm, question, tranReq);
|
}
|
|
public AgentArgument getAgentArgument(String name)
|
{
|
return _mapName2Argument.get(name);
|
}
|
|
//agent_id,agent_type,tool_desc,tool_arguments,tool_return,clz_name,clz_arguments
|
public void initInstance(DBRecord rec) throws Exception
|
{
|
|
try
|
{
|
_inSupervisor = !("Y".equals(rec.getString("inner_call")));
|
_agentTitle = rec.getString("agent_title");
|
if(_agentTitle == null)
|
_agentTitle = "业务查询";
|
_agentType = rec.getString("agent_type");
|
_agentId = rec.getString("agent_id");
|
_agentGroup = rec.getString("agent_group");
|
if(rec.getFieldMap().containsKey("GROUP_TYPE"))
|
_agentGroupType = rec.getString("group_type");
|
else
|
_agentGroupType = "";
|
|
_agentPrompt = getAgentPrompt(rec);
|
|
_logger.info("workflow : " + _agentId + ":\n" + _agentPrompt);
|
}
|
catch(Exception ex)
|
{
|
throw new Exception("init supervisor agent error : " + _agentId, ex);
|
}
|
|
}
|
|
public boolean isAgentGroupMatched(Set<String> setAgentGroup)
|
{
|
if(setAgentGroup == null)
|
return true;
|
if(setAgentGroup.contains(_agentGroup))
|
return true;
|
|
return false;
|
}
|
|
public String getAgentTitle()
|
{
|
return _agentTitle == null ? "" : _agentTitle;
|
}
|
|
public String getAgentId()
|
{
|
return _agentId;
|
}
|
|
public String getAgentType()
|
{
|
return _agentType;
|
}
|
|
public String getAgentPrompt()
|
{
|
return _agentPrompt;
|
}
|
|
public boolean isQuestionMatched(String question)
|
{
|
for(QuestionMatch match : _listQuestionMatch)
|
{
|
if(!match.isMatch(question))
|
return false;
|
}
|
|
return true;
|
}
|
|
protected String getToolDesc(Element xmlTitle) throws Exception
|
{
|
String toolDesc = SMTStatic.trimStrLines(xmlTitle.getText());
|
|
return toolDesc;
|
}
|
|
protected String getAgentPrompt(DBRecord rec) throws Exception
|
{
|
String agentId = rec.getString("agent_id");
|
try
|
{
|
Document doc = SMTStatic.convStrToXmlDoc("<ROOT>" + rec.getString("agent_xml") + "</ROOT>");
|
Element xmlTitle = (Element) doc.selectSingleNode("ROOT/TITLE");
|
String toolDesc = getToolDesc(xmlTitle);
|
|
StringBuilder sbPrompt = new StringBuilder();
|
sbPrompt.append(
|
agentId + "\n"
|
+ " 功能:\n"
|
+ " " + SMTStatic.indentStrLines(toolDesc, " ") + "\n"
|
+ " 参数:\n"
|
+ " question:提出的问题\n"
|
);
|
|
for(Node nodeArg : doc.selectNodes("ROOT/ARGS/ARG"))
|
{
|
Element xmlArg = (Element)nodeArg;
|
AgentArgument agentArg = new AgentArgument(xmlArg);
|
_mapName2Argument.put(agentArg._name, agentArg);
|
String argName = SMTStatic.getXmlAttr(xmlArg, "name");
|
String argDesc = SMTStatic.trimStrLines(xmlArg.getText());
|
sbPrompt.append(" " + argName + ":" + argDesc + "\n");
|
}
|
sbPrompt.append(
|
" 返回:\n"
|
+ " 返回查询结果"
|
);
|
|
for(Node nodeMatchRegex : doc.selectNodes("ROOT/MATCHES/REGEX"))
|
{
|
QuestionMatchRegex match = new QuestionMatchRegex();
|
match.initInstance((Element)nodeMatchRegex);
|
_listQuestionMatch.add(match);
|
}
|
|
return sbPrompt.toString();
|
}
|
catch(Exception ex)
|
{
|
throw new Exception("get agent exception : " + agentId, ex);
|
}
|
}
|
|
protected SMTJavaAIError queryUnknowQuestion(String errorMsg, SMTLLMConnect llm, String question, SMTAIServerRequest tranReq) throws Exception
|
{
|
SMTJavaAIError orgError = new SMTJavaAIError(errorMsg);
|
|
tranReq.sendChunkedBlock("begin", "无法精确匹配已有的模式,进行模糊匹配");
|
|
String prompt = (String) SMTAIServerApp.getApp().getGlobalConfig("prompt.metric_unknow");
|
String answer = llm.callWithMessage(new String[] {prompt}, question, tranReq).replace("\r", "");
|
Json jsonAST = SMTStatic.convLLMAnswerToJson(answer, true);
|
|
tranReq.traceLLMDebug("unknow question Agents:[\n" + SMTStatic.formatJson(jsonAST) + "\n]");
|
|
if(!jsonAST.isArray())
|
return orgError;
|
|
List<Json> jsonCalls = jsonAST.asJsonList();
|
if(jsonCalls.size() == 0)
|
return orgError;
|
Json jsonCall = jsonCalls.get(0);
|
if(!"agent_unknow".equals(jsonCall.safeGetStr("call", "")))
|
return orgError;
|
Json jsonArgs = jsonCall.safeGetJson("args");
|
if(jsonArgs == null)
|
return orgError;
|
if(!jsonArgs.has("keyword"))
|
return orgError;
|
|
SMTJavaAIError error = SMTAIServerApp.getApp().getQwenAgentManager().callUnknowQuestionAgent(jsonArgs, llm, question, tranReq);
|
|
return error;
|
}
|
}
|