“wangquan” 57fdf656b9 add_xiaozhi
2025-07-18 13:12:09 +08:00

149 lines
4.5 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import argparse
import yaml
from collections.abc import Mapping
from config.manage_api_client import init_service, get_server_config, get_agent_models
# 添加全局配置缓存
_config_cache = None
def get_project_dir():
"""获取项目根目录"""
return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + "/"
def read_config(config_path):
with open(config_path, "r", encoding="utf-8") as file:
config = yaml.safe_load(file)
return config
def load_config():
"""加载配置文件"""
global _config_cache
if _config_cache is not None:
return _config_cache
default_config_path = get_project_dir() + "config.yaml"
custom_config_path = get_project_dir() + "data/.config.yaml"
# 加载默认配置
default_config = read_config(default_config_path)
custom_config = read_config(custom_config_path)
if custom_config.get("manager-api", {}).get("url"):
config = get_config_from_api(custom_config)
else:
# 合并配置
config = merge_configs(default_config, custom_config)
# 初始化目录
ensure_directories(config)
_config_cache = config
return config
def get_config_from_api(config):
"""从Java API获取配置"""
# 初始化API客户端
init_service(config)
# 获取服务器配置
config_data = get_server_config()
if config_data is None:
raise Exception("Failed to fetch server config from API")
config_data["read_config_from_api"] = True
config_data["manager-api"] = {
"url": config["manager-api"].get("url", ""),
"secret": config["manager-api"].get("secret", ""),
}
# server的配置以本地为准
if config.get("server"):
config_data["server"] = {
"ip": config["server"].get("ip", ""),
"port": config["server"].get("port", ""),
"http_port": config["server"].get("http_port", ""),
"vision_explain": config["server"].get("vision_explain", ""),
"auth_key": config["server"].get("auth_key", ""),
}
return config_data
def get_private_config_from_api(config, device_id, client_id):
"""从Java API获取私有配置"""
return get_agent_models(device_id, client_id, config["selected_module"])
def ensure_directories(config):
"""确保所有配置路径存在"""
dirs_to_create = set()
project_dir = get_project_dir() # 获取项目根目录
# 日志文件目录
log_dir = config.get("log", {}).get("log_dir", "tmp")
dirs_to_create.add(os.path.join(project_dir, log_dir))
# ASR/TTS模块输出目录
for module in ["ASR", "TTS"]:
if config.get(module) is None:
continue
for provider in config.get(module, {}).values():
output_dir = provider.get("output_dir", "")
if output_dir:
dirs_to_create.add(output_dir)
# 根据selected_module创建模型目录
selected_modules = config.get("selected_module", {})
for module_type in ["ASR", "LLM", "TTS"]:
selected_provider = selected_modules.get(module_type)
if not selected_provider:
continue
if config.get(module) is None:
continue
if config.get(selected_provider) is None:
continue
provider_config = config.get(module_type, {}).get(selected_provider, {})
output_dir = provider_config.get("output_dir")
if output_dir:
full_model_dir = os.path.join(project_dir, output_dir)
dirs_to_create.add(full_model_dir)
# 统一创建目录保留原data目录创建
for dir_path in dirs_to_create:
try:
os.makedirs(dir_path, exist_ok=True)
except PermissionError:
print(f"警告:无法创建目录 {dir_path},请检查写入权限")
def merge_configs(default_config, custom_config):
"""
递归合并配置custom_config优先级更高
Args:
default_config: 默认配置
custom_config: 用户自定义配置
Returns:
合并后的配置
"""
if not isinstance(default_config, Mapping) or not isinstance(
custom_config, Mapping
):
return custom_config
merged = dict(default_config)
for key, value in custom_config.items():
if (
key in merged
and isinstance(merged[key], Mapping)
and isinstance(value, Mapping)
):
merged[key] = merge_configs(merged[key], value)
else:
merged[key] = value
return merged