package com.smtaiserver.smtaiserver.javaai.llm.ollama; import java.util.ArrayList; import java.util.Date; import java.util.List; import java.util.concurrent.TimeUnit; import com.alibaba.dashscope.common.Message; import com.alibaba.dashscope.common.Role; import com.smtaiserver.smtaiserver.core.SMTAIServerRequest; import com.smtaiserver.smtaiserver.javaai.llm.core.SMTLLMConnect; import com.smtservlet.util.Json; import com.smtservlet.util.SMTJsonWriter; import com.smtservlet.util.SMTStatic; import okhttp3.Headers; import okhttp3.MediaType; import okhttp3.OkHttpClient; import okhttp3.Request; import okhttp3.RequestBody; import okhttp3.Response; public class SMTLLMConnectOllama extends SMTLLMConnect { private String _modelLLM; private String _modelVector; @SuppressWarnings("unused") private SMTLLMFactoryOllama _factory; private OkHttpClient _web; private String _baseUrl; private Headers _headers; public SMTLLMConnectOllama(SMTLLMFactoryOllama factory) { OkHttpClient.Builder builder = new OkHttpClient.Builder(); builder.connectTimeout(120, TimeUnit.SECONDS); builder.readTimeout(12000, TimeUnit.SECONDS); builder.writeTimeout(12000, TimeUnit.SECONDS); _web = builder.build(); _factory = factory; _baseUrl = factory.getFactoryJson().getJson("base_url").asString(); _modelLLM = factory.getConnectJson().getJson("model").asString(); _modelVector = "deepseek-chat"; _headers = new Headers.Builder() .add("Accept", "application/json") .build(); } @Override public void close() { } @SuppressWarnings("unchecked") @Override public String callWithMessage(Object listSysMsg, Object userMsg, SMTAIServerRequest tranReq) throws Exception { List messages = new ArrayList<>(); Message timeMsg = Message.builder().role(Role.SYSTEM.getValue()).content("问题中的所有未指定确切时间的全部以" + SMTStatic.toString(new Date()).substring(0, 10) +"作为当天日期进行推算。").build(); messages.add(timeMsg); if(listSysMsg != null) { if(listSysMsg instanceof List) { for(String sysMsg : (List)listSysMsg) { Message systemMsg = Message.builder().role(Role.SYSTEM.getValue()).content(sysMsg).build(); messages.add(systemMsg); } } else if(listSysMsg instanceof String[]) { for(String sysMsg : (String[])listSysMsg) { Message systemMsg = Message.builder().role(Role.SYSTEM.getValue()).content(sysMsg).build(); messages.add(systemMsg); } } else if(listSysMsg instanceof Message[]) { for(Message sysMsg : (Message[])listSysMsg) { messages.add(sysMsg); } } else if(listSysMsg instanceof String) { Message systemMsg = Message.builder().role(Role.SYSTEM.getValue()).content((String)listSysMsg).build(); messages.add(systemMsg); } else throw new Exception("unknow listSysMsg type"); } if(userMsg instanceof String) { Message userMessage = Message.builder().role(Role.USER.getValue()).content((String)userMsg).build(); messages.add(userMessage); } else { if(userMsg instanceof String[]) { for(String msg : (String[])userMsg) { messages.add(Message.builder().role(Role.USER.getValue()).content((String)msg).build()); } } } SMTJsonWriter jsonWr = new SMTJsonWriter(false); jsonWr.addKeyValue("model", _modelLLM); jsonWr.addKeyValue("stream", false); jsonWr.beginArray("messages"); for(Message message : messages) { jsonWr.beginMap(null); { jsonWr.addKeyValue("role", message.getRole()); jsonWr.addKeyValue("content", message.getContent()); } jsonWr.endMap(); } jsonWr.endArray(); // 创建一个请求头对象 System.out.println(jsonWr.getFullJson()); System.out.println(_baseUrl + "v1/chat/completions"); RequestBody body = RequestBody.create(MediaType.parse("application/json; charset=utf-8"), jsonWr.getFullJson()); Request request = new Request.Builder() .url(_baseUrl + "v1/chat/completions") // 替换为你的API端点 .headers(_headers) .post(body) .build(); Response response = _web.newCall(request).execute(); String sJsonResult = response.body().string(); System.out.println(sJsonResult); Json jsonResult = Json.read(sJsonResult); Json jsonContent = jsonResult.getJsonPath("choices|0|message|content", false); return jsonContent.asString(); } public String getVector(String text) throws Exception { SMTJsonWriter jsonWr = new SMTJsonWriter(false); jsonWr.addKeyValue("model", _modelVector); jsonWr.addKeyValue("input", text); // 创建一个请求头对象 RequestBody body = RequestBody.create(MediaType.parse("application/json; charset=utf-8"), jsonWr.getFullJson()); Request request = new Request.Builder() .url(_baseUrl + "v1/embeddings") // 替换为你的API端点 .headers(_headers) .post(body) .build(); Response response = _web.newCall(request).execute(); Json jsonResult = Json.read(response.body().string()).getJsonPath("data|0|embedding", false); StringBuilder sbResult = new StringBuilder(); for(Json jsonValue : jsonResult.asJsonList()) { if(sbResult.length() > 0) sbResult.append(","); sbResult.append(SMTStatic.toString(jsonValue.asDouble())); } return sbResult.toString(); // List list = new ArrayList(); // list.add(text); // TextEmbeddingParam param = TextEmbeddingParam // .builder() // .model(TextEmbedding.Models.TEXT_EMBEDDING_V1) // .texts(list).build(); // TextEmbedding textEmbedding = _factory.allocLLMEmbedding(); // try // { // TextEmbeddingResult result = textEmbedding.call(param); // List listResult = result.getOutput().getEmbeddings(); // // StringBuilder sbVector = new StringBuilder(); // for(Double v : listResult.get(0).getEmbedding()) // { // if(sbVector.length() > 0) // sbVector.append(", "); // sbVector.append(v); // // } // // return sbVector.toString(); // } // finally // { // _factory.freeLLMEmbedding(textEmbedding); // } // TODO SMTLLMConnectDMXApi的getVector未实现 //throw new Exception("SMTLLMConnectDMXApi的getVector未实现"); } }