288 lines
10 KiB
C#
288 lines
10 KiB
C#
using System.Net.Http.Headers;
|
||
using System.Text;
|
||
using System.Text.Json;
|
||
using Health.Domain.Entities;
|
||
using Health.Domain.Enums;
|
||
using Health.Infrastructure.AI;
|
||
using Health.Infrastructure.Data;
|
||
using Microsoft.EntityFrameworkCore;
|
||
using Microsoft.Extensions.Configuration;
|
||
|
||
namespace Health.Tests;
|
||
|
||
/// <summary>
|
||
/// AI 智能体集成测试 — 模拟真实用户对话,验证 Tool Calling 与数据库写入
|
||
/// 运行前需确保后端已启动: dotnet run --project src/Health.WebApi
|
||
/// </summary>
|
||
public class AiAgentTests
|
||
{
|
||
private static readonly HttpClient Http = new()
|
||
{
|
||
BaseAddress = new Uri("http://localhost:5000"),
|
||
Timeout = TimeSpan.FromSeconds(120)
|
||
};
|
||
|
||
private static readonly JsonSerializerOptions JsonOpts = new()
|
||
{
|
||
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
|
||
PropertyNameCaseInsensitive = true
|
||
};
|
||
|
||
// ==================== PromptManager 单元测试 ====================
|
||
|
||
[Fact]
|
||
public void PromptManager_Default_Should_Contain_HeartKeywords()
|
||
{
|
||
var pm = new PromptManager();
|
||
var prompt = pm.GetSystemPrompt(AgentType.Default);
|
||
Assert.Contains("心脏", prompt);
|
||
Assert.Contains("健康", prompt);
|
||
}
|
||
|
||
[Fact]
|
||
public void PromptManager_Consultation_Should_Contain_TriageRules()
|
||
{
|
||
var pm = new PromptManager();
|
||
var prompt = pm.GetSystemPrompt(AgentType.Consultation);
|
||
Assert.Contains("剧烈胸痛", prompt);
|
||
Assert.Contains("呼吸困难", prompt);
|
||
Assert.Contains("160/100", prompt);
|
||
}
|
||
|
||
[Fact]
|
||
public void PromptManager_Health_Should_Contain_NormalRanges()
|
||
{
|
||
var pm = new PromptManager();
|
||
var prompt = pm.GetSystemPrompt(AgentType.Health);
|
||
Assert.Contains("139", prompt); // 收缩压上界
|
||
Assert.Contains("89", prompt); // 舒张压下界
|
||
Assert.Contains("100", prompt); // 心率上界
|
||
}
|
||
|
||
[Fact]
|
||
public void PromptManager_Diet_Should_Contain_VlmKeywords()
|
||
{
|
||
var pm = new PromptManager();
|
||
var prompt = pm.GetSystemPrompt(AgentType.Diet);
|
||
Assert.Contains("VLM", prompt);
|
||
Assert.Contains("能不能吃", prompt);
|
||
}
|
||
|
||
[Fact]
|
||
public void PromptManager_Medication_Should_Contain_ParseRules()
|
||
{
|
||
var pm = new PromptManager();
|
||
var prompt = pm.GetSystemPrompt(AgentType.Medication);
|
||
Assert.Contains("药名", prompt);
|
||
Assert.Contains("剂量", prompt);
|
||
}
|
||
|
||
[Fact]
|
||
public void PromptManager_Report_Should_Contain_Disclaimer()
|
||
{
|
||
var pm = new PromptManager();
|
||
var prompt = pm.GetSystemPrompt(AgentType.Report);
|
||
Assert.Contains("AI预解读", prompt);
|
||
Assert.Contains("医生确认", prompt);
|
||
}
|
||
|
||
[Fact]
|
||
public void PromptManager_Exercise_Should_Contain_RehabKeywords()
|
||
{
|
||
var pm = new PromptManager();
|
||
var prompt = pm.GetSystemPrompt(AgentType.Exercise);
|
||
Assert.Contains("心脏康复", prompt);
|
||
Assert.Contains("循序渐进", prompt);
|
||
}
|
||
|
||
[Fact]
|
||
public void PromptManager_All_Agents_Should_Return_NonEmpty()
|
||
{
|
||
var pm = new PromptManager();
|
||
foreach (AgentType agent in Enum.GetValues<AgentType>())
|
||
{
|
||
var prompt = pm.GetSystemPrompt(agent);
|
||
Assert.False(string.IsNullOrWhiteSpace(prompt), $"{agent} 的 SystemPrompt 为空");
|
||
}
|
||
}
|
||
|
||
// ==================== DeepSeekClient 连通性测试 ====================
|
||
|
||
[Fact]
|
||
public void DeepSeekClient_ApiKey_Should_Be_Configured()
|
||
{
|
||
// 仅验证 API Key 已配置(实际调用通过集成测试验证)
|
||
CreateDeepSeekClient();
|
||
var config = new ConfigurationBuilder().AddEnvironmentVariables().Build();
|
||
var apiKey = config["DEEPSEEK_API_KEY"] ?? "";
|
||
Assert.False(string.IsNullOrEmpty(apiKey) || apiKey.StartsWith("sk-your-key"),
|
||
"DEEPSEEK_API_KEY 未配置或为占位符,请在 backend/.env 中设置真实 Key");
|
||
}
|
||
|
||
// ==================== AI 对话 + Tool Calling 集成测试 ====================
|
||
|
||
[Fact]
|
||
public async Task HealthAgent_RecordBloodPressure_Should_SaveToDb()
|
||
{
|
||
// 先登录获取 token
|
||
var token = await LoginAsync("13800000001");
|
||
|
||
// 发送对话消息触发 Tool Calling
|
||
var events = await SendChatMessage(token, "health", "我刚刚测了血压,138/86");
|
||
var toolResults = events.Where(e => e.Action == "tool_result").ToList();
|
||
|
||
Assert.NotEmpty(toolResults);
|
||
}
|
||
|
||
[Fact]
|
||
public async Task HealthAgent_RecordHeartRate_Should_SaveToDb()
|
||
{
|
||
var token = await LoginAsync("13800000001");
|
||
var events = await SendChatMessage(token, "health", "心率72");
|
||
var toolResults = events.Where(e => e.Action == "tool_result").ToList();
|
||
|
||
Assert.NotEmpty(toolResults);
|
||
}
|
||
|
||
[Fact]
|
||
public async Task MedicationAgent_Query_Should_Return_Medications()
|
||
{
|
||
var token = await LoginAsync("13800000001");
|
||
var events = await SendChatMessage(token, "medication", "我现在在吃什么药?");
|
||
|
||
var answers = events.Where(e => e.Action == "answer")
|
||
.Select(e => e.Data?.ToString() ?? "");
|
||
|
||
var fullResponse = string.Join("", answers);
|
||
Assert.NotEmpty(fullResponse);
|
||
// 应该提到阿司匹林或阿托伐他汀
|
||
Assert.True(fullResponse.Contains("阿司匹林") || fullResponse.Contains("阿托伐他汀") ||
|
||
fullResponse.Contains("Aspirin"));
|
||
}
|
||
|
||
[Fact]
|
||
public async Task ConsultationAgent_SymptomCheck_Should_AskFollowUp()
|
||
{
|
||
var token = await LoginAsync("13800000001");
|
||
var events = await SendChatMessage(token, "consultation", "最近胸口有点不舒服");
|
||
|
||
var answers = events.Where(e => e.Action == "answer")
|
||
.Select(e => e.Data?.ToString() ?? "");
|
||
|
||
var fullResponse = string.Join("", answers);
|
||
Assert.NotEmpty(fullResponse);
|
||
}
|
||
|
||
[Fact]
|
||
public async Task DefaultAgent_GeneralQuestion_Should_Respond()
|
||
{
|
||
var token = await LoginAsync("13800000001");
|
||
var events = await SendChatMessage(token, "default", "你好,介绍一下你自己");
|
||
|
||
var answers = events.Where(e => e.Action == "answer")
|
||
.Select(e => e.Data?.ToString() ?? "");
|
||
|
||
var fullResponse = string.Join("", answers);
|
||
Assert.NotEmpty(fullResponse);
|
||
Assert.True(fullResponse.Length > 10, "默认 Agent 应返回有效回复");
|
||
}
|
||
|
||
// ==================== 辅助方法 ====================
|
||
|
||
/// <summary>
|
||
/// 发送验证码 + 登录,返回 accessToken
|
||
/// </summary>
|
||
private static async Task<string> LoginAsync(string phone)
|
||
{
|
||
// 发送验证码
|
||
var smsPayload = JsonSerializer.Serialize(new { phone }, JsonOpts);
|
||
var smsResp = await Http.PostAsync("/api/auth/send-sms",
|
||
new StringContent(smsPayload, Encoding.UTF8, "application/json"));
|
||
var smsJson = JsonDocument.Parse(await smsResp.Content.ReadAsStringAsync());
|
||
var devCode = smsJson.RootElement.GetProperty("data").GetProperty("devCode").GetString()!;
|
||
|
||
// 登录
|
||
var loginPayload = JsonSerializer.Serialize(new { phone, smsCode = devCode }, JsonOpts);
|
||
var loginResp = await Http.PostAsync("/api/auth/login",
|
||
new StringContent(loginPayload, Encoding.UTF8, "application/json"));
|
||
var loginJson = JsonDocument.Parse(await loginResp.Content.ReadAsStringAsync());
|
||
|
||
return loginJson.RootElement.GetProperty("data").GetProperty("accessToken").GetString()!;
|
||
}
|
||
|
||
/// <summary>
|
||
/// 向指定 Agent 发送消息,返回所有 SSE 事件
|
||
/// </summary>
|
||
private static async Task<List<SseEvent>> SendChatMessage(string token, string agentType, string message)
|
||
{
|
||
var url = $"/api/ai/{agentType}/chat?message={Uri.EscapeDataString(message)}&token={Uri.EscapeDataString(token)}";
|
||
|
||
Http.DefaultRequestHeaders.Authorization = null; // token 走 query string
|
||
var response = await Http.GetAsync(url, HttpCompletionOption.ResponseHeadersRead);
|
||
response.EnsureSuccessStatusCode();
|
||
|
||
var events = new List<SseEvent>();
|
||
using var stream = await response.Content.ReadAsStreamAsync();
|
||
using var reader = new StreamReader(stream);
|
||
|
||
string? line;
|
||
while ((line = await reader.ReadLineAsync()) != null)
|
||
{
|
||
if (string.IsNullOrWhiteSpace(line)) continue;
|
||
if (!line.StartsWith("data: ")) continue;
|
||
|
||
var data = line["data: ".Length..];
|
||
if (data == "[DONE]") break;
|
||
|
||
try
|
||
{
|
||
var parsed = JsonSerializer.Deserialize<SseEvent>(data, JsonOpts);
|
||
if (parsed != null) events.Add(parsed);
|
||
}
|
||
catch { /* 跳过无法解析的 chunk */ }
|
||
}
|
||
|
||
return events;
|
||
}
|
||
|
||
/// <summary>
|
||
/// 创建 DeepSeekClient(读取 .env 配置)
|
||
/// </summary>
|
||
private static DeepSeekClient CreateDeepSeekClient()
|
||
{
|
||
// 从测试输出目录向上 5 级找到 backend/.env
|
||
// bin/Debug/net10.0 → Health.Tests → tests → backend
|
||
var baseDir = AppContext.BaseDirectory;
|
||
var envPath = Path.GetFullPath(Path.Combine(baseDir, "..", "..", "..", "..", "..", ".env"));
|
||
|
||
if (File.Exists(envPath))
|
||
{
|
||
foreach (var envLine in File.ReadAllLines(envPath))
|
||
{
|
||
var trimmed = envLine.Trim();
|
||
if (string.IsNullOrEmpty(trimmed) || trimmed.StartsWith('#')) continue;
|
||
var eqIdx = trimmed.IndexOf('=');
|
||
if (eqIdx <= 0) continue;
|
||
var key = trimmed[..eqIdx].Trim();
|
||
var value = trimmed[(eqIdx + 1)..].Trim();
|
||
Environment.SetEnvironmentVariable(key, value);
|
||
}
|
||
}
|
||
|
||
var config = new ConfigurationBuilder().AddEnvironmentVariables().Build();
|
||
var apiKey = config["DEEPSEEK_API_KEY"] ?? "";
|
||
var baseUrl = (config["DEEPSEEK_BASE_URL"] ?? "https://api.deepseek.com/v1").TrimEnd('/') + "/";
|
||
var httpClient = new HttpClient { BaseAddress = new Uri(baseUrl), Timeout = TimeSpan.FromSeconds(120) };
|
||
httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", apiKey);
|
||
return new DeepSeekClient(httpClient, config);
|
||
}
|
||
}
|
||
|
||
/// <summary>SSE 事件模型</summary>
|
||
public class SseEvent
|
||
{
|
||
public string? Action { get; set; }
|
||
public object? Data { get; set; }
|
||
public string? Message { get; set; }
|
||
}
|