test: AI 流式服务单元测试
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,102 @@
|
|||||||
|
package com.emotion.service;
|
||||||
|
|
||||||
|
import com.emotion.dto.request.ai.AiRuntimeRequest;
|
||||||
|
import com.emotion.dto.response.ai.AiStreamEvent;
|
||||||
|
import com.emotion.entity.AiCallLog;
|
||||||
|
import com.emotion.entity.AiEndpointConfig;
|
||||||
|
import com.emotion.entity.AiProvider;
|
||||||
|
import com.emotion.entity.AiSceneBinding;
|
||||||
|
import com.emotion.service.ai.AiProviderAdapter;
|
||||||
|
import com.emotion.service.impl.AiRuntimeServiceImpl;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.mockito.ArgumentCaptor;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
class AiRuntimeServiceImplTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@DisplayName("流式调用已有输出时,末尾异常应保留输出并按成功完成")
|
||||||
|
void invokeStreamRecoversWhenOutputExists() {
|
||||||
|
AiSceneBindingService sceneBindingService = mock(AiSceneBindingService.class);
|
||||||
|
AiEndpointConfigService endpointConfigService = mock(AiEndpointConfigService.class);
|
||||||
|
AiProviderService providerService = mock(AiProviderService.class);
|
||||||
|
AiCallLogService callLogService = mock(AiCallLogService.class);
|
||||||
|
ScriptContextService scriptContextService = mock(ScriptContextService.class);
|
||||||
|
AiProviderAdapter adapter = mock(AiProviderAdapter.class);
|
||||||
|
|
||||||
|
AiSceneBinding scene = new AiSceneBinding();
|
||||||
|
scene.setSceneCode("script_generate");
|
||||||
|
scene.setEndpointId("endpoint-1");
|
||||||
|
scene.setRequiredStream(1);
|
||||||
|
AiEndpointConfig endpoint = new AiEndpointConfig();
|
||||||
|
endpoint.setId("endpoint-1");
|
||||||
|
endpoint.setEndpointCode("dify.script_generate.chat_messages");
|
||||||
|
endpoint.setProviderId("provider-1");
|
||||||
|
endpoint.setSupportStream(1);
|
||||||
|
AiProvider provider = new AiProvider();
|
||||||
|
provider.setId("provider-1");
|
||||||
|
provider.setProviderCode("dify-local");
|
||||||
|
provider.setProviderType("fake");
|
||||||
|
|
||||||
|
when(sceneBindingService.resolveScene("script_generate")).thenReturn(scene);
|
||||||
|
when(endpointConfigService.getEnabledById("endpoint-1")).thenReturn(endpoint);
|
||||||
|
when(providerService.getEnabledById("provider-1")).thenReturn(provider);
|
||||||
|
when(adapter.supports("fake")).thenReturn(true);
|
||||||
|
when(callLogService.save(any(AiCallLog.class))).thenAnswer(invocation -> {
|
||||||
|
AiCallLog log = invocation.getArgument(0);
|
||||||
|
if (log.getId() == null) {
|
||||||
|
log.setId("log-1");
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
when(callLogService.updateById(any(AiCallLog.class))).thenReturn(true);
|
||||||
|
org.mockito.Mockito.doAnswer(invocation -> {
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
java.util.function.Consumer<AiStreamEvent> consumer = invocation.getArgument(3);
|
||||||
|
consumer.accept(AiStreamEvent.delta("完整输出", 1));
|
||||||
|
throw new IllegalStateException("AI_STREAM_CLIENT_DISCONNECTED");
|
||||||
|
}).when(adapter).stream(any(), any(), any(), any());
|
||||||
|
|
||||||
|
AiRuntimeServiceImpl service = new AiRuntimeServiceImpl(
|
||||||
|
sceneBindingService,
|
||||||
|
endpointConfigService,
|
||||||
|
providerService,
|
||||||
|
callLogService,
|
||||||
|
scriptContextService,
|
||||||
|
List.of(adapter)
|
||||||
|
);
|
||||||
|
|
||||||
|
AiRuntimeRequest request = new AiRuntimeRequest();
|
||||||
|
request.setSceneCode("script_generate");
|
||||||
|
request.setUserId("user-1");
|
||||||
|
request.setRequestId("client-request-1");
|
||||||
|
List<AiStreamEvent> events = new ArrayList<>();
|
||||||
|
|
||||||
|
service.invokeStream(request, events::add);
|
||||||
|
|
||||||
|
assertEquals("完整输出", events.stream()
|
||||||
|
.filter(event -> "delta".equals(event.getType()))
|
||||||
|
.map(AiStreamEvent::getContent)
|
||||||
|
.findFirst()
|
||||||
|
.orElse(""));
|
||||||
|
assertTrue(events.stream().anyMatch(event -> "done".equals(event.getType())));
|
||||||
|
assertFalse(events.stream().anyMatch(event -> "error".equals(event.getType())));
|
||||||
|
|
||||||
|
ArgumentCaptor<AiCallLog> captor = ArgumentCaptor.forClass(AiCallLog.class);
|
||||||
|
org.mockito.Mockito.verify(callLogService).updateById(captor.capture());
|
||||||
|
AiCallLog savedLog = captor.getValue();
|
||||||
|
assertEquals("client-request-1", savedLog.getRequestId());
|
||||||
|
assertEquals("success", savedLog.getStatus());
|
||||||
|
assertEquals("完整输出", savedLog.getOutputText());
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user