| | |
| | | import com.smtaiserver.smtaiserver.core.SMTAIServerApp; |
| | | import com.smtaiserver.smtaiserver.database.SMTDatabase; |
| | | import com.smtservlet.core.SMTRequest; |
| | | import com.smtservlet.util.Json; |
| | | import com.smtservlet.util.SMTJsonWriter; |
| | | import com.smtservlet.util.SMTStatic; |
| | | import java.io.UnsupportedEncodingException; |
| | | import java.net.URLEncoder; |
| | | import java.time.LocalDateTime; |
| | | import java.time.format.DateTimeFormatter; |
| | | import java.util.HashMap; |
| | | import java.util.List; |
| | | import java.util.Map; |
| | | import java.util.Objects; |
| | | import java.util.concurrent.TimeUnit; |
| | | import javax.servlet.http.HttpServletResponse; |
| | | import okhttp3.MediaType; |
| | | import okhttp3.MultipartBody; |
| | | import okhttp3.OkHttpClient; |
| | | import okhttp3.Request; |
| | | import okhttp3.RequestBody; |
| | | import okhttp3.Response; |
| | | import org.apache.logging.log4j.LogManager; |
| | | import org.apache.logging.log4j.Logger; |
| | | import org.springframework.web.bind.annotation.RequestParam; |
| | | import org.springframework.web.multipart.MultipartFile; |
| | | import org.springframework.web.servlet.ModelAndView; |
| | | |
| | | import java.util.HashMap; |
| | | import java.util.Map; |
| | | |
| | | /** lightRag controller */ |
| | | public class SMTLightRAGController { |
| | |
| | | */ |
| | | public ModelAndView getLightragServerList(SMTRequest tranReq) throws Exception { |
| | | SMTDatabase db = SMTAIServerApp.getApp().allocDatabase(); |
| | | SMTJsonWriter jsonWr = tranReq.newReturnJsonWriter(true, null, null); |
| | | jsonWr.beginArray("values"); |
| | | |
| | | try { |
| | | SMTDatabase.DBRecords records = db.querySQL("SELECT * FROM lightrag_server_list", null); |
| | | SMTDatabase.DBRecords records = |
| | | db.querySQL("SELECT * FROM lightrag_server_list where is_enable = 'Y'", null); |
| | | if (records.getRowCount() == 0) return tranReq.returnJsonState(true, null, null); |
| | | SMTJsonWriter jsonWr = tranReq.newReturnJsonWriter(true, null, null); |
| | | for (int i = 0; i < records.getRowCount(); i++) { |
| | | SMTDatabase.DBRecord record = records.getRecord(i); |
| | | jsonWr.addKeyValue("server_id", record.getValue("server_id")); |
| | | jsonWr.addKeyValue("server_title", record.getValue("server_title")); |
| | | jsonWr.addKeyValue("server_port", record.getValue("server_port")); |
| | | jsonWr.addKeyValue("is_enable", record.getValue("is_enable")); |
| | | jsonWr.beginMap(null); |
| | | { |
| | | SMTDatabase.DBRecord record = records.getRecord(i); |
| | | Object serverId = record.getValue("server_id"); |
| | | Object serverPort = record.getValue("server_port"); |
| | | jsonWr.addKeyValue("server_id", serverId); |
| | | jsonWr.addKeyValue("server_title", record.getValue("server_title")); |
| | | jsonWr.addKeyValue("server_port", serverPort); |
| | | jsonWr.addKeyValue("is_enable", record.getValue("is_enable")); |
| | | |
| | | // 获取分组 |
| | | SMTDatabase.DBRecords dbRecords = |
| | | db.querySQL("SELECT clz_arguments ,agent_group FROM ai_agent_amis", null); |
| | | for (SMTDatabase.DBRecord dbRecordsRecord : dbRecords.getRecords()) { |
| | | Object clzArguments = dbRecordsRecord.getValue("clz_arguments"); |
| | | boolean exist = checkServerIdInString(clzArguments, serverId); |
| | | if (exist) { |
| | | String agentGroup = dbRecordsRecord.getString("agent_group"); |
| | | jsonWr.addKeyValue("agent_group", agentGroup); |
| | | break; |
| | | } |
| | | } |
| | | |
| | | // 知识库内文档数量查询内部接口 |
| | | Json responeJson = sendRequest("/documents", (Integer) serverPort, "GET", null); |
| | | Json statuses = responeJson.getJson("statuses"); |
| | | List<Json> processed = statuses.safeGetJsonList("processed"); |
| | | if (processed != null) { |
| | | int documentsSize = processed.size(); |
| | | jsonWr.addKeyValue("documents_count", documentsSize); |
| | | } else { |
| | | jsonWr.addKeyValue("documents_count", 0); |
| | | } |
| | | |
| | | // 获取上一次创建新文档的时间 |
| | | if (processed != null) { |
| | | String latestUpdateTime = findLatestUpdateTime(processed); |
| | | LocalDateTime dateTime = |
| | | LocalDateTime.parse(latestUpdateTime, DateTimeFormatter.ISO_LOCAL_DATE_TIME); |
| | | DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"); |
| | | String formattedTime = dateTime.format(formatter); |
| | | jsonWr.addKeyValue("update_time", formattedTime); |
| | | } else { |
| | | jsonWr.addKeyValue("update_time", null); |
| | | } |
| | | } |
| | | jsonWr.endMap(); |
| | | } |
| | | |
| | | return tranReq.returnJson(jsonWr); |
| | | } catch (Exception e) { |
| | | throw new Exception("getLightragServerList error" + e); |
| | | } finally { |
| | | db.close(); |
| | | } |
| | |
| | | // 从请求中获取参数 |
| | | String serverId = tranReq.convParamToString("server_id", true); |
| | | String serverTitle = tranReq.convParamToString("server_title", true); |
| | | String serverPort = tranReq.convParamToString("server_port", true); |
| | | String isEnable = tranReq.convParamToString("is_enable", true); |
| | | |
| | | Integer serverPort = tranReq.convParamToInteger("server_port", true); |
| | | String isEnable = tranReq.convParamToString("is_enable", false); |
| | | if (SMTStatic.isNullOrEmpty(isEnable)) isEnable = "Y"; |
| | | // 参数校验 |
| | | if (serverId == null || serverTitle == null || serverPort == null || isEnable == null) { |
| | | if (serverId == null || serverTitle == null || serverPort == null) { |
| | | return tranReq.returnJsonState( |
| | | false, "参数缺失: server_id、server_title、server_port、is_enable 不能为空", null); |
| | | } |
| | |
| | | if (affectedRows == 0) { |
| | | return tranReq.returnJsonState(false, "新增失败", null); |
| | | } |
| | | // 启动服务 |
| | | SMTAIServerApp.getApp().getLightragServer(serverId); |
| | | |
| | | // 成功返回 |
| | | return tranReq.returnJsonState(true, null, null); |
| | |
| | | db.close(); |
| | | } |
| | | } |
| | | |
| | | /** |
| | | * 通用转发接口 lightrag |
| | | * |
| | | * @param tranReq |
| | | * @return |
| | | * @throws Exception |
| | | */ |
| | | public ModelAndView genericInterfaceForwarding(SMTRequest tranReq) throws Exception { |
| | | String serverUrl = tranReq.convParamToString("server_url", true); |
| | | String serverPort = tranReq.convParamToString("server_port", true); |
| | | Json paramJson = tranReq.convParamToJson("param_json", false); |
| | | String methodType = tranReq.convParamToString("method_type", true); |
| | | |
| | | OkHttpClient okHttpClient = |
| | | new OkHttpClient.Builder() |
| | | .readTimeout(0, TimeUnit.SECONDS) // 不超时,支持流 |
| | | .build(); |
| | | |
| | | HttpServletResponse response = tranReq.getResponse(); |
| | | response.setContentType("application/json"); |
| | | response.setCharacterEncoding("UTF-8"); |
| | | |
| | | String fullUrl = "http://localhost:" + serverPort + serverUrl; |
| | | Request.Builder requestBuilder = new Request.Builder(); |
| | | |
| | | if ("GET".equalsIgnoreCase(methodType)) { |
| | | // GET请求,把paramJson拼接到URL参数上 |
| | | fullUrl = appendParamsToUrl(fullUrl, paramJson); |
| | | requestBuilder.url(fullUrl).get(); |
| | | } else if ("POST".equalsIgnoreCase(methodType)) { |
| | | // POST请求,paramJson放body |
| | | RequestBody body = |
| | | RequestBody.create(MediaType.parse("application/json"), paramJson.toString()); |
| | | requestBuilder.url(fullUrl).post(body); |
| | | } else { |
| | | throw new IllegalArgumentException("不支持的请求类型: " + methodType); |
| | | } |
| | | |
| | | Request request = requestBuilder.build(); |
| | | |
| | | try (Response lightRagResp = okHttpClient.newCall(request).execute()) { |
| | | if (lightRagResp.body() == null) { |
| | | return tranReq.returnJsonState(false, "请求失败", null); |
| | | } |
| | | Json read = Json.read(lightRagResp.body().string()); |
| | | return tranReq.returnJsonState(true, "", read); |
| | | } catch (Exception e) { |
| | | throw new Exception("genericInterfaceForwarding error: " + e.getMessage(), e); |
| | | } |
| | | } |
| | | |
| | | /** |
| | | * 上传文档至对应知识库中 |
| | | * |
| | | * @param file |
| | | * @param res |
| | | * @return |
| | | * @throws Exception |
| | | */ |
| | | public ModelAndView uploadFileInLightRAG( |
| | | @RequestParam(value = "file", required = false) MultipartFile file, SMTRequest res) |
| | | throws Exception { |
| | | String port = res.convParamToString("server_port", true); |
| | | OkHttpClient client = new OkHttpClient(); |
| | | RequestBody fileBody = |
| | | RequestBody.create( |
| | | MediaType.parse(Objects.requireNonNull(file.getContentType())), file.getBytes()); |
| | | MultipartBody requestBody = |
| | | new MultipartBody.Builder() |
| | | .setType(MultipartBody.FORM) |
| | | .addFormDataPart("file", file.getOriginalFilename(), fileBody) // 使用传入的file |
| | | .build(); |
| | | Request request = |
| | | new Request.Builder() |
| | | .url("http://localhost:" + port + "/documents/upload") // 目标URL |
| | | .addHeader("Accept", "application/json, text/plain, */*") |
| | | .addHeader("Accept-Language", "zh-CN,zh;q=0.9,ak;q=0.8") |
| | | // .addHeader( |
| | | // "Authorization", |
| | | // "Bearer |
| | | // eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJndWVzdCIsImV4cCI6MTc0NTgxNzE1Nywicm9sZSI6Imd1ZXN0IiwibWV0YXRhY29kZSI6eyJhdXRoX21vZGUiOiJkaXNhYmxlZCJ9fQ.6V-k4FG_RmwQfg3-ivNLNyF7AntEUsyJf4yNFPiudLQ") // 替换成你的Authorization token |
| | | .addHeader("Content-Type", "multipart/form-data") // 必须是 multipart/form-data |
| | | .addHeader("Connection", "keep-alive") |
| | | .addHeader("Origin", "http://localhost:" + port) |
| | | .addHeader("Referer", "http://localhost:" + port + "/webui/") |
| | | .addHeader("Sec-Fetch-Dest", "empty") |
| | | .addHeader("Sec-Fetch-Mode", "cors") |
| | | .addHeader("Sec-Fetch-Site", "same-origin") |
| | | .addHeader( |
| | | "User-Agent", |
| | | "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/135.0.0.0 Safari/537.36") |
| | | .post(requestBody) // 使用POST方法 |
| | | .build(); |
| | | |
| | | Response response = client.newCall(request).execute(); |
| | | Json read = null; |
| | | if (response.body() != null) { |
| | | read = Json.read(response.body().string()); |
| | | } |
| | | if (response.isSuccessful()) { |
| | | return res.returnJsonState(true, "请求成功", read); |
| | | } else { |
| | | return res.returnJsonState(false, "请求失败", read); |
| | | } |
| | | } |
| | | |
| | | /** 辅助方法:把Json参数拼接到URL */ |
| | | private String appendParamsToUrl(String url, Json paramJson) { |
| | | if (paramJson == null || paramJson.isNull()) { |
| | | return url; |
| | | } |
| | | StringBuilder sb = new StringBuilder(url); |
| | | if (!url.contains("?")) { |
| | | sb.append("?"); |
| | | } else if (!url.endsWith("&") && !url.endsWith("?")) { |
| | | sb.append("&"); |
| | | } |
| | | Map<String, Object> paramMap = paramJson.asMap(); |
| | | for (Map.Entry<String, Object> entry : paramMap.entrySet()) { |
| | | String key = entry.getKey(); |
| | | String value = entry.getValue() != null ? entry.getValue().toString() : ""; |
| | | try { |
| | | sb.append(URLEncoder.encode(key, "UTF-8")) |
| | | .append("=") |
| | | .append(URLEncoder.encode(value, "UTF-8")) |
| | | .append("&"); |
| | | } catch (UnsupportedEncodingException e) { |
| | | // 正常不会走到这里,忽略编码异常 |
| | | sb.append(key).append("=").append(value).append("&"); |
| | | } |
| | | } |
| | | // 删除最后一个多余的& |
| | | if (sb.charAt(sb.length() - 1) == '&') { |
| | | sb.deleteCharAt(sb.length() - 1); |
| | | } |
| | | return sb.toString(); |
| | | } |
| | | |
| | | public boolean checkServerIdInString(Object clzArgumentsObj, Object serverIdObj) { |
| | | if (clzArgumentsObj == null || serverIdObj == null) { |
| | | return false; |
| | | } |
| | | String clzArguments = clzArgumentsObj.toString(); |
| | | String serverId = serverIdObj.toString(); |
| | | |
| | | String searchStr = "server_id=\"" + serverId + "\""; |
| | | return clzArguments.contains(searchStr); |
| | | } |
| | | |
| | | /** 内部方法:通用转发请求 */ |
| | | private Json sendRequest(String serverUrl, Integer serverPort, String methodType, Json paramJson) |
| | | throws Exception { |
| | | OkHttpClient okHttpClient = |
| | | new OkHttpClient.Builder() |
| | | .readTimeout(0, TimeUnit.SECONDS) // 不超时,支持流 |
| | | .build(); |
| | | |
| | | String fullUrl = "http://localhost:" + serverPort + serverUrl; |
| | | Request.Builder requestBuilder = new Request.Builder(); |
| | | |
| | | if ("GET".equalsIgnoreCase(methodType)) { |
| | | fullUrl = appendParamsToUrl(fullUrl, paramJson); |
| | | requestBuilder.url(fullUrl).get(); |
| | | } else if ("POST".equalsIgnoreCase(methodType)) { |
| | | RequestBody body = |
| | | RequestBody.create(MediaType.parse("application/json"), paramJson.toString()); |
| | | requestBuilder.url(fullUrl).post(body); |
| | | } else { |
| | | throw new IllegalArgumentException("不支持的请求类型: " + methodType); |
| | | } |
| | | |
| | | Request request = requestBuilder.build(); |
| | | |
| | | try (Response lightRagResp = okHttpClient.newCall(request).execute()) { |
| | | if (lightRagResp.body() == null) { |
| | | throw new Exception("请求返回空 body"); |
| | | } |
| | | return Json.read(lightRagResp.body().string()); |
| | | } catch (Exception e) { |
| | | throw new Exception("sendRequest error: " + e.getMessage(), e); |
| | | } |
| | | } |
| | | |
| | | /** |
| | | * 找到 processed 列表里 updated_at 最大的时间 |
| | | * |
| | | * @param processed Json对象列表 |
| | | * @return 最新的 updated_at 字符串,找不到返回 null |
| | | */ |
| | | public static String findLatestUpdateTime(List<Json> processed) { |
| | | String latestTime = null; |
| | | for (Json one : processed) { |
| | | String updatedAt = one.safeGetStr("updated_at", null); |
| | | if (updatedAt == null) { |
| | | continue; |
| | | } |
| | | if (latestTime == null || updatedAt.compareTo(latestTime) > 0) { |
| | | latestTime = updatedAt; |
| | | } |
| | | } |
| | | return latestTime; |
| | | } |
| | | |
| | | /** |
| | | * 转义字符串里的反斜杠(\),变成双反斜杠(\\) |
| | | * |
| | | * @param input |
| | | * @return |
| | | */ |
| | | public String escapeBackslashes(String input) { |
| | | if (input == null) { |
| | | return null; |
| | | } |
| | | return input.replace("\\", "\\\\"); |
| | | } |
| | | } |