package com.smtaiserver.smtaiserver.javaai.llm.ollama;
|
|
import java.util.ArrayList;
|
import java.util.Date;
|
import java.util.List;
|
import java.util.concurrent.TimeUnit;
|
|
import com.alibaba.dashscope.common.Message;
|
import com.alibaba.dashscope.common.Role;
|
import com.smtaiserver.smtaiserver.core.SMTAIServerRequest;
|
import com.smtaiserver.smtaiserver.javaai.llm.core.SMTLLMConnect;
|
import com.smtservlet.util.Json;
|
import com.smtservlet.util.SMTJsonWriter;
|
import com.smtservlet.util.SMTStatic;
|
|
import okhttp3.Headers;
|
import okhttp3.MediaType;
|
import okhttp3.OkHttpClient;
|
import okhttp3.Request;
|
import okhttp3.RequestBody;
|
import okhttp3.Response;
|
|
public class SMTLLMConnectOllama extends SMTLLMConnect {
|
|
private String _modelLLM;
|
private String _modelVector;
|
@SuppressWarnings("unused")
|
private SMTLLMFactoryOllama _factory;
|
private OkHttpClient _web;
|
private String _baseUrl;
|
private Headers _headers;
|
|
|
public SMTLLMConnectOllama(SMTLLMFactoryOllama factory)
|
{
|
OkHttpClient.Builder builder = new OkHttpClient.Builder();
|
builder.connectTimeout(120, TimeUnit.SECONDS);
|
builder.readTimeout(12000, TimeUnit.SECONDS);
|
builder.writeTimeout(12000, TimeUnit.SECONDS);
|
|
_web = builder.build();
|
|
_factory = factory;
|
_baseUrl = factory.getFactoryJson().getJson("base_url").asString();
|
_modelLLM = factory.getConnectJson().getJson("model").asString();
|
_modelVector = "deepseek-chat";
|
_headers = new Headers.Builder()
|
.add("Accept", "application/json")
|
.build();
|
}
|
|
@Override
|
public void close()
|
{
|
}
|
|
@SuppressWarnings("unchecked")
|
@Override
|
public String callWithMessage(Object listSysMsg, Object userMsg, SMTAIServerRequest tranReq) throws Exception
|
{
|
|
List<Message> messages = new ArrayList<>();
|
|
Message timeMsg = Message.builder().role(Role.SYSTEM.getValue()).content("问题中的所有未指定确切时间的全部以" + SMTStatic.toString(new Date()).substring(0, 10) +"作为当天日期进行推算。").build();
|
messages.add(timeMsg);
|
|
|
if(listSysMsg != null)
|
{
|
if(listSysMsg instanceof List)
|
{
|
for(String sysMsg : (List<String>)listSysMsg)
|
{
|
Message systemMsg = Message.builder().role(Role.SYSTEM.getValue()).content(sysMsg).build();
|
messages.add(systemMsg);
|
}
|
}
|
else if(listSysMsg instanceof String[])
|
{
|
for(String sysMsg : (String[])listSysMsg)
|
{
|
Message systemMsg = Message.builder().role(Role.SYSTEM.getValue()).content(sysMsg).build();
|
messages.add(systemMsg);
|
}
|
}
|
else if(listSysMsg instanceof Message[])
|
{
|
for(Message sysMsg : (Message[])listSysMsg)
|
{
|
messages.add(sysMsg);
|
}
|
}
|
else if(listSysMsg instanceof String)
|
{
|
Message systemMsg = Message.builder().role(Role.SYSTEM.getValue()).content((String)listSysMsg).build();
|
messages.add(systemMsg);
|
}
|
else
|
throw new Exception("unknow listSysMsg type");
|
}
|
|
|
if(userMsg instanceof String)
|
{
|
Message userMessage = Message.builder().role(Role.USER.getValue()).content((String)userMsg).build();
|
messages.add(userMessage);
|
}
|
else
|
{
|
if(userMsg instanceof String[])
|
{
|
for(String msg : (String[])userMsg)
|
{
|
messages.add(Message.builder().role(Role.USER.getValue()).content((String)msg).build());
|
}
|
}
|
}
|
|
SMTJsonWriter jsonWr = new SMTJsonWriter(false);
|
jsonWr.addKeyValue("model", _modelLLM);
|
jsonWr.addKeyValue("stream", false);
|
jsonWr.beginArray("messages");
|
for(Message message : messages)
|
{
|
jsonWr.beginMap(null);
|
{
|
jsonWr.addKeyValue("role", message.getRole());
|
jsonWr.addKeyValue("content", message.getContent());
|
}
|
jsonWr.endMap();
|
}
|
jsonWr.endArray();
|
|
// 创建一个请求头对象
|
System.out.println(jsonWr.getFullJson());
|
System.out.println(_baseUrl + "v1/chat/completions");
|
RequestBody body = RequestBody.create(MediaType.parse("application/json; charset=utf-8"), jsonWr.getFullJson());
|
Request request = new Request.Builder()
|
.url(_baseUrl + "v1/chat/completions") // 替换为你的API端点
|
.headers(_headers)
|
.post(body)
|
.build();
|
|
Response response = _web.newCall(request).execute();
|
String sJsonResult = response.body().string();
|
System.out.println(sJsonResult);
|
Json jsonResult = Json.read(sJsonResult);
|
Json jsonContent = jsonResult.getJsonPath("choices|0|message|content", false);
|
|
return jsonContent.asString();
|
}
|
|
|
public String getVector(String text) throws Exception
|
{
|
SMTJsonWriter jsonWr = new SMTJsonWriter(false);
|
jsonWr.addKeyValue("model", _modelVector);
|
jsonWr.addKeyValue("input", text);
|
|
// 创建一个请求头对象
|
RequestBody body = RequestBody.create(MediaType.parse("application/json; charset=utf-8"), jsonWr.getFullJson());
|
Request request = new Request.Builder()
|
.url(_baseUrl + "v1/embeddings") // 替换为你的API端点
|
.headers(_headers)
|
.post(body)
|
.build();
|
|
Response response = _web.newCall(request).execute();
|
Json jsonResult = Json.read(response.body().string()).getJsonPath("data|0|embedding", false);
|
StringBuilder sbResult = new StringBuilder();
|
for(Json jsonValue : jsonResult.asJsonList())
|
{
|
if(sbResult.length() > 0)
|
sbResult.append(",");
|
sbResult.append(SMTStatic.toString(jsonValue.asDouble()));
|
}
|
|
return sbResult.toString();
|
|
// List<String> list = new ArrayList<String>();
|
// list.add(text);
|
// TextEmbeddingParam param = TextEmbeddingParam
|
// .builder()
|
// .model(TextEmbedding.Models.TEXT_EMBEDDING_V1)
|
// .texts(list).build();
|
// TextEmbedding textEmbedding = _factory.allocLLMEmbedding();
|
// try
|
// {
|
// TextEmbeddingResult result = textEmbedding.call(param);
|
// List<TextEmbeddingResultItem> listResult = result.getOutput().getEmbeddings();
|
//
|
// StringBuilder sbVector = new StringBuilder();
|
// for(Double v : listResult.get(0).getEmbedding())
|
// {
|
// if(sbVector.length() > 0)
|
// sbVector.append(", ");
|
// sbVector.append(v);
|
//
|
// }
|
//
|
// return sbVector.toString();
|
// }
|
// finally
|
// {
|
// _factory.freeLLMEmbedding(textEmbedding);
|
// }
|
|
// TODO SMTLLMConnectDMXApi的getVector未实现
|
//throw new Exception("SMTLLMConnectDMXApi的getVector未实现");
|
}
|
|
|
}
|