对话逻辑修复

This commit is contained in:
2025-07-25 17:48:02 +08:00
parent a4c6140ed5
commit f576de68da
31 changed files with 2129 additions and 588 deletions
@@ -7,7 +7,7 @@ import com.emotion.entity.Conversation;
import com.emotion.entity.CozeApiCall;
import com.emotion.entity.EmotionRecord;
import com.emotion.entity.EmotionAnalysis;
import com.emotion.service.AIChatService;
import com.emotion.service.AiChatService;
import com.emotion.service.MessageService;
import com.emotion.service.ConversationService;
import com.emotion.service.CozeApiCallService;
@@ -34,7 +34,6 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;
/**
* AI聊天服务实现类
@@ -44,7 +43,7 @@ import java.util.stream.Collectors;
*/
@Slf4j
@Service
public class AiChatServiceImpl implements AIChatService {
public class AiChatServiceImpl implements AiChatService {
@Autowired
private RestTemplate restTemplate;
@@ -138,6 +137,34 @@ public class AiChatServiceImpl implements AIChatService {
} catch (Exception e) {
log.error("发送聊天消息失败", e);
return "抱歉,AI服务暂时不可用,请稍后再试。";
}
}
@Override
public String sendChatMessageForWebSocket(String conversationId, String message, String userId) {
log.info("WebSocket发送聊天消息: conversationId={}, userId={}, message={}", conversationId, userId, message);
try {
// 调用Coze API
String aiReply = sendMessage(conversationId, message, userId);
// 注意:不保存用户消息,因为WebSocket处理器已经保存了
// 只保存AI回复
Message aiMessage = new Message();
aiMessage.setConversationId(conversationId);
aiMessage.setCreateBy("ai");
aiMessage.setContent(aiReply);
aiMessage.setType("text");
aiMessage.setSender("ai");
aiMessage = messageService.createMessage(aiMessage);
log.info("WebSocket聊天消息处理完成: aiMessageId={}", aiMessage.getId());
return aiReply;
} catch (Exception e) {
log.error("WebSocket发送聊天消息失败", e);
return "抱歉,我暂时无法回复,请稍后再试。";
}
}
@@ -1051,6 +1078,39 @@ public class AiChatServiceImpl implements AIChatService {
.build();
}
/**
* 根据主要情绪确定情绪极性
*/
private String determinePolarity(String primaryEmotion) {
if (primaryEmotion == null || primaryEmotion.trim().isEmpty()) {
return "neutral";
}
String emotion = primaryEmotion.toLowerCase().trim();
// 积极情绪
if (emotion.contains("快乐") || emotion.contains("高兴") || emotion.contains("喜悦") ||
emotion.contains("兴奋") || emotion.contains("满足") || emotion.contains("感激") ||
emotion.contains("") || emotion.contains("希望") || emotion.contains("自信") ||
emotion.contains("平静") || emotion.contains("放松") || emotion.contains("开心") ||
emotion.contains("幸福") || emotion.contains("乐观") || emotion.contains("满意")) {
return "positive";
}
// 消极情绪
if (emotion.contains("悲伤") || emotion.contains("愤怒") || emotion.contains("恐惧") ||
emotion.contains("焦虑") || emotion.contains("沮丧") || emotion.contains("失望") ||
emotion.contains("孤独") || emotion.contains("痛苦") || emotion.contains("绝望") ||
emotion.contains("愧疚") || emotion.contains("羞耻") || emotion.contains("嫉妒") ||
emotion.contains("厌恶") || emotion.contains("烦躁") || emotion.contains("压抑") ||
emotion.contains("无助") || emotion.contains("困惑") || emotion.contains("担心")) {
return "negative";
}
// 默认为中性
return "neutral";
}
/**
* 从AI回复中提取JSON字符串
*/
@@ -5,14 +5,25 @@ import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.emotion.common.BasePageRequest;
import com.emotion.common.PageResult;
import com.emotion.dto.request.MessagePageRequest;
import com.emotion.dto.request.MessageSearchRequest;
import com.emotion.dto.request.MessageRecentRequest;
import com.emotion.dto.request.MessageCreateRequest;
import com.emotion.dto.response.MessageResponse;
import com.emotion.entity.Message;
import com.emotion.mapper.MessageMapper;
import com.emotion.service.MessageService;
import com.emotion.util.CurrentUserUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.List;
import java.util.stream.Collectors;
/**
* 消息服务实现类
@@ -20,9 +31,12 @@ import java.util.List;
* @author emotion-museum
* @date 2025-07-24
*/
@Slf4j
@Service
public class MessageServiceImpl extends ServiceImpl<MessageMapper, Message> implements MessageService {
private static final DateTimeFormatter DATE_TIME_FORMATTER = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
@Override
public IPage<Message> getPage(BasePageRequest request) {
Page<Message> page = new Page<>(request.getCurrent(), request.getSize());
@@ -201,4 +215,124 @@ public class MessageServiceImpl extends ServiceImpl<MessageMapper, Message> impl
// 获取用户最近的消息,按时间倒序
return this.baseMapper.getRecentByUserId(userId, limit);
}
@Override
public PageResult<MessageResponse> getUserMessagesWithPage(MessagePageRequest request) {
log.info("获取用户消息分页: current={}, size={}", request.getCurrent(), request.getSize());
// 从上下文中获取当前用户ID
String userId = CurrentUserUtil.requireCurrentUserId();
log.info("当前用户ID: {}", userId);
// 调用原有的分页查询方法
IPage<Message> page = getByUserIdWithPage(userId, Math.toIntExact(request.getCurrent()),
Math.toIntExact(request.getSize()));
log.info("查询结果: total={}, records={}", page.getTotal(), page.getRecords().size());
// 转换为响应对象
List<MessageResponse> responses = page.getRecords().stream()
.map(this::convertToResponse)
.collect(Collectors.toList());
// 构建分页结果
PageResult<MessageResponse> pageResult = new PageResult<>();
pageResult.setCurrent(page.getCurrent());
pageResult.setSize(page.getSize());
pageResult.setTotal(page.getTotal());
pageResult.setPages(page.getPages());
pageResult.setRecords(responses);
return pageResult;
}
@Override
public List<MessageResponse> searchUserMessages(MessageSearchRequest request) {
log.info("搜索用户消息: keyword={}, limit={}", request.getKeyword(), request.getLimit());
// 从上下文中获取当前用户ID
String userId = CurrentUserUtil.requireCurrentUserId();
log.info("当前用户ID: {}", userId);
// 调用原有的搜索方法
List<Message> messages = searchByUserIdAndKeyword(userId, request.getKeyword(), request.getLimit());
log.info("搜索结果: {} 条消息", messages.size());
// 转换为响应对象
return messages.stream()
.map(this::convertToResponse)
.collect(Collectors.toList());
}
@Override
public List<MessageResponse> getUserRecentMessages(MessageRecentRequest request) {
log.info("获取用户最近消息: limit={}", request.getLimit());
// 从上下文中获取当前用户ID
String userId = CurrentUserUtil.requireCurrentUserId();
log.info("当前用户ID: {}", userId);
// 调用原有的获取最近消息方法
List<Message> messages = getRecentByUserId(userId, request.getLimit());
log.info("查询结果: {} 条最近消息", messages.size());
// 转换为响应对象
return messages.stream()
.map(this::convertToResponse)
.collect(Collectors.toList());
}
@Override
public MessageResponse createMessageFromRequest(MessageCreateRequest request) {
log.info("根据请求创建消息: conversationId={}", request.getConversationId());
// 从上下文中获取当前用户ID
String userId = CurrentUserUtil.requireCurrentUserId();
log.info("当前用户ID: {}", userId);
// 构建消息对象
Message message = new Message();
message.setConversationId(request.getConversationId());
message.setCreateBy(userId);
message.setContent(request.getContent());
message.setType(request.getContentType());
message.setSender(request.getSenderType());
// 调用原有的创建方法
Message savedMessage = createMessage(message);
log.info("创建消息成功: messageId={}", savedMessage.getId());
// 转换为响应对象
return convertToResponse(savedMessage);
}
@Override
public MessageResponse getMessageById(String id) {
log.info("根据ID获取消息: id={}", id);
Message message = getById(id);
if (message == null) {
log.warn("消息不存在: id={}", id);
return null;
}
// 转换为响应对象
return convertToResponse(message);
}
/**
* 转换为响应对象
*/
private MessageResponse convertToResponse(Message message) {
MessageResponse response = new MessageResponse();
BeanUtils.copyProperties(message, response);
response.setId(message.getId());
if (message.getCreateTime() != null) {
response.setCreateTime(message.getCreateTime().format(DATE_TIME_FORMATTER));
}
if (message.getUpdateTime() != null) {
response.setUpdateTime(message.getUpdateTime().format(DATE_TIME_FORMATTER));
}
return response;
}
}
@@ -0,0 +1,399 @@
package com.emotion.service.impl;
import com.emotion.dto.websocket.ChatRequest;
import com.emotion.dto.websocket.ConnectRequest;
import com.emotion.dto.websocket.WebSocketMessage;
import com.emotion.entity.Message;
import com.emotion.entity.Conversation;
import com.emotion.service.WebSocketService;
import com.emotion.service.AiChatService;
import com.emotion.service.MessageService;
import com.emotion.service.ConversationService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.stereotype.Service;
import java.security.Principal;
import java.time.LocalDateTime;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
/**
* WebSocket服务实现类
*
* @author emotion-museum
* @date 2025-07-25
*/
@Slf4j
@Service
public class WebSocketServiceImpl implements WebSocketService {
@Autowired
private SimpMessagingTemplate messagingTemplate;
@Autowired
private AiChatService aiChatService;
@Autowired
private MessageService messageService;
@Autowired
private ConversationService conversationService;
// 在线用户管理
private final ConcurrentHashMap<String, String> onlineUsers = new ConcurrentHashMap<>();
/**
* 处理聊天消息
*/
@Override
public void handleChatMessage(ChatRequest request, String sessionId, Principal principal) {
try {
log.info("处理聊天消息: request={}, sessionId={}, principal={}", request, sessionId, principal);
// 验证请求参数
if (request.getContent() == null || request.getContent().trim().isEmpty()) {
sendErrorMessage(request.getSenderId(), "消息内容不能为空");
return;
}
// 确定用户身份和类型
String userId = request.getSenderId();
WebSocketMessage.SenderType senderType = WebSocketMessage.SenderType.GUEST;
if (principal != null) {
userId = principal.getName();
// 如果用户ID不是以guest_开头,说明是认证用户
if (!userId.startsWith("guest_")) {
senderType = WebSocketMessage.SenderType.USER;
}
}
// 更新请求中的用户信息
request.setSenderId(userId);
request.setSenderType(senderType == WebSocketMessage.SenderType.USER ? ChatRequest.SenderType.USER
: ChatRequest.SenderType.GUEST);
log.info("确定用户身份: userId={}, senderType={}", userId, senderType);
// 构建用户消息
WebSocketMessage userMessage = WebSocketMessage.builder()
.messageId(UUID.randomUUID().toString())
.conversationId(request.getConversationId())
.type(WebSocketMessage.MessageType.TEXT)
.content(request.getContent())
.senderId(userId)
.senderType(senderType)
.status(WebSocketMessage.MessageStatus.SENT)
.createTime(LocalDateTime.now())
.build();
// 发送用户消息到会话
if (request.getConversationId() != null) {
messagingTemplate.convertAndSend("/topic/conversation/" + request.getConversationId(), userMessage);
}
// 发送给用户私有队列
messagingTemplate.convertAndSendToUser(request.getSenderId(), "/queue/messages", userMessage);
// 发送AI思考状态
sendAiThinkingMessage(request.getSenderId(), request.getConversationId());
// 异步调用AI服务
processAiResponse(request);
} catch (Exception e) {
log.error("处理聊天消息失败", e);
sendErrorMessage(request.getSenderId(), "消息处理失败,请稍后重试");
}
}
/**
* 处理用户连接
*/
@Override
public void handleUserConnect(ConnectRequest request, String sessionId, Principal principal) {
try {
String userId = request.getUserId();
boolean isAuthenticated = false;
// 优先从Principal获取认证用户信息
if (principal != null) {
userId = principal.getName();
// 检查是否是认证用户(不是访客)
isAuthenticated = !userId.startsWith("guest_");
}
// 如果还没有userId,生成访客ID
if (userId == null) {
userId = "guest_" + sessionId;
}
log.info("用户连接WebSocket: userId={}, sessionId={}, authenticated={}",
userId, sessionId, isAuthenticated);
// 记录在线用户
onlineUsers.put(sessionId, userId);
// 发送连接成功消息
WebSocketMessage connectMessage = WebSocketMessage.builder()
.messageId(UUID.randomUUID().toString())
.type(WebSocketMessage.MessageType.CONNECTION)
.content("连接成功")
.senderId("system")
.senderType(WebSocketMessage.SenderType.SYSTEM)
.status(WebSocketMessage.MessageStatus.SENT)
.createTime(LocalDateTime.now())
.build();
messagingTemplate.convertAndSendToUser(userId, "/queue/messages", connectMessage);
} catch (Exception e) {
log.error("处理用户连接失败", e);
}
}
/**
* 处理用户断开连接
*/
@Override
public void handleUserDisconnect(String sessionId, Principal principal) {
try {
String userId = onlineUsers.remove(sessionId);
log.info("用户断开WebSocket连接: userId={}, sessionId={}", userId, sessionId);
} catch (Exception e) {
log.error("处理用户断开连接失败", e);
}
}
/**
* 处理心跳消息
*/
@Override
public void handleHeartbeat(String sessionId, Principal principal) {
try {
String userId = onlineUsers.get(sessionId);
if (userId == null && principal != null) {
userId = principal.getName();
}
// 发送心跳响应
WebSocketMessage heartbeatMessage = WebSocketMessage.builder()
.messageId(UUID.randomUUID().toString())
.type(WebSocketMessage.MessageType.HEARTBEAT)
.content("pong")
.senderId("system")
.senderType(WebSocketMessage.SenderType.SYSTEM)
.status(WebSocketMessage.MessageStatus.SENT)
.createTime(LocalDateTime.now())
.build();
if (userId != null) {
messagingTemplate.convertAndSendToUser(userId, "/queue/messages", heartbeatMessage);
}
} catch (Exception e) {
log.error("处理心跳消息失败", e);
}
}
/**
* 发送AI思考状态消息
*/
private void sendAiThinkingMessage(String userId, String conversationId) {
WebSocketMessage thinkingMessage = WebSocketMessage.builder()
.messageId(UUID.randomUUID().toString())
.conversationId(conversationId)
.type(WebSocketMessage.MessageType.AI_THINKING)
.content("AI正在思考中...")
.senderId("ai")
.senderType(WebSocketMessage.SenderType.AI)
.status(WebSocketMessage.MessageStatus.SENT)
.createTime(LocalDateTime.now())
.build();
messagingTemplate.convertAndSendToUser(userId, "/queue/messages", thinkingMessage);
if (conversationId != null) {
messagingTemplate.convertAndSend("/topic/conversation/" + conversationId, thinkingMessage);
}
}
/**
* 异步处理AI响应
*/
private void processAiResponse(ChatRequest request) {
// 使用线程池异步处理AI响应
new Thread(() -> {
try {
String userId = request.getSenderId();
String conversationId = request.getConversationId();
// 如果没有会话ID,创建新会话
if (conversationId == null || conversationId.trim().isEmpty()) {
conversationId = createNewConversation(userId, request);
request.setConversationId(conversationId);
}
// 确保会话存在并更新活跃时间
ensureConversationExists(conversationId, userId, request);
// 保存用户消息到数据库
Message userMessage = new Message();
userMessage.setConversationId(conversationId);
userMessage.setUserId(userId);
userMessage
.setUserType(request.getSenderType() == ChatRequest.SenderType.USER ? "registered" : "guest");
userMessage.setContent(request.getContent());
userMessage.setType("text");
userMessage.setSender("user");
userMessage.setCozeRole("user");
userMessage.setCozeContentType("text");
messageService.createMessage(userMessage);
// 调用AI服务(WebSocket专用方法,不重复保存用户消息)
String aiReply = aiChatService.sendChatMessageForWebSocket(
conversationId,
request.getContent(),
userId
);
// 构建AI回复消息(不分割,保持完整性)
WebSocketMessage aiMessage = WebSocketMessage.builder()
.messageId(UUID.randomUUID().toString())
.conversationId(conversationId)
.type(WebSocketMessage.MessageType.TEXT)
.content(aiReply)
.senderId("ai")
.senderType(WebSocketMessage.SenderType.AI)
.status(WebSocketMessage.MessageStatus.SENT)
.createTime(LocalDateTime.now())
.build();
// AI回复已经在sendChatMessageForWebSocket中保存了,这里不需要重复保存
// 发送AI回复
messagingTemplate.convertAndSendToUser(userId, "/queue/messages", aiMessage);
if (conversationId != null) {
messagingTemplate.convertAndSend("/topic/conversation/" + conversationId, aiMessage);
}
// 更新会话的最后活跃时间和消息数量
updateConversationActivity(conversationId);
} catch (Exception e) {
log.error("AI响应处理失败", e);
sendErrorMessage(request.getSenderId(), "AI服务暂时不可用,请稍后重试");
}
}).start();
}
/**
* 发送错误消息
*/
private void sendErrorMessage(String userId, String errorContent) {
WebSocketMessage errorMessage = WebSocketMessage.builder()
.messageId(UUID.randomUUID().toString())
.type(WebSocketMessage.MessageType.ERROR)
.content(errorContent)
.senderId("system")
.senderType(WebSocketMessage.SenderType.SYSTEM)
.status(WebSocketMessage.MessageStatus.SENT)
.createTime(LocalDateTime.now())
.build();
messagingTemplate.convertAndSendToUser(userId, "/queue/messages", errorMessage);
}
/**
* 获取在线用户数量
*/
@Override
public int getOnlineUserCount() {
return onlineUsers.size();
}
/**
* 创建新会话
*/
private String createNewConversation(String userId, ChatRequest request) {
try {
String conversationId = "conv_" + System.currentTimeMillis() + "_" + UUID.randomUUID().toString().substring(0, 8);
Conversation conversation = Conversation.builder()
.userId(userId)
.userType(request.getSenderType() == ChatRequest.SenderType.USER ? "registered" : "guest")
.title("新对话")
.type("chat")
.conversationStatus("active")
.startTime(LocalDateTime.now())
.lastActiveTime(LocalDateTime.now())
.messageCount(0)
.build();
// 设置ID
conversation.setId(conversationId);
conversationService.save(conversation);
log.info("创建新会话: conversationId={}, userId={}", conversationId, userId);
return conversationId;
} catch (Exception e) {
log.error("创建新会话失败: userId={}", userId, e);
throw new RuntimeException("创建会话失败", e);
}
}
/**
* 确保会话存在并更新活跃时间
*/
private void ensureConversationExists(String conversationId, String userId, ChatRequest request) {
try {
Conversation conversation = conversationService.getById(conversationId);
if (conversation == null) {
// 如果会话不存在,创建一个
conversation = Conversation.builder()
.userId(userId)
.userType(request.getSenderType() == ChatRequest.SenderType.USER ? "registered" : "guest")
.title("对话")
.type("chat")
.conversationStatus("active")
.startTime(LocalDateTime.now())
.lastActiveTime(LocalDateTime.now())
.messageCount(0)
.build();
// 设置ID
conversation.setId(conversationId);
conversationService.save(conversation);
log.info("创建会话: conversationId={}, userId={}", conversationId, userId);
} else {
// 更新最后活跃时间
conversation.setLastActiveTime(LocalDateTime.now());
conversationService.updateById(conversation);
}
} catch (Exception e) {
log.error("确保会话存在失败: conversationId={}, userId={}", conversationId, userId, e);
}
}
/**
* 更新会话活跃状态
*/
private void updateConversationActivity(String conversationId) {
try {
Conversation conversation = conversationService.getById(conversationId);
if (conversation != null) {
conversation.setLastActiveTime(LocalDateTime.now());
conversation.setMessageCount((conversation.getMessageCount() != null ? conversation.getMessageCount() : 0) + 1);
conversationService.updateById(conversation);
}
} catch (Exception e) {
log.error("更新会话活跃状态失败: conversationId={}", conversationId, e);
}
}
}