From 84af570841d26d7925b62e484887405574ed3be5 Mon Sep 17 00:00:00 2001 From: Peanut Date: Sun, 24 May 2026 18:39:31 +0800 Subject: [PATCH] =?UTF-8?q?test:=20AI=20=E6=B5=81=E5=BC=8F=E6=9C=8D?= =?UTF-8?q?=E5=8A=A1=E5=8D=95=E5=85=83=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.7 --- .../service/AiRuntimeServiceImplTest.java | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 backend-single/src/test/java/com/emotion/service/AiRuntimeServiceImplTest.java diff --git a/backend-single/src/test/java/com/emotion/service/AiRuntimeServiceImplTest.java b/backend-single/src/test/java/com/emotion/service/AiRuntimeServiceImplTest.java new file mode 100644 index 0000000..c1c844f --- /dev/null +++ b/backend-single/src/test/java/com/emotion/service/AiRuntimeServiceImplTest.java @@ -0,0 +1,102 @@ +package com.emotion.service; + +import com.emotion.dto.request.ai.AiRuntimeRequest; +import com.emotion.dto.response.ai.AiStreamEvent; +import com.emotion.entity.AiCallLog; +import com.emotion.entity.AiEndpointConfig; +import com.emotion.entity.AiProvider; +import com.emotion.entity.AiSceneBinding; +import com.emotion.service.ai.AiProviderAdapter; +import com.emotion.service.impl.AiRuntimeServiceImpl; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +class AiRuntimeServiceImplTest { + + @Test + @DisplayName("流式调用已有输出时,末尾异常应保留输出并按成功完成") + void invokeStreamRecoversWhenOutputExists() { + AiSceneBindingService sceneBindingService = mock(AiSceneBindingService.class); + AiEndpointConfigService endpointConfigService = mock(AiEndpointConfigService.class); + AiProviderService providerService = mock(AiProviderService.class); + AiCallLogService callLogService = mock(AiCallLogService.class); + ScriptContextService scriptContextService = mock(ScriptContextService.class); + AiProviderAdapter adapter = mock(AiProviderAdapter.class); + + AiSceneBinding scene = new AiSceneBinding(); + scene.setSceneCode("script_generate"); + scene.setEndpointId("endpoint-1"); + scene.setRequiredStream(1); + AiEndpointConfig endpoint = new AiEndpointConfig(); + endpoint.setId("endpoint-1"); + endpoint.setEndpointCode("dify.script_generate.chat_messages"); + endpoint.setProviderId("provider-1"); + endpoint.setSupportStream(1); + AiProvider provider = new AiProvider(); + provider.setId("provider-1"); + provider.setProviderCode("dify-local"); + provider.setProviderType("fake"); + + when(sceneBindingService.resolveScene("script_generate")).thenReturn(scene); + when(endpointConfigService.getEnabledById("endpoint-1")).thenReturn(endpoint); + when(providerService.getEnabledById("provider-1")).thenReturn(provider); + when(adapter.supports("fake")).thenReturn(true); + when(callLogService.save(any(AiCallLog.class))).thenAnswer(invocation -> { + AiCallLog log = invocation.getArgument(0); + if (log.getId() == null) { + log.setId("log-1"); + } + return true; + }); + when(callLogService.updateById(any(AiCallLog.class))).thenReturn(true); + org.mockito.Mockito.doAnswer(invocation -> { + @SuppressWarnings("unchecked") + java.util.function.Consumer consumer = invocation.getArgument(3); + consumer.accept(AiStreamEvent.delta("完整输出", 1)); + throw new IllegalStateException("AI_STREAM_CLIENT_DISCONNECTED"); + }).when(adapter).stream(any(), any(), any(), any()); + + AiRuntimeServiceImpl service = new AiRuntimeServiceImpl( + sceneBindingService, + endpointConfigService, + providerService, + callLogService, + scriptContextService, + List.of(adapter) + ); + + AiRuntimeRequest request = new AiRuntimeRequest(); + request.setSceneCode("script_generate"); + request.setUserId("user-1"); + request.setRequestId("client-request-1"); + List events = new ArrayList<>(); + + service.invokeStream(request, events::add); + + assertEquals("完整输出", events.stream() + .filter(event -> "delta".equals(event.getType())) + .map(AiStreamEvent::getContent) + .findFirst() + .orElse("")); + assertTrue(events.stream().anyMatch(event -> "done".equals(event.getType()))); + assertFalse(events.stream().anyMatch(event -> "error".equals(event.getType()))); + + ArgumentCaptor captor = ArgumentCaptor.forClass(AiCallLog.class); + org.mockito.Mockito.verify(callLogService).updateById(captor.capture()); + AiCallLog savedLog = captor.getValue(); + assertEquals("client-request-1", savedLog.getRequestId()); + assertEquals("success", savedLog.getStatus()); + assertEquals("完整输出", savedLog.getOutputText()); + } +}