931 lines
36 KiB
Python
Raw Normal View History

2025-07-18 13:12:09 +08:00
import os
import sys
import copy
import json
import uuid
import time
import queue
import asyncio
import threading
import traceback
import subprocess
import websockets
from core.utils.util import (
extract_json_from_string,
check_vad_update,
check_asr_update,
filter_sensitive_info,
)
from typing import Dict, Any
from core.utils.modules_initialize import (
initialize_modules,
initialize_tts,
initialize_asr,
)
from core.handle.reportHandle import report
from core.providers.tts.default import DefaultTTS
from concurrent.futures import ThreadPoolExecutor
from core.utils.dialogue import Message, Dialogue
from core.providers.asr.dto.dto import InterfaceType
from core.handle.textHandle import handleTextMessage
from core.providers.tools.unified_tool_handler import UnifiedToolHandler
from plugins_func.loadplugins import auto_import_modules
from plugins_func.register import Action, ActionResponse
from core.auth import AuthMiddleware, AuthenticationError
from config.config_loader import get_private_config_from_api
from core.providers.tts.dto.dto import ContentType, TTSMessageDTO, SentenceType
from config.logger import setup_logging, build_module_string, update_module_string
from config.manage_api_client import DeviceNotFoundException, DeviceBindException
TAG = __name__
auto_import_modules("plugins_func.functions")
class TTSException(RuntimeError):
pass
class ConnectionHandler:
def __init__(
self,
config: Dict[str, Any],
_vad,
_asr,
_llm,
_memory,
_intent,
server=None,
):
self.common_config = config
self.config = copy.deepcopy(config)
self.session_id = str(uuid.uuid4())
self.logger = setup_logging()
self.server = server # 保存server实例的引用
self.auth = AuthMiddleware(config)
self.need_bind = False
self.bind_code = None
self.read_config_from_api = self.config.get("read_config_from_api", False)
self.websocket = None
self.headers = None
self.device_id = None
self.client_ip = None
self.client_ip_info = {}
self.prompt = None
self.welcome_msg = None
self.max_output_size = 0
self.chat_history_conf = 0
self.audio_format = "opus"
# 客户端状态相关
self.client_abort = False
self.client_is_speaking = False
self.client_listen_mode = "auto"
# 线程任务相关
self.loop = asyncio.get_event_loop()
self.stop_event = threading.Event()
self.executor = ThreadPoolExecutor(max_workers=5)
# 添加上报线程池
self.report_queue = queue.Queue()
self.report_thread = None
# 未来可以通过修改此处调节asr的上报和tts的上报目前默认都开启
self.report_asr_enable = self.read_config_from_api
self.report_tts_enable = self.read_config_from_api
# 依赖的组件
self.vad = None
self.asr = None
self.tts = None
self._asr = _asr
self._vad = _vad
self.llm = _llm
self.memory = _memory
self.intent = _intent
# vad相关变量
self.client_audio_buffer = bytearray()
self.client_have_voice = False
self.last_activity_time = 0.0 # 统一的活动时间戳(毫秒)
self.client_voice_stop = False
# asr相关变量
# 因为实际部署时可能会用到公共的本地ASR不能把变量暴露给公共ASR
# 所以涉及到ASR的变量需要在这里定义属于connection的私有变量
self.asr_audio = []
self.asr_audio_queue = queue.Queue()
# llm相关变量
self.llm_finish_task = True
self.dialogue = Dialogue()
# tts相关变量
self.sentence_id = None
# iot相关变量
self.iot_descriptors = {}
self.func_handler = None
self.cmd_exit = self.config["exit_commands"]
self.max_cmd_length = 0
for cmd in self.cmd_exit:
if len(cmd) > self.max_cmd_length:
self.max_cmd_length = len(cmd)
# 是否在聊天结束后关闭连接
self.close_after_chat = False
self.load_function_plugin = False
self.intent_type = "nointent"
self.timeout_seconds = (
int(self.config.get("close_connection_no_voice_time", 120)) + 60
) # 在原来第一道关闭的基础上加60秒进行二道关闭
self.timeout_task = None
# {"mcp":true} 表示启用MCP功能
self.features = None
async def handle_connection(self, ws):
try:
# 获取并验证headers
self.headers = dict(ws.request.headers)
if self.headers.get("device-id", None) is None:
# 尝试从 URL 的查询参数中获取 device-id
from urllib.parse import parse_qs, urlparse
# 从 WebSocket 请求中获取路径
request_path = ws.request.path
if not request_path:
self.logger.bind(tag=TAG).error("无法获取请求路径")
return
parsed_url = urlparse(request_path)
query_params = parse_qs(parsed_url.query)
if "device-id" in query_params:
self.headers["device-id"] = query_params["device-id"][0]
self.headers["client-id"] = query_params["client-id"][0]
else:
await ws.send("端口正常如需测试连接请使用test_page.html")
await self.close(ws)
return
# 获取客户端ip地址
self.client_ip = ws.remote_address[0]
self.logger.bind(tag=TAG).info(
f"{self.client_ip} conn - Headers: {self.headers}"
)
# 进行认证
await self.auth.authenticate(self.headers)
# 认证通过,继续处理
self.websocket = ws
self.device_id = self.headers.get("device-id", None)
# 初始化活动时间戳
self.last_activity_time = time.time() * 1000
# 启动超时检查任务
self.timeout_task = asyncio.create_task(self._check_timeout())
self.welcome_msg = self.config["xiaozhi"]
self.welcome_msg["session_id"] = self.session_id
await self.websocket.send(json.dumps(self.welcome_msg))
# 获取差异化配置
self._initialize_private_config()
# 异步初始化
self.executor.submit(self._initialize_components)
try:
async for message in self.websocket:
await self._route_message(message)
except websockets.exceptions.ConnectionClosed:
self.logger.bind(tag=TAG).info("客户端断开连接")
except AuthenticationError as e:
self.logger.bind(tag=TAG).error(f"Authentication failed: {str(e)}")
return
except Exception as e:
stack_trace = traceback.format_exc()
self.logger.bind(tag=TAG).error(f"Connection error: {str(e)}-{stack_trace}")
return
finally:
await self._save_and_close(ws)
async def _save_and_close(self, ws):
"""保存记忆并关闭连接"""
try:
if self.memory:
# 使用线程池异步保存记忆
def save_memory_task():
try:
# 创建新事件循环(避免与主循环冲突)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(
self.memory.save_memory(self.dialogue.dialogue)
)
except Exception as e:
self.logger.bind(tag=TAG).error(f"保存记忆失败: {e}")
finally:
loop.close()
# 启动线程保存记忆,不等待完成
threading.Thread(target=save_memory_task, daemon=True).start()
except Exception as e:
self.logger.bind(tag=TAG).error(f"保存记忆失败: {e}")
finally:
# 立即关闭连接,不等待记忆保存完成
await self.close(ws)
async def _route_message(self, message):
"""消息路由"""
if isinstance(message, str):
self.last_activity_time = time.time() * 1000
await handleTextMessage(self, message)
elif isinstance(message, bytes):
if self.vad is None:
return
if self.asr is None:
return
self.asr_audio_queue.put(message)
async def handle_restart(self, message):
"""处理服务器重启请求"""
try:
self.logger.bind(tag=TAG).info("收到服务器重启指令,准备执行...")
# 发送确认响应
await self.websocket.send(
json.dumps(
{
"type": "server",
"status": "success",
"message": "服务器重启中...",
"content": {"action": "restart"},
}
)
)
# 异步执行重启操作
def restart_server():
"""实际执行重启的方法"""
time.sleep(1)
self.logger.bind(tag=TAG).info("执行服务器重启...")
subprocess.Popen(
[sys.executable, "app.py"],
stdin=sys.stdin,
stdout=sys.stdout,
stderr=sys.stderr,
start_new_session=True,
)
os._exit(0)
# 使用线程执行重启避免阻塞事件循环
threading.Thread(target=restart_server, daemon=True).start()
except Exception as e:
self.logger.bind(tag=TAG).error(f"重启失败: {str(e)}")
await self.websocket.send(
json.dumps(
{
"type": "server",
"status": "error",
"message": f"Restart failed: {str(e)}",
"content": {"action": "restart"},
}
)
)
def _initialize_components(self):
try:
self.selected_module_str = build_module_string(
self.config.get("selected_module", {})
)
update_module_string(self.selected_module_str)
"""初始化组件"""
if self.config.get("prompt") is not None:
self.prompt = self.config["prompt"]
self.change_system_prompt(self.prompt)
self.logger.bind(tag=TAG).info(
f"初始化组件: prompt成功 {self.prompt[:50]}..."
)
"""初始化本地组件"""
if self.vad is None:
self.vad = self._vad
if self.asr is None:
self.asr = self._initialize_asr()
# 打开语音识别通道
asyncio.run_coroutine_threadsafe(
self.asr.open_audio_channels(self), self.loop
)
if self.tts is None:
self.tts = self._initialize_tts()
# 打开语音合成通道
asyncio.run_coroutine_threadsafe(
self.tts.open_audio_channels(self), self.loop
)
"""加载记忆"""
self._initialize_memory()
"""加载意图识别"""
self._initialize_intent()
"""初始化上报线程"""
self._init_report_threads()
except Exception as e:
self.logger.bind(tag=TAG).error(f"实例化组件失败: {e}")
def _init_report_threads(self):
"""初始化ASR和TTS上报线程"""
if not self.read_config_from_api or self.need_bind:
return
if self.chat_history_conf == 0:
return
if self.report_thread is None or not self.report_thread.is_alive():
self.report_thread = threading.Thread(
target=self._report_worker, daemon=True
)
self.report_thread.start()
self.logger.bind(tag=TAG).info("TTS上报线程已启动")
def _initialize_tts(self):
"""初始化TTS"""
tts = None
if not self.need_bind:
tts = initialize_tts(self.config)
if tts is None:
tts = DefaultTTS(self.config, delete_audio_file=True)
return tts
def _initialize_asr(self):
"""初始化ASR"""
if self._asr.interface_type == InterfaceType.LOCAL:
# 如果公共ASR是本地服务则直接返回
# 因为本地一个实例ASR可以被多个连接共享
asr = self._asr
else:
# 如果公共ASR是远程服务则初始化一个新实例
# 因为远程ASR涉及到websocket连接和接收线程需要每个连接一个实例
asr = initialize_asr(self.config)
return asr
def _initialize_private_config(self):
"""如果是从配置文件获取,则进行二次实例化"""
if not self.read_config_from_api:
return
"""从接口获取差异化的配置进行二次实例化,非全量重新实例化"""
try:
begin_time = time.time()
private_config = get_private_config_from_api(
self.config,
self.headers.get("device-id"),
self.headers.get("client-id", self.headers.get("device-id")),
)
private_config["delete_audio"] = bool(self.config.get("delete_audio", True))
self.logger.bind(tag=TAG).info(
f"{time.time() - begin_time} 秒,获取差异化配置成功: {json.dumps(filter_sensitive_info(private_config), ensure_ascii=False)}"
)
except DeviceNotFoundException as e:
self.need_bind = True
private_config = {}
except DeviceBindException as e:
self.need_bind = True
self.bind_code = e.bind_code
private_config = {}
except Exception as e:
self.need_bind = True
self.logger.bind(tag=TAG).error(f"获取差异化配置失败: {e}")
private_config = {}
init_llm, init_tts, init_memory, init_intent = (
False,
False,
False,
False,
)
init_vad = check_vad_update(self.common_config, private_config)
init_asr = check_asr_update(self.common_config, private_config)
if init_vad:
self.config["VAD"] = private_config["VAD"]
self.config["selected_module"]["VAD"] = private_config["selected_module"][
"VAD"
]
if init_asr:
self.config["ASR"] = private_config["ASR"]
self.config["selected_module"]["ASR"] = private_config["selected_module"][
"ASR"
]
if private_config.get("TTS", None) is not None:
init_tts = True
self.config["TTS"] = private_config["TTS"]
self.config["selected_module"]["TTS"] = private_config["selected_module"][
"TTS"
]
if private_config.get("LLM", None) is not None:
init_llm = True
self.config["LLM"] = private_config["LLM"]
self.config["selected_module"]["LLM"] = private_config["selected_module"][
"LLM"
]
if private_config.get("Memory", None) is not None:
init_memory = True
self.config["Memory"] = private_config["Memory"]
self.config["selected_module"]["Memory"] = private_config[
"selected_module"
]["Memory"]
if private_config.get("Intent", None) is not None:
init_intent = True
self.config["Intent"] = private_config["Intent"]
model_intent = private_config.get("selected_module", {}).get("Intent", {})
self.config["selected_module"]["Intent"] = model_intent
# 加载插件配置
if model_intent != "Intent_nointent":
plugin_from_server = private_config.get("plugins", {})
for plugin, config_str in plugin_from_server.items():
plugin_from_server[plugin] = json.loads(config_str)
self.config["plugins"] = plugin_from_server
self.config["Intent"][self.config["selected_module"]["Intent"]][
"functions"
] = plugin_from_server.keys()
if private_config.get("prompt", None) is not None:
self.config["prompt"] = private_config["prompt"]
if private_config.get("summaryMemory", None) is not None:
self.config["summaryMemory"] = private_config["summaryMemory"]
if private_config.get("device_max_output_size", None) is not None:
self.max_output_size = int(private_config["device_max_output_size"])
if private_config.get("chat_history_conf", None) is not None:
self.chat_history_conf = int(private_config["chat_history_conf"])
if private_config.get("mcp_endpoint", None) is not None:
self.config["mcp_endpoint"] = private_config["mcp_endpoint"]
try:
modules = initialize_modules(
self.logger,
private_config,
init_vad,
init_asr,
init_llm,
init_tts,
init_memory,
init_intent,
)
except Exception as e:
self.logger.bind(tag=TAG).error(f"初始化组件失败: {e}")
modules = {}
if modules.get("tts", None) is not None:
self.tts = modules["tts"]
if modules.get("vad", None) is not None:
self.vad = modules["vad"]
if modules.get("asr", None) is not None:
self.asr = modules["asr"]
if modules.get("llm", None) is not None:
self.llm = modules["llm"]
if modules.get("intent", None) is not None:
self.intent = modules["intent"]
if modules.get("memory", None) is not None:
self.memory = modules["memory"]
def _initialize_memory(self):
if self.memory is None:
return
"""初始化记忆模块"""
self.memory.init_memory(
role_id=self.device_id,
llm=self.llm,
summary_memory=self.config.get("summaryMemory", None),
save_to_file=not self.read_config_from_api,
)
# 获取记忆总结配置
memory_config = self.config["Memory"]
memory_type = self.config["Memory"][self.config["selected_module"]["Memory"]][
"type"
]
# 如果使用 nomen直接返回
if memory_type == "nomem":
return
# 使用 mem_local_short 模式
elif memory_type == "mem_local_short":
memory_llm_name = memory_config[self.config["selected_module"]["Memory"]][
"llm"
]
if memory_llm_name and memory_llm_name in self.config["LLM"]:
# 如果配置了专用LLM则创建独立的LLM实例
from core.utils import llm as llm_utils
memory_llm_config = self.config["LLM"][memory_llm_name]
memory_llm_type = memory_llm_config.get("type", memory_llm_name)
memory_llm = llm_utils.create_instance(
memory_llm_type, memory_llm_config
)
self.logger.bind(tag=TAG).info(
f"为记忆总结创建了专用LLM: {memory_llm_name}, 类型: {memory_llm_type}"
)
self.memory.set_llm(memory_llm)
else:
# 否则使用主LLM
self.memory.set_llm(self.llm)
self.logger.bind(tag=TAG).info("使用主LLM作为意图识别模型")
def _initialize_intent(self):
if self.intent is None:
return
self.intent_type = self.config["Intent"][
self.config["selected_module"]["Intent"]
]["type"]
if self.intent_type == "function_call" or self.intent_type == "intent_llm":
self.load_function_plugin = True
"""初始化意图识别模块"""
# 获取意图识别配置
intent_config = self.config["Intent"]
intent_type = self.config["Intent"][self.config["selected_module"]["Intent"]][
"type"
]
# 如果使用 nointent直接返回
if intent_type == "nointent":
return
# 使用 intent_llm 模式
elif intent_type == "intent_llm":
intent_llm_name = intent_config[self.config["selected_module"]["Intent"]][
"llm"
]
if intent_llm_name and intent_llm_name in self.config["LLM"]:
# 如果配置了专用LLM则创建独立的LLM实例
from core.utils import llm as llm_utils
intent_llm_config = self.config["LLM"][intent_llm_name]
intent_llm_type = intent_llm_config.get("type", intent_llm_name)
intent_llm = llm_utils.create_instance(
intent_llm_type, intent_llm_config
)
self.logger.bind(tag=TAG).info(
f"为意图识别创建了专用LLM: {intent_llm_name}, 类型: {intent_llm_type}"
)
self.intent.set_llm(intent_llm)
else:
# 否则使用主LLM
self.intent.set_llm(self.llm)
self.logger.bind(tag=TAG).info("使用主LLM作为意图识别模型")
"""加载统一工具处理器"""
self.func_handler = UnifiedToolHandler(self)
# 异步初始化工具处理器
if hasattr(self, "loop") and self.loop:
asyncio.run_coroutine_threadsafe(self.func_handler._initialize(), self.loop)
def change_system_prompt(self, prompt):
self.prompt = prompt
# 更新系统prompt至上下文
self.dialogue.update_system_message(self.prompt)
def chat(self, query, tool_call=False):
self.logger.bind(tag=TAG).info(f"大模型收到用户消息: {query}")
self.llm_finish_task = False
if not tool_call:
self.dialogue.put(Message(role="user", content=query))
# Define intent functions
functions = None
if self.intent_type == "function_call" and hasattr(self, "func_handler"):
functions = self.func_handler.get_functions()
response_message = []
try:
# 使用带记忆的对话
memory_str = None
if self.memory is not None:
future = asyncio.run_coroutine_threadsafe(
self.memory.query_memory(query), self.loop
)
memory_str = future.result()
self.sentence_id = str(uuid.uuid4().hex)
if self.intent_type == "function_call" and functions is not None:
# 使用支持functions的streaming接口
llm_responses = self.llm.response_with_functions(
self.session_id,
self.dialogue.get_llm_dialogue_with_memory(memory_str),
functions=functions,
)
else:
llm_responses = self.llm.response(
self.session_id,
self.dialogue.get_llm_dialogue_with_memory(memory_str),
)
except Exception as e:
self.logger.bind(tag=TAG).error(f"LLM 处理出错 {query}: {e}")
return None
# 处理流式响应
tool_call_flag = False
function_name = None
function_id = None
function_arguments = ""
content_arguments = ""
text_index = 0
self.client_abort = False
for response in llm_responses:
if self.client_abort:
break
if self.intent_type == "function_call" and functions is not None:
content, tools_call = response
if "content" in response:
content = response["content"]
tools_call = None
if content is not None and len(content) > 0:
content_arguments += content
if not tool_call_flag and content_arguments.startswith("<tool_call>"):
# print("content_arguments", content_arguments)
tool_call_flag = True
if tools_call is not None and len(tools_call) > 0:
tool_call_flag = True
if tools_call[0].id is not None:
function_id = tools_call[0].id
if tools_call[0].function.name is not None:
function_name = tools_call[0].function.name
if tools_call[0].function.arguments is not None:
function_arguments += tools_call[0].function.arguments
else:
content = response
if content is not None and len(content) > 0:
if not tool_call_flag:
response_message.append(content)
if text_index == 0:
self.tts.tts_text_queue.put(
TTSMessageDTO(
sentence_id=self.sentence_id,
sentence_type=SentenceType.FIRST,
content_type=ContentType.ACTION,
)
)
self.tts.tts_text_queue.put(
TTSMessageDTO(
sentence_id=self.sentence_id,
sentence_type=SentenceType.MIDDLE,
content_type=ContentType.TEXT,
content_detail=content,
)
)
text_index += 1
# 处理function call
if tool_call_flag:
bHasError = False
if function_id is None:
a = extract_json_from_string(content_arguments)
if a is not None:
try:
content_arguments_json = json.loads(a)
function_name = content_arguments_json["name"]
function_arguments = json.dumps(
content_arguments_json["arguments"], ensure_ascii=False
)
function_id = str(uuid.uuid4().hex)
except Exception as e:
bHasError = True
response_message.append(a)
else:
bHasError = True
response_message.append(content_arguments)
if bHasError:
self.logger.bind(tag=TAG).error(
f"function call error: {content_arguments}"
)
if not bHasError:
response_message.clear()
self.logger.bind(tag=TAG).debug(
f"function_name={function_name}, function_id={function_id}, function_arguments={function_arguments}"
)
function_call_data = {
"name": function_name,
"id": function_id,
"arguments": function_arguments,
}
# 使用统一工具处理器处理所有工具调用
result = asyncio.run_coroutine_threadsafe(
self.func_handler.handle_llm_function_call(
self, function_call_data
),
self.loop,
).result()
self._handle_function_result(result, function_call_data)
# 存储对话内容
if len(response_message) > 0:
self.dialogue.put(
Message(role="assistant", content="".join(response_message))
)
if text_index > 0:
self.tts.tts_text_queue.put(
TTSMessageDTO(
sentence_id=self.sentence_id,
sentence_type=SentenceType.LAST,
content_type=ContentType.ACTION,
)
)
self.llm_finish_task = True
self.logger.bind(tag=TAG).debug(
json.dumps(self.dialogue.get_llm_dialogue(), indent=4, ensure_ascii=False)
)
return True
def _handle_function_result(self, result, function_call_data):
if result.action == Action.RESPONSE: # 直接回复前端
text = result.response
self.tts.tts_one_sentence(self, ContentType.TEXT, content_detail=text)
self.dialogue.put(Message(role="assistant", content=text))
elif result.action == Action.REQLLM: # 调用函数后再请求llm生成回复
text = result.result
if text is not None and len(text) > 0:
function_id = function_call_data["id"]
function_name = function_call_data["name"]
function_arguments = function_call_data["arguments"]
self.dialogue.put(
Message(
role="assistant",
tool_calls=[
{
"id": function_id,
"function": {
"arguments": function_arguments,
"name": function_name,
},
"type": "function",
"index": 0,
}
],
)
)
self.dialogue.put(
Message(
role="tool",
tool_call_id=(
str(uuid.uuid4()) if function_id is None else function_id
),
content=text,
)
)
self.chat(text, tool_call=True)
elif result.action == Action.NOTFOUND or result.action == Action.ERROR:
text = result.response if result.response else result.result
self.tts.tts_one_sentence(self, ContentType.TEXT, content_detail=text)
self.dialogue.put(Message(role="assistant", content=text))
else:
pass
def _report_worker(self):
"""聊天记录上报工作线程"""
while not self.stop_event.is_set():
try:
# 从队列获取数据,设置超时以便定期检查停止事件
item = self.report_queue.get(timeout=1)
if item is None: # 检测毒丸对象
break
type, text, audio_data, report_time = item
try:
# 检查线程池状态
if self.executor is None:
continue
# 提交任务到线程池
self.executor.submit(
self._process_report, type, text, audio_data, report_time
)
except Exception as e:
self.logger.bind(tag=TAG).error(f"聊天记录上报线程异常: {e}")
except queue.Empty:
continue
except Exception as e:
self.logger.bind(tag=TAG).error(f"聊天记录上报工作线程异常: {e}")
self.logger.bind(tag=TAG).info("聊天记录上报线程已退出")
def _process_report(self, type, text, audio_data, report_time):
"""处理上报任务"""
try:
# 执行上报(传入二进制数据)
report(self, type, text, audio_data, report_time)
except Exception as e:
self.logger.bind(tag=TAG).error(f"上报处理异常: {e}")
finally:
# 标记任务完成
self.report_queue.task_done()
def clearSpeakStatus(self):
self.client_is_speaking = False
self.logger.bind(tag=TAG).debug(f"清除服务端讲话状态")
async def close(self, ws=None):
"""资源清理方法"""
try:
# 取消超时任务
if self.timeout_task:
self.timeout_task.cancel()
self.timeout_task = None
# 清理工具处理器资源
if hasattr(self, "func_handler") and self.func_handler:
await self.func_handler.cleanup()
# 触发停止事件
if self.stop_event:
self.stop_event.set()
# 清空任务队列
self.clear_queues()
# 关闭WebSocket连接
if ws:
await ws.close()
elif self.websocket:
await self.websocket.close()
# 最后关闭线程池(避免阻塞)
if self.executor:
self.executor.shutdown(wait=False)
self.executor = None
self.logger.bind(tag=TAG).info("连接资源已释放")
except Exception as e:
self.logger.bind(tag=TAG).error(f"关闭连接时出错: {e}")
def clear_queues(self):
"""清空所有任务队列"""
if self.tts:
self.logger.bind(tag=TAG).debug(
f"开始清理: TTS队列大小={self.tts.tts_text_queue.qsize()}, 音频队列大小={self.tts.tts_audio_queue.qsize()}"
)
# 使用非阻塞方式清空队列
for q in [
self.tts.tts_text_queue,
self.tts.tts_audio_queue,
self.report_queue,
]:
if not q:
continue
while True:
try:
q.get_nowait()
except queue.Empty:
break
self.logger.bind(tag=TAG).debug(
f"清理结束: TTS队列大小={self.tts.tts_text_queue.qsize()}, 音频队列大小={self.tts.tts_audio_queue.qsize()}"
)
def reset_vad_states(self):
self.client_audio_buffer = bytearray()
self.client_have_voice = False
self.client_voice_stop = False
self.logger.bind(tag=TAG).debug("VAD states reset.")
def chat_and_close(self, text):
"""Chat with the user and then close the connection"""
try:
# Use the existing chat method
self.chat(text)
# After chat is complete, close the connection
self.close_after_chat = True
except Exception as e:
self.logger.bind(tag=TAG).error(f"Chat and close error: {str(e)}")
async def _check_timeout(self):
"""检查连接超时"""
try:
while not self.stop_event.is_set():
# 检查是否超时(只有在时间戳已初始化的情况下)
if self.last_activity_time > 0.0:
current_time = time.time() * 1000
if (
current_time - self.last_activity_time
> self.timeout_seconds * 1000
):
if not self.stop_event.is_set():
self.logger.bind(tag=TAG).info("连接超时,准备关闭")
await self.close(self.websocket)
break
# 每10秒检查一次避免过于频繁
await asyncio.sleep(10)
except Exception as e:
self.logger.bind(tag=TAG).error(f"超时检查任务出错: {e}")