package com.smtaiserver.smtaiserver.javaai.jsonflow.core; import java.util.HashMap; import java.util.List; import java.util.Map; import com.smtaiserver.smtaiserver.javaai.jsonflow.node.*; import org.apache.commons.pool2.BasePooledObjectFactory; import org.apache.commons.pool2.PooledObject; import org.apache.commons.pool2.impl.DefaultPooledObject; import org.apache.commons.pool2.impl.GenericObjectPool; import org.apache.commons.pool2.impl.GenericObjectPoolConfig; import com.smtaiserver.smtaiserver.core.SMTAIServerRequest; import com.smtaiserver.smtaiserver.javaai.SMTJavaAIError; import com.smtaiserver.smtaiserver.javaai.llm.core.SMTLLMConnect; import com.smtservlet.util.Json; public class SMTJsonFlowManager { public class SMTJsonFlowScriptPoolFactory extends BasePooledObjectFactory { @Override public SMTJsonFlowScriptJet create() throws Exception { return new SMTJsonFlowScriptJet(); } @Override public PooledObject wrap(SMTJsonFlowScriptJet obj) { return new DefaultPooledObject<>(obj); } } protected SMTJsonFlowNodeStart _startNode; protected Map _mapId2AgentFlowNode = new HashMap<>(); protected GenericObjectPool _scriptJetPool = null; public void initInstance(Json jsonWorkflow) throws Exception { // 创建脚本连接池 GenericObjectPoolConfig config = new GenericObjectPoolConfig<>(); config.setMaxTotal(102400); config.setMinIdle(0); _scriptJetPool = new GenericObjectPool<>(new SMTJsonFlowScriptPoolFactory(), config); // 创建并初始化节点 for(Json jsonNode : jsonWorkflow.getJson("nodes").asJsonList()) { SMTJsonFlowNode wfNode; String nodeType = jsonNode.getJson("type").asString(); if("start".equals(nodeType)) { wfNode = new SMTJsonFlowNodeStart(); } else if("end".equals(nodeType)) { wfNode = new SMTJsonFlowNodeEnd(); } else if("condition".equals(nodeType)) { wfNode = new SMTJsonFlowNodeCondJson(); } else if("agent".equals(nodeType)) { wfNode = new SMTJsonFlowNodeAgent(); } else if("output_msg".equals(nodeType)) { wfNode = new SMTJsonFlowNodeUserAsk(); } else if("code".equals(nodeType)) { wfNode = new SMTJsonFlowNodeScript(); } else if("text_resource".equals(nodeType)) { wfNode = new SMTJsonFlowNodeTextResource(); } else if("func".equals(nodeType)) { wfNode = new SMTJsonFlowNodeProcedure(); } else if("LLM".equals(nodeType)) { wfNode = new SMTJsonFlowNodeLLM(); } else if("python_code".equals(nodeType)) { wfNode = new SMTJsonFlowNodePython(); } else if("n8n".equals(nodeType)) { wfNode = new SMTJsonFlowNodeN8n(); } else { throw new Exception("unknow json workflow node type : " + nodeType); } wfNode.initInstane(this, jsonNode); _mapId2AgentFlowNode.put(wfNode.getId(), wfNode); } // 初始化管线 for(Json jsonEdge : jsonWorkflow.getJson("edges").asJsonList()) { String sourceId = jsonEdge.getJson("source").asString(); String targetId = jsonEdge.getJson("target").asString(); SMTJsonFlowNode sourceNode = getFlowNode(sourceId); SMTJsonFlowNode targetNode = getFlowNode(targetId); sourceNode.initEdge(targetNode, jsonEdge); } // 后续初始化 for(SMTJsonFlowNode flowNode : _mapId2AgentFlowNode.values()) { flowNode.afterInstance(); } // 寻找起始节点 _startNode = null; for(SMTJsonFlowNode flowNode : _mapId2AgentFlowNode.values()) { if(flowNode.isStartNode()) { if(_startNode != null) throw new Exception("start node exist : " + _startNode.getId() + " , but find other : " + flowNode.getId()); _startNode = (SMTJsonFlowNodeStart) flowNode; } } if(_startNode == null) throw new Exception("cant find any start node"); } public SMTJsonFlowScriptJet allocScriptJet() throws Exception { SMTJsonFlowScriptJet scriptJet = _scriptJetPool.borrowObject(); scriptJet.setPoolManager(this); return scriptJet; } public void freeScripteJet(SMTJsonFlowScriptJet scriptJet) { scriptJet.setPoolManager(null); _scriptJetPool.returnObject(scriptJet); } public List getStartArgs() { return _startNode.getFlowArgList(); } public SMTJsonFlowNode getFlowNode(String nodeId) throws Exception { SMTJsonFlowNode flowNode = _mapId2AgentFlowNode.get(nodeId); if(flowNode == null) throw new Exception("can't find json flow node : " + nodeId); return flowNode; } public SMTJavaAIError executeJsonFlow(String jsonPath, Json jsonArgs, SMTLLMConnect llm, String question, SMTAIServerRequest tranReq) throws Exception { // 创建执行器,并将起始节点加入执行器 SMTJsonFlowExecArg execArg = new SMTJsonFlowExecArg(jsonPath, jsonArgs, llm, question, tranReq); execArg._stackNodeExec.addLast(_startNode.createFlowNodeExec()); // 执行execArg._stackNode中的节点,直至结束 while(!execArg._stackNodeExec.isEmpty()) { // 执行下一个节点 SMTJsonFlowNodeExec flowNodeExec = execArg._stackNodeExec.pollFirst(); execArg._tranReq.sendChunkedBlock("begin", "执行节点:" + flowNodeExec._flowNode.getTitle() + "(" + flowNodeExec._flowNode.getId() + ")"); SMTJavaAIError error = flowNodeExec.executeFlowNode(execArg); if(error != null) return error; // 如果堆栈中只剩下isSuspendNode()的节点,则报错退出 boolean noSuspendNode = true; for(SMTJsonFlowNodeExec exec : execArg._stackNodeExec) { if(exec._flowNode.isSuspendNode()) { noSuspendNode = false; break; } } if(!noSuspendNode) return new SMTJavaAIError("工作流陷入死循环"); } execArg._tranReq.sendChunkedBlock("begin", "工作流执行完成"); return null; } }