package com.smtaiserver.smtaiserver.control; import com.smtaiserver.smtaiserver.core.SMTAIServerApp; import com.smtaiserver.smtaiserver.database.SMTDatabase; import com.smtservlet.core.SMTRequest; import com.smtservlet.util.Json; import com.smtservlet.util.SMTJsonWriter; import com.smtservlet.util.SMTStatic; import java.io.UnsupportedEncodingException; import java.net.URLEncoder; import java.time.LocalDateTime; import java.time.format.DateTimeFormatter; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.concurrent.TimeUnit; import javax.servlet.http.HttpServletResponse; import okhttp3.MediaType; import okhttp3.MultipartBody; import okhttp3.OkHttpClient; import okhttp3.Request; import okhttp3.RequestBody; import okhttp3.Response; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.multipart.MultipartFile; import org.springframework.web.servlet.ModelAndView; /** lightRag controller */ public class SMTLightRAGController { private static Logger _logger = LogManager.getLogger(SMTLightRAGController.class); /** * 获取lightrag服务的启动列表 * * @param tranReq * @return * @throws Exception */ public ModelAndView getLightragServerList(SMTRequest tranReq) throws Exception { SMTDatabase db = SMTAIServerApp.getApp().allocDatabase(); SMTJsonWriter jsonWr = tranReq.newReturnJsonWriter(true, null, null); jsonWr.beginArray("values"); try { SMTDatabase.DBRecords records = db.querySQL("SELECT * FROM lightrag_server_list where is_enable = 'Y'", null); if (records.getRowCount() == 0) return tranReq.returnJsonState(true, null, null); for (int i = 0; i < records.getRowCount(); i++) { jsonWr.beginMap(null); { SMTDatabase.DBRecord record = records.getRecord(i); Object serverId = record.getValue("server_id"); Object serverPort = record.getValue("server_port"); jsonWr.addKeyValue("server_id", serverId); jsonWr.addKeyValue("server_title", record.getValue("server_title")); jsonWr.addKeyValue("server_port", serverPort); jsonWr.addKeyValue("is_enable", record.getValue("is_enable")); // 获取分组 SMTDatabase.DBRecords dbRecords = db.querySQL("SELECT clz_arguments ,agent_group FROM ai_agent_amis", null); for (SMTDatabase.DBRecord dbRecordsRecord : dbRecords.getRecords()) { Object clzArguments = dbRecordsRecord.getValue("clz_arguments"); boolean exist = checkServerIdInString(clzArguments, serverId); if (exist) { String agentGroup = dbRecordsRecord.getString("agent_group"); jsonWr.addKeyValue("agent_group", agentGroup); break; } } // 知识库内文档数量查询内部接口 Json responeJson = sendRequest("/documents", (Integer) serverPort, "GET", null); Json statuses = responeJson.getJson("statuses"); List processed = statuses.safeGetJsonList("processed"); if (processed != null) { int documentsSize = processed.size(); jsonWr.addKeyValue("documents_count", documentsSize); } else { jsonWr.addKeyValue("documents_count", 0); } // 获取上一次创建新文档的时间 if (processed != null) { String latestUpdateTime = findLatestUpdateTime(processed); LocalDateTime dateTime = LocalDateTime.parse(latestUpdateTime, DateTimeFormatter.ISO_LOCAL_DATE_TIME); DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"); String formattedTime = dateTime.format(formatter); jsonWr.addKeyValue("update_time", formattedTime); } else { jsonWr.addKeyValue("update_time", null); } } jsonWr.endMap(); } return tranReq.returnJson(jsonWr); } finally { db.close(); } } public ModelAndView updateLightragServerEnable(SMTRequest tranReq) throws Exception { SMTDatabase db = SMTAIServerApp.getApp().allocDatabase(); try { // 从请求中获取参数 String serverId = tranReq.convParamToString("server_id", true); String isEnable = tranReq.convParamToString("is_enable", true); // 参数校验 if (serverId == null || isEnable == null) { return tranReq.returnJsonState(false, "参数缺失: server_id 或 is_enable 不能为空", null); } // 执行更新 Map params = new HashMap<>(); params.put("server_id", serverId); params.put("is_enable", isEnable); int affectedRows = db.executeSQL( "UPDATE lightrag_server_list SET is_enable = ? WHERE server_id = ?", new Object[] {isEnable, serverId}); if (affectedRows == 0) { return tranReq.returnJsonState(false, "未找到对应的 server_id,更新失败", null); } // 成功返回 return tranReq.returnJsonState(true, null, null); } catch (Exception e) { throw new Exception("updateLightragServerEnable error: " + e); } finally { db.close(); } } public ModelAndView addLightragServer(SMTRequest tranReq) throws Exception { SMTDatabase db = SMTAIServerApp.getApp().allocDatabase(); try { // 从请求中获取参数 String serverId = tranReq.convParamToString("server_id", true); String serverTitle = tranReq.convParamToString("server_title", true); Integer serverPort = tranReq.convParamToInteger("server_port", true); String isEnable = tranReq.convParamToString("is_enable", false); if (SMTStatic.isNullOrEmpty(isEnable)) isEnable = "Y"; // 参数校验 if (serverId == null || serverTitle == null || serverPort == null) { return tranReq.returnJsonState( false, "参数缺失: server_id、server_title、server_port、is_enable 不能为空", null); } // 先查询 server_id 是否已存在 SMTDatabase.DBRecords existingRecords = db.querySQL( "SELECT server_id FROM lightrag_server_list WHERE server_id = ?", new Object[] {serverId}); if (existingRecords.getRowCount() > 0) { return tranReq.returnJsonState(false, "server_id 已存在,不能重复新增", null); } // 执行插入 int affectedRows = db.executeSQL( "INSERT INTO lightrag_server_list (server_id, server_title, server_port, is_enable) VALUES (?, ?, ?, ?)", new Object[] {serverId, serverTitle, serverPort, isEnable}); if (affectedRows == 0) { return tranReq.returnJsonState(false, "新增失败", null); } // 启动服务 SMTAIServerApp.getApp().getLightragServer(serverId); // 成功返回 return tranReq.returnJsonState(true, null, null); } catch (Exception e) { throw new Exception("addLightragServer error: " + e); } finally { db.close(); } } /** * 通用转发接口 lightrag * * @param tranReq * @return * @throws Exception */ public ModelAndView genericInterfaceForwarding(SMTRequest tranReq) throws Exception { String serverUrl = tranReq.convParamToString("server_url", true); String serverPort = tranReq.convParamToString("server_port", true); Json paramJson = tranReq.convParamToJson("param_json", false); String methodType = tranReq.convParamToString("method_type", true); OkHttpClient okHttpClient = new OkHttpClient.Builder() .readTimeout(0, TimeUnit.SECONDS) // 不超时,支持流 .build(); HttpServletResponse response = tranReq.getResponse(); response.setContentType("application/json"); response.setCharacterEncoding("UTF-8"); String fullUrl = "http://localhost:" + serverPort + serverUrl; Request.Builder requestBuilder = new Request.Builder(); if ("GET".equalsIgnoreCase(methodType)) { // GET请求,把paramJson拼接到URL参数上 fullUrl = appendParamsToUrl(fullUrl, paramJson); requestBuilder.url(fullUrl).get(); } else if ("POST".equalsIgnoreCase(methodType)) { // POST请求,paramJson放body RequestBody body = RequestBody.create(MediaType.parse("application/json"), paramJson.toString()); requestBuilder.url(fullUrl).post(body); } else { throw new IllegalArgumentException("不支持的请求类型: " + methodType); } Request request = requestBuilder.build(); try (Response lightRagResp = okHttpClient.newCall(request).execute()) { if (lightRagResp.body() == null) { return tranReq.returnJsonState(false, "请求失败", null); } Json read = Json.read(lightRagResp.body().string()); return tranReq.returnJsonState(true, "", read); } catch (Exception e) { throw new Exception("genericInterfaceForwarding error: " + e.getMessage(), e); } } /** * 上传文档至对应知识库中 * * @param file * @param res * @return * @throws Exception */ public ModelAndView uploadFileInLightRAG( @RequestParam(value = "file", required = false) MultipartFile file, SMTRequest res) throws Exception { String port = res.convParamToString("server_port", true); OkHttpClient client = new OkHttpClient(); RequestBody fileBody = RequestBody.create( MediaType.parse(Objects.requireNonNull(file.getContentType())), file.getBytes()); MultipartBody requestBody = new MultipartBody.Builder() .setType(MultipartBody.FORM) .addFormDataPart("file", file.getOriginalFilename(), fileBody) // 使用传入的file .build(); Request request = new Request.Builder() .url("http://localhost:" + port + "/documents/upload") // 目标URL .addHeader("Accept", "application/json, text/plain, */*") .addHeader("Accept-Language", "zh-CN,zh;q=0.9,ak;q=0.8") // .addHeader( // "Authorization", // "Bearer // eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJndWVzdCIsImV4cCI6MTc0NTgxNzE1Nywicm9sZSI6Imd1ZXN0IiwibWV0YXRhY29kZSI6eyJhdXRoX21vZGUiOiJkaXNhYmxlZCJ9fQ.6V-k4FG_RmwQfg3-ivNLNyF7AntEUsyJf4yNFPiudLQ") // 替换成你的Authorization token .addHeader("Content-Type", "multipart/form-data") // 必须是 multipart/form-data .addHeader("Connection", "keep-alive") .addHeader("Origin", "http://localhost:" + port) .addHeader("Referer", "http://localhost:" + port + "/webui/") .addHeader("Sec-Fetch-Dest", "empty") .addHeader("Sec-Fetch-Mode", "cors") .addHeader("Sec-Fetch-Site", "same-origin") .addHeader( "User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/135.0.0.0 Safari/537.36") .post(requestBody) // 使用POST方法 .build(); Response response = client.newCall(request).execute(); Json read = null; if (response.body() != null) { read = Json.read(response.body().string()); } if (response.isSuccessful()) { return res.returnJsonState(true, "请求成功", read); } else { return res.returnJsonState(false, "请求失败", read); } } /** 辅助方法:把Json参数拼接到URL */ private String appendParamsToUrl(String url, Json paramJson) { if (paramJson == null || paramJson.isNull()) { return url; } StringBuilder sb = new StringBuilder(url); if (!url.contains("?")) { sb.append("?"); } else if (!url.endsWith("&") && !url.endsWith("?")) { sb.append("&"); } Map paramMap = paramJson.asMap(); for (Map.Entry entry : paramMap.entrySet()) { String key = entry.getKey(); String value = entry.getValue() != null ? entry.getValue().toString() : ""; try { sb.append(URLEncoder.encode(key, "UTF-8")) .append("=") .append(URLEncoder.encode(value, "UTF-8")) .append("&"); } catch (UnsupportedEncodingException e) { // 正常不会走到这里,忽略编码异常 sb.append(key).append("=").append(value).append("&"); } } // 删除最后一个多余的& if (sb.charAt(sb.length() - 1) == '&') { sb.deleteCharAt(sb.length() - 1); } return sb.toString(); } public boolean checkServerIdInString(Object clzArgumentsObj, Object serverIdObj) { if (clzArgumentsObj == null || serverIdObj == null) { return false; } String clzArguments = clzArgumentsObj.toString(); String serverId = serverIdObj.toString(); String searchStr = "server_id=\"" + serverId + "\""; return clzArguments.contains(searchStr); } /** 内部方法:通用转发请求 */ private Json sendRequest(String serverUrl, Integer serverPort, String methodType, Json paramJson) throws Exception { OkHttpClient okHttpClient = new OkHttpClient.Builder() .readTimeout(0, TimeUnit.SECONDS) // 不超时,支持流 .build(); String fullUrl = "http://localhost:" + serverPort + serverUrl; Request.Builder requestBuilder = new Request.Builder(); if ("GET".equalsIgnoreCase(methodType)) { fullUrl = appendParamsToUrl(fullUrl, paramJson); requestBuilder.url(fullUrl).get(); } else if ("POST".equalsIgnoreCase(methodType)) { RequestBody body = RequestBody.create(MediaType.parse("application/json"), paramJson.toString()); requestBuilder.url(fullUrl).post(body); } else { throw new IllegalArgumentException("不支持的请求类型: " + methodType); } Request request = requestBuilder.build(); try (Response lightRagResp = okHttpClient.newCall(request).execute()) { if (lightRagResp.body() == null) { throw new Exception("请求返回空 body"); } return Json.read(lightRagResp.body().string()); } catch (Exception e) { throw new Exception("sendRequest error: " + e.getMessage(), e); } } /** * 找到 processed 列表里 updated_at 最大的时间 * * @param processed Json对象列表 * @return 最新的 updated_at 字符串,找不到返回 null */ public static String findLatestUpdateTime(List processed) { String latestTime = null; for (Json one : processed) { String updatedAt = one.safeGetStr("updated_at", null); if (updatedAt == null) { continue; } if (latestTime == null || updatedAt.compareTo(latestTime) > 0) { latestTime = updatedAt; } } return latestTime; } /** * 转义字符串里的反斜杠(\),变成双反斜杠(\\) * * @param input * @return */ public String escapeBackslashes(String input) { if (input == null) { return null; } return input.replace("\\", "\\\\"); } }