TangCheng
2025-02-28 d787e447e95c7b897c2cc9c0e832f8d2e5084934
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
package com.smtaiserver.smtaiserver.javaai.qwen.agent;
 
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Pattern;
 
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.dom4j.Document;
import org.dom4j.Element;
import org.dom4j.Node;
 
import com.smtaiserver.smtaiserver.core.SMTAIServerApp;
import com.smtaiserver.smtaiserver.core.SMTAIServerRequest;
import com.smtaiserver.smtaiserver.database.SMTDatabase.DBRecord;
import com.smtaiserver.smtaiserver.javaai.SMTJavaAIError;
import com.smtaiserver.smtaiserver.javaai.llm.core.SMTLLMConnect;
import com.smtaiserver.smtaiserver.javaai.qwen.SMTQwenApp;
import com.smtservlet.util.Json;
import com.smtservlet.util.SMTStatic;
 
public abstract class SMTQwenAgent 
{
    protected interface QuestionMatch
    {
        public void    initInstance(Element xmlRegex);
        public boolean isMatch(String question);
    }
    
    protected static class QuestionMatchRegex implements QuestionMatch
    {
        private Pattern _patMatch;
        
        @Override
        public void initInstance(Element xmlRegex) 
        {
            _patMatch = Pattern.compile(SMTStatic.trimStrLines(xmlRegex.getText()));
        }
 
        @Override
        public boolean isMatch(String question) 
        {
            if(!_patMatch.matcher(question).find())
                return false;
            return true;
        }
        
    }
    
    public static class AgentArgument
    {
        public String        _name;
        public String        _title;
        public char            _type;
        public boolean         _required;
        public String        _question;
        public String        _prompt;
        
        public AgentArgument(Element xmlArg) throws Exception
        {
            _required = "true".equals(SMTStatic.getXmlAttr(xmlArg, "required", "false"));
            _name = SMTStatic.getXmlAttr(xmlArg, "name");
            _type = SMTStatic.getXmlAttr(xmlArg, "type", "S").charAt(0);
            _question = SMTStatic.getXmlAttr(xmlArg, "question", _name);
            _prompt = xmlArg.getText();
        }
        
        public SMTJavaAIError checkArgumentValid(Json jsonArgs)
        {
            Json jsonArg = jsonArgs.safeGetJson(_name);
            if(jsonArg == null)
            {
                if(_required)
                    return new SMTJavaAIError("参数" + _title + "未设置");
                
                return null;
            }
            
            String sValue = jsonArg.asString();
            switch(_type)
            {
            case 'S':
                return null;
                
            case 'D':
                try
                {
                    SMTStatic.toDouble(sValue);
                }
                catch(Exception ex)
                {
                    return new SMTJavaAIError("参数" + _title + "的值" + sValue + "不是浮点数");
                }
                return null;
                
            case 'I':
                try
                {
                    SMTStatic.toInt(sValue);
                }
                catch(Exception ex)
                {
                    return new SMTJavaAIError("参数" + _title + "的值" + sValue + "不是整数");
                }
                return null;
                
            case 'T':
                try
                {
                    SMTStatic.toDate(sValue);
                }
                catch(Exception ex)
                {
                    return new SMTJavaAIError("参数" + _title + "的值" + sValue + "不是时间");
                }
                return null;
            }
            
            return null;
        }
    }
    
    protected boolean            _inSupervisor = false;
    protected String            _agentGroup;
    protected String            _agentTitle;
    protected String            _agentGroupType;
    protected String            _agentId;
    protected String            _agentType;
    protected String            _agentPrompt;
    protected Map<String, AgentArgument>    _mapName2Argument = new HashMap<>();
    protected List<QuestionMatch>    _listQuestionMatch = new ArrayList<>();
    private static Logger                     _logger = LogManager.getLogger(SMTQwenApp.class);
    
    public abstract SMTJavaAIError callAgents(String jsonPath, Json jsonArgs, SMTLLMConnect llm, String question, SMTAIServerRequest tranReq) throws Exception;
    
    
    public void queryUnknowQuestionList(String keyword, SMTAIServerRequest tranReq) throws Exception
    {
    }
    
    public boolean inSupervisor()
    {
        return _inSupervisor;
    }
    
    public SMTJavaAIError mergeASTList(List<Json> listASTSrc, List<Json> r_listASTTag)
    {
        for(Json jsonAST : listASTSrc)
        {
            r_listASTTag.add(jsonAST);
        }
        
        return null;
    }
    
    public boolean isMatchGroupType(String groupType)
    {
        if(SMTStatic.isNullOrEmpty(groupType))
            return true;
        if(groupType.equals(_agentGroupType))
            return true;
        
        return false;
    }
    
    public SMTJavaAIError callSupervisorJson(String agentId, String jsonPath, Json callFunc, SMTAIServerRequest tranReq) throws Exception
    {
        return new SMTJavaAIError("agent不支持二次操作");
    }
    
    public SMTJavaAIError checkArgsValid(Json jsonArgs)
    {
        for(AgentArgument agentArgument : _mapName2Argument.values())
        {
            SMTJavaAIError error = agentArgument.checkArgumentValid(jsonArgs);
            if(error != null)
                return error;
        }
        return null;
    }
    
    public SMTJavaAIError callExtJson(Json jsonArgs, SMTLLMConnect llm, String question, SMTAIServerRequest tranReq) throws Exception
    {
        return this.callAgents("", jsonArgs, llm, question, tranReq);
    }
    
