test: AI 流式服务单元测试

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-24 18:39:31 +08:00
parent fc14051073
commit 84af570841
@@ -0,0 +1,102 @@
package com.emotion.service;
import com.emotion.dto.request.ai.AiRuntimeRequest;
import com.emotion.dto.response.ai.AiStreamEvent;
import com.emotion.entity.AiCallLog;
import com.emotion.entity.AiEndpointConfig;
import com.emotion.entity.AiProvider;
import com.emotion.entity.AiSceneBinding;
import com.emotion.service.ai.AiProviderAdapter;
import com.emotion.service.impl.AiRuntimeServiceImpl;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
class AiRuntimeServiceImplTest {
@Test
@DisplayName("流式调用已有输出时,末尾异常应保留输出并按成功完成")
void invokeStreamRecoversWhenOutputExists() {
AiSceneBindingService sceneBindingService = mock(AiSceneBindingService.class);
AiEndpointConfigService endpointConfigService = mock(AiEndpointConfigService.class);
AiProviderService providerService = mock(AiProviderService.class);
AiCallLogService callLogService = mock(AiCallLogService.class);
ScriptContextService scriptContextService = mock(ScriptContextService.class);
AiProviderAdapter adapter = mock(AiProviderAdapter.class);
AiSceneBinding scene = new AiSceneBinding();
scene.setSceneCode("script_generate");
scene.setEndpointId("endpoint-1");
scene.setRequiredStream(1);
AiEndpointConfig endpoint = new AiEndpointConfig();
endpoint.setId("endpoint-1");
endpoint.setEndpointCode("dify.script_generate.chat_messages");
endpoint.setProviderId("provider-1");
endpoint.setSupportStream(1);
AiProvider provider = new AiProvider();
provider.setId("provider-1");
provider.setProviderCode("dify-local");
provider.setProviderType("fake");
when(sceneBindingService.resolveScene("script_generate")).thenReturn(scene);
when(endpointConfigService.getEnabledById("endpoint-1")).thenReturn(endpoint);
when(providerService.getEnabledById("provider-1")).thenReturn(provider);
when(adapter.supports("fake")).thenReturn(true);
when(callLogService.save(any(AiCallLog.class))).thenAnswer(invocation -> {
AiCallLog log = invocation.getArgument(0);
if (log.getId() == null) {
log.setId("log-1");
}
return true;
});
when(callLogService.updateById(any(AiCallLog.class))).thenReturn(true);
org.mockito.Mockito.doAnswer(invocation -> {
@SuppressWarnings("unchecked")
java.util.function.Consumer<AiStreamEvent> consumer = invocation.getArgument(3);
consumer.accept(AiStreamEvent.delta("完整输出", 1));
throw new IllegalStateException("AI_STREAM_CLIENT_DISCONNECTED");
}).when(adapter).stream(any(), any(), any(), any());
AiRuntimeServiceImpl service = new AiRuntimeServiceImpl(
sceneBindingService,
endpointConfigService,
providerService,
callLogService,
scriptContextService,
List.of(adapter)
);
AiRuntimeRequest request = new AiRuntimeRequest();
request.setSceneCode("script_generate");
request.setUserId("user-1");
request.setRequestId("client-request-1");
List<AiStreamEvent> events = new ArrayList<>();
service.invokeStream(request, events::add);
assertEquals("完整输出", events.stream()
.filter(event -> "delta".equals(event.getType()))
.map(AiStreamEvent::getContent)
.findFirst()
.orElse(""));
assertTrue(events.stream().anyMatch(event -> "done".equals(event.getType())));
assertFalse(events.stream().anyMatch(event -> "error".equals(event.getType())));
ArgumentCaptor<AiCallLog> captor = ArgumentCaptor.forClass(AiCallLog.class);
org.mockito.Mockito.verify(callLogService).updateById(captor.capture());
AiCallLog savedLog = captor.getValue();
assertEquals("client-request-1", savedLog.getRequestId());
assertEquals("success", savedLog.getStatus());
assertEquals("完整输出", savedLog.getOutputText());
}
}