package com.smtaiserver.smtaiserver.javaai.qwen.agent; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import org.dom4j.Document; import org.dom4j.Element; import org.dom4j.Node; import org.dom4j.tree.DefaultCDATA; import org.dom4j.tree.DefaultText; import com.smtaiserver.smtaiserver.core.SMTAIServerRequest; import com.smtaiserver.smtaiserver.database.SMTDatabase; import com.smtaiserver.smtaiserver.database.SMTDatabase.DBRecord; import com.smtaiserver.smtaiserver.database.SMTDatabase.DBRecords; import com.smtaiserver.smtaiserver.javaai.SMTJavaAIError; import com.smtaiserver.smtaiserver.javaai.ast.ASTDBMap; import com.smtaiserver.smtaiserver.javaai.llm.core.SMTLLMConnect; import com.smtservlet.util.Json; import com.smtservlet.util.SMTJsonWriter; import com.smtservlet.util.SMTStatic; public class SMTQwenAgentSummaryAMIS extends SMTQwenAgent { /////////////////////////////////////////////////////////////////////////////////////// private static class SQLXMLExecArg { public Json _toolArguments = null; public StringBuilder _sbSQLText = null; public Map _mapAmisData = new HashMap<>(); public List _sqlParams = new ArrayList<>(); public Map _mapId2Recs = new HashMap<>(); public SMTAIServerRequest _tranReq; public Map _mapAmisPath2Replace = new LinkedHashMap<>(); public ASTDBMap _dbMap = new ASTDBMap(); public SMTDatabase _curDB = null; public DBRecords _curRecs = null; public SQLXMLExecArg(Json toolArguments, SMTJsonWriter jsonWr, SMTAIServerRequest tranReq) { _toolArguments = toolArguments; _tranReq = tranReq; } public void close() { _dbMap.close(); } } /////////////////////////////////////////////////////////////////////////////////////// private static abstract class SQLXMLNode { public abstract SMTJavaAIError execute(SQLXMLExecArg execArg) throws Exception; } /////////////////////////////////////////////////////////////////////////////////////// private static class SQLXMLNodeAMIS_DATA extends SQLXMLNode { private String _keyField; private String _valueField; private String _databaseId; private List _listChildren = new ArrayList<>(); public SQLXMLNodeAMIS_DATA(Element xmlRoot) throws Exception { _databaseId = SMTStatic.getXmlAttr(xmlRoot, "database", ""); _keyField = SMTStatic.getXmlAttr(xmlRoot, "key_field"); _valueField = SMTStatic.getXmlAttr(xmlRoot, "value_field"); for (Iterator iterInner = xmlRoot.nodeIterator(); iterInner.hasNext();) { Node nodeInner = iterInner.next(); if(nodeInner.getNodeType() == Node.TEXT_NODE || nodeInner instanceof DefaultCDATA) { String text; if(nodeInner instanceof DefaultCDATA) text = ((DefaultCDATA)nodeInner).getText(); else text = ((DefaultText)nodeInner).getText(); int lastPos = _listChildren.size() - 1; if(lastPos >= 0 && _listChildren.get(lastPos) instanceof String) { _listChildren.set(lastPos, (String)_listChildren.get(lastPos) + text); } else { _listChildren.add(text); } } else { _listChildren.add(createSQLXMLNode((Element)nodeInner)); } } } @Override public SMTJavaAIError execute(SQLXMLExecArg execArg) throws Exception { StringBuilder orgSQLText = execArg._sbSQLText; List orgSqlParams = execArg._sqlParams; SMTDatabase orgDB = execArg._curDB; try { // 生成查询SQL execArg._sbSQLText = new StringBuilder(); execArg._sqlParams = new ArrayList<>(); execArg._curDB = execArg._dbMap.getDatabase(_databaseId); for(Object oxmlNode : _listChildren) { if(oxmlNode instanceof String) { execArg._sbSQLText.append(oxmlNode); } else if(oxmlNode instanceof SQLXMLNode) { ((SQLXMLNode)oxmlNode).execute(execArg); } } // 执行查询SQL DBRecords recs = execArg._curRecs; if(recs == null) recs = execArg._curDB.querySQL(execArg._sbSQLText.toString(), execArg._sqlParams.toArray(new Object[execArg._sqlParams.size()])); // 设置值 for(DBRecord rec : recs.getRecords()) { execArg._mapAmisData.put(rec.getString(_keyField), rec.getString(_valueField)); } } finally { execArg._curDB = orgDB; execArg._sbSQLText = orgSQLText; execArg._sqlParams = orgSqlParams; } return null; } } private static class SQLXMLNodeAMIS_FOREACH_TABLE extends SQLXMLNode { private String _noRecErrFmt = null; private String _databaseId = ""; private String _loopDatabaseId = ""; private List _listChildren = new ArrayList<>(); private List _listSQL = new ArrayList<>(); public SQLXMLNodeAMIS_FOREACH_TABLE(Element xmlRoot) throws Exception { Element nodeChild = (Element)xmlRoot.selectSingleNode("LOOP"); _loopDatabaseId = SMTStatic.getXmlAttr(nodeChild, "database", ""); for (Iterator iterInner = nodeChild.nodeIterator(); iterInner.hasNext();) { Node nodeInner = iterInner.next(); if(nodeInner.getNodeType() == Node.TEXT_NODE || nodeInner instanceof DefaultCDATA) { String text; if(nodeInner instanceof DefaultCDATA) text = ((DefaultCDATA)nodeInner).getText(); else text = ((DefaultText)nodeInner).getText(); int lastPos = _listChildren.size() - 1; if(lastPos >= 0 && _listChildren.get(lastPos) instanceof String) { _listChildren.set(lastPos, (String)_listChildren.get(lastPos) + text); } else { _listChildren.add(text); } } else { _listChildren.add(createSQLXMLNode((Element)nodeInner)); } } Element nodeSQL = (Element)xmlRoot.selectSingleNode("QUERY"); _noRecErrFmt = SMTStatic.getXmlAttr(nodeSQL, "no_rec_error", null); for (Iterator iterInner = nodeSQL.nodeIterator(); iterInner.hasNext();) { Node nodeInner = iterInner.next(); if(nodeInner.getNodeType() == Node.TEXT_NODE || nodeInner instanceof DefaultCDATA) { String text; if(nodeInner instanceof DefaultCDATA) text = ((DefaultCDATA)nodeInner).getText(); else text = ((DefaultText)nodeInner).getText(); int lastPos = _listSQL.size() - 1; if(lastPos >= 0 && _listSQL.get(lastPos) instanceof String) { _listSQL.set(lastPos, (String)_listSQL.get(lastPos) + text); } else { _listSQL.add(text); } } else { _listSQL.add(createSQLXMLNode((Element)nodeInner)); } } } @Override public SMTJavaAIError execute(SQLXMLExecArg execArg) throws Exception { StringBuilder orgSQLText = execArg._sbSQLText; List orgSqlParams = execArg._sqlParams; SMTDatabase orgDB = execArg._curDB; try { // 生成查询SQL execArg._sbSQLText = new StringBuilder(); execArg._sqlParams = new ArrayList<>(); execArg._curDB = execArg._dbMap.getDatabase(_databaseId); for(Object oxmlNode : _listSQL) { if(oxmlNode instanceof String) { execArg._sbSQLText.append(oxmlNode); } else if(oxmlNode instanceof SQLXMLNode) { ((SQLXMLNode)oxmlNode).execute(execArg); } } // 执行查询SQL DBRecords recs = execArg._curRecs; if(recs == null) recs = execArg._curDB.querySQL(execArg._sbSQLText.toString(), execArg._sqlParams.toArray(new Object[execArg._sqlParams.size()])); // 如果查不到数据,且设置了错误信息,则返回错误信息 if(recs.getRowCount() == 0 && !SMTStatic.isNullOrEmpty(_noRecErrFmt)) { String errMsg = SMTStatic.stringFormat(_noRecErrFmt, new SMTStatic.StringNamedNotify() { @Override public Object getNamedValue(String name, Object[] args) throws Exception { String value = execArg._toolArguments.safeGetStr(name, ""); return value; } }); return new SMTJavaAIError(errMsg); } // 遍历查询SQL,并执行子查询 for(DBRecord rec : recs.getRecords()) { // 设置当前参数 for(Entry entry : rec.getFieldMap().entrySet()) { execArg._toolArguments.set(entry.getKey().toLowerCase(), rec.getString(entry.getValue())); } // 解析子SQL StringBuilder orgSQLText1 = execArg._sbSQLText; List orgSqlParams1 = execArg._sqlParams; SMTDatabase orgDB1 = execArg._curDB; try { execArg._curDB = execArg._dbMap.getDatabase(_loopDatabaseId); execArg._sbSQLText = new StringBuilder(); execArg._sqlParams = new ArrayList<>(); for(Object oxmlNode : _listChildren) { if(oxmlNode instanceof String) { execArg._sbSQLText.append(oxmlNode); } else if(oxmlNode instanceof SQLXMLNode) { ((SQLXMLNode)oxmlNode).execute(execArg); } } // 执行查询SQL // 如果当前不存在记录集,则将此查询作为新记录集 if(execArg._curRecs == null) { execArg._curRecs = execArg._curDB.querySQL(execArg._sbSQLText.toString(), execArg._sqlParams.toArray(new Object[execArg._sqlParams.size()])); } // 如果当前存在记录集,则确保两组记录集字段个数相同,并合并 else { DBRecords recsNew = execArg._curDB.querySQL(execArg._sbSQLText.toString(), execArg._sqlParams.toArray(new Object[execArg._sqlParams.size()])); if(execArg._curRecs.getColCount() != recsNew.getColCount()) throw new Exception("megre record field count is different"); // 合并记录集 for(DBRecord recNew : recsNew.getRecords()) { execArg._curRecs.addRecord(recNew.getValues()); } } } finally { execArg._curDB = orgDB1; execArg._sbSQLText = orgSQLText1; execArg._sqlParams = orgSqlParams1; } } } finally { execArg._curDB = orgDB; execArg._sbSQLText = orgSQLText; execArg._sqlParams = orgSqlParams; } return null; } } /////////////////////////////////////////////////////////////////////////////////////// private static class SQLXMLNodeAMIS_JSON extends SQLXMLNode { private List _listChildren = new ArrayList<>(); public SQLXMLNodeAMIS_JSON(Element xmlRoot) throws Exception { for (Iterator iterInner = xmlRoot.nodeIterator(); iterInner.hasNext();) { Node nodeInner = iterInner.next(); if(nodeInner.getNodeType() == Node.TEXT_NODE || nodeInner instanceof DefaultCDATA) { String text; if(nodeInner instanceof DefaultCDATA) text = ((DefaultCDATA)nodeInner).getText(); else text = ((DefaultText)nodeInner).getText(); int lastPos = _listChildren.size() - 1; if(lastPos >= 0 && _listChildren.get(lastPos) instanceof String) { _listChildren.set(lastPos, (String)_listChildren.get(lastPos) + text); } else { _listChildren.add(text); } } else { _listChildren.add(createSQLXMLNode((Element)nodeInner)); } } } @Override public SMTJavaAIError execute(SQLXMLExecArg execArg) throws Exception { SMTJavaAIError error; execArg._sbSQLText = new StringBuilder(); execArg._sqlParams = new ArrayList<>(); for(Object oxmlNode : _listChildren) { if(oxmlNode instanceof String) { execArg._sbSQLText.append(oxmlNode); } else if(oxmlNode instanceof SQLXMLNode) { if((error = ((SQLXMLNode)oxmlNode).execute(execArg)) != null) return error; } } if(execArg._mapAmisPath2Replace.size() > 0) { Json jsonAMISRoot = Json.read(execArg._sbSQLText.toString()); for(Entry entry : execArg._mapAmisPath2Replace.entrySet()) { Json jsonAMIS = jsonAMISRoot; String[] pathList = entry.getKey().split("/"); for(int i = 0; i < (pathList.length - 1); i ++) { String path = pathList[i]; if(jsonAMIS.isObject()) { jsonAMIS = jsonAMIS.getJson(path); } else if(jsonAMIS.isArray()) { jsonAMIS = jsonAMIS.at(SMTStatic.toInt(path)); } else { throw new Exception("json path error"); } } String lastPath = pathList[pathList.length - 1]; if(jsonAMIS.isObject()) { jsonAMIS.set(lastPath, entry.getValue()); } else { throw new Exception("json path error"); } } execArg._sbSQLText.setLength(0); execArg._sbSQLText.append(jsonAMISRoot.toString()); } return null; } } /////////////////////////////////////////////////////////////////////////////////////// private static class SQLXMLNodeRECORDS extends SQLXMLNode { private List _listChildren = new ArrayList<>(); public SQLXMLNodeRECORDS(Element xmlRoot) throws Exception { for(Node nodeSQL : xmlRoot.selectNodes("SQL")) { SQLXMLNodeSQL sqlxmlSQL = new SQLXMLNodeSQL((Element)nodeSQL); _listChildren.add(sqlxmlSQL); } } @Override public SMTJavaAIError execute(SQLXMLExecArg execArg) throws Exception { SMTJavaAIError error; for(SQLXMLNodeSQL sqlxmlSQL : _listChildren) { if((error = sqlxmlSQL.execute(execArg)) != null) return error; } return null; } } /////////////////////////////////////////////////////////////////////////////////////// private static class SQLXMLNodeAMIS_ASYNCS extends SQLXMLNode { private List _listChildren = new ArrayList<>(); public SQLXMLNodeAMIS_ASYNCS(Element xmlRoot) throws Exception { for(Node nodeSQL : xmlRoot.selectNodes("AMIS_ASYNC")) { SQLXMLNodeAMIS_ASYNC_RS sqlxmlSQL = new SQLXMLNodeAMIS_ASYNC_RS((Element)nodeSQL); _listChildren.add(sqlxmlSQL); } } @Override public SMTJavaAIError execute(SQLXMLExecArg execArg) throws Exception { SMTJavaAIError error; for(SQLXMLNodeAMIS_ASYNC_RS sqlxmlRS : _listChildren) { if((error = sqlxmlRS.execute(execArg)) != null) return error; } return null; } } /////////////////////////////////////////////////////////////////////////////////////// private static class SQLXMLNodeSQL extends SQLXMLNode { private String _id; private String _databaseId; private List _listChildren = new ArrayList<>(); public SQLXMLNodeSQL(Element xmlRoot) throws Exception { _id = SMTStatic.getXmlAttr(xmlRoot, "id", null); _databaseId = SMTStatic.getXmlAttr(xmlRoot, "database"); for (Iterator iterInner = xmlRoot.nodeIterator(); iterInner.hasNext();) { Node nodeInner = iterInner.next(); if(nodeInner.getNodeType() == Node.TEXT_NODE) { String text = ((DefaultText)nodeInner).getText(); int lastPos = _listChildren.size() - 1; if(lastPos >= 0 && _listChildren.get(lastPos) instanceof String) { _listChildren.set(lastPos, (String)_listChildren.get(lastPos) + text); } else { _listChildren.add(text); } } else if(nodeInner instanceof DefaultCDATA) { String text = ((DefaultCDATA)nodeInner).getText(); int lastPos = _listChildren.size() - 1; if(lastPos >= 0 && _listChildren.get(lastPos) instanceof String) { _listChildren.set(lastPos, (String)_listChildren.get(lastPos) + text); } else { _listChildren.add(text); } } else { _listChildren.add(createSQLXMLNode((Element)nodeInner)); } } } @Override public SMTJavaAIError execute(SQLXMLExecArg execArg) throws Exception { if(SMTStatic.isNullOrEmpty(_id)) return null; SMTDatabase orgDB = execArg._curDB; DBRecords orgRecs = execArg._curRecs; try { execArg._tranReq.setAsynProcessText("正在查询" + _id); execArg._sbSQLText = new StringBuilder(); execArg._sqlParams = new ArrayList<>(); execArg._curDB = execArg._dbMap.getDatabase(_databaseId); execArg._curRecs = null; for(Object oxmlNode : _listChildren) { if(oxmlNode instanceof String) { execArg._sbSQLText.append(oxmlNode); } else if(oxmlNode instanceof SQLXMLNode) { SMTJavaAIError error = ((SQLXMLNode)oxmlNode).execute(execArg); if(error != null) return error; } } DBRecords recs = execArg._curRecs; if(recs == null) recs = execArg._curDB.querySQL(execArg._sbSQLText.toString(), execArg._sqlParams.toArray(new Object[execArg._sqlParams.size()])); execArg._mapId2Recs.put(_id, recs); } finally { execArg._curDB = orgDB; execArg._curRecs = orgRecs; } return null; } } /////////////////////////////////////////////////////////////////////////////////////// private static class SQLXMLNodeAMIS_ASYNC_RS extends SQLXMLNode { private String _recId; private String _asyncId; private String _amisPath; public SQLXMLNodeAMIS_ASYNC_RS(Element xmlRoot) throws Exception { _recId = SMTStatic.getXmlAttr(xmlRoot, "rec_id"); _asyncId = SMTStatic.getXmlAttr(xmlRoot, "async_id"); _amisPath = SMTStatic.getXmlAttr(xmlRoot, "amis_path"); } @Override public SMTJavaAIError execute(SQLXMLExecArg execArg) throws Exception { execArg._mapAmisPath2Replace.put(_amisPath, "/chat/chat_async_query?rec_id=" + _asyncId + "&history_id=" + (execArg._tranReq.isAgentCheckMode() ? "__AGENT_CHECK_MODE__" : execArg._tranReq.getChatHistoryId())); SMTJsonWriter jsonWrAsync = execArg._tranReq.prepareAsyncQueryJson(); DBRecords recs = execArg._mapId2Recs.get(_recId); if(recs == null) return new SMTJavaAIError("未配置id为[" + _recId + "]的SQL查询"); jsonWrAsync.beginMap(_asyncId); { // 加入列状态 jsonWrAsync.beginArray("columns"); for(String colName : recs.getFieldMap().keySet()) { jsonWrAsync.addKeyValue(null, colName); } jsonWrAsync.endArray(); // 加入值 int colCount = recs.getColCount(); jsonWrAsync.beginArray("values"); for(DBRecord rec : recs.getRecords()) { jsonWrAsync.beginArray(null); for(int i = 0; i < colCount; i ++) { jsonWrAsync.addKeyValue(null, rec.getValue(i)); } jsonWrAsync.endArray(); } jsonWrAsync.endArray(); } jsonWrAsync.endMap(); return null; } } /////////////////////////////////////////////////////////////////////////////////////// private static class SQLXMLNodePARAM extends SQLXMLNode { private String _key; public SQLXMLNodePARAM(Element xmlRoot) throws Exception { _key = SMTStatic.getXmlAttr(xmlRoot, "key"); } @Override public SMTJavaAIError execute(SQLXMLExecArg execArg) throws Exception { execArg._sbSQLText.append("?"); execArg._sqlParams.add(execArg._toolArguments.getJson(_key).asString()); return null; } } /////////////////////////////////////////////////////////////////////////////////////// private static class SQLXMLNodeJSON_URL extends SQLXMLNode { private String _recId; public SQLXMLNodeJSON_URL(Element xmlRoot) throws Exception { _recId = SMTStatic.getXmlAttr(xmlRoot, "rec_id"); } @Override public SMTJavaAIError execute(SQLXMLExecArg execArg) throws Exception { execArg._sbSQLText.append("/chat/chat_async_query?rec_id=" + _recId + "&history_id=" + execArg._tranReq.getChatHistoryId()); return null; } } private List _listSQLXMLNode = new ArrayList<>(); private SQLXMLNodeAMIS_JSON _sqlxmlAmisJson = null; private static SQLXMLNode createSQLXMLNode(Element xmlRoot) throws Exception { String name = xmlRoot.getName().toUpperCase(); if("SQL".equals(name)) return new SQLXMLNodeSQL(xmlRoot); else if("PARAM".equals(name)) return new SQLXMLNodePARAM(xmlRoot); else if("AMIS_JSON".equals(name)) return new SQLXMLNodeAMIS_JSON(xmlRoot); else if("AMIS_DATA".equals(name)) return new SQLXMLNodeAMIS_DATA(xmlRoot); else if("JSON_URL".equals(name)) return new SQLXMLNodeJSON_URL(xmlRoot); else if("AMIS_ASYNC".equals(name) || "AMIS_ASYNC_RS".equals(name)) return new SQLXMLNodeAMIS_ASYNC_RS(xmlRoot); else if("RECORDS".equals(name)) return new SQLXMLNodeRECORDS(xmlRoot); else if("AMIS_ASYNCS".equals(name)) return new SQLXMLNodeAMIS_ASYNCS(xmlRoot); else if("FOREACH_TABLE".equals(name)) return new SQLXMLNodeAMIS_FOREACH_TABLE(xmlRoot); else throw new Exception("unknow SQLXML : " + name); } @Override public void initInstance(DBRecord rec) throws Exception { super.initInstance(rec); Document doc = SMTStatic.convStrToXmlDoc("" + rec.getString("clz_arguments") + ""); // 读取根节点 Element rootElement=doc.getRootElement(); // 读取所有子节点 for (Iterator iterInner = rootElement.elementIterator(); iterInner.hasNext();) { Node nodeInner = iterInner.next(); if(nodeInner.getNodeType() == Node.TEXT_NODE) { } else { String tagName = ((Element)nodeInner).getName(); if("AMIS_JSON".equals(tagName)) _sqlxmlAmisJson = new SQLXMLNodeAMIS_JSON((Element)nodeInner); else _listSQLXMLNode.add(createSQLXMLNode((Element)nodeInner)); } } } @Override public SMTJavaAIError callAgents(String jsonPath, Json jsonArgs, SMTLLMConnect llm, String question, SMTAIServerRequest tranReq) throws Exception { SMTJsonWriter jsonWrResult = tranReq.getResultJsonWr(); // 生成json SQLXMLExecArg execArg = new SQLXMLExecArg(jsonArgs, jsonWrResult, tranReq); try { for(SQLXMLNode xmlNode : _listSQLXMLNode) { SMTJavaAIError error = xmlNode.execute(execArg); if(error != null) return error; } _sqlxmlAmisJson.execute(execArg); // 输出json jsonWrResult.addKeyValue("answer_type", "summary"); jsonWrResult.beginArray("summary"); { jsonWrResult.beginMap(null); { jsonWrResult.addKeyValue("type", "amis_page"); jsonWrResult.addKeyValue("title", ""); jsonWrResult.addKeyRaw("amis_json", execArg._sbSQLText.toString()); jsonWrResult.beginMap("amis_data"); { for(Entry entry : jsonArgs.asJsonMap().entrySet()) { jsonWrResult.addKeyValue(entry.getKey(), entry.getValue().asString()); } for(Entry entry : execArg._mapAmisData.entrySet()) { jsonWrResult.addKeyValue(entry.getKey(), entry.getValue()); } } jsonWrResult.endMap(); } jsonWrResult.endMap(); } jsonWrResult.endArray(); } finally { execArg.close(); } return null; } }