    public AgentArgument getAgentArgument(String name)
    {
        return _mapName2Argument.get(name);
    }
    
    //agent_id,agent_type,tool_desc,tool_arguments,tool_return,clz_name,clz_arguments
    public void initInstance(DBRecord rec) throws Exception
    {
        
        try
        {
            _inSupervisor = !("Y".equals(rec.getString("inner_call")));
            _agentTitle = rec.getString("agent_title");
            if(_agentTitle == null)
                _agentTitle = "业务查询";
            _agentType = rec.getString("agent_type");
            _agentId = rec.getString("agent_id");
            _agentGroup = rec.getString("agent_group");
            if(rec.getFieldMap().containsKey("GROUP_TYPE"))
                _agentGroupType = rec.getString("group_type");
            else
                _agentGroupType = "";
    
            _agentPrompt = getAgentPrompt(rec);
            
            _logger.info("workflow : " + _agentId + ":\n" + _agentPrompt);
        }
        catch(Exception ex)
        {
            throw new Exception("init supervisor agent error : " + _agentId, ex);
        }
        
    }
    
    public boolean isAgentGroupMatched(Set<String> setAgentGroup)
    {
        if(setAgentGroup == null)
            return true;
        if(setAgentGroup.contains(_agentGroup))
            return true;
        
        return false;
    }
    
    public String getAgentTitle()
    {
        return _agentTitle == null ? "" : _agentTitle;
    }
    
    public String getAgentId()
    {
        return _agentId;
    }
    
    public String getAgentType()
    {
        return _agentType;
    }
    
    public String getAgentPrompt()
    {
        return _agentPrompt;
    }
    
    public boolean isQuestionMatched(String question)
    {
        for(QuestionMatch match : _listQuestionMatch)
        {
            if(!match.isMatch(question))
                return false;
        }
        
        return true;
    }
    
    protected String getToolDesc(Element xmlTitle) throws Exception
    {
        String toolDesc = SMTStatic.trimStrLines(xmlTitle.getText());
        
        return toolDesc;
    }
    
    protected String getAgentPrompt(DBRecord rec) throws Exception
    {
        String agentId = rec.getString("agent_id");
        try
        {
            Document doc = SMTStatic.convStrToXmlDoc("<ROOT>" + rec.getString("agent_xml") + "</ROOT>");
            Element xmlTitle = (Element) doc.selectSingleNode("ROOT/TITLE");
            String toolDesc = getToolDesc(xmlTitle);
            
            StringBuilder sbPrompt = new StringBuilder();
            sbPrompt.append(
                  agentId + "\n"
                + "    功能:\n"
                + "        " + SMTStatic.indentStrLines(toolDesc, "        ") + "\n"
                + "    参数:\n"
                + "        question:提出的问题\n"
            );        
            
            for(Node nodeArg : doc.selectNodes("ROOT/ARGS/ARG"))
            {
                Element xmlArg = (Element)nodeArg;
                AgentArgument agentArg = new AgentArgument(xmlArg);
                _mapName2Argument.put(agentArg._name, agentArg);
                String argName = SMTStatic.getXmlAttr(xmlArg, "name");
                String argDesc = SMTStatic.trimStrLines(xmlArg.getText());
                sbPrompt.append("        " + argName + ":" + argDesc + "\n");
            }
            sbPrompt.append(
                  "    返回:\n"
                + "        返回查询结果" 
            );
            
            for(Node nodeMatchRegex : doc.selectNodes("ROOT/MATCHES/REGEX"))
            {
                QuestionMatchRegex match = new QuestionMatchRegex();
                match.initInstance((Element)nodeMatchRegex);
                _listQuestionMatch.add(match);
            }
            
            return sbPrompt.toString();
        }
        catch(Exception ex)
        {
            throw new Exception("get agent exception : " + agentId, ex);
        }
    }
    
    protected SMTJavaAIError queryUnknowQuestion(String errorMsg, SMTLLMConnect llm, String question, SMTAIServerRequest tranReq) throws Exception
    {
        SMTJavaAIError orgError = new SMTJavaAIError(errorMsg);
        
        tranReq.sendChunkedBlock("begin", "无法精确匹配已有的模式,进行模糊匹配");
        
        String prompt = (String) SMTAIServerApp.getApp().getGlobalConfig("prompt.metric_unknow");
        String answer = llm.callWithMessage(new String[] {prompt}, question, tranReq).replace("\r", "");
        Json jsonAST = SMTStatic.convLLMAnswerToJson(answer, true);
        
        tranReq.traceLLMDebug("unknow question Agents:[\n" + SMTStatic.formatJson(jsonAST) + "\n]");
        
        if(!jsonAST.isArray())
            return orgError;
        
        List<Json> jsonCalls = jsonAST.asJsonList();
        if(jsonCalls.size() == 0)
            return orgError;
        Json jsonCall = jsonCalls.get(0);
        if(!"agent_unknow".equals(jsonCall.safeGetStr("call", "")))
            return orgError;
        Json jsonArgs = jsonCall.safeGetJson("args");
        if(jsonArgs == null)
            return orgError;
        if(!jsonArgs.has("keyword"))
            return orgError;
        
        SMTJavaAIError error = SMTAIServerApp.getApp().getQwenAgentManager().callUnknowQuestionAgent(jsonArgs, llm, question, tranReq);
        
        return error;
    }
}