package com.smtaiserver.smtaiserver.control; import com.smtaiserver.smtaiserver.core.SMTAIServerApp; import com.smtaiserver.smtaiserver.core.SMTAIServerRequest; import com.smtaiserver.smtaiserver.database.SMTDatabase; import com.smtaiserver.smtaiserver.database.SMTDatabase.DBRecords; import com.smtaiserver.smtaiserver.javaai.SMTJavaAIChat; import com.smtaiserver.smtaiserver.javaai.ast.ASTDBMap; import com.smtaiserver.smtaiserver.javaai.llm.core.SMTLLMConnect; import com.smtaiserver.smtaiserver.util.SMTWXSStatic; import com.smtservlet.util.Json; import com.smtservlet.util.SMTJsonWriter; import com.smtservlet.util.SMTStatic; import java.util.*; import java.util.concurrent.CompletableFuture; import javax.servlet.http.HttpServletRequest; import okhttp3.OkHttpClient; import okhttp3.Request; import okhttp3.Response; import okhttp3.ResponseBody; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.jetbrains.annotations.NotNull; import org.springframework.web.servlet.ModelAndView; public class SMTAIWeixinControl { private static final String FROM_USER_NAME = "FromUserName"; private static final String TO_USER_NAME = "ToUserName"; private static final String CONTENT = "Content"; private static final Logger _logger = LogManager.getLogger(SMTAIServerControl.class); private final Object _lockToken = new Object(); public ModelAndView weChatTest(SMTAIServerRequest tranReq) throws Exception { String question = "什么是冰箱"; tranReq.setAIQuestion(question); tranReq.setTextResultMode(); Set setAgentGroup = new HashSet(); String weixinGroupId = (String)SMTAIServerApp.getApp().getGlobalConfig("weixin.group_id", false); setAgentGroup.add(weixinGroupId); SMTJavaAIChat.questionChat("aliyun", setAgentGroup, "业务场景", question, null, false, tranReq); String text = tranReq.getTextResult(); return tranReq.returnText(text); } /** 微信验证 */ public ModelAndView weChatNotify(SMTAIServerRequest tranReq) throws Exception { String method = tranReq.getRequest().getMethod(); if (method.equals("GET")) return SMTWXSStatic.getModelAndView(tranReq); return reply(tranReq); } /** 被动回复 */ private ModelAndView reply(SMTAIServerRequest tranReq) { long l = System.currentTimeMillis() / 1000; String createTimeStr = String.valueOf(l); HttpServletRequest request = tranReq.getRequest(); Map requestMap = SMTWXSStatic.getWechatReqMap(request); String fromUserName = requestMap.get(FROM_USER_NAME); String toUserName = requestMap.get(TO_USER_NAME); if (requestMap.isEmpty()) { return null; } String result = getString(fromUserName, toUserName, createTimeStr); String reqContent = requestMap.get(CONTENT); // 异步调用 aiReplyToTheUserASecondTime CompletableFuture.runAsync( () -> { try { String answer = callAIForAnswerQuestion(reqContent, tranReq); // Ai调用 返回结果 aiReplyToTheUserASecondTime(answer, requestMap.get(FROM_USER_NAME)); } catch (Exception e) { _logger.error("aiReplyToTheUserASecondTime error", e); } }); _logger.info("微信消息返参:" + result); // 返回 XML 字符串 return tranReq.returnText(result); } /** ai回复 */ private String callAIForAnswerQuestion(String question, SMTAIServerRequest tranReq) throws Exception { String callFunc = "query_water_fee:\n" + " 功能:\n" + " 查询用户用水量和水费信息\n" + " 参数:\n" + " question:用户问题\n" + " user_name:用户名\n" + " value_title:'用水量'或'水费'\n" + " value_name:用水量:volume, 水费:amount\n" + " start_time:查询起始日期,格式:年-月-日\n" + " end_time:查询结束时间,格式:年-月-日\n"; String prompt = ((String) SMTAIServerApp.getApp().getGlobalConfig("prompt.agent_tools")) .replace("{{{AGENT_TOOL_DEFINE_LIST}}}", callFunc); SMTLLMConnect llm = SMTAIServerApp.getApp().allocLLMConnect(null); String answer = llm.callWithMessage(new String[] {prompt}, question, tranReq); tranReq.traceLLMDebug(answer); Json ojsonASTList = SMTStatic.convLLMAnswerToJson(answer, true); if (ojsonASTList != null && ojsonASTList.isArray()) { List jsonASTList = ojsonASTList.asJsonList(); if (!jsonASTList.isEmpty()) { Json jsonAST = jsonASTList.get(0); if ("query_water_fee".equals(jsonAST.safeGetStr("call", null))) { jsonAST = jsonAST.getJson("args"); try (ASTDBMap dbMap = new ASTDBMap()) { SMTDatabase db = dbMap.getDatabase("DS_74_CHENGTOU"); DBRecords recs = db.querySQL( " SELECT ROUND(SUM(" + jsonAST.getJson("value_name").asString() + ")::NUMERIC(10, 2), 2) AS TOTAL" + " FROM chengtou_data.bill_data WHERE billing_date BETWEEN ? AND ?", new Object[] { SMTStatic.toDate(jsonAST.getJson("start_time").asString()), SMTStatic.toDate(jsonAST.getJson("end_time").asString()) }); if (recs.getRecord(0).getString(0) == null) return "从" + jsonAST.getJson("start_time").asString() + "到" + jsonAST.getJson("end_time").asString() + "的" + jsonAST.getJson("value_title").asString() + "未查到任何数据"; return "从" + jsonAST.getJson("start_time").asString() + "到" + jsonAST.getJson("end_time").asString() + "的" + jsonAST.getJson("value_title").asString() + "总计" + recs.getRecord(0).getString(0) + ("volume".equals(jsonAST.getJson("value_name").asString()) ? "吨" : "元"); } } } } answer = llm.callWithMessage(null, question, tranReq); return answer; } /** 二次回复 */ public void aiReplyToTheUserASecondTime(String answer, String fromUserName) throws Exception { String accessToken = getAccessToken(); SMTJsonWriter jsonWr = new SMTJsonWriter(false); jsonWr.addKeyValue("touser", fromUserName); jsonWr.addKeyValue("msgtype", "text"); jsonWr.beginMap("text"); { jsonWr.addKeyValue("content", answer); } jsonWr.endMap(); String url = String.format( "https://api.weixin.qq.com/cgi-bin/message/custom/send?access_token=%s", accessToken); String s = SMTWXSStatic.sendPost(url, jsonWr.getRootJson()); _logger.info("上传结果: : " + s); } /** * 数据库获取 access_koen */ public String getAccessToken() throws Exception { synchronized (this._lockToken) { try (SMTDatabase db = SMTAIServerApp.getApp().allocDatabase()) { HashMap weixinParam = SMTWXSStatic.getWeixinParam(); Date curTime = new Date(); String appId = weixinParam.get("appId"); // 查询未过期的 access_token DBRecords dbRecord = db.querySQL( "SELECT app_id, access_token, expires_time " + "FROM ai_weixin_token " + "WHERE app_id = ? " + " AND ? < expires_time", new Object[] {appId, curTime}); // 数据库无记录,从微信服务器获取 access_token List records = dbRecord.getRecords(); if (dbRecord.getRowCount() > 0) { return records.get(0).getString("access_token"); } // 微信取,返回token并且保存或覆盖数据 else { Object[] accessToken = fetchAccessTokenFromWeixinServer(); // 从微信服务器获取 access_token Date expiresTime = SMTStatic.calculateTime( curTime, SMTStatic.SMTCalcTime.ADD_SECOND, ((int) accessToken[1]) / 2); String sql = "INSERT INTO ai_weixin_token (app_id, access_token, expires_time) " + "VALUES (?, ?, ?) " + "ON CONFLICT (app_id) " + // 如果 app_id 冲突 "DO UPDATE SET " + " access_token = EXCLUDED.access_token, " + " expires_time = EXCLUDED.expires_time;"; db.executeSQL(sql, new Object[] {appId, accessToken[0], expiresTime}); return (String) accessToken[0]; } } catch (Exception e) { throw new Exception("Failed to get access token", e); } } } /** * 从微信服务器获取 access_token */ private Object[] fetchAccessTokenFromWeixinServer() throws Exception { HashMap weixinParam = SMTWXSStatic.getWeixinParam(); OkHttpClient okHttpClient = new OkHttpClient(); // 创建请求 Request request = new Request.Builder() .url( "https://api.weixin.qq.com/cgi-bin/token?grant_type=client_credential&appid=" + weixinParam.get("appId") + "&secret=" + weixinParam.get("secret")) // 请求URL .get() // 使用GET方法 .build(); Response response = okHttpClient.newCall(request).execute(); if (!response.isSuccessful() || response.body() == null) { throw new Exception("can't get weixin token"); } Json json = Json.read(response.body().string()); String accessToken = json.safeGetStr("access_token", null); String expiresIn = json.safeGetStr("expires_in", null); if (accessToken != null) { return new Object[] {accessToken, SMTStatic.toInt(expiresIn)}; } else { throw new Exception("can't get weixin token : " + json); } } @NotNull private static String getString(String fromUserName, String toUserName, String createTimeStr) { String xmltemp = "\n" + " \n" + " \n" + " {{{CreateTime}}}\n" + " \n" + " \n" + ""; // 替换占位符 return xmltemp .replace("{{{toUser}}}", fromUserName) .replace("{{{fromUser}}}", toUserName) .replace("{{{CreateTime}}}", createTimeStr); } }