test: AI 流式服务单元测试
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,102 @@
|
||||
package com.emotion.service;
|
||||
|
||||
import com.emotion.dto.request.ai.AiRuntimeRequest;
|
||||
import com.emotion.dto.response.ai.AiStreamEvent;
|
||||
import com.emotion.entity.AiCallLog;
|
||||
import com.emotion.entity.AiEndpointConfig;
|
||||
import com.emotion.entity.AiProvider;
|
||||
import com.emotion.entity.AiSceneBinding;
|
||||
import com.emotion.service.ai.AiProviderAdapter;
|
||||
import com.emotion.service.impl.AiRuntimeServiceImpl;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
class AiRuntimeServiceImplTest {
|
||||
|
||||
@Test
|
||||
@DisplayName("流式调用已有输出时,末尾异常应保留输出并按成功完成")
|
||||
void invokeStreamRecoversWhenOutputExists() {
|
||||
AiSceneBindingService sceneBindingService = mock(AiSceneBindingService.class);
|
||||
AiEndpointConfigService endpointConfigService = mock(AiEndpointConfigService.class);
|
||||
AiProviderService providerService = mock(AiProviderService.class);
|
||||
AiCallLogService callLogService = mock(AiCallLogService.class);
|
||||
ScriptContextService scriptContextService = mock(ScriptContextService.class);
|
||||
AiProviderAdapter adapter = mock(AiProviderAdapter.class);
|
||||
|
||||
AiSceneBinding scene = new AiSceneBinding();
|
||||
scene.setSceneCode("script_generate");
|
||||
scene.setEndpointId("endpoint-1");
|
||||
scene.setRequiredStream(1);
|
||||
AiEndpointConfig endpoint = new AiEndpointConfig();
|
||||
endpoint.setId("endpoint-1");
|
||||
endpoint.setEndpointCode("dify.script_generate.chat_messages");
|
||||
endpoint.setProviderId("provider-1");
|
||||
endpoint.setSupportStream(1);
|
||||
AiProvider provider = new AiProvider();
|
||||
provider.setId("provider-1");
|
||||
provider.setProviderCode("dify-local");
|
||||
provider.setProviderType("fake");
|
||||
|
||||
when(sceneBindingService.resolveScene("script_generate")).thenReturn(scene);
|
||||
when(endpointConfigService.getEnabledById("endpoint-1")).thenReturn(endpoint);
|
||||
when(providerService.getEnabledById("provider-1")).thenReturn(provider);
|
||||
when(adapter.supports("fake")).thenReturn(true);
|
||||
when(callLogService.save(any(AiCallLog.class))).thenAnswer(invocation -> {
|
||||
AiCallLog log = invocation.getArgument(0);
|
||||
if (log.getId() == null) {
|
||||
log.setId("log-1");
|
||||
}
|
||||
return true;
|
||||
});
|
||||
when(callLogService.updateById(any(AiCallLog.class))).thenReturn(true);
|
||||
org.mockito.Mockito.doAnswer(invocation -> {
|
||||
@SuppressWarnings("unchecked")
|
||||
java.util.function.Consumer<AiStreamEvent> consumer = invocation.getArgument(3);
|
||||
consumer.accept(AiStreamEvent.delta("完整输出", 1));
|
||||
throw new IllegalStateException("AI_STREAM_CLIENT_DISCONNECTED");
|
||||
}).when(adapter).stream(any(), any(), any(), any());
|
||||
|
||||
AiRuntimeServiceImpl service = new AiRuntimeServiceImpl(
|
||||
sceneBindingService,
|
||||
endpointConfigService,
|
||||
providerService,
|
||||
callLogService,
|
||||
scriptContextService,
|
||||
List.of(adapter)
|
||||
);
|
||||
|
||||
AiRuntimeRequest request = new AiRuntimeRequest();
|
||||
request.setSceneCode("script_generate");
|
||||
request.setUserId("user-1");
|
||||
request.setRequestId("client-request-1");
|
||||
List<AiStreamEvent> events = new ArrayList<>();
|
||||
|
||||
service.invokeStream(request, events::add);
|
||||
|
||||
assertEquals("完整输出", events.stream()
|
||||
.filter(event -> "delta".equals(event.getType()))
|
||||
.map(AiStreamEvent::getContent)
|
||||
.findFirst()
|
||||
.orElse(""));
|
||||
assertTrue(events.stream().anyMatch(event -> "done".equals(event.getType())));
|
||||
assertFalse(events.stream().anyMatch(event -> "error".equals(event.getType())));
|
||||
|
||||
ArgumentCaptor<AiCallLog> captor = ArgumentCaptor.forClass(AiCallLog.class);
|
||||
org.mockito.Mockito.verify(callLogService).updateById(captor.capture());
|
||||
AiCallLog savedLog = captor.getValue();
|
||||
assertEquals("client-request-1", savedLog.getRequestId());
|
||||
assertEquals("success", savedLog.getStatus());
|
||||
assertEquals("完整输出", savedLog.getOutputText());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user