package com.smtaiserver.smtaiserver.javaai.llm.qwen;
|
|
import java.util.ArrayList;
|
import java.util.Date;
|
import java.util.List;
|
import com.alibaba.dashscope.aigc.generation.Generation;
|
import com.alibaba.dashscope.aigc.generation.GenerationParam;
|
import com.alibaba.dashscope.aigc.generation.GenerationResult;
|
import com.alibaba.dashscope.common.Message;
|
import com.alibaba.dashscope.common.Role;
|
import com.alibaba.dashscope.embeddings.TextEmbedding;
|
import com.alibaba.dashscope.embeddings.TextEmbeddingParam;
|
import com.alibaba.dashscope.embeddings.TextEmbeddingResult;
|
import com.alibaba.dashscope.embeddings.TextEmbeddingResultItem;
|
import com.smtaiserver.smtaiserver.core.SMTAIServerRequest;
|
import com.smtaiserver.smtaiserver.javaai.llm.core.SMTLLMConnect;
|
import com.smtservlet.util.SMTStatic;
|
|
import io.reactivex.Flowable;
|
|
public class SMTLLMConnectQwen extends SMTLLMConnect
|
{
|
private String _model = "qwen-long";
|
private Generation _gen;
|
private SMTLLMFactoryQwen _factory;
|
|
|
public SMTLLMConnectQwen(Generation gen, SMTLLMFactoryQwen factory)
|
{
|
_gen = gen;
|
_factory = factory;
|
_model = factory.getConnectJson().getJson("model").asString();
|
}
|
|
@Override
|
public void close()
|
{
|
if(_gen != null)
|
{
|
_factory.returnGenToPool(_gen);
|
_gen = null;
|
}
|
}
|
|
@SuppressWarnings("unchecked")
|
@Override
|
public String callWithMessage(Object listSysMsg, Object userMsg, SMTAIServerRequest tranReq) throws Exception
|
{
|
{
|
List<Message> messages = new ArrayList<>();
|
|
Date dateNow = new Date();
|
Date dateyesterday = SMTStatic.calculateTime(dateNow, SMTStatic.SMTCalcTime.ADD_DATE, -1);
|
String sNow = SMTStatic.toString(dateNow).substring(0, 10);
|
String sYes = SMTStatic.toString(dateyesterday).substring(0, 10);
|
Message timeMsg = Message.builder().role(Role.SYSTEM.getValue()).content(
|
"问题中的所有未指定确切时间的全部以" + sNow + " 00:00:00到" + sNow + " 23:59:59作为当天时间范围,以" + sYes + " 00:00:00到" + sYes + " 23:59:59作为昨天时间范围,以此类推。"
|
).build();
|
messages.add(timeMsg);
|
|
if(listSysMsg != null)
|
{
|
if(listSysMsg instanceof List)
|
{
|
for(String sysMsg : (List<String>)listSysMsg)
|
{
|
Message systemMsg = Message.builder().role(Role.SYSTEM.getValue()).content(sysMsg).build();
|
messages.add(systemMsg);
|
}
|
}
|
else if(listSysMsg instanceof String[])
|
{
|
for(String sysMsg : (String[])listSysMsg)
|
{
|
Message systemMsg = Message.builder().role(Role.SYSTEM.getValue()).content(sysMsg).build();
|
messages.add(systemMsg);
|
}
|
}
|
else if(listSysMsg instanceof Message[])
|
{
|
for(Message sysMsg : (Message[])listSysMsg)
|
{
|
messages.add(sysMsg);
|
}
|
}
|
else if(listSysMsg instanceof String)
|
{
|
Message systemMsg = Message.builder().role(Role.SYSTEM.getValue()).content((String)listSysMsg).build();
|
messages.add(systemMsg);
|
}
|
else
|
throw new Exception("unknow listSysMsg type");
|
|
}
|
|
if(userMsg instanceof String)
|
{
|
Message userMessage = Message.builder().role(Role.USER.getValue()).content((String)userMsg).build();
|
messages.add(userMessage);
|
}
|
else
|
{
|
if(userMsg instanceof String[])
|
{
|
for(String msg : (String[])userMsg)
|
{
|
messages.add(Message.builder().role(Role.USER.getValue()).content((String)msg).build());
|
}
|
}
|
}
|
|
|
GenerationParam param =
|
GenerationParam.builder().model(
|
SMTStatic.isNullOrEmpty(_model) ? Generation.Models.QWEN_TURBO : _model
|
).messages(messages)
|
.resultFormat(GenerationParam.ResultFormat.MESSAGE)
|
.build();
|
//GenerationResult result = _gen.call(param);
|
//return result.getOutput().getChoices().get(0).getMessage().getContent();
|
|
tranReq.sendChunkedBlock("begin_stream", "");
|
tranReq. sendChunkedBlock("send_stream", "开始分析:\n");
|
String[] fullContent = new String[] {""};
|
Flowable<GenerationResult> result = _gen.streamCall(param);
|
result.blockingForEach(message -> {
|
String content = message.getOutput().getChoices().get(0).getMessage().getContent();
|
|
int size = fullContent[0].length();
|
if(content.length() <= size)
|
return;
|
|
fullContent[0] = content;
|
tranReq.sendChunkedStreamBlock(content.substring(size));
|
});
|
tranReq.sendChunkedBlock("end_stream", "");
|
return fullContent[0];
|
}
|
}
|
|
public String getVector(String text) throws Exception
|
{
|
List<String> list = new ArrayList<String>();
|
list.add(text);
|
TextEmbeddingParam param = TextEmbeddingParam
|
.builder()
|
.model(TextEmbedding.Models.TEXT_EMBEDDING_V1)
|
.texts(list).build();
|
TextEmbedding textEmbedding = _factory.allocLLMEmbedding();
|
try
|
{
|
TextEmbeddingResult result = textEmbedding.call(param);
|
List<TextEmbeddingResultItem> listResult = result.getOutput().getEmbeddings();
|
|
StringBuilder sbVector = new StringBuilder();
|
for(Double v : listResult.get(0).getEmbedding())
|
{
|
if(sbVector.length() > 0)
|
sbVector.append(", ");
|
sbVector.append(v);
|
|
}
|
|
return sbVector.toString();
|
}
|
finally
|
{
|
_factory.freeLLMEmbedding(textEmbedding);
|
}
|
|
|
}
|
|
}
|