package com.smtaiserver.smtaiserver.javaai.sse; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.springframework.http.MediaType; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import com.smtaiserver.smtaiserver.control.SMTAIServerControl; public class SMTSSEBroadcastChat { private static Logger _logger = LogManager.getLogger(SMTAIServerControl.class); private String _userId; private Map> _mapSessionId2SSE = new HashMap<>(); private static MediaType _sseMediaType = new MediaType("text", "plain", StandardCharsets.UTF_8); public SMTSSEBroadcastChat(String userId) { _userId = userId; } public void addSSEConnect(String sessionId, String from, SseEmitter sse) { int setCount = 0; int sessCount = 0; synchronized(_mapSessionId2SSE) { // 如果存在旧的sse,则首先关闭 Set setSSE = _mapSessionId2SSE.get(sessionId); if(setSSE == null) { setSSE = new HashSet<>(); _mapSessionId2SSE.put(sessionId, setSSE); } // 加入新的sse setSSE.add(sse); setCount = setSSE.size(); sessCount = _mapSessionId2SSE.size(); } _logger.info("========> add session to sse : user=" + _userId + ", session=" + sessionId + ", sess count=" + sessCount + ", sse count=" + setCount + ", from=" + from); } public boolean removeSSEConnect(String sessionId) { int setCount = 0; int sessCount = 0; synchronized(_mapSessionId2SSE) { Set setSSE = _mapSessionId2SSE.remove(sessionId); if(setSSE != null) { setCount = setSSE.size(); for(SseEmitter sse : setSSE) { try { sse.complete(); } catch(Exception ex) { sse.completeWithError(ex); } } } sessCount = _mapSessionId2SSE.size(); } if(setCount > 0) { _logger.info("========> remove session to sse : user=" + _userId + ", sess count=" + sessCount + ", sse count=" + setCount); return true; } return false; } public void sendChatNotify(String session, String jsonSend) { List listSender = new ArrayList<>(); // 快速获取要发送的列表 synchronized(_mapSessionId2SSE) { for(Entry> entry : _mapSessionId2SSE.entrySet()) { for(SseEmitter sse : entry.getValue()) { listSender.add(new Object[] {sse, entry.getKey()}); } } } // 发送通知 List removeList = null; _logger.info("===================================begin send"); for(Object[] sender : listSender) { SseEmitter sse = (SseEmitter)sender[0]; try { sse.send(jsonSend.getBytes("UTF-8"), _sseMediaType); _logger.info("===================================>send:" + _userId + " : " + sender[1] + " : " + jsonSend); } catch(Exception ex) { _logger.info("===================================>remove:" + _userId + " : " + sender[1] + " : " + ex.getMessage()); if(removeList == null) removeList = new ArrayList<>(); removeList.add(sender); sse.completeWithError(ex); } } _logger.info("===================================end send"); // 删除已经发送失败的对象 if(removeList != null) { synchronized(_mapSessionId2SSE) { for(Object[] sender : removeList) { Set setSSE = _mapSessionId2SSE.get(sender[1]); if(setSSE != null) { setSSE.remove(sender[0]); } } } } } }