秦芳睿
8 天以前 a7ae5db38611821cfa805556ed37065a256baefb
JAVA/SMTAIServer/src/main/java/com/smtaiserver/smtaiserver/control/SMTAIWeixinControl.java
@@ -1,185 +1,140 @@
package com.smtaiserver.smtaiserver.control;
import cn.hutool.http.HttpUtil;
import com.smtaiserver.smtaiserver.core.SMTAIServerApp;
import com.smtaiserver.smtaiserver.core.SMTAIServerRequest;
import com.smtaiserver.smtaiserver.database.SMTDatabase;
import com.smtaiserver.smtaiserver.database.SMTDatabase.DBRecords;
import com.smtaiserver.smtaiserver.javaai.ast.ASTDBMap;
import com.smtaiserver.smtaiserver.javaai.llm.core.SMTLLMConnect;
import com.smtaiserver.smtaiserver.util.SMTWXSStatic.WeixinuUtil;
import com.smtservlet.core.SMTRequest;
import com.smtaiserver.smtaiserver.javaai.SMTJavaAIChat;
import com.smtaiserver.smtaiserver.util.SMTWXSStatic;
import com.smtservlet.util.Json;
import com.smtservlet.util.SMTJsonWriter;
import com.smtservlet.util.SMTStatic;
import java.util.*;
import javax.servlet.http.HttpServletRequest;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.dom4j.Document;
import org.dom4j.DocumentException;
import org.dom4j.Element;
import org.dom4j.io.SAXReader;
import org.springframework.boot.configurationprocessor.json.JSONObject;
import org.jetbrains.annotations.NotNull;
import org.springframework.web.servlet.ModelAndView;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.sql.Timestamp;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import static com.smtaiserver.smtaiserver.util.SMTWXSStatic.WeixinuUtil.*;
import static java.util.Arrays.sort;
public class SMTAIWeixinControl {
  private static final String FROM_USER_NAME = "FromUserName";
  private static final String TO_USER_NAME = "ToUserName";
  private static final String CONTENT = "Content";
  private static Logger _logger = LogManager.getLogger(SMTAIServerControl.class);
    private static final HashMap<String, String> asynchronousList = new HashMap<>();
    private static final String FROM_USER_NAME = "FromUserName";
    private static final String TO_USER_NAME = "ToUserName";
    private static final String CONTENT = "Content";
    private static final Logger _logger = LogManager.getLogger(SMTAIServerControl.class);
  private final Object _lockToken = new Object();
  private String _tokenValue = null;
  private long _tokenTicket = 0;
  private OkHttpClient _web;
    private final Object _lockToken = new Object();
  /** 微信验证 */
  public ModelAndView weChatNotify(SMTAIServerRequest tranReq) throws Exception {
    String method = tranReq.getRequest().getMethod();
    if (method.equals("GET")) return WeixinuUtil.getModelAndView(tranReq);
    return reply(tranReq);
  }
    /**
     * 微信验证
     */
    public ModelAndView weChatNotify(SMTAIServerRequest tranReq) throws Exception {
        String method = tranReq.getRequest().getMethod();
        if (method.equals("GET")) return SMTWXSStatic.getModelAndView(tranReq);
        return reply(tranReq);
    }
  /** 被动回复 */
  private ModelAndView reply(SMTAIServerRequest tranReq) {
    long l = System.currentTimeMillis() / 1000;
    String createTimeStr = String.valueOf(l);
    HttpServletRequest request = tranReq.getRequest();
    Map<String, String> requestMap = getWechatReqMap(request);
    Map<String, String> requestMap = SMTWXSStatic.getWechatReqMap(request);
    String fromUserName = requestMap.get(FROM_USER_NAME);
    String toUserName = requestMap.get(TO_USER_NAME);
    if (requestMap.isEmpty()) {
      return null;
    }
    String xmltemp =
        "<xml>\n"
            + "  <ToUserName><![CDATA[{{{toUser}}}]]></ToUserName>\n"
            + "  <FromUserName><![CDATA[{{{fromUser}}}]]></FromUserName>\n"
            + "  <CreateTime>{{{CreateTime}}}</CreateTime>\n"
            + "  <MsgType><![CDATA[text]]></MsgType>\n"
            + "  <Content><![CDATA[我正在思考哦~请稍等……]]></Content>\n"
            + "</xml>";
    // 替换占位符
    String result =
        xmltemp
            .replace("{{{toUser}}}", fromUserName)
            .replace("{{{fromUser}}}", toUserName)
            .replace("{{{CreateTime}}}", createTimeStr);
    String reqContent = requestMap.get(CONTENT);
    // 异步调用 aiReplyToTheUserASecondTime
    CompletableFuture.runAsync(
        () -> {
          try {
            String answer = callAIForAnswerQuestion(reqContent, tranReq); // Ai调用 返回结果
            aiReplyToTheUserASecondTime(answer, requestMap.get(FROM_USER_NAME));
          } catch (Exception e) {
            _logger.error("aiReplyToTheUserASecondTime error", e);
          }
        });
    try {
      _logger.info("微信消息返参:" + result);
      // 返回 XML 字符串
      return tranReq.returnText(result);
    } catch (Exception e) {
      throw new RuntimeException(e);
    if (asynchronousList.get(fromUserName) != null && !reqContent.equals("停止输出")) {
      String dissuadeReturn = dissuadeReturn(fromUserName, toUserName, createTimeStr);
      return tranReq.returnText(dissuadeReturn);
    }
  }
  /**
   * ai回复
   *
   * @throws Exception
   */
  private String callAIForAnswerQuestion(String question, SMTAIServerRequest tranReq)
      throws Exception {
    String callFunc =
        "query_water_fee:\n"
            + "    功能:\n"
            + "         查询用户用水量和水费信息\n"
            + "    参数:\n"
            + "         question:用户问题\n"
            + "         user_name:用户名\n"
            + "         value_title:'用水量'或'水费'\n"
            + "         value_name:用水量:volume, 水费:amount\n"
            + "         start_time:查询起始日期,格式:年-月-日\n"
            + "         end_time:查询结束时间,格式:年-月-日\n";
    String prompt =
        ((String) SMTAIServerApp.getApp().getGlobalConfig("prompt.agent_tools"))
            .replace("{{{AGENT_TOOL_DEFINE_LIST}}}", callFunc);
    SMTLLMConnect llm = SMTAIServerApp.getApp().allocLLMConnect(null);
    String answer = llm.callWithMessage(new String[] {prompt}, question, tranReq);
    tranReq.traceLLMDebug(answer);
    Json ojsonASTList = SMTStatic.convLLMAnswerToJson(answer, true);
    if (ojsonASTList != null && ojsonASTList.isArray()) {
      List<Json> jsonASTList = ojsonASTList.asJsonList();
      if (jsonASTList.size() > 0) {
        Json jsonAST = jsonASTList.get(0);
        if ("query_water_fee".equals(jsonAST.safeGetStr("call", null))) {
          jsonAST = jsonAST.getJson("args");
          try (ASTDBMap dbMap = new ASTDBMap()) {
            SMTDatabase db = dbMap.getDatabase("DS_74_CHENGTOU");
            DBRecords recs =
                db.querySQL(
                    " SELECT ROUND(SUM("
                        + jsonAST.getJson("value_name").asString()
                        + ")::NUMERIC(10, 2), 2) AS TOTAL"
                        + " FROM chengtou_data.bill_data WHERE billing_date BETWEEN ? AND ?",
                    new Object[] {
                      SMTStatic.toDate(jsonAST.getJson("start_time").asString()),
                      SMTStatic.toDate(jsonAST.getJson("end_time").asString())
                    });
            if (recs.getRecord(0).getString(0) == null)
              return "从"
                  + jsonAST.getJson("start_time").asString()
                  + "到"
                  + jsonAST.getJson("end_time").asString()
                  + "的"
                  + jsonAST.getJson("value_title").asString()
                  + "未查到任何数据";
            return "从"
                + jsonAST.getJson("start_time").asString()
                + "到"
                + jsonAST.getJson("end_time").asString()
                + "的"
                + jsonAST.getJson("value_title").asString()
                + "总计"
                + recs.getRecord(0).getString(0)
                + ("volume".equals(jsonAST.getJson("value_name").asString()) ? "吨" : "元");
          }
        }
      }
    UUID randomUuid = UUID.randomUUID();
    asynchronousList.put(fromUserName, randomUuid.toString());
    String reply;
    reply = getReply(reqContent);
    String textContent =
        "我正在思考哦~请稍等……\n<a href=\"weixin://bizmsgmenu?msgmenucontent=停止输出&msgmenuid=101\">停止输出</a>";
    String baiduTextContent = "我是百度\n<a href=\"www.baidu.com\">百度</a>";
    String result =
        getString(
            fromUserName,
            toUserName,
            createTimeStr,
            reply,
            null,
            "https://pics0.baidu.com/feed/d31b0ef41bd5ad6e9388c7f8d8eda0d4b6fd3c60.png@f_auto?token=3da8e06f44a46832a7d0f50fa9e92c34",
            "总书记引用的这些古语耐人寻味",
            "“没有规矩,不成方圆”、“己不正,焉能正人”、“尽小者大,慎微者著”……总书记引用的这些古语耐人寻味。",
            "https://www.baidu.com/s?&wd=%E6%80%BB%E4%B9%A6%E8%AE%B0%E5%BC%95%E7%94%A8%E7%9A%84%E8%BF%99%E4%BA%9B%E5%8F%A4%E8%AF%AD%E8%80%90%E4%BA%BA%E5%AF%BB%E5%91%B3",
            null);
    if (reqContent.equals("停止输出")) {
      asynchronousList.remove(fromUserName);
      _logger.info("用户停止输出");
      return null;
    }
    answer = llm.callWithMessage(null, question, tranReq);
    return answer;
  }
        // 异步调用 aiReplyToTheUserASecondTime
//    CompletableFuture.runAsync(
//        () -> {
//          try {
//            SMTAIServerRequest threadTranReq = new SMTAIServerRequest();
//            threadTranReq.initInstance(null, "", null);
//            String answer = callAIForAnswerQuestion(reqContent, threadTranReq); // Ai调用 返回结果
//            aiReplyToTheUserASecondTime(
//                answer, requestMap.get(FROM_USER_NAME), randomUuid.toString());
//          } catch (Exception e) {
//            _logger.error("aiReplyToTheUserASecondTime error", e);
//            asynchronousList.remove(request.getParameter("FromUserName"));
//          }
//        });
        _logger.info("微信消息返参:" + result);
        // 返回 XML 字符串
        return tranReq.returnText(result);
    }
//  public ModelAndView weChatTest(SMTAIServerRequest tranReq) throws Exception {
//    String question = "冰箱是什么";
//    tranReq.setAIQuestion(question);
//    tranReq.setTextResultMode();
//
//    Set<String> setAgentGroup = new HashSet<>();
//    String weixinGroupId =
//        (String) SMTAIServerApp.getApp().getGlobalConfig("weixin.group_id", false);
//    setAgentGroup.add(weixinGroupId);
//
//    SMTJavaAIChat.questionChat("aliyun", setAgentGroup, "业务场景", question, null, false, tranReq);
//    String text = tranReq.getTextResult();
//
//    return tranReq.returnText(text);
//  }
  /** ai回复 */
//  private String callAIForAnswerQuestion(String question, SMTAIServerRequest tranReq)
//      throws Exception {
//    tranReq.setAIQuestion(question);
//    tranReq.setTextResultMode();
//
//    Set<String> setAgentGroup = new HashSet<>();
//    String weixinGroupId =
//        (String) SMTAIServerApp.getApp().getGlobalConfig("weixin.group_id", false);
//    setAgentGroup.add(weixinGroupId);
//
//    SMTJavaAIChat.questionChat("aliyun", setAgentGroup, "业务场景", question, null, false, tranReq);
//    return tranReq.getTextResult();
//  }
  /** 二次回复 */
  public ModelAndView aiReplyToTheUserASecondTime(String answer, String fromUserName)
  public void aiReplyToTheUserASecondTime(String answer, String fromUserName, String abortID)
      throws Exception {
    String accessToken = getAccessToken();
    if (answer.isEmpty()) answer = "抱歉,我暂时无法理解您的问题。";
    SMTJsonWriter jsonWr = new SMTJsonWriter(false);
    jsonWr.addKeyValue("touser", fromUserName);
    jsonWr.addKeyValue("msgtype", "text");
@@ -191,78 +146,201 @@
    String url =
        String.format(
            "https://api.weixin.qq.com/cgi-bin/message/custom/send?access_token=%s", accessToken);
    String s = sendPost(url, jsonWr.getRootJson());
    _logger.info("上传结果: {}", s);
    return null;
    if (abortID.equals(asynchronousList.get(fromUserName))) {
      String s = SMTWXSStatic.sendPost(url, jsonWr.getRootJson());
      asynchronousList.remove(fromUserName);
      _logger.info("上传结果: : " + s);
    } else {
      asynchronousList.remove(fromUserName);
      _logger.info("异步调用被取消");
    }
  }
  /** 数据库获取 access_koen */
  public String getAccessToken() throws Exception {
    synchronized (this._lockToken) {
      SMTDatabase db = SMTAIServerApp.getApp().allocDatabase();
      try {
        HashMap<String, String> weixinParam = getWeixinParam();
        long ONE_HOUR_IN_MILLIS = 3600 * 1000;
        long expiresTime = System.currentTimeMillis() + ONE_HOUR_IN_MILLIS;
        Timestamp expiresTimestamp = new Timestamp(expiresTime);
        String appId = weixinParam.get("appId");
      //      try (SMTDatabase db = SMTAIServerApp.getApp().allocDatabase()) {
      HashMap<String, String> weixinParam = SMTWXSStatic.getWeixinParam();
      Date curTime = new Date();
      String appId = weixinParam.get("appId");
        // 查询未过期的 access_token
        DBRecords dbRecord =
            db.querySQL(
                "SELECT app_id, access_token, expires_time "
                    + "FROM ai_weixin_token "
                    + "WHERE app_id = ? "
                    + "  AND expires_time > ?",
                new Object[] {appId, expiresTimestamp});
      // 查询未过期的 access_token
      DBRecords dbRecord =
          db.querySQL(
              "SELECT app_id, access_token, expires_time "
                  + "FROM ai_weixin_token "
                  + "WHERE app_id = ? "
                  + "  AND ? < expires_time",
              new Object[] {appId, curTime});
        // 数据库无记录,从微信服务器获取 access_token
        List<SMTDatabase.DBRecord> records = dbRecord.getRecords();
        boolean res = false;
        if (dbRecord.getRowCount() != 0) {
          String expires_time = records.get(0).getString("expires_time");
          long dbExpiresTime = Timestamp.valueOf(expires_time).getTime();
          res = System.currentTimeMillis() <= dbExpiresTime;
        }
        // 微信取,返回token并且保存或覆盖数据
        if (dbRecord.getRowCount() == 0 || !res) {
          String accessToken = fetchAccessTokenFromWeixinServer(); // 从微信服务器获取 access_token
          String sql =
              "INSERT INTO ai_weixin_token (app_id, access_token, expires_time) "
                  + "VALUES (?, ?, ?) "
                  + "ON CONFLICT (app_id) "
                  + // 如果 app_id 冲突
                  "DO UPDATE SET "
                  + "    access_token = EXCLUDED.access_token, "
                  + "    expires_time = EXCLUDED.expires_time;";
          db.executeSQL(sql, new Object[] {appId, accessToken, expiresTimestamp});
          return accessToken;
        } else { // 直接拿数据库accesstoken
          return records.get(0).getString("access_token");
        }
      } catch (Exception e) {
        throw new Exception("Failed to get access token: " + e);
      } finally {
        if (db != null) {
          db.close();
        }
      // 数据库无记录,从微信服务器获取 access_token
      List<SMTDatabase.DBRecord> records = dbRecord.getRecords();
      if (dbRecord.getRowCount() > 0) {
        return records.get(0).getString("access_token");
      }
      // 微信取,返回token并且保存或覆盖数据
      else {
        Object[] accessToken = fetchAccessTokenFromWeixinServer(); // 从微信服务器获取 access_token
        Date expiresTime =
            SMTStatic.calculateTime(
                curTime, SMTStatic.SMTCalcTime.ADD_SECOND, ((int) accessToken[1]) / 2);
        String sql =
            "INSERT INTO ai_weixin_token (app_id, access_token, expires_time) "
                + "VALUES (?, ?, ?) "
                + "ON CONFLICT (app_id) "
                + // 如果 app_id 冲突
                "DO UPDATE SET "
                + "    access_token = EXCLUDED.access_token, "
                + "    expires_time = EXCLUDED.expires_time;";
        db.executeSQL(sql, new Object[] {appId, accessToken[0], expiresTime});
        return (String) accessToken[0];
      }
      //      } catch (Exception e) {
      //        throw new Exception("Failed to get access token", e);
      //
      //      }
    }
  }
  private String fetchAccessTokenFromWeixinServer() throws Exception {
    HashMap<String, String> weixinParam = getWeixinParam();
    String url =
        "https://api.weixin.qq.com/cgi-bin/token?grant_type=client_credential&appid="
            + weixinParam.get("appId")
            + "&secret="
            + weixinParam.get("secret");
    String response = HttpUtil.get(url);
    Json json = Json.read(response);
  /** 从微信服务器获取 access_token */
  private Object[] fetchAccessTokenFromWeixinServer() throws Exception {
    HashMap<String, String> weixinParam = SMTWXSStatic.getWeixinParam();
    OkHttpClient okHttpClient = new OkHttpClient();
    // 创建请求
    Request request =
        new Request.Builder()
            .url(
                "https://api.weixin.qq.com/cgi-bin/token?grant_type=client_credential&appid="
                    + weixinParam.get("appId")
                    + "&secret="
                    + weixinParam.get("secret")) // 请求URL
            .get() // 使用GET方法
            .build();
    Response response = okHttpClient.newCall(request).execute();
    if (!response.isSuccessful() || response.body() == null) {
      throw new Exception("can't get weixin token");
    }
    Json json = Json.read(response.body().string());
    String accessToken = json.safeGetStr("access_token", null);
    String expiresIn = json.safeGetStr("expires_in", null);
    if (accessToken != null) {
      return accessToken;
      return new Object[] {accessToken, SMTStatic.toInt(expiresIn)};
    } else {
      throw new Exception("can't get weixin token : " + json);
      _logger.info("can't get weixin token : " + json);
      return null;
      //      throw new Exception("can't get weixin token : " + json);
    }
  }
    @NotNull
    private static String getString(String fromUserName, String toUserName, String createTimeStr, String msgType, String content, String mediaId, String title, String description, String musicUrl, String hqMusicUrl) {
        StringBuilder xmlBuilder = new StringBuilder();
        xmlBuilder.append("<xml>\n")
                .append("  <ToUserName><![CDATA[").append(fromUserName).append("]]></ToUserName>\n")
                .append("  <FromUserName><![CDATA[").append(toUserName).append("]]></FromUserName>\n")
                .append("  <CreateTime>").append(createTimeStr).append("</CreateTime>\n")
                .append("  <MsgType><![CDATA[").append(msgType).append("]]></MsgType>\n");
        switch (msgType) {
            case "text":
                xmlBuilder.append("  <Content><![CDATA[").append(content).append("]]></Content>\n");
                break;
            case "image":
                xmlBuilder.append("  <Image><MediaId><![CDATA[").append(mediaId).append("]]></MediaId></Image>\n");
                break;
            case "voice":
                xmlBuilder.append("  <Voice><MediaId><![CDATA[").append(mediaId).append("]]></MediaId></Voice>\n");
                break;
            case "video":
                xmlBuilder.append("  <Video>\n")
                        .append("    <MediaId><![CDATA[").append(mediaId).append("]]></MediaId>\n")
                        .append("    <Title><![CDATA[").append(title).append("]]></Title>\n")
                        .append("    <Description><![CDATA[").append(description).append("]]></Description>\n")
                        .append("  </Video>\n");
                break;
            case "music":
                xmlBuilder.append("  <Music>\n")
                        .append("    <Title><![CDATA[").append(title).append("]]></Title>\n")
                        .append("    <Description><![CDATA[").append(description).append("]]></Description>\n")
                        .append("    <MusicUrl><![CDATA[").append(musicUrl).append("]]></MusicUrl>\n")
                        .append("    <HQMusicUrl><![CDATA[").append(hqMusicUrl).append("]]></HQMusicUrl>\n")
                        .append("    <ThumbMediaId><![CDATA[").append(mediaId).append("]]></ThumbMediaId>\n")
                        .append("  </Music>\n");
                break;
            case "news": // 图文消息
                xmlBuilder.append("  <ArticleCount>2</ArticleCount>\n")
                        .append("  <Articles>\n")
                        .append("    <item>\n")
                        .append("      <Title><![CDATA[").append(title).append("]]></Title>\n")
                        .append("      <Description><![CDATA[").append(description).append("]]></Description>\n")
                        .append("      <PicUrl><![CDATA[").append(mediaId).append("]]></PicUrl>\n")
                        .append("      <Url><![CDATA[").append(musicUrl).append("]]></Url>\n")
                        .append("    </item>\n")
                        .append("  </Articles>\n");
                break;
            default:
                xmlBuilder.append("  <Content><![CDATA[未知的消息类型]]></Content>\n");
        }
        xmlBuilder.append("</xml>");
        return xmlBuilder.toString();
    }
    /**
     * 用于测试,正常可以不用
     *
     * @param reqContent
     * @return
     */
    @NotNull
    private static String getReply(String reqContent) {
        String reply;
        switch (reqContent) {
            case "文字":
                reply = "text";
                break;
            case "图片":
                reply = "image";
                break;
            case "语音":
                reply = "voice";
                break;
            case "视频":
                reply = "video";
                break;
            case "音乐":
                reply = "music";
                break;
            case "图文":
                reply = "news";
                break;
            default:
                reply = "text";
                break;
        }
        return reply;
    }
    private static String dissuadeReturn(
            String fromUserName, String toUserName, String createTimeStr) {
        String xmltemp =
                "<xml>\n"
                        + "  <ToUserName><![CDATA[{{{toUser}}}]]></ToUserName>\n"
                        + "  <FromUserName><![CDATA[{{{fromUser}}}]]></FromUserName>\n"
                        + "  <CreateTime>{{{CreateTime}}}</CreateTime>\n"
                        + "  <MsgType><![CDATA[text]]></MsgType>\n"
                        + "  <Content><![CDATA[上一条消息还在加载中哦,请稍等或点击下次停止上一轮回复\n<a href=\"weixin://bizmsgmenu?msgmenucontent=停止输出&msgmenuid=101\">停止输出</a>]]></Content>\n"
                        + "  <Content><![CDATA[]]></Content>\n"
                        + "</xml>";
        // 替换占位符
        return xmltemp
                .replace("{{{toUser}}}", fromUserName)
                .replace("{{{fromUser}}}", toUserName)
                .replace("{{{CreateTime}}}", createTimeStr);
    }
}