Files
AI-Health/backend/tests/Health.Tests/ai_agent_tests.cs

288 lines
10 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using System.Net.Http.Headers;
using System.Text;
using System.Text.Json;
using Health.Domain.Entities;
using Health.Domain.Enums;
using Health.Infrastructure.AI;
using Health.Infrastructure.Data;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Configuration;
namespace Health.Tests;
/// <summary>
/// AI 智能体集成测试 — 模拟真实用户对话,验证 Tool Calling 与数据库写入
/// 运行前需确保后端已启动: dotnet run --project src/Health.WebApi
/// </summary>
public class AiAgentTests
{
private static readonly HttpClient Http = new()
{
BaseAddress = new Uri("http://localhost:5000"),
Timeout = TimeSpan.FromSeconds(120)
};
private static readonly JsonSerializerOptions JsonOpts = new()
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
PropertyNameCaseInsensitive = true
};
// ==================== PromptManager 单元测试 ====================
[Fact]
public void PromptManager_Default_Should_Contain_HeartKeywords()
{
var pm = new PromptManager();
var prompt = pm.GetSystemPrompt(AgentType.Default);
Assert.Contains("心脏", prompt);
Assert.Contains("健康", prompt);
}
[Fact]
public void PromptManager_Consultation_Should_Contain_TriageRules()
{
var pm = new PromptManager();
var prompt = pm.GetSystemPrompt(AgentType.Consultation);
Assert.Contains("剧烈胸痛", prompt);
Assert.Contains("呼吸困难", prompt);
Assert.Contains("160/100", prompt);
}
[Fact]
public void PromptManager_Health_Should_Contain_NormalRanges()
{
var pm = new PromptManager();
var prompt = pm.GetSystemPrompt(AgentType.Health);
Assert.Contains("139", prompt); // 收缩压上界
Assert.Contains("89", prompt); // 舒张压下界
Assert.Contains("100", prompt); // 心率上界
}
[Fact]
public void PromptManager_Diet_Should_Contain_VlmKeywords()
{
var pm = new PromptManager();
var prompt = pm.GetSystemPrompt(AgentType.Diet);
Assert.Contains("VLM", prompt);
Assert.Contains("能不能吃", prompt);
}
[Fact]
public void PromptManager_Medication_Should_Contain_ParseRules()
{
var pm = new PromptManager();
var prompt = pm.GetSystemPrompt(AgentType.Medication);
Assert.Contains("药名", prompt);
Assert.Contains("剂量", prompt);
}
[Fact]
public void PromptManager_Report_Should_Contain_Disclaimer()
{
var pm = new PromptManager();
var prompt = pm.GetSystemPrompt(AgentType.Report);
Assert.Contains("AI预解读", prompt);
Assert.Contains("医生确认", prompt);
}
[Fact]
public void PromptManager_Exercise_Should_Contain_RehabKeywords()
{
var pm = new PromptManager();
var prompt = pm.GetSystemPrompt(AgentType.Exercise);
Assert.Contains("心脏康复", prompt);
Assert.Contains("循序渐进", prompt);
}
[Fact]
public void PromptManager_All_Agents_Should_Return_NonEmpty()
{
var pm = new PromptManager();
foreach (AgentType agent in Enum.GetValues<AgentType>())
{
var prompt = pm.GetSystemPrompt(agent);
Assert.False(string.IsNullOrWhiteSpace(prompt), $"{agent} 的 SystemPrompt 为空");
}
}
// ==================== DeepSeekClient 连通性测试 ====================
[Fact]
public void DeepSeekClient_ApiKey_Should_Be_Configured()
{
// 仅验证 API Key 已配置(实际调用通过集成测试验证)
CreateDeepSeekClient();
var config = new ConfigurationBuilder().AddEnvironmentVariables().Build();
var apiKey = config["DEEPSEEK_API_KEY"] ?? "";
Assert.False(string.IsNullOrEmpty(apiKey) || apiKey.StartsWith("sk-your-key"),
"DEEPSEEK_API_KEY 未配置或为占位符,请在 backend/.env 中设置真实 Key");
}
// ==================== AI 对话 + Tool Calling 集成测试 ====================
[Fact]
public async Task HealthAgent_RecordBloodPressure_Should_SaveToDb()
{
// 先登录获取 token
var token = await LoginAsync("13800000001");
// 发送对话消息触发 Tool Calling
var events = await SendChatMessage(token, "health", "我刚刚测了血压138/86");
var toolResults = events.Where(e => e.Action == "tool_result").ToList();
Assert.NotEmpty(toolResults);
}
[Fact]
public async Task HealthAgent_RecordHeartRate_Should_SaveToDb()
{
var token = await LoginAsync("13800000001");
var events = await SendChatMessage(token, "health", "心率72");
var toolResults = events.Where(e => e.Action == "tool_result").ToList();
Assert.NotEmpty(toolResults);
}
[Fact]
public async Task MedicationAgent_Query_Should_Return_Medications()
{
var token = await LoginAsync("13800000001");
var events = await SendChatMessage(token, "medication", "我现在在吃什么药?");
var answers = events.Where(e => e.Action == "answer")
.Select(e => e.Data?.ToString() ?? "");
var fullResponse = string.Join("", answers);
Assert.NotEmpty(fullResponse);
// 应该提到阿司匹林或阿托伐他汀
Assert.True(fullResponse.Contains("阿司匹林") || fullResponse.Contains("阿托伐他汀") ||
fullResponse.Contains("Aspirin"));
}
[Fact]
public async Task ConsultationAgent_SymptomCheck_Should_AskFollowUp()
{
var token = await LoginAsync("13800000001");
var events = await SendChatMessage(token, "consultation", "最近胸口有点不舒服");
var answers = events.Where(e => e.Action == "answer")
.Select(e => e.Data?.ToString() ?? "");
var fullResponse = string.Join("", answers);
Assert.NotEmpty(fullResponse);
}
[Fact]
public async Task DefaultAgent_GeneralQuestion_Should_Respond()
{
var token = await LoginAsync("13800000001");
var events = await SendChatMessage(token, "default", "你好,介绍一下你自己");
var answers = events.Where(e => e.Action == "answer")
.Select(e => e.Data?.ToString() ?? "");
var fullResponse = string.Join("", answers);
Assert.NotEmpty(fullResponse);
Assert.True(fullResponse.Length > 10, "默认 Agent 应返回有效回复");
}
// ==================== 辅助方法 ====================
/// <summary>
/// 发送验证码 + 登录,返回 accessToken
/// </summary>
private static async Task<string> LoginAsync(string phone)
{
// 发送验证码
var smsPayload = JsonSerializer.Serialize(new { phone }, JsonOpts);
var smsResp = await Http.PostAsync("/api/auth/send-sms",
new StringContent(smsPayload, Encoding.UTF8, "application/json"));
var smsJson = JsonDocument.Parse(await smsResp.Content.ReadAsStringAsync());
var devCode = smsJson.RootElement.GetProperty("data").GetProperty("devCode").GetString()!;
// 登录
var loginPayload = JsonSerializer.Serialize(new { phone, smsCode = devCode }, JsonOpts);
var loginResp = await Http.PostAsync("/api/auth/login",
new StringContent(loginPayload, Encoding.UTF8, "application/json"));
var loginJson = JsonDocument.Parse(await loginResp.Content.ReadAsStringAsync());
return loginJson.RootElement.GetProperty("data").GetProperty("accessToken").GetString()!;
}
/// <summary>
/// 向指定 Agent 发送消息,返回所有 SSE 事件
/// </summary>
private static async Task<List<SseEvent>> SendChatMessage(string token, string agentType, string message)
{
var url = $"/api/ai/{agentType}/chat?message={Uri.EscapeDataString(message)}&token={Uri.EscapeDataString(token)}";
Http.DefaultRequestHeaders.Authorization = null; // token 走 query string
var response = await Http.GetAsync(url, HttpCompletionOption.ResponseHeadersRead);
response.EnsureSuccessStatusCode();
var events = new List<SseEvent>();
using var stream = await response.Content.ReadAsStreamAsync();
using var reader = new StreamReader(stream);
string? line;
while ((line = await reader.ReadLineAsync()) != null)
{
if (string.IsNullOrWhiteSpace(line)) continue;
if (!line.StartsWith("data: ")) continue;
var data = line["data: ".Length..];
if (data == "[DONE]") break;
try
{
var parsed = JsonSerializer.Deserialize<SseEvent>(data, JsonOpts);
if (parsed != null) events.Add(parsed);
}
catch { /* 跳过无法解析的 chunk */ }
}
return events;
}
/// <summary>
/// 创建 DeepSeekClient读取 .env 配置)
/// </summary>
private static DeepSeekClient CreateDeepSeekClient()
{
// 从测试输出目录向上 5 级找到 backend/.env
// bin/Debug/net10.0 → Health.Tests → tests → backend
var baseDir = AppContext.BaseDirectory;
var envPath = Path.GetFullPath(Path.Combine(baseDir, "..", "..", "..", "..", "..", ".env"));
if (File.Exists(envPath))
{
foreach (var envLine in File.ReadAllLines(envPath))
{
var trimmed = envLine.Trim();
if (string.IsNullOrEmpty(trimmed) || trimmed.StartsWith('#')) continue;
var eqIdx = trimmed.IndexOf('=');
if (eqIdx <= 0) continue;
var key = trimmed[..eqIdx].Trim();
var value = trimmed[(eqIdx + 1)..].Trim();
Environment.SetEnvironmentVariable(key, value);
}
}
var config = new ConfigurationBuilder().AddEnvironmentVariables().Build();
var apiKey = config["DEEPSEEK_API_KEY"] ?? "";
var baseUrl = (config["DEEPSEEK_BASE_URL"] ?? "https://api.deepseek.com/v1").TrimEnd('/') + "/";
var httpClient = new HttpClient { BaseAddress = new Uri(baseUrl), Timeout = TimeSpan.FromSeconds(120) };
httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", apiKey);
return new DeepSeekClient(httpClient, config);
}
}
/// <summary>SSE 事件模型</summary>
public class SseEvent
{
public string? Action { get; set; }
public object? Data { get; set; }
public string? Message { get; set; }
}