Files
AI-Health/backend/src/Health.WebApi/Endpoints/ai_chat_endpoints.cs

594 lines
28 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using System.Drawing;
using System.Drawing.Imaging;
using Health.Infrastructure.AI;
namespace Health.WebApi.Endpoints;
/// <summary>
/// AI 对话 SSE 端点——支持 7 个 Agent
/// </summary>
public static class AiChatEndpoints
{
private static readonly JsonSerializerOptions JsonOpts = new()
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
PropertyNameCaseInsensitive = true,
};
public static void MapAiChatEndpoints(this WebApplication app)
{
// SSE 流式对话GET 方式token 通过 query string 传递)
app.MapGet("/api/ai/{agentType}/chat", async (
string message,
string? conversationId,
string token,
string agentType,
HttpContext http,
AppDbContext db,
DeepSeekClient llmClient,
PromptManager promptManager,
CancellationToken ct) =>
{
// 支持 token 通过 query string浏览器 EventSource或 header 传递
var userId = GetUserId(http) ?? GetUserIdFromToken(token);
if (userId == null)
{
http.Response.StatusCode = 401;
http.Response.ContentType = "application/json";
await http.Response.WriteAsync(JsonSerializer.Serialize(new { code = 40002, data = (object?)null, message = "未登录" }), ct);
return;
}
if (!Enum.TryParse<AgentType>(agentType, ignoreCase: true, out var parsedType))
parsedType = AgentType.Default;
// SSE 响应头
http.Response.ContentType = "text/event-stream";
http.Response.Headers.CacheControl = "no-cache";
http.Response.Headers.Connection = "keep-alive";
http.Response.Headers["X-Accel-Buffering"] = "no";
// 创建或获取对话
Conversation? conversation = null;
if (!string.IsNullOrEmpty(conversationId) && Guid.TryParse(conversationId, out var convId))
conversation = await db.Conversations.FindAsync([convId], ct);
if (conversation == null)
{
conversation = new Conversation
{
Id = Guid.NewGuid(), UserId = userId.Value, AgentType = parsedType,
Title = message.Length > 30 ? message[..30] : message,
CreatedAt = DateTime.UtcNow, UpdatedAt = DateTime.UtcNow,
};
db.Conversations.Add(conversation);
await db.SaveChangesAsync(ct);
await SseWriteAsync(http, new { action = "conversation_id", data = conversation.Id.ToString() }, ct);
}
// 保存用户消息
var userMsg = new ConversationMessage
{
Id = Guid.NewGuid(), ConversationId = conversation.Id, Role = MessageRole.User,
Content = message, CreatedAt = DateTime.UtcNow,
};
db.ConversationMessages.Add(userMsg);
conversation.MessageCount++;
conversation.UpdatedAt = DateTime.UtcNow;
await db.SaveChangesAsync(ct);
// 加载上下文
var systemPrompt = promptManager.GetSystemPrompt(parsedType);
var patientContext = await BuildPatientContext(db, userId.Value, ct);
var messages = new List<ChatMessage>
{
new() { Role = "system", Content = systemPrompt + "\n\n当前患者信息\n" + patientContext },
};
// 加载历史对话(最近 10 条)
var history = await db.ConversationMessages
.Where(m => m.ConversationId == conversation.Id)
.OrderByDescending(m => m.CreatedAt)
.Take(12)
.ToListAsync(ct);
foreach (var h in history.Reverse<ConversationMessage>())
{
messages.Add(new ChatMessage
{
Role = h.Role == MessageRole.User ? "user" : "assistant",
Content = h.Content,
});
}
// Tool Calling 循环
var tools = GetToolsForAgent(parsedType);
var maxIterations = 5;
var fullResponse = "";
var completedNormally = false;
for (int i = 0; i < maxIterations; i++)
{
await SseWriteAsync(http, new { action = "notice", message = i == 0 ? "正在分析..." : "正在处理..." }, ct);
var response = await llmClient.ChatAsync(messages, tools: tools.Count > 0 ? tools : null, ct: ct);
var choice = response.Choices?.FirstOrDefault();
if (choice == null) break;
if (choice.FinishReason == "stop")
{
// 流式输出最终回复(带上完整的 tool call 历史,方便 LLM 利用工具结果生成回复)
await foreach (var chunk in llmClient.ChatStreamAsync(messages, tools: null, ct: ct))
{
try
{
var delta = JsonSerializer.Deserialize<ChatCompletionResponse>(chunk, JsonOpts);
var content = delta?.Choices?.FirstOrDefault()?.Delta?.Content;
if (!string.IsNullOrEmpty(content))
{
fullResponse += content;
await SseWriteAsync(http, new { action = "answer", data = content }, ct);
}
}
catch (JsonException) { /* 跳过解析失败的 chunk */ }
}
completedNormally = true;
break;
}
else if (choice.FinishReason == "tool_calls" && choice.Message?.ToolCalls != null)
{
// 一条 assistant 消息包含所有 tool calls符合 OpenAI 协议)
messages.Add(new ChatMessage
{
Role = "assistant",
Content = choice.Message.Content ?? "",
ToolCalls = choice.Message.ToolCalls,
});
foreach (var tc in choice.Message.ToolCalls)
{
object toolResult;
try
{
toolResult = await ExecuteToolCall(tc.Function.Name, tc.Function.Arguments, db, userId.Value);
}
catch (Exception ex)
{
toolResult = new { success = false, message = $"工具执行异常: {ex.Message}" };
}
await SseWriteAsync(http, new { action = "tool_result", tool = tc.Function.Name, data = toolResult }, ct);
messages.Add(new ChatMessage { Role = "tool", Content = JsonSerializer.Serialize(toolResult, JsonOpts), ToolCallId = tc.Id });
}
}
else break;
}
// 保存 AI 回复
if (!string.IsNullOrEmpty(fullResponse))
{
db.ConversationMessages.Add(new ConversationMessage
{
Id = Guid.NewGuid(), ConversationId = conversation.Id, Role = MessageRole.Assistant,
Content = fullResponse, CreatedAt = DateTime.UtcNow,
});
conversation.MessageCount++;
conversation.Summary = fullResponse.Length > 100 ? fullResponse[..100] : fullResponse;
conversation.UpdatedAt = DateTime.UtcNow;
await db.SaveChangesAsync(ct);
}
await SseWriteAsync(http, new { action = "status", data = completedNormally ? "done" : "error" }, ct);
await http.Response.WriteAsync("data: [DONE]\n\n", ct);
});
// 获取对话列表
app.MapGet("/api/ai/conversations", async (HttpContext http, AppDbContext db, CancellationToken ct) =>
{
var userId = GetUserId(http);
if (userId == null) return Results.Json(new { code = 40002, data = (object?)null, message = "未登录" }, statusCode: 401);
var conversations = await db.Conversations
.Where(c => c.UserId == userId.Value)
.OrderByDescending(c => c.UpdatedAt)
.Select(c => new { c.Id, AgentType = c.AgentType.ToString(), c.Title, c.Summary, c.MessageCount, c.CreatedAt, c.UpdatedAt })
.ToListAsync(ct);
return Results.Ok(new { code = 0, data = conversations, message = (string?)null });
});
// 获取对话历史
app.MapGet("/api/ai/conversations/{id:guid}", async (Guid id, HttpContext http, AppDbContext db, CancellationToken ct) =>
{
var userId = GetUserId(http);
if (userId == null) return Results.Json(new { code = 40002 }, statusCode: 401);
var messages = await db.ConversationMessages
.Where(m => m.ConversationId == id && m.Conversation.UserId == userId.Value)
.OrderBy(m => m.CreatedAt)
.Select(m => new { m.Id, Role = m.Role.ToString(), m.Content, m.Intent, m.MetadataJson, m.CreatedAt })
.ToListAsync(ct);
return Results.Ok(new { code = 0, data = messages, message = (string?)null });
});
// 删除对话
app.MapDelete("/api/ai/conversations/{id:guid}", async (Guid id, HttpContext http, AppDbContext db, CancellationToken ct) =>
{
var userId = GetUserId(http);
if (userId == null) return Results.Json(new { code = 40002 }, statusCode: 401);
var conv = await db.Conversations.FirstOrDefaultAsync(c => c.Id == id && c.UserId == userId.Value, ct);
if (conv != null)
{
db.Conversations.Remove(conv);
await db.SaveChangesAsync(ct);
}
return Results.Ok(new { code = 0, data = new { success = true }, message = (string?)null });
});
// VLM 食物识别
app.MapPost("/api/ai/analyze-food-image", async (
HttpRequest httpRequest, HttpContext http,
QwenVisionClient visionClient, AppDbContext db,
CancellationToken ct) =>
{
var userId = GetUserId(http);
if (userId == null) return Results.Json(new { code = 40002 }, statusCode: 401);
var form = await httpRequest.ReadFormAsync(ct);
var files = form.Files.GetFiles("images");
if (files == null || files.Count == 0)
return Results.Ok(new { code = 40001, data = (object?)null, message = "请上传至少一张图片" });
var imageUrls = new List<string>();
var uploadsDir = Path.Combine(Directory.GetCurrentDirectory(), "uploads");
Directory.CreateDirectory(uploadsDir);
foreach (var file in files)
{
if (file.Length > 20 * 1024 * 1024)
return Results.Ok(new { code = 40001, data = (object?)null, message = "文件大小超过 20MB 限制" });
var ext = Path.GetExtension(file.FileName).ToLowerInvariant();
if (ext is not ".jpg" and not ".jpeg" and not ".png" and not ".heic")
return Results.Ok(new { code = 40001, data = (object?)null, message = "不支持的图片格式,仅支持 JPG/PNG/HEIC" });
var safeName = $"{Guid.NewGuid()}_{Path.GetFileName(file.FileName)}";
var filePath = Path.Combine(uploadsDir, safeName);
using (var stream = new FileStream(filePath, FileMode.Create))
await file.CopyToAsync(stream, ct);
// 压缩图片后转 base64VLM API 有请求体大小限制)
var compressedPath = Path.Combine(uploadsDir, $"compressed_{safeName}");
CompressImage(filePath, compressedPath, maxWidth: 2048, quality: 92L);
var compressedBytes = await File.ReadAllBytesAsync(compressedPath, ct);
var base64 = Convert.ToBase64String(compressedBytes);
imageUrls.Add($"data:image/jpeg;base64,{base64}");
}
var prompt = """
JSON
{
"foods": [{"name":"食物名","portion":"份量","calories":}]
}
""";
try
{
var response = await visionClient.VisionAsync(prompt, imageUrls, ct: ct);
var result = response.Choices?.FirstOrDefault()?.Message?.Content ?? "{}";
return Results.Ok(new { code = 0, data = result, message = (string?)null });
}
catch (Exception ex)
{
return Results.Ok(new { code = 50001, data = (object?)null, message = $"食物识别失败:{ex.Message}" });
}
});
}
private static async Task SseWriteAsync(HttpContext http, object data, CancellationToken ct)
{
var json = JsonSerializer.Serialize(data, JsonOpts);
await http.Response.WriteAsync($"data: {json}\n\n", ct);
await http.Response.Body.FlushAsync(ct);
}
private static Guid? GetUserId(HttpContext http) =>
Guid.TryParse(http.User.FindFirst(System.Security.Claims.ClaimTypes.NameIdentifier)?.Value, out var id) ? id : null;
/// 从 query string token 解析用户 ID浏览器 EventSource 用)
private static Guid? GetUserIdFromToken(string? token)
{
if (string.IsNullOrEmpty(token)) return null;
try
{
var handler = new System.IdentityModel.Tokens.Jwt.JwtSecurityTokenHandler();
var jwt = handler.ReadJwtToken(token);
var sub = jwt.Claims.FirstOrDefault(c => c.Type == System.Security.Claims.ClaimTypes.NameIdentifier)?.Value;
return sub != null && Guid.TryParse(sub, out var id) ? id : null;
}
catch (Exception) { return null; }
}
private static List<ToolDefinition> GetToolsForAgent(AgentType agentType) => agentType switch
{
AgentType.Health => [RecordHealthDataTool, QueryHealthRecordsTool],
AgentType.Medication => [ManageMedicationTool, CheckArchiveTool],
AgentType.Diet => [EstimateFoodTool, CheckArchiveTool],
AgentType.Consultation => [QueryHealthRecordsTool, CheckArchiveTool, RequestDoctorTool],
AgentType.Report => [AnalyzeReportTool, QueryHealthRecordsTool],
AgentType.Exercise => [ManageExerciseTool],
_ => [QueryHealthRecordsTool, CheckArchiveTool],
};
private static async Task<object> ExecuteToolCall(string toolName, string arguments, AppDbContext db, Guid userId)
{
using var jsonDoc = JsonDocument.Parse(arguments);
var root = jsonDoc.RootElement;
return toolName switch
{
"record_health_data" => await ExecuteRecordHealthData(db, userId, root),
"query_health_records" => await ExecuteQueryHealthRecords(db, userId, root),
"check_archive" => await ExecuteCheckArchive(db, userId),
"manage_medication" => await ExecuteManageMedication(db, userId, root),
"manage_exercise" => await ExecuteManageExercise(db, userId, root),
_ => new { success = false, message = $"未知工具: {toolName}" }
};
}
private static async Task<object> ExecuteRecordHealthData(AppDbContext db, Guid userId, JsonElement args)
{
var type = args.TryGetProperty("type", out var t) ? t.GetString()! : "";
var record = new HealthRecord
{
Id = Guid.NewGuid(), UserId = userId, Source = HealthRecordSource.AiEntry,
RecordedAt = args.TryGetProperty("recorded_at", out var ra) && ra.TryGetDateTime(out var dt) ? dt : DateTime.UtcNow,
CreatedAt = DateTime.UtcNow,
};
switch (type)
{
case "blood_pressure":
record.MetricType = HealthMetricType.BloodPressure;
record.Systolic = args.TryGetProperty("systolic", out var s) ? s.GetInt32() : null;
record.Diastolic = args.TryGetProperty("diastolic", out var d) ? d.GetInt32() : null;
record.Unit = "mmHg";
record.IsAbnormal = record.Systolic >= 140 || record.Diastolic >= 90 || record.Systolic <= 89 || record.Diastolic <= 59;
break;
case "heart_rate":
record.MetricType = HealthMetricType.HeartRate;
record.Value = args.TryGetProperty("heart_rate", out var hr) ? hr.GetDecimal() : null;
record.Unit = "次/分";
record.IsAbnormal = record.Value > 100 || record.Value < 60;
break;
case "glucose":
record.MetricType = HealthMetricType.Glucose;
record.Value = args.TryGetProperty("glucose", out var g) ? g.GetDecimal() : null;
record.Unit = "mmol/L";
record.IsAbnormal = record.Value >= 7.0m || record.Value <= 3.8m;
break;
case "spo2":
record.MetricType = HealthMetricType.SpO2;
record.Value = args.TryGetProperty("spo2", out var o) ? o.GetDecimal() : null;
record.Unit = "%";
record.IsAbnormal = record.Value <= 94;
break;
case "weight":
record.MetricType = HealthMetricType.Weight;
record.Value = args.TryGetProperty("weight", out var w) ? w.GetDecimal() : null;
record.Unit = "kg";
break;
default:
return new { success = false, message = $"未知指标类型: {type}" };
}
db.HealthRecords.Add(record);
await db.SaveChangesAsync();
return new { success = true, record_id = record.Id, type = record.MetricType.ToString() };
}
private static async Task<object> ExecuteQueryHealthRecords(AppDbContext db, Guid userId, JsonElement args)
{
var type = args.TryGetProperty("type", out var t) ? t.GetString() : null;
var days = args.TryGetProperty("days", out var d) ? d.GetInt32() : 7;
var query = db.HealthRecords.Where(r => r.UserId == userId);
if (!string.IsNullOrEmpty(type) && Enum.TryParse<HealthMetricType>(type, ignoreCase: true, out var mt))
query = query.Where(r => r.MetricType == mt);
query = query.Where(r => r.RecordedAt >= DateTime.UtcNow.AddDays(-days));
var records = await query.OrderByDescending(r => r.RecordedAt).Take(30).Select(r => new
{
r.Id, Type = r.MetricType.ToString(), r.Systolic, r.Diastolic, r.Value, r.Unit, r.IsAbnormal, r.RecordedAt,
}).ToListAsync();
return new { count = records.Count, records };
}
private static async Task<object> ExecuteCheckArchive(AppDbContext db, Guid userId)
{
var archive = await db.HealthArchives.FirstOrDefaultAsync(a => a.UserId == userId);
if (archive == null) return new { found = false };
return new
{
found = true, archive.Diagnosis, archive.SurgeryType,
SurgeryDate = archive.SurgeryDate?.ToString("yyyy-MM-dd"),
archive.Allergies, archive.DietRestrictions, archive.ChronicDiseases, archive.FamilyHistory,
};
}
private static async Task<object> ExecuteManageMedication(AppDbContext db, Guid userId, JsonElement args)
{
var action = args.TryGetProperty("action", out var a) ? a.GetString()! : "query";
return action switch
{
"create" => await CreateMedication(db, userId, args),
"query" => await QueryMedications(db, userId),
"confirm" => await ConfirmMedication(db, userId, args),
_ => new { success = false, message = $"未知操作: {action}" }
};
}
private static async Task<object> CreateMedication(AppDbContext db, Guid userId, JsonElement args)
{
var med = new Medication
{
Id = Guid.NewGuid(), UserId = userId,
Name = args.TryGetProperty("name", out var n) ? n.GetString()! : "",
Dosage = args.TryGetProperty("dosage", out var dg) ? dg.GetString() : null,
Source = MedicationSource.AiEntry, IsActive = true,
};
db.Medications.Add(med);
await db.SaveChangesAsync();
return new { success = true, medication_id = med.Id, med.Name };
}
private static async Task<object> QueryMedications(AppDbContext db, Guid userId)
{
var meds = await db.Medications.Where(m => m.UserId == userId && m.IsActive)
.Select(m => new { m.Id, m.Name, m.Dosage, m.TimeOfDay }).ToListAsync();
return new { count = meds.Count, medications = meds };
}
private static async Task<object> ConfirmMedication(AppDbContext db, Guid userId, JsonElement args)
{
var medId = args.TryGetProperty("medication_id", out var mid) ? mid.GetGuid() : Guid.Empty;
db.MedicationLogs.Add(new MedicationLog
{
Id = Guid.NewGuid(), MedicationId = medId, UserId = userId,
Status = MedicationLogStatus.Taken, ScheduledTime = TimeOnly.FromDateTime(DateTime.Now), ConfirmedAt = DateTime.UtcNow,
});
await db.SaveChangesAsync();
return new { success = true };
}
private static async Task<object> ExecuteManageExercise(AppDbContext db, Guid userId, JsonElement args)
{
var action = args.TryGetProperty("action", out var a) ? a.GetString()! : "query";
if (action != "query") return new { success = false, message = "运动计划管理暂未实现" };
var plan = await db.ExercisePlans.Where(p => p.UserId == userId)
.OrderByDescending(p => p.WeekStartDate).FirstOrDefaultAsync();
if (plan == null) return new { found = false };
var items = await db.ExercisePlanItems.Where(i => i.PlanId == plan.Id).OrderBy(i => i.DayOfWeek).ToListAsync();
return new { found = true, plan_id = plan.Id, items = items.Select(i => new { i.DayOfWeek, i.ExerciseType, i.DurationMinutes, i.IsCompleted }) };
}
private static async Task<string> BuildPatientContext(AppDbContext db, Guid userId, CancellationToken ct)
{
var archive = await db.HealthArchives.FirstOrDefaultAsync(a => a.UserId == userId, ct);
var recentRecords = await db.HealthRecords.Where(r => r.UserId == userId)
.OrderByDescending(r => r.RecordedAt).Take(10).ToListAsync(ct);
var sb = new System.Text.StringBuilder();
if (archive != null)
{
if (!string.IsNullOrEmpty(archive.Diagnosis)) sb.AppendLine($"诊断: {archive.Diagnosis}");
if (!string.IsNullOrEmpty(archive.SurgeryType)) sb.AppendLine($"手术: {archive.SurgeryType} ({archive.SurgeryDate})");
if (archive.Allergies.Count > 0) sb.AppendLine($"过敏: {string.Join(", ", archive.Allergies)}");
if (archive.DietRestrictions.Count > 0) sb.AppendLine($"饮食限制: {string.Join(", ", archive.DietRestrictions)}");
}
if (recentRecords.Count > 0)
{
sb.AppendLine("近期健康数据:");
foreach (var r in recentRecords)
sb.AppendLine($" {r.MetricType}: {RecordValue(r)} ({r.RecordedAt:MM-dd HH:mm})");
}
return sb.ToString();
}
private static string RecordValue(HealthRecord r) => r.MetricType switch
{
HealthMetricType.BloodPressure => $"{r.Systolic}/{r.Diastolic}",
HealthMetricType.HeartRate => $"{r.Value}次/分",
HealthMetricType.Glucose => $"{r.Value}",
HealthMetricType.SpO2 => $"{r.Value}%",
HealthMetricType.Weight => $"{r.Value}kg",
_ => "—"
};
// ---- Tool Definitions ----
private static readonly ToolDefinition RecordHealthDataTool = new()
{
Function = new()
{
Name = "record_health_data", Description = "记录健康数据(血压/心率/血糖/血氧/体重)",
Parameters = new { type = "object", properties = new { type = new { type = "string" }, systolic = new { type = "integer" }, diastolic = new { type = "integer" }, heart_rate = new { type = "number" }, glucose = new { type = "number" }, spo2 = new { type = "number" }, weight = new { type = "number" } }, required = new[] { "type" } }
}
};
private static readonly ToolDefinition QueryHealthRecordsTool = new()
{
Function = new()
{
Name = "query_health_records", Description = "查询近期健康数据",
Parameters = new { type = "object", properties = new { type = new { type = "string" }, days = new { type = "integer" } } }
}
};
private static readonly ToolDefinition CheckArchiveTool = new()
{
Function = new() { Name = "check_archive", Description = "查询患者健康档案", Parameters = new { type = "object", properties = new { } } }
};
private static readonly ToolDefinition ManageMedicationTool = new()
{
Function = new()
{
Name = "manage_medication", Description = "用药管理",
Parameters = new { type = "object", properties = new { action = new { type = "string" }, name = new { type = "string" }, dosage = new { type = "string" } }, required = new[] { "action" } }
}
};
private static readonly ToolDefinition ManageExerciseTool = new()
{
Function = new()
{
Name = "manage_exercise", Description = "运动计划管理",
Parameters = new { type = "object", properties = new { action = new { type = "string" } }, required = new[] { "action" } }
}
};
private static readonly ToolDefinition EstimateFoodTool = new()
{
Function = new() { Name = "estimate_food_text", Description = "根据文字描述估算食物份量和热量", Parameters = new { type = "object", properties = new { text = new { type = "string" } }, required = new[] { "text" } } }
};
private static readonly ToolDefinition AnalyzeReportTool = new()
{
Function = new() { Name = "analyze_report", Description = "分析报告图片", Parameters = new { type = "object", properties = new { image_url = new { type = "string" } }, required = new[] { "image_url" } } }
};
private static readonly ToolDefinition RequestDoctorTool = new()
{
Function = new()
{
Name = "request_doctor", Description = "请求转接真人医生",
Parameters = new { type = "object", properties = new { reason = new { type = "string" }, urgency_level = new { type = "string" } } }
}
};
/// <summary>压缩图片到合理大小供 VLM API 使用</summary>
private static void CompressImage(string inputPath, string outputPath, int maxWidth, long quality)
{
using var image = Image.FromFile(inputPath);
var width = image.Width;
var height = image.Height;
if (width > maxWidth)
{
height = (int)((double)height / width * maxWidth);
width = maxWidth;
}
using var bitmap = new Bitmap(width, height);
using var graphics = Graphics.FromImage(bitmap);
graphics.InterpolationMode = System.Drawing.Drawing2D.InterpolationMode.HighQualityBicubic;
graphics.DrawImage(image, 0, 0, width, height);
var jpegCodec = ImageCodecInfo.GetImageEncoders().First(c => c.MimeType == "image/jpeg");
var parameters = new EncoderParameters(1);
parameters.Param[0] = new EncoderParameter(Encoder.Quality, quality);
bitmap.Save(outputPath, jpegCodec, parameters);
}
}
/// <summary>AI 对话请求</summary>
public sealed record ChatRequest(string Message, string? ConversationId);