Files
AI-Health/health_app/lib/providers/chat_provider.dart

272 lines
8.0 KiB
Dart
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import 'dart:async';
import 'package:flutter_riverpod/flutter_riverpod.dart';
import 'auth_provider.dart';
import 'data_providers.dart';
import '../utils/sse_handler.dart';
enum MessageType { text, dataConfirm, medicationConfirm, dietAnalysis, reportAnalysis, quickOptions, agentWelcome }
class ChatMessage {
final String id;
final String role;
String content;
final DateTime createdAt;
MessageType type;
final Map<String, dynamic>? metadata;
ChatMessage({
required this.id,
required this.role,
required this.content,
required this.createdAt,
this.type = MessageType.text,
this.metadata,
});
bool get isUser => role == 'user';
}
enum ActiveAgent { default_, consultation, health, diet, medication, report, exercise }
class ChatState {
final ActiveAgent activeAgent;
final List<ChatMessage> messages;
final String? conversationId;
final bool isStreaming;
final String? thinkingText;
const ChatState({
this.activeAgent = ActiveAgent.default_,
this.messages = const [],
this.conversationId,
this.isStreaming = false,
this.thinkingText,
});
ChatState copyWith({ActiveAgent? activeAgent, List<ChatMessage>? messages,
String? conversationId, bool? isStreaming, String? thinkingText}) =>
ChatState(
activeAgent: activeAgent ?? this.activeAgent,
messages: messages ?? this.messages,
conversationId: conversationId ?? this.conversationId,
isStreaming: isStreaming ?? this.isStreaming,
thinkingText: thinkingText ?? this.thinkingText,
);
}
class ConversationItem {
final String id;
final String title;
final String lastMessage;
final DateTime updatedAt;
final ActiveAgent agent;
ConversationItem({
required this.id,
required this.title,
required this.lastMessage,
required this.updatedAt,
required this.agent,
});
}
class SelectedAgentNotifier extends Notifier<ActiveAgent?> {
@override
ActiveAgent? build() => null;
void select(ActiveAgent? a) => state = a;
}
final selectedAgentProvider =
NotifierProvider<SelectedAgentNotifier, ActiveAgent?>(SelectedAgentNotifier.new);
final chatProvider = NotifierProvider<ChatNotifier, ChatState>(ChatNotifier.new);
final conversationListProvider = FutureProvider<List<ConversationItem>>((ref) async {
final api = ref.watch(apiClientProvider);
final token = await api.accessToken;
if (token == null) return [];
try {
final res = await api.get('/api/conversations');
final list = res.data['data'] as List? ?? [];
return list.map((item) {
final data = item as Map<String, dynamic>;
return ConversationItem(
id: data['id']?.toString() ?? '',
title: data['title']?.toString() ?? '对话',
lastMessage: data['lastMessage']?.toString() ?? '',
updatedAt: DateTime.parse(data['updatedAt']?.toString() ?? DateTime.now().toIso8601String()),
agent: _parseAgent(data['agentType']?.toString()),
);
}).toList();
} catch (_) {
return _mockConversations;
}
});
ActiveAgent _parseAgent(String? type) {
switch (type?.toLowerCase()) {
case 'consultation': return ActiveAgent.consultation;
case 'health': return ActiveAgent.health;
case 'diet': return ActiveAgent.diet;
case 'medication': return ActiveAgent.medication;
case 'report': return ActiveAgent.report;
case 'exercise': return ActiveAgent.exercise;
default: return ActiveAgent.default_;
}
}
final _mockConversations = [
ConversationItem(
id: '1',
title: '用药咨询',
lastMessage: '阿司匹林应该什么时候吃?',
updatedAt: DateTime.now().subtract(const Duration(hours: 2)),
agent: ActiveAgent.medication,
),
ConversationItem(
id: '2',
title: '血压偏高',
lastMessage: '血压145/90需要注意什么',
updatedAt: DateTime.now().subtract(const Duration(hours: 5)),
agent: ActiveAgent.health,
),
ConversationItem(
id: '3',
title: '饮食建议',
lastMessage: '今天吃了米饭和红烧肉',
updatedAt: DateTime.now().subtract(const Duration(days: 1)),
agent: ActiveAgent.diet,
),
];
class ChatNotifier extends Notifier<ChatState> {
StreamSubscription<Map<String, dynamic>>? _subscription;
@override
ChatState build() => const ChatState();
void setAgent(ActiveAgent a) {
_subscription?.cancel();
state = state.activeAgent == a ? const ChatState() : ChatState(activeAgent: a);
}
void insertAgentWelcome(ActiveAgent agent) {
state = state.copyWith(messages: [...state.messages, ChatMessage(
id: 'welcome_${agent.name}_${DateTime.now().millisecondsSinceEpoch}',
role: 'assistant',
content: '',
createdAt: DateTime.now(),
type: MessageType.agentWelcome,
metadata: {'agent': agent.name},
)]);
}
Future<void> sendMessage(String text) async {
if (text.trim().isEmpty || state.isStreaming) return;
final userMsg = ChatMessage(
id: '${DateTime.now().millisecondsSinceEpoch}',
role: 'user',
content: text,
createdAt: DateTime.now(),
);
state = state.copyWith(
messages: [...state.messages, userMsg], isStreaming: true);
final aiMsg = ChatMessage(
id: '${DateTime.now().millisecondsSinceEpoch}_ai',
role: 'assistant',
content: '',
createdAt: DateTime.now(),
);
try {
final token = await ref.read(apiClientProvider).accessToken;
if (token == null) {
_addError(aiMsg, '未登录,请重新登录');
return;
}
final agentPath =
state.activeAgent.name.replaceFirst('default_', 'default');
final stream = SseHandler.connect(
agentType: agentPath,
message: text,
conversationId: state.conversationId,
token: token,
);
await for (final event in stream) {
_processEvent(event, aiMsg);
}
} catch (e) {
_addError(aiMsg, '网络异常,请稍后重试');
}
}
void _addError(ChatMessage aiMsg, String errorText) {
state = state.copyWith(
messages: [
...state.messages,
ChatMessage(
id: 'err_${DateTime.now().millisecondsSinceEpoch}',
role: 'assistant',
content: errorText,
createdAt: DateTime.now(),
),
],
isStreaming: false,
thinkingText: null,
);
}
void _processEvent(Map<String, dynamic> j, ChatMessage aiMsg) {
final a = j['action'] as String?;
switch (a) {
case 'conversation_id':
state = state.copyWith(conversationId: j['data']?.toString());
case 'answer':
aiMsg.content += (j['data'] as String?) ?? '';
final messageType = j['type'] as String? ?? 'text';
aiMsg.type = _parseMessageType(messageType);
state = state.copyWith(thinkingText: null);
_update(aiMsg);
case 'notice':
state = state.copyWith(thinkingText: j['message'] as String?);
case 'tool_result':
final tool = j['tool'] as String? ?? '';
if (tool == 'record_health_data') {
ref.invalidate(latestHealthProvider);
}
case 'status':
_done(aiMsg);
case 'error':
_done(aiMsg);
}
}
MessageType _parseMessageType(String type) {
switch (type) {
case 'data_confirm': return MessageType.dataConfirm;
case 'medication_confirm': return MessageType.medicationConfirm;
case 'diet_analysis': return MessageType.dietAnalysis;
case 'report_analysis': return MessageType.reportAnalysis;
case 'quick_options': return MessageType.quickOptions;
case 'agent_welcome': return MessageType.agentWelcome;
default: return MessageType.text;
}
}
void _update(ChatMessage m) {
final u = state.messages.toList();
final i = u.indexWhere((x) => x.id == m.id);
if (i >= 0) {
u[i] = m;
} else if (m.content.isNotEmpty) {
u.add(m);
}
state = state.copyWith(messages: u);
}
void _done(ChatMessage m) {
final u = state.messages.toList();
if (!u.any((x) => x.id == m.id) && m.content.isNotEmpty) u.add(m);
state = state.copyWith(messages: u, isStreaming: false, thinkingText: null);
}
}