maxkb/apps/common/handle/impl/mineru/maxkb_adapter/adapter.py
朱潮 edc80888cc
Some checks are pending
sync2gitee / repo-sync (push) Waiting to run
Typos Check / Spell Check with Typos (push) Waiting to run
传入的 llm_model_id 和 vision_model_id 会被正确传递到配置中
2025-08-26 00:58:18 +08:00

399 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
MaxKB平台适配器 - 实现MaxKB特定的功能
这个适配器实现了MaxKB平台的特定功能包括
- MaxKB的trace上下文管理如果有
- MaxKB的锁机制如果有
- MaxKB的文件存储客户端
- MaxKB的模型客户端
"""
import contextlib
import os
from typing import Any, Dict, Optional
from .logger import get_module_logger
logger = get_module_logger('adapter')
# 导入MaxKB特有的模块
from .file_storage_client import FileStorageClient
from .maxkb_model_client import maxkb_model_client
# 导入基类从上级目录
from ..base_parser import PlatformAdapter, BaseMinerUExtractor
class MaxKBAdapter(PlatformAdapter):
"""MaxKB平台的适配器实现"""
def __init__(self):
"""初始化MaxKB适配器"""
self.file_storage = FileStorageClient()
self.model_client = maxkb_model_client
# 导入配置以获取存储路径
from .config_maxkb import MaxKBMinerUConfig
self.config = MaxKBMinerUConfig()
self.storage_path = self.config.file_storage_path
@contextlib.asynccontextmanager
async def trace_context(self, trace_id: str):
"""MaxKB的trace上下文 - 如果没有特殊实现,使用简单的上下文"""
# MaxKB可能没有trace_context这里提供一个简单的实现
logger.info(f"MaxKB: Starting trace {trace_id}")
try:
yield
finally:
logger.info(f"MaxKB: Ending trace {trace_id}")
async def lock_enter(self, temp_dir: str):
"""MaxKB的锁机制进入 - 如果没有特殊实现,创建目录即可"""
# MaxKB可能没有特殊的锁机制确保目录存在即可
os.makedirs(temp_dir, exist_ok=True)
logger.debug(f"MaxKB: Entered lock for {temp_dir}")
async def lock_release(self, temp_dir: str):
"""MaxKB的锁机制释放 - 如果没有特殊实现,简单记录即可"""
logger.debug(f"MaxKB: Released lock for {temp_dir}")
async def upload_file(self, file_path: str, options: Any = None) -> str:
"""使用MaxKB的文件存储上传文件 - 直接复制文件到存储目录"""
import shutil
import uuid
logger.info(f"MaxKB: upload_file called with path={file_path}, options={options}")
# 如果在测试模式下,直接返回原图地址
#if os.getenv('MINERU_TEST_FILE'):
# logger.info(f"MaxKB: Test mode - returning original path: {file_path}")
# return file_path
try:
# 确保文件存在
if not os.path.exists(file_path):
logger.warning(f"MaxKB: File not found: {file_path}")
return file_path
# 获取knowledge_id如果在options中提供
knowledge_id = None
if options and isinstance(options, (tuple, list)) and len(options) > 0:
knowledge_id = options[0]
# 创建存储目录结构
# 使用 knowledge_id 或 'mineru' 作为子目录
sub_dir = knowledge_id if knowledge_id else 'mineru'
storage_dir = os.path.join(self.storage_path, sub_dir, 'images')
# 确保存储目录存在
os.makedirs(storage_dir, exist_ok=True)
# 生成唯一的文件名,保留原始扩展名
file_ext = os.path.splitext(file_path)[1]
file_name = f"{uuid.uuid4().hex}{file_ext}"
dest_path = os.path.join(storage_dir, file_name)
# 复制文件到存储目录
shutil.copy2(file_path, dest_path)
# 返回相对路径或URL格式
# 生成相对于storage根目录的路径
relative_path = os.path.relpath(dest_path, self.storage_path)
# 确保路径使用正斜杠(兼容所有系统)
relative_path = relative_path.replace(os.path.sep, '/')
# 根据环境配置生成完整的URL
# 检查是否配置了基础URL
base_url = os.getenv('MAXKB_BASE_URL', '')
if base_url:
# 如果有基础URL生成完整的URL
result_url = f"{base_url.rstrip('/')}/storage/{relative_path}"
else:
# 生成相对URL直接使用/storage/路径
result_url = f"/storage/{relative_path}"
logger.info(f"MaxKB: Copied file {file_path} -> {dest_path}")
logger.info(f"MaxKB: Returning URL: {result_url}")
return result_url
except Exception as e:
logger.error(f"MaxKB: Failed to copy file {file_path}: {str(e)}")
# 如果复制失败,返回本地路径
return file_path
def get_logger(self):
"""获取MaxKB的日志器"""
return logger
def get_settings(self) -> Dict[str, Any]:
"""获取MaxKB的配置"""
# 返回MaxKB的配置可以从环境变量或配置文件读取
return {
'api_key': os.getenv('MAXKB_API_KEY'),
'api_url': os.getenv('MAXKB_API_URL', 'https://api.maxkb.com'),
'model': os.getenv('MAXKB_DEFAULT_MODEL', 'gpt-4o'),
# 添加其他MaxKB特有的配置
}
def get_learn_type(self, params: Dict[str, Any]) -> int:
"""获取learn_type参数 - MaxKB使用model_id映射到learn_type"""
# MaxKB使用llm_model_id和vision_model_id
# 这里可以根据model_id映射到learn_type
llm_model_id = params.get('llm_model_id')
vision_model_id = params.get('vision_model_id')
# 根据模型ID映射到learn_type这里是示例映射
if llm_model_id:
# 可以根据具体的模型ID返回不同的learn_type
return 9 # 默认返回9
return 9
def set_trace_id(self, trace_id: str):
"""设置trace ID用于日志跟踪"""
# MaxKB可能有自己的trace机制
# 这里简单记录日志
logger.debug(f"MaxKB: Setting trace ID to: {trace_id}")
class MinerUExtractor(BaseMinerUExtractor):
"""
MaxKB平台的MinerU解析器
继承自基类使用MaxKB适配器
"""
def __init__(self, llm_model_id: Optional[str] = None,
vision_model_id: Optional[str] = None):
"""
初始化MaxKB的MinerU解析器
Args:
llm_model_id: 大语言模型IDMaxKB特有参数
vision_model_id: 视觉模型IDMaxKB特有参数
"""
# 创建MaxKB适配器
adapter = MaxKBAdapter()
# 导入并创建MaxKB特定的配置传递模型ID
from .config_maxkb import MaxKBMinerUConfig
config = MaxKBMinerUConfig(llm_model_id=llm_model_id, vision_model_id=vision_model_id)
# 调用基类初始化传递适配器、配置和MaxKB特有参数
super().__init__(
adapter,
config=config,
llm_model_id=llm_model_id,
vision_model_id=vision_model_id
)
# 保存MaxKB特有的参数
self.llm_model_id = llm_model_id
self.vision_model_id = vision_model_id
# 如果需要可以在这里初始化MaxKB特有的组件
self.model_client = adapter.model_client
self.file_storage = adapter.file_storage
async def process_file_with_models(self, filepath: str, src_name: str = None,
upload_options: Any = None) -> Any:
"""
MaxKB特有的处理方法 - 使用指定的模型处理文件
这个方法展示了如何在子类中添加平台特有的功能
"""
logger.info(f"MaxKB: Processing with LLM: {self.llm_model_id}, Vision: {self.vision_model_id}")
# 调用基类的处理方法
result = await self.process_file(filepath, src_name, upload_options)
# 可以在这里添加MaxKB特有的后处理
if self.llm_model_id:
logger.info(f"MaxKB: Used LLM model {self.llm_model_id} for processing")
if self.vision_model_id:
logger.info(f"MaxKB: Used Vision model {self.vision_model_id} for image processing")
return result
class MinerUAdapter:
"""
MinerU文档处理适配器 - 用于MaxKB中处理文档
"""
def __init__(self):
"""初始化MinerU适配器"""
self.extractor = None
self._init_extractor()
def _init_extractor(self):
"""初始化MinerU解析器"""
try:
# 获取配置的模型ID
llm_model_id = os.environ.get('MINERU_LLM_MODEL_ID')
vision_model_id = os.environ.get('MINERU_VISION_MODEL_ID')
# 创建解析器
self.extractor = MinerUExtractor(
llm_model_id=llm_model_id,
vision_model_id=vision_model_id
)
logger.info("MinerU适配器初始化成功")
except Exception as e:
logger.error(f"MinerU适配器初始化失败: {str(e)}")
raise
def process_document(self, file_content: bytes, file_name: str,
save_image_func=None, **kwargs) -> Dict[str, Any]:
"""
处理文档并返回结构化内容
Args:
file_content: 文件内容字节流
file_name: 文件名
save_image_func: 保存图片的函数
**kwargs: 额外参数包括llm_model_id和vision_model_id
Returns:
包含sections的字典每个section包含content、title和images
"""
import tempfile
import asyncio
import threading
try:
# 创建临时文件
with tempfile.NamedTemporaryFile(suffix=os.path.splitext(file_name)[1],
delete=False) as tmp_file:
tmp_file.write(file_content)
tmp_file_path = tmp_file.name
try:
# 在新线程中运行异步代码,避免事件循环冲突
result = None
exception = None
def run_async():
nonlocal result, exception
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# 提取模型ID参数
llm_model_id = kwargs.get('llm_model_id')
vision_model_id = kwargs.get('vision_model_id')
if llm_model_id and vision_model_id:
logger.info(f"使用指定模型处理文档: LLM={llm_model_id}, Vision={vision_model_id}")
# TODO: 将模型ID传递给extractor
# 目前暂时使用默认配置,后续可以在这里设置模型
result = loop.run_until_complete(
self.extractor.process_file(tmp_file_path, file_name)
)
finally:
loop.close()
except Exception as e:
exception = e
thread = threading.Thread(target=run_async)
thread.start()
thread.join(timeout=300) # 5分钟超时
if exception:
raise exception
if result is None:
logger.warning("MinerU处理超时或无结果")
return {'sections': []}
logger.info(f"MinerU返回结果类型: {type(result)}")
# 转换结果格式
sections = []
# 检查result的类型
if isinstance(result, list):
# result是一个页面文档列表Langchain Document对象
for page_doc in result:
# 处理Langchain Document对象
if hasattr(page_doc, 'page_content') and hasattr(page_doc, 'metadata'):
page_content = page_doc.page_content
metadata = page_doc.metadata
if page_content:
# 提取页码
page_num = metadata.get('page', metadata.get('page_num', ''))
# 提取图片列表
images = metadata.get('images', [])
sections.append({
'content': page_content,
'title': f"Page {page_num}" if page_num else '',
'images': images
})
elif isinstance(page_doc, dict):
# 字典格式的文档
page_content = page_doc.get('text', page_doc.get('page_content', ''))
if page_content:
sections.append({
'content': page_content,
'title': f"Page {page_doc.get('page_num', '')}",
'images': page_doc.get('images', [])
})
elif isinstance(result, dict):
# result是一个字典
if 'content' in result:
content = result['content']
if isinstance(content, str):
sections.append({
'content': content,
'title': '',
'images': []
})
elif isinstance(content, list):
for item in content:
if isinstance(item, dict):
sections.append({
'content': item.get('text', ''),
'title': item.get('title', ''),
'images': item.get('images', [])
})
else:
sections.append({
'content': str(item),
'title': '',
'images': []
})
# 处理图片(如果有保存函数)
if save_image_func and sections:
for section in sections:
if section.get('images'):
saved_images = []
for img_path in section['images']:
try:
# 检查图片文件是否存在
if os.path.exists(img_path):
with open(img_path, 'rb') as f:
img_content = f.read()
saved_path = save_image_func(img_content)
saved_images.append(saved_path)
else:
saved_images.append(img_path)
except Exception as e:
logger.warning(f"保存图片失败 {img_path}: {str(e)}")
saved_images.append(img_path)
section['images'] = saved_images
logger.info(f"MinerU处理完成提取了{len(sections)}个sections")
return {'sections': sections}
finally:
# 清理临时文件
if os.path.exists(tmp_file_path):
os.unlink(tmp_file_path)
except Exception as e:
logger.error(f"MinerU处理文档失败: {str(e)}")
# 返回空结果而不是抛出异常,让调用方可以回退到其他处理器
return {'sections': []}