package com.smtaiscript.lib.aliyunai; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import org.mozilla.javascript.Context; import org.mozilla.javascript.Function; import org.mozilla.javascript.NativeArray; import org.mozilla.javascript.NativeObject; import com.alibaba.dashscope.aigc.conversation.ConversationParam.ResultFormat; import com.alibaba.dashscope.aigc.generation.Generation; import com.alibaba.dashscope.aigc.generation.GenerationParam; import com.alibaba.dashscope.aigc.generation.GenerationResult; import com.alibaba.dashscope.aigc.generation.GenerationOutput.Choice; import com.alibaba.dashscope.common.Message; import com.alibaba.dashscope.common.Role; import com.alibaba.dashscope.tools.FunctionDefinition; import com.alibaba.dashscope.tools.ToolBase; import com.alibaba.dashscope.tools.ToolCallBase; import com.alibaba.dashscope.tools.ToolCallFunction; import com.alibaba.dashscope.tools.ToolFunction; import com.alibaba.dashscope.utils.JsonUtils; import com.smtaiscript.lib.JSStaticSMTAI; import com.smtscript.lib.JSComment; import com.smtscript.utils.Json; import com.smtscript.utils.JsonWriter; import com.smtscript.utils.SMTStatic; public class SMTAliyunAISession { private static class AIToolParam { public String _name; public String _desc; public boolean _required; public AIToolParam(NativeObject nvToolParam) throws Exception { _name = (String) SMTStatic.getJSValue(nvToolParam, "name"); _desc = (String) SMTStatic.getJSValue(nvToolParam, "desc"); _required = (boolean)SMTStatic.getJSValue(nvToolParam, "required", true); } } private static class AIToolDef { public String _name; public String _desc; public Function _func; public List _listParam = new ArrayList<>(); public AIToolDef(NativeObject nvToolDef) throws Exception { _name = (String) SMTStatic.getJSValue(nvToolDef, "name"); _desc = (String) SMTStatic.getJSValue(nvToolDef, "desc"); _func = (Function)SMTStatic.getJSValue(nvToolDef, "func"); NativeArray arrJSParams = (NativeArray) SMTStatic.getJSValue(nvToolDef, "params", null); if(arrJSParams != null) { for(int i = 0; i < arrJSParams.size(); i ++) { NativeObject nvToolParam = (NativeObject)SMTStatic.unwrapObject(arrJSParams.get(i)); AIToolParam toolParam = new AIToolParam(nvToolParam); _listParam.add(toolParam); } } } public String call(NativeObject jsParam, Context cx, NativeObject nvTag) { String result = SMTStatic.toString(_func.call(cx, _func, null, new Object[] {jsParam, nvTag})); return result; } public ToolFunction getToolFunction() { JsonWriter jsonWr = new JsonWriter(false); jsonWr.addKeyValue("type", "function"); jsonWr.beginMap("function"); { jsonWr.addKeyValue("name", _name); jsonWr.addKeyValue("description", _desc); // 加入参数列表 List listParamRequired = new ArrayList<>(); jsonWr.beginMap("parameters"); if(_listParam.size() > 0) { jsonWr.addKeyValue("type", "object"); jsonWr.beginMap("properties"); for(AIToolParam paramInfo : _listParam) { if(paramInfo._required) listParamRequired.add(paramInfo._name); jsonWr.beginMap(paramInfo._name); { jsonWr.addKeyValue("type", "string"); jsonWr.addKeyValue("description", paramInfo._desc); } jsonWr.endMap(); } jsonWr.endMap(); } jsonWr.endMap(); // 加入参数是否需要请求 if(listParamRequired.size() > 0) { jsonWr.beginArray("required"); for(String paramName : listParamRequired) { jsonWr.addKeyValue(null, paramName); } jsonWr.endArray(); } } jsonWr.endMap(); FunctionDefinition funcDefine = FunctionDefinition.builder().name(_name).description(_desc) .parameters(JsonUtils.parseString(jsonWr.getFullJson()).getAsJsonObject()).build(); ToolFunction toolFunction = ToolFunction.builder().function(funcDefine).build(); return toolFunction; } } private static class AIToolManager { private Map _mapName2ToolDef = new HashMap<>(); public AIToolManager(NativeArray arrJSToolList) throws Exception { for(int i = 0; i < arrJSToolList.size(); i ++) { NativeObject nvToolDef = (NativeObject)SMTStatic.unwrapObject(arrJSToolList.get(i)); AIToolDef toolDef = new AIToolDef(nvToolDef); _mapName2ToolDef.put(toolDef._name, toolDef); } } public AIToolDef getToolDef(String funcName) { return _mapName2ToolDef.get(funcName); } public List getFunctList() { List list = new ArrayList<>(); for(AIToolDef toolDef : _mapName2ToolDef.values()) { list.add(toolDef.getToolFunction()); } return list; } } /////////////////////////////////////////////////////////////////////// private JSStaticSMTAI _parent; private Generation _gen; public SMTAliyunAISession(JSStaticSMTAI parent) { _parent = parent; _gen = new Generation(); } @JSComment( "
\n" +
		"nvConfig:\n" +
		"    model    - String   : (qwen-turbo,qwen-max,qwen-plus)\n" +
		"               llm:qwen-turbo, tool:qwen-max\n" +
		"    systems  - String[] : system message\n" +
		"    question - String   : question\n" +
		"    call_tool- Function : function(run_ms, llmResult, nvTag, jsParam, toolResult)\n" +
		"    tools    - Map[]    : tool define\n" +
		"       name  - String   : tool name\n" +
		"       desc  - String   : tool desc\n" +
		"       func  - Function : function(nvParam, nvTag)\n" +
		"       params- Map[]    : tool params\n" +
		"          name-String   : param name\n" +
		"          desc-String   : param desc\n" +
		"
\n" ) public NativeObject askQuestion(NativeObject nvConfig, NativeObject nvTag) throws Exception { String modelName = (String) SMTStatic.getJSValue(nvConfig, "model", null); long tick = System.currentTimeMillis(); String question = SMTStatic.toString(SMTStatic.getJSValue(nvConfig, "question")); List systemMessages = new ArrayList<>(); NativeArray nvSysMsg = (NativeArray)SMTStatic.getJSValue(nvConfig, "systems", null); if(nvSysMsg != null) { for(int i = 0; i < nvSysMsg.size(); i ++) { systemMessages.add(SMTStatic.toString(SMTStatic.unwrapObject(nvSysMsg.get(i)))); } } String answer = null; NativeArray nvTools = (NativeArray) SMTStatic.getJSValue(nvConfig, "tools", null); if(nvTools == null || nvTools.size() == 0) { answer = callWithMessage(modelName, systemMessages, question); } else { Function funcCallTool = (Function) SMTStatic.getJSValue(nvConfig, "call_tool", null); Context cx = _parent.__parentScope__().__runtime__().entryContext(); try { AIToolManager toolManager = new AIToolManager(nvTools); answer = callWithTools(modelName, systemMessages, question, toolManager, cx, funcCallTool, nvTag); } finally { Context.exit(); } } tick = System.currentTimeMillis() - tick; NativeObject nvAnswer = new NativeObject(); SMTStatic.putJSNotNullValue(nvAnswer, "answer", answer); SMTStatic.putJSNotNullValue(nvAnswer, "run_ms", tick); return nvAnswer; } private String callWithMessage(String modelName, List listSysMsg, String userMsg) throws Exception { List messages = new ArrayList<>(); if(listSysMsg != null) { for(String sysMsg : listSysMsg) { Message systemMsg = Message.builder().role(Role.SYSTEM.getValue()).content(sysMsg).build(); messages.add(systemMsg); } } Message userMessage = Message.builder().role(Role.USER.getValue()).content(userMsg).build(); messages.add(userMessage); GenerationParam param = GenerationParam.builder().model( SMTStatic.isNullOrEmpty(modelName) ? Generation.Models.QWEN_TURBO : modelName ).messages(messages) .resultFormat(GenerationParam.ResultFormat.MESSAGE) .build(); GenerationResult result = _gen.call(param); return result.getOutput().getChoices().get(0).getMessage().getContent(); } @SuppressWarnings("unchecked") private String callWithTools(String modelName, Object listSysMsg, String userMsgText, AIToolManager toolManager, Context cx, Function funcCallTool, NativeObject nvTag) throws Exception { List messages = new ArrayList<>(); Message topSystemMsg = Message.builder().role(Role.SYSTEM.getValue()) .content("You are a helpful assistant. When asked a question, use tools wherever possible.") .build(); messages.add(topSystemMsg); if(listSysMsg != null) { if(listSysMsg instanceof List) { for(String sysMsg : (List)listSysMsg) { Message systemMsg = Message.builder().role(Role.SYSTEM.getValue()).content(sysMsg).build(); messages.add(systemMsg); } } else if(listSysMsg instanceof String[]) { for(String sysMsg : (String[])listSysMsg) { Message systemMsg = Message.builder().role(Role.SYSTEM.getValue()).content(sysMsg).build(); messages.add(systemMsg); } } else if(listSysMsg instanceof String) { Message systemMsg = Message.builder().role(Role.SYSTEM.getValue()).content((String)listSysMsg).build(); messages.add(systemMsg); } else throw new Exception("unknow listSysMsg type"); } Message userMessage = Message.builder().role(Role.USER.getValue()).content(userMsgText).build(); messages.add(userMessage); GenerationParam param = GenerationParam.builder().model( SMTStatic.isNullOrEmpty(modelName) ? "qwen-long" : modelName ) .messages(messages).resultFormat(ResultFormat.MESSAGE) .tools( toolManager.getFunctList() ).build(); while(true) { long tick = System.currentTimeMillis(); GenerationResult result = _gen.call(param); tick = System.currentTimeMillis() - tick; boolean existCallTool= false; for (Choice choice : result.getOutput().getChoices()) { messages.add(choice.getMessage()); if (result.getOutput().getChoices().get(0).getMessage().getToolCalls() != null) { for (ToolCallBase toolCall : result.getOutput().getChoices().get(0).getMessage().getToolCalls()) { if (toolCall.getType().equals("function")) { // 获取工具函数名称和入参 String functionName = ((ToolCallFunction) toolCall).getFunction().getName(); String functionArgument = ((ToolCallFunction) toolCall).getFunction().getArguments(); AIToolDef toolDef = toolManager.getToolDef(functionName); if(toolDef != null) { NativeObject jsParam = (NativeObject) SMTStatic.convJsonToJS(Json.read(functionArgument)); String toolResult = toolDef.call(jsParam, cx, nvTag); if(funcCallTool != null) { Object llmResult = SMTStatic.convJsonToJS(Json.read(JsonUtils.toJson(result))); funcCallTool.call(cx, funcCallTool, null, new Object[] {tick, llmResult, nvTag, jsParam, toolResult }); } if(toolResult != null) { existCallTool = true; Message toolResultMessage = Message.builder().role("tool").content(toolResult).toolCallId(toolDef._name).build(); messages.add(toolResultMessage); } } } } } } if(!existCallTool) { if(funcCallTool != null) { Object llmResult = SMTStatic.convJsonToJS(Json.read(JsonUtils.toJson(result))); funcCallTool.call(cx, funcCallTool, null, new Object[] {tick, llmResult, nvTag, null, null}); } String finalResult = result.getOutput().getChoices().get(0).getMessage().getContent(); return finalResult; } } } }