对话逻辑修复

This commit is contained in:
2025-07-25 17:48:02 +08:00
parent a4c6140ed5
commit f576de68da
31 changed files with 2129 additions and 588 deletions
@@ -3,7 +3,6 @@ package com.emotion.controller;
import com.emotion.common.Result;
import com.emotion.dto.request.AiChatRequest;
import com.emotion.dto.request.AiSummaryRequest;
import com.emotion.dto.request.ChatStatsRequest;
import com.emotion.dto.request.GuestChatRequest;
import com.emotion.dto.request.ConversationCreateRequest;
import com.emotion.dto.response.AiChatResponse;
@@ -14,7 +13,7 @@ import com.emotion.dto.response.GuestChatResponse;
import com.emotion.dto.response.GuestUserInfoResponse;
import com.emotion.dto.response.ConversationResponse;
import com.emotion.entity.Conversation;
import com.emotion.service.AIChatService;
import com.emotion.service.AiChatService;
import com.emotion.service.MessageService;
import com.emotion.service.ConversationService;
import lombok.extern.slf4j.Slf4j;
@@ -25,7 +24,6 @@ import org.springframework.web.bind.annotation.*;
import javax.servlet.http.HttpServletRequest;
import javax.validation.Valid;
import java.time.format.DateTimeFormatter;
import java.util.HashMap;
import java.util.Map;
/**
@@ -40,7 +38,7 @@ import java.util.Map;
public class AiChatController {
@Autowired
private AIChatService aiChatService;
private AiChatService aiChatService;
@Autowired
private MessageService messageService;
@@ -1,7 +1,7 @@
package com.emotion.controller;
import com.emotion.common.Result;
import com.emotion.service.AIChatService;
import com.emotion.service.AiChatService;
import com.emotion.util.CurrentUserUtil;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
@@ -24,7 +24,7 @@ import java.util.Map;
public class EmotionSummaryController {
@Autowired
private AIChatService aiChatService;
private AiChatService aiChatService;
@Operation(summary = "生成用户当天的情绪记录总结", description = "基于用户当天的聊天记录生成情绪分析和记录")
@PostMapping("/generate")
@@ -1,22 +1,19 @@
package com.emotion.controller;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.emotion.common.PageResult;
import com.emotion.common.Result;
import com.emotion.dto.request.PageRequest;
import com.emotion.dto.request.MessageCreateRequest;
import com.emotion.dto.request.MessagePageRequest;
import com.emotion.dto.request.MessageSearchRequest;
import com.emotion.dto.request.MessageRecentRequest;
import com.emotion.dto.response.MessageResponse;
import com.emotion.entity.Message;
import com.emotion.service.MessageService;
import com.emotion.util.CurrentUserUtil;
import org.springframework.beans.BeanUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
import javax.validation.Valid;
import java.time.format.DateTimeFormatter;
import java.util.List;
import java.util.stream.Collectors;
/**
* 消息控制器
@@ -26,52 +23,31 @@ import java.util.stream.Collectors;
*/
@RestController
@RequestMapping("/message")
@Slf4j
public class MessageController {
@Autowired
private MessageService messageService;
private static final DateTimeFormatter DATE_TIME_FORMATTER = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
/**
* 分页查询消息
* 创建消息
*/
@GetMapping("/page")
public Result<PageResult<MessageResponse>> getPage(@Valid PageRequest request) {
IPage<Message> page = messageService.getPage(request);
List<MessageResponse> responses = page.getRecords().stream()
.map(this::convertToResponse)
.collect(Collectors.toList());
@PostMapping
public Result<MessageResponse> create(@Valid @RequestBody MessageCreateRequest request) {
log.info("创建消息: conversationId={}", request.getConversationId());
PageResult<MessageResponse> pageResult = new PageResult<>();
pageResult.setCurrent(page.getCurrent());
pageResult.setSize(page.getSize());
pageResult.setTotal(page.getTotal());
pageResult.setPages(page.getPages());
pageResult.setRecords(responses);
try {
MessageResponse response = messageService.createMessageFromRequest(request);
log.info("创建消息成功: messageId={}", response.getId());
return Result.success(response);
return Result.success(pageResult);
}
/**
* 根据会话ID分页查询消息
*/
@GetMapping("/conversation/{conversationId}/page")
public Result<PageResult<MessageResponse>> getPageByConversationId(@PathVariable String conversationId,
@Valid PageRequest request) {
IPage<Message> page = messageService.getPageByConversationId(request, conversationId);
List<MessageResponse> responses = page.getRecords().stream()
.map(this::convertToResponse)
.collect(Collectors.toList());
PageResult<MessageResponse> pageResult = new PageResult<>();
pageResult.setCurrent(page.getCurrent());
pageResult.setSize(page.getSize());
pageResult.setTotal(page.getTotal());
pageResult.setPages(page.getPages());
pageResult.setRecords(responses);
return Result.success(pageResult);
} catch (IllegalStateException e) {
log.error("用户未认证: {}", e.getMessage());
return Result.error(401, "用户未登录或认证失败");
} catch (Exception e) {
log.error("创建消息失败", e);
return Result.error(500, "创建消息失败,请稍后重试");
}
}
/**
@@ -79,139 +55,90 @@ public class MessageController {
*/
@GetMapping("/{id}")
public Result<MessageResponse> getById(@PathVariable String id) {
Message message = messageService.getById(id);
if (message == null) {
return Result.notFound("消息不存在");
log.info("获取消息详情: id={}", id);
try {
MessageResponse response = messageService.getMessageById(id);
if (response == null) {
return Result.error(404, "消息不存在");
}
return Result.success(response);
} catch (Exception e) {
log.error("获取消息详情失败", e);
return Result.error(500, "获取消息详情失败,请稍后重试");
}
return Result.success(convertToResponse(message));
}
/**
* 创建消息
*/
@PostMapping
public Result<MessageResponse> create(@Valid @RequestBody MessageCreateRequest request) {
Message message = new Message();
message.setConversationId(request.getConversationId());
message.setCreateBy(request.getUserId());
message.setContent(request.getContent());
message.setType(request.getContentType());
message.setSender(request.getSenderType());
// 可以根据需要设置其他字段
Message savedMessage = messageService.createMessage(message);
return Result.success(convertToResponse(savedMessage));
}
/**
* 根据会话ID查询消息
*/
@GetMapping("/conversation/{conversationId}")
public Result<List<MessageResponse>> getByConversationId(@PathVariable String conversationId) {
List<Message> messages = messageService.getByConversationId(conversationId);
List<MessageResponse> responses = messages.stream()
.map(this::convertToResponse)
.collect(Collectors.toList());
return Result.success(responses);
}
/**
* 统计会话消息数量
*/
@GetMapping("/conversation/{conversationId}/count")
public Result<Long> countByConversationId(@PathVariable String conversationId) {
Long count = messageService.countByConversationId(conversationId);
return Result.success(count);
}
/**
* 根据用户ID分页查询消息
*/
@GetMapping("/user/page")
public Result<PageResult<MessageResponse>> getPageByUserId(@Valid PageRequest request) {
public Result<PageResult<MessageResponse>> getPageByUserId(
@RequestParam(defaultValue = "1") Long current,
@RequestParam(defaultValue = "20") Long size) {
log.info("获取用户消息分页: current={}, size={}", current, size);
try {
// 从上下文中获取当前用户ID
String userId = CurrentUserUtil.requireCurrentUserId();
IPage<Message> page = messageService.getByUserIdWithPage(userId, Math.toIntExact(request.getCurrent()),
Math.toIntExact(request.getSize()));
List<MessageResponse> responses = page.getRecords().stream()
.map(this::convertToResponse)
.collect(Collectors.toList());
PageResult<MessageResponse> pageResult = new PageResult<>();
pageResult.setCurrent(page.getCurrent());
pageResult.setSize(page.getSize());
pageResult.setTotal(page.getTotal());
pageResult.setPages(page.getPages());
pageResult.setRecords(responses);
// 构建请求对象
MessagePageRequest request = new MessagePageRequest();
request.setCurrent(current);
request.setSize(size);
PageResult<MessageResponse> pageResult = messageService.getUserMessagesWithPage(request);
log.info("获取用户消息分页成功: total={}", pageResult.getTotal());
return Result.success(pageResult);
} catch (IllegalStateException e) {
return Result.error(e.getMessage());
log.error("用户未认证: {}", e.getMessage());
return Result.error(401, "用户未登录或认证失败");
} catch (Exception e) {
log.error("获取用户消息失败", e);
return Result.error(500, "获取消息失败,请稍后重试");
}
}
/**
* 根据用户ID和关键词搜索消息
*/
@GetMapping("/user/search")
public Result<List<MessageResponse>> searchByUserId(
@RequestParam String keyword,
@RequestParam(defaultValue = "50") Integer limit) {
@PostMapping("/user/search")
public Result<List<MessageResponse>> searchByUserId(@Valid @RequestBody MessageSearchRequest request) {
log.info("搜索用户消息: keyword={}, limit={}", request.getKeyword(), request.getLimit());
try {
// 从上下文中获取当前用户ID
String userId = CurrentUserUtil.requireCurrentUserId();
List<Message> messages = messageService.searchByUserIdAndKeyword(userId, keyword, limit);
List<MessageResponse> responses = messages.stream()
.map(this::convertToResponse)
.collect(Collectors.toList());
List<MessageResponse> responses = messageService.searchUserMessages(request);
log.info("搜索用户消息成功: {} 条消息", responses.size());
return Result.success(responses);
} catch (IllegalStateException e) {
return Result.error(e.getMessage());
log.error("用户未认证: {}", e.getMessage());
return Result.error(401, "用户未登录或认证失败");
} catch (Exception e) {
log.error("搜索用户消息失败", e);
return Result.error(500, "搜索失败,请稍后重试");
}
}
/**
* 获取用户最近的聊天记录
*/
@GetMapping("/user/recent")
public Result<List<MessageResponse>> getRecentMessages(
@RequestParam(defaultValue = "10") Integer limit) {
@PostMapping("/user/recent")
public Result<List<MessageResponse>> getRecentMessages(@Valid @RequestBody MessageRecentRequest request) {
log.info("获取用户最近消息: limit={}", request.getLimit());
try {
// 从上下文中获取当前用户ID
String userId = CurrentUserUtil.requireCurrentUserId();
List<Message> messages = messageService.getRecentByUserId(userId, limit);
List<MessageResponse> responses = messages.stream()
.map(this::convertToResponse)
.collect(Collectors.toList());
List<MessageResponse> responses = messageService.getUserRecentMessages(request);
log.info("获取用户最近消息成功: {} 条消息", responses.size());
return Result.success(responses);
} catch (IllegalStateException e) {
return Result.error(e.getMessage());
log.error("用户未认证: {}", e.getMessage());
return Result.error(401, "用户未登录或认证失败");
} catch (Exception e) {
log.error("获取最近消息失败", e);
return Result.error(500, "获取最近消息失败,请稍后重试");
}
}
/**
* 转换为响应对象
*/
private MessageResponse convertToResponse(Message message) {
MessageResponse response = new MessageResponse();
BeanUtils.copyProperties(message, response);
response.setId(message.getId());
if (message.getCreateTime() != null) {
response.setCreateTime(message.getCreateTime().format(DATE_TIME_FORMATTER));
}
if (message.getUpdateTime() != null) {
response.setUpdateTime(message.getUpdateTime().format(DATE_TIME_FORMATTER));
}
return response;
}
}
@@ -1,12 +1,11 @@
package com.emotion.controller;
import com.emotion.service.AIChatService;
import com.emotion.service.AiChatService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.annotation.Payload;
import org.springframework.messaging.handler.annotation.SendTo;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.stereotype.Controller;
@@ -31,7 +30,7 @@ public class WebSocketController {
private SimpMessagingTemplate messagingTemplate;
@Autowired
private AIChatService aiChatService;
private AiChatService aiChatService;
// 已移除旧的WebSocket消息处理方法,使用新的ChatWebSocketController