Compare commits
12 Commits
dd0360fb6f
...
ec6e699390
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ec6e699390 | ||
|
|
5f9f2a9325 | ||
|
|
5da36659c2 | ||
|
|
565a07f9c6 | ||
|
|
c4eaeb6499 | ||
|
|
a5d1dda65f | ||
|
|
dea3454011 | ||
|
|
10e38b2c05 | ||
|
|
638bf0dd1b | ||
|
|
459b0c8307 | ||
|
|
86ef54fb75 | ||
|
|
b05f42259e |
253
MEDIA_ASYNC_GUIDE.md
Normal file
253
MEDIA_ASYNC_GUIDE.md
Normal file
@ -0,0 +1,253 @@
|
||||
# 音视频异步处理使用指南
|
||||
|
||||
## 🎯 概述
|
||||
|
||||
音视频处理现已完全异步化,提供详细的状态追踪和更好的用户体验。
|
||||
|
||||
## 📋 状态流程
|
||||
|
||||
```
|
||||
📋 排队中 (PENDING)
|
||||
↓
|
||||
🔄 生成中 (STARTED)
|
||||
↓
|
||||
📚 索引中 (STARTED)
|
||||
↓
|
||||
✅ 完成 (SUCCESS)
|
||||
↓
|
||||
💥 失败 (FAILURE)
|
||||
```
|
||||
|
||||
## 🚀 使用方式
|
||||
|
||||
### 1. 上传音视频文件
|
||||
|
||||
```python
|
||||
# 上传时指定STT和LLM模型
|
||||
document_data = {
|
||||
'name': '会议录音.mp3',
|
||||
'source_file_id': file_id,
|
||||
'stt_model_id': 'whisper-large', # 必需
|
||||
'llm_model_id': 'gpt-4', # 可选,用于文本优化
|
||||
}
|
||||
|
||||
# 系统会自动:
|
||||
# 1. 创建文档
|
||||
# 2. 设置状态为"排队中"
|
||||
# 3. 提交异步任务
|
||||
```
|
||||
|
||||
### 2. 查看处理状态
|
||||
|
||||
```python
|
||||
# 获取文档状态
|
||||
document = Document.objects.get(id=document_id)
|
||||
status = Status(document.status)
|
||||
embedding_status = status[TaskType.EMBEDDING]
|
||||
|
||||
# 状态映射
|
||||
status_map = {
|
||||
'0': '排队中',
|
||||
'1': '生成中/索引中',
|
||||
'2': '完成',
|
||||
'3': '失败',
|
||||
'4': '已取消'
|
||||
}
|
||||
|
||||
current_status = status_map.get(embedding_status.value, '未知')
|
||||
print(f"当前状态: {current_status}")
|
||||
```
|
||||
|
||||
### 3. 批量处理
|
||||
|
||||
```python
|
||||
# 批量上传多个音视频文件
|
||||
documents = [
|
||||
{'name': '录音1.mp3', 'stt_model_id': 'whisper-large'},
|
||||
{'name': '视频1.mp4', 'stt_model_id': 'whisper-large'},
|
||||
{'name': '录音2.mp3', 'stt_model_id': 'whisper-large'},
|
||||
]
|
||||
|
||||
# 系统会:
|
||||
# 1. 为每个文档创建独立的异步任务
|
||||
# 2. 并行处理多个文件
|
||||
# 3. 提供独立的状态追踪
|
||||
```
|
||||
|
||||
## 🎛️ 配置选项
|
||||
|
||||
### 处理选项
|
||||
```python
|
||||
options = {
|
||||
'enable_punctuation': True, # 启用标点符号优化
|
||||
'enable_summary': True, # 启用摘要生成
|
||||
'language': 'auto', # 语言检测
|
||||
'segment_duration': 300, # 分段时长(秒)
|
||||
'async_processing': True # 异步处理(默认启用)
|
||||
}
|
||||
```
|
||||
|
||||
### 模型配置
|
||||
```python
|
||||
# STT模型(必需)
|
||||
stt_model_id = 'whisper-large' # 语音转写模型
|
||||
|
||||
# LLM模型(可选)
|
||||
llm_model_id = 'gpt-4' # 文本优化和摘要生成
|
||||
```
|
||||
|
||||
## 📊 状态说明
|
||||
|
||||
| 状态 | 代码 | 描述 | 用户可见 |
|
||||
|------|------|------|----------|
|
||||
| 排队中 | PENDING | 任务已提交,等待处理 | ✅ |
|
||||
| 生成中 | STARTED | 正在转写音视频内容 | ✅ |
|
||||
| 索引中 | STARTED | 正在创建段落和索引 | ✅ |
|
||||
| 完成 | SUCCESS | 处理完成 | ✅ |
|
||||
| 失败 | FAILURE | 处理失败 | ✅ |
|
||||
| 已取消 | REVOKE | 任务已取消 | ✅ |
|
||||
|
||||
## 🔧 错误处理
|
||||
|
||||
### 自动重试
|
||||
- 网络错误自动重试
|
||||
- 模型调用失败自动重试
|
||||
- 最多重试3次
|
||||
|
||||
### 失败处理
|
||||
```python
|
||||
# 检查失败原因
|
||||
if embedding_status == State.FAILURE:
|
||||
# 查看错误日志
|
||||
# 检查模型配置
|
||||
# 手动重新处理
|
||||
```
|
||||
|
||||
### 重新处理
|
||||
```python
|
||||
# 手动触发重新处理
|
||||
from knowledge.tasks.media_learning import media_learning_by_document
|
||||
media_learning_by_document.delay(
|
||||
document_id, knowledge_id, workspace_id,
|
||||
stt_model_id, llm_model_id
|
||||
)
|
||||
```
|
||||
|
||||
## 📈 性能优化
|
||||
|
||||
### 并发处理
|
||||
- 多个工作线程并行处理
|
||||
- 每个音视频文件独立处理
|
||||
- 支持批量上传和处理
|
||||
|
||||
### 资源管理
|
||||
- 自动清理临时文件
|
||||
- 内存使用优化
|
||||
- 处理超时保护
|
||||
|
||||
### 队列管理
|
||||
- 任务队列优先级
|
||||
- 失败任务重试队列
|
||||
- 任务状态监控
|
||||
|
||||
## 🎯 最佳实践
|
||||
|
||||
### 1. 文件准备
|
||||
- 使用支持的音频格式:MP3, WAV, M4A
|
||||
- 使用支持的视频格式:MP4, AVI, MOV
|
||||
- 确保文件大小在合理范围内
|
||||
|
||||
### 2. 模型选择
|
||||
- 根据语言选择合适的STT模型
|
||||
- 根据需求选择是否使用LLM优化
|
||||
- 测试模型性能和准确性
|
||||
|
||||
### 3. 批量处理
|
||||
- 合理控制批量上传的数量
|
||||
- 监控系统资源使用情况
|
||||
- 避免在高峰期大量上传
|
||||
|
||||
### 4. 状态监控
|
||||
- 定期检查处理状态
|
||||
- 及时处理失败的任务
|
||||
- 记录处理统计信息
|
||||
|
||||
## 🔍 故障排除
|
||||
|
||||
### 常见问题
|
||||
|
||||
1. **任务卡在排队中**
|
||||
- 检查Celery服务是否运行
|
||||
- 检查任务队列是否正常
|
||||
- 查看系统资源使用情况
|
||||
|
||||
2. **转写质量差**
|
||||
- 检查音频质量
|
||||
- 尝试不同的STT模型
|
||||
- 调整语言设置
|
||||
|
||||
3. **处理失败**
|
||||
- 查看详细错误日志
|
||||
- 检查模型配置
|
||||
- 验证文件格式
|
||||
|
||||
4. **索引创建失败**
|
||||
- 检查向量模型配置
|
||||
- 验证数据库连接
|
||||
- 检查磁盘空间
|
||||
|
||||
### 日志查看
|
||||
```bash
|
||||
# 查看异步任务日志
|
||||
tail -f /var/log/celery/worker.log
|
||||
|
||||
# 查看应用日志
|
||||
tail -f /var/log/maxkb/application.log
|
||||
```
|
||||
|
||||
## 📝 API示例
|
||||
|
||||
### 上传音视频文件
|
||||
```python
|
||||
import requests
|
||||
|
||||
# 上传文件
|
||||
files = {'file': open('meeting.mp3', 'rb')}
|
||||
data = {
|
||||
'name': '会议录音',
|
||||
'stt_model_id': 'whisper-large',
|
||||
'llm_model_id': 'gpt-4'
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
'http://localhost:8000/api/knowledge/{knowledge_id}/document/',
|
||||
files=files,
|
||||
data=data
|
||||
)
|
||||
```
|
||||
|
||||
### 查看文档状态
|
||||
```python
|
||||
import requests
|
||||
|
||||
# 获取文档状态
|
||||
response = requests.get(
|
||||
f'http://localhost:8000/api/knowledge/document/{document_id}/'
|
||||
)
|
||||
|
||||
document = response.json()
|
||||
status = document['status']
|
||||
print(f"文档状态: {status}")
|
||||
```
|
||||
|
||||
## 🎉 总结
|
||||
|
||||
音视频异步处理提供了:
|
||||
- ✅ 完全异步化的处理流程
|
||||
- ✅ 详细的状态追踪和反馈
|
||||
- ✅ 强大的错误处理和重试机制
|
||||
- ✅ 高性能的并发处理能力
|
||||
- ✅ 灵活的配置选项
|
||||
- ✅ 完善的监控和日志
|
||||
|
||||
这大大提升了用户体验和系统稳定性!
|
||||
@ -83,6 +83,8 @@ class MediaAdapter:
|
||||
self.logger.info(f" - stt_model_id: {stt_model_id}")
|
||||
self.logger.info(f" - workspace_id: {workspace_id}")
|
||||
self.logger.info(f" - llm_model_id: {llm_model_id}")
|
||||
self.logger.info(f" - options: {options}")
|
||||
self.logger.info(f" - enable_summary in options: {options.get('enable_summary')}")
|
||||
|
||||
try:
|
||||
# 判断媒体类型
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
音频处理器 - 复用MaxKB的音频处理工具
|
||||
支持同步和异步处理模式
|
||||
"""
|
||||
import asyncio
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class AudioProcessor:
|
||||
"""音频处理器 - 复用MaxKB的音频处理工具"""
|
||||
@ -13,6 +16,7 @@ class AudioProcessor:
|
||||
def __init__(self, config, logger):
|
||||
self.config = config
|
||||
self.logger = logger
|
||||
self.async_processor = None
|
||||
|
||||
def process(self,
|
||||
file_content: bytes,
|
||||
@ -23,6 +27,50 @@ class AudioProcessor:
|
||||
"""处理音频文件"""
|
||||
|
||||
options = options or {}
|
||||
|
||||
# 检查是否启用异步模式
|
||||
use_async = options.get('async_processing', self.config.get('async_processing', False))
|
||||
|
||||
if use_async:
|
||||
return self._process_async(file_content, file_name, stt_model, llm_model, options)
|
||||
else:
|
||||
return self._process_sync(file_content, file_name, stt_model, llm_model, options)
|
||||
|
||||
def _process_async(self, file_content: bytes, file_name: str,
|
||||
stt_model: Optional[Any], llm_model: Optional[Any],
|
||||
options: Dict[str, Any]) -> Dict:
|
||||
"""异步处理音频文件"""
|
||||
try:
|
||||
# 初始化简化异步处理器
|
||||
if not self.async_processor:
|
||||
from ..simple_async_audio_processor import SimpleAsyncAudioProcessor
|
||||
self.async_processor = SimpleAsyncAudioProcessor(self.config, self.logger)
|
||||
|
||||
# 运行异步处理
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
result = loop.run_until_complete(
|
||||
self.async_processor.process_audio_async(
|
||||
file_content, file_name, stt_model, llm_model, options
|
||||
)
|
||||
)
|
||||
return result
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"异步音频处理失败: {e}", exc_info=True)
|
||||
# 回退到同步处理
|
||||
self.logger.info("回退到同步处理模式")
|
||||
return self._process_sync(file_content, file_name, stt_model, llm_model, options)
|
||||
|
||||
def _process_sync(self, file_content: bytes, file_name: str,
|
||||
stt_model: Optional[Any], llm_model: Optional[Any],
|
||||
options: Dict[str, Any]) -> Dict:
|
||||
"""同步处理音频文件"""
|
||||
|
||||
|
||||
segment_duration = options.get('segment_duration', self.config.get('segment_duration', 300)) # 默认5分钟
|
||||
|
||||
# 保存临时文件
|
||||
@ -163,85 +211,106 @@ class AudioProcessor:
|
||||
try:
|
||||
# 调用LLM模型
|
||||
enhanced = None
|
||||
if hasattr(llm_model, 'generate'):
|
||||
response = llm_model.generate(prompt)
|
||||
self.logger.info(f"LLM generate response type: {type(response)}, value: {response}")
|
||||
# 处理不同的响应格式
|
||||
try:
|
||||
if hasattr(response, 'content'):
|
||||
enhanced = response.content
|
||||
elif isinstance(response, str):
|
||||
enhanced = response
|
||||
else:
|
||||
enhanced = str(response)
|
||||
except Exception as attr_error:
|
||||
self.logger.warning(f"Error accessing response content: {str(attr_error)}")
|
||||
enhanced = str(response) if response else original_text
|
||||
elif hasattr(llm_model, 'invoke'):
|
||||
response = llm_model.invoke(prompt)
|
||||
self.logger.info(f"LLM invoke response type: {type(response)}, value: {response}")
|
||||
|
||||
if hasattr(llm_model, 'invoke'):
|
||||
# 使用MaxKB的方式调用模型 - 直接使用invoke方法和标准消息格式
|
||||
self.logger.info(f"Calling llm_model.invoke with MaxKB message format")
|
||||
try:
|
||||
# MaxKB使用消息列表格式
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
response = llm_model.invoke(messages)
|
||||
except Exception as invoke_error:
|
||||
self.logger.warning(f"Invoke with messages failed: {str(invoke_error)}")
|
||||
# 回退到直接invoke
|
||||
response = llm_model.invoke(prompt)
|
||||
self.logger.info(f"LLM invoke response type: {type(response)}, value: {str(response)[:200]}...")
|
||||
# 处理不同的响应格式
|
||||
try:
|
||||
if hasattr(response, 'content'):
|
||||
self.logger.info("Response has 'content' attribute")
|
||||
enhanced = response.content
|
||||
elif isinstance(response, str):
|
||||
self.logger.info("Response is string type")
|
||||
enhanced = response
|
||||
else:
|
||||
self.logger.info(f"Response is other type: {type(response)}")
|
||||
enhanced = str(response)
|
||||
except Exception as attr_error:
|
||||
self.logger.warning(f"Error accessing response content: {str(attr_error)}")
|
||||
enhanced = str(response) if response else original_text
|
||||
else:
|
||||
self.logger.info("LLM model has no generate or invoke method")
|
||||
# 尝试其他可能的方法
|
||||
enhanced = original_text
|
||||
|
||||
# 如果所有方法都失败了,使用原始文本
|
||||
if enhanced is None:
|
||||
self.logger.warning("All LLM methods failed, using original text for enhancement")
|
||||
enhanced = original_text
|
||||
|
||||
if enhanced and enhanced.strip():
|
||||
segment['enhanced_text'] = enhanced.strip()
|
||||
except Exception as e:
|
||||
import traceback
|
||||
self.logger.warning(f"优化文本失败: {str(e)}")
|
||||
|
||||
if options.get('enable_summary', False) and original_text and len(original_text) > 100:
|
||||
self.logger.warning(f"优化文本失败详细堆栈: {traceback.format_exc()}")
|
||||
if original_text and len(original_text) > 50:
|
||||
# 生成摘要
|
||||
prompt = f"请用一句话(不超过50字)总结以下内容的核心要点:\n\n{original_text}"
|
||||
# 添加调试信息:检查原始文本长度和选项
|
||||
|
||||
# 添加调试信息
|
||||
self.logger.info(f"Generating summary for original text (length {len(original_text)}): {original_text[:100]}...")
|
||||
|
||||
try:
|
||||
summary = None
|
||||
if hasattr(llm_model, 'generate'):
|
||||
response = llm_model.generate(prompt)
|
||||
self.logger.info(f"LLM summary generate response type: {type(response)}, value: {response}")
|
||||
# 处理不同的响应格式
|
||||
try:
|
||||
if hasattr(response, 'content'):
|
||||
summary = response.content
|
||||
elif isinstance(response, str):
|
||||
summary = response
|
||||
else:
|
||||
summary = str(response)
|
||||
except Exception as attr_error:
|
||||
self.logger.warning(f"Error accessing summary response content: {str(attr_error)}")
|
||||
summary = str(response) if response else None
|
||||
elif hasattr(llm_model, 'invoke'):
|
||||
response = llm_model.invoke(prompt)
|
||||
self.logger.info(f"LLM summary invoke response type: {type(response)}, value: {response}")
|
||||
if hasattr(llm_model, 'invoke'):
|
||||
# 使用MaxKB的方式调用模型 - 直接使用invoke方法和标准消息格式
|
||||
self.logger.info(f"Calling llm_model.invoke with MaxKB message format (summary)")
|
||||
try:
|
||||
# MaxKB使用消息列表格式
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
response = llm_model.invoke(messages)
|
||||
except Exception as invoke_error:
|
||||
self.logger.warning(f"Invoke with messages failed (summary): {str(invoke_error)}")
|
||||
# 回退到直接invoke
|
||||
response = llm_model.invoke(prompt)
|
||||
self.logger.info(f"LLM summary invoke response type: {type(response)}, value: {str(response)[:200]}...")
|
||||
# 处理不同的响应格式
|
||||
try:
|
||||
if hasattr(response, 'content'):
|
||||
self.logger.info("Summary response has 'content' attribute")
|
||||
summary = response.content
|
||||
elif isinstance(response, str):
|
||||
self.logger.info("Summary response is string type")
|
||||
summary = response
|
||||
else:
|
||||
self.logger.info(f"Summary response is other type: {type(response)}")
|
||||
summary = str(response)
|
||||
except Exception as attr_error:
|
||||
self.logger.warning(f"Error accessing summary response content: {str(attr_error)}")
|
||||
summary = str(response) if response else None
|
||||
else:
|
||||
self.logger.info("LLM model has no generate or invoke method for summary")
|
||||
summary = None
|
||||
|
||||
# 如果所有方法都失败了,使用原始文本
|
||||
if summary is None:
|
||||
self.logger.warning("All LLM methods failed, using original text for summary")
|
||||
summary = original_text[:100] + "..." if len(original_text) > 100 else original_text
|
||||
|
||||
if summary and summary.strip():
|
||||
segment['summary'] = summary.strip()
|
||||
self.logger.info(f"Successfully generated summary: {summary.strip()}")
|
||||
else:
|
||||
self.logger.info("Summary generation failed or returned empty summary")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
self.logger.warning(f"生成摘要失败: {str(e)}")
|
||||
|
||||
self.logger.warning(f"生成摘要失败详细堆栈: {traceback.format_exc()}")
|
||||
|
||||
|
||||
|
||||
return segments
|
||||
except Exception as e:
|
||||
self.logger.error(f"文本优化失败: {str(e)}")
|
||||
@ -251,4 +320,4 @@ class AudioProcessor:
|
||||
"""获取文件后缀"""
|
||||
if '.' in file_name:
|
||||
return '.' + file_name.split('.')[-1].lower()
|
||||
return '.mp3'
|
||||
return '.mp3'
|
||||
|
||||
@ -0,0 +1,467 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
简化异步音频处理器 - 单队列异步执行
|
||||
"""
|
||||
import asyncio
|
||||
import io
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from .logger import MediaLogger
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
"""任务处理状态"""
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioSegmentTask:
|
||||
"""音频片段任务"""
|
||||
segment_id: int
|
||||
file_content: bytes
|
||||
file_name: str
|
||||
start_time: float
|
||||
end_time: float
|
||||
temp_dir: str
|
||||
status: TaskStatus = TaskStatus.PENDING
|
||||
transcription: Optional[str] = None
|
||||
enhanced_text: Optional[str] = None
|
||||
summary: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
# 模型和选项
|
||||
stt_model: Optional[Any] = field(default=None, repr=False)
|
||||
llm_model: Optional[Any] = field(default=None, repr=False)
|
||||
options: Dict[str, Any] = field(default_factory=dict)
|
||||
# 重试配置
|
||||
retry_count: int = 0
|
||||
max_retries: int = 3
|
||||
|
||||
|
||||
class SimpleAsyncAudioProcessor:
|
||||
"""
|
||||
简化异步音频处理器 - 单队列异步执行
|
||||
|
||||
架构特点:
|
||||
- 单个工作线程池处理所有任务
|
||||
- 每个任务独立完成分割、转写、增强、摘要等所有步骤
|
||||
- 简化的队列管理,专注于异步执行
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any], logger_wrapper: MediaLogger):
|
||||
self.config = config
|
||||
self.logger = logger_wrapper
|
||||
|
||||
# 任务队列
|
||||
self.task_queue = queue.Queue(maxsize=config.get('queue_size', 10))
|
||||
|
||||
# 任务跟踪
|
||||
self.segment_tasks: Dict[int, AudioSegmentTask] = {}
|
||||
self.tasks_lock = threading.Lock()
|
||||
|
||||
# 线程控制
|
||||
self.shutdown_event = threading.Event()
|
||||
self.workers: List[threading.Thread] = []
|
||||
|
||||
# 结果收集
|
||||
self.completed_tasks: List[AudioSegmentTask] = []
|
||||
self.completed_lock = threading.Lock()
|
||||
|
||||
# 线程池大小
|
||||
self.worker_count = config.get('worker_count', 2)
|
||||
|
||||
def initialize_workers(self):
|
||||
"""初始化工作线程"""
|
||||
self.logger.info(f"初始化 {self.worker_count} 个异步音频处理工作线程...")
|
||||
|
||||
# 创建工作线程
|
||||
for i in range(self.worker_count):
|
||||
worker = threading.Thread(
|
||||
target=self._worker_loop,
|
||||
name=f"Audio-Worker-{i+1}",
|
||||
daemon=True
|
||||
)
|
||||
worker.start()
|
||||
self.workers.append(worker)
|
||||
self.logger.info(f"启动工作线程: {worker.name}")
|
||||
|
||||
def _worker_loop(self):
|
||||
"""工作线程主循环"""
|
||||
self.logger.info(f"工作线程 {threading.current_thread().name} 启动")
|
||||
|
||||
while not self.shutdown_event.is_set():
|
||||
try:
|
||||
# 从队列获取任务
|
||||
try:
|
||||
task = self.task_queue.get(timeout=1.0)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
self.logger.info(f"工作线程 {threading.current_thread().name} 处理片段 {task.segment_id}")
|
||||
|
||||
# 更新任务状态
|
||||
with self.tasks_lock:
|
||||
task.status = TaskStatus.PROCESSING
|
||||
self.segment_tasks[task.segment_id] = task
|
||||
|
||||
try:
|
||||
# 处理任务(包含所有步骤)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(self._process_task_async(task))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
self.logger.info(f"工作线程 {threading.current_thread().name} 完成片段 {task.segment_id}")
|
||||
|
||||
except Exception as e:
|
||||
task.error = f"任务处理失败: {str(e)}"
|
||||
self.logger.error(f"工作线程 {threading.current_thread().name} 失败片段 {task.segment_id}: {e}")
|
||||
self._mark_task_completed(task)
|
||||
|
||||
finally:
|
||||
self.task_queue.task_done()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"工作线程 {threading.current_thread().name} 错误: {e}")
|
||||
|
||||
self.logger.info(f"工作线程 {threading.current_thread().name} 停止")
|
||||
|
||||
async def _process_task_async(self, task: AudioSegmentTask):
|
||||
"""异步处理单个任务(包含所有步骤)"""
|
||||
try:
|
||||
# 1. 分割音频
|
||||
audio_path = await self._split_audio_segment(task)
|
||||
task.audio_path = audio_path
|
||||
task.metadata['audio_duration'] = task.end_time - task.start_time
|
||||
|
||||
# 2. 转写音频
|
||||
if task.stt_model:
|
||||
transcription = await self._transcribe_audio_segment(task)
|
||||
task.transcription = transcription
|
||||
else:
|
||||
task.transcription = f"[音频片段 {task.segment_id}]"
|
||||
|
||||
# 3. 增强文本
|
||||
if task.llm_model and task.options.get('enable_punctuation', True):
|
||||
enhanced_text = await self._enhance_text_segment(task)
|
||||
task.enhanced_text = enhanced_text
|
||||
else:
|
||||
task.enhanced_text = task.transcription
|
||||
|
||||
# 4. 生成摘要
|
||||
if task.llm_model:
|
||||
summary = await self._generate_summary(task)
|
||||
task.summary = summary
|
||||
|
||||
# 标记任务完成
|
||||
self._mark_task_completed(task)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def _split_audio_segment(self, task: AudioSegmentTask) -> str:
|
||||
"""分割音频片段"""
|
||||
try:
|
||||
# 保存临时音频文件
|
||||
audio_path = os.path.join(task.temp_dir, f"segment_{task.segment_id}.mp3")
|
||||
|
||||
# 使用BytesIO处理音频内容
|
||||
audio_buffer = io.BytesIO(task.file_content)
|
||||
|
||||
# 使用pydub分割音频
|
||||
from pydub import AudioSegment
|
||||
audio = AudioSegment.from_file(audio_buffer)
|
||||
|
||||
# 计算时间点(毫秒)
|
||||
start_ms = int(task.start_time * 1000)
|
||||
end_ms = int(task.end_time * 1000)
|
||||
|
||||
# 提取片段
|
||||
segment_audio = audio[start_ms:end_ms]
|
||||
|
||||
# 保存为MP3
|
||||
segment_audio.export(audio_path, format='mp3')
|
||||
|
||||
self.logger.info(f"已分割音频片段 {task.segment_id}: {task.start_time:.1f}s - {task.end_time:.1f}s")
|
||||
return audio_path
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"分割音频片段失败: {e}")
|
||||
raise
|
||||
|
||||
async def _transcribe_audio_segment(self, task: AudioSegmentTask) -> str:
|
||||
"""转写音频片段"""
|
||||
try:
|
||||
from common.utils.common import split_and_transcribe
|
||||
|
||||
# 调用转写函数
|
||||
text = split_and_transcribe(task.audio_path, task.stt_model)
|
||||
|
||||
self.logger.info(f"已转写音频片段 {task.segment_id}: {len(text)} 字符")
|
||||
return text if text else "[无法识别]"
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"转写音频片段失败: {e}")
|
||||
raise
|
||||
|
||||
async def _enhance_text_segment(self, task: AudioSegmentTask) -> str:
|
||||
"""增强文本片段"""
|
||||
try:
|
||||
if not task.transcription:
|
||||
return ""
|
||||
|
||||
# 添加标点符号
|
||||
prompt = f"请为以下语音转写文本添加适当的标点符号,保持原意不变,直接返回处理后的文本:\n\n{task.transcription}"
|
||||
|
||||
# 调用LLM模型
|
||||
enhanced = await self._call_llm_model(task.llm_model, prompt)
|
||||
|
||||
if enhanced and enhanced.strip():
|
||||
return enhanced.strip()
|
||||
else:
|
||||
return task.transcription
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"增强文本失败: {e}")
|
||||
return task.transcription
|
||||
|
||||
async def _generate_summary(self, task: AudioSegmentTask) -> Optional[str]:
|
||||
"""生成摘要"""
|
||||
try:
|
||||
text = task.enhanced_text or task.transcription
|
||||
|
||||
if len(text) < 50: # 文本太短不生成摘要
|
||||
return None
|
||||
|
||||
# 生成摘要
|
||||
prompt = f"请用一句话(不超过50字)总结以下内容的核心要点:\n\n{text}"
|
||||
|
||||
# 调用LLM模型
|
||||
summary = await self._call_llm_model(task.llm_model, prompt)
|
||||
|
||||
if summary and summary.strip():
|
||||
return summary.strip()
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"生成摘要失败: {e}")
|
||||
return None
|
||||
|
||||
async def _call_llm_model(self, llm_model, prompt: str) -> Optional[str]:
|
||||
"""调用LLM模型"""
|
||||
try:
|
||||
if hasattr(llm_model, 'invoke'):
|
||||
# 使用MaxKB的消息格式
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
response = llm_model.invoke(messages)
|
||||
|
||||
# 处理响应
|
||||
if hasattr(response, 'content'):
|
||||
return response.content
|
||||
elif isinstance(response, str):
|
||||
return response
|
||||
else:
|
||||
return str(response)
|
||||
else:
|
||||
self.logger.warning("LLM模型不支持invoke方法")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"调用LLM模型失败: {e}")
|
||||
return None
|
||||
|
||||
def _mark_task_completed(self, task: AudioSegmentTask):
|
||||
"""标记任务完成"""
|
||||
with self.tasks_lock:
|
||||
task.status = TaskStatus.COMPLETED
|
||||
|
||||
with self.completed_lock:
|
||||
self.completed_tasks.append(task)
|
||||
|
||||
self.logger.info(f"任务完成: 片段 {task.segment_id}, 状态: {task.status.value}")
|
||||
|
||||
async def process_audio_async(self, file_content: bytes, file_name: str,
|
||||
stt_model: Any, llm_model: Any,
|
||||
options: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
异步处理音频文件
|
||||
|
||||
Args:
|
||||
file_content: 音频文件内容
|
||||
file_name: 文件名
|
||||
stt_model: STT模型
|
||||
llm_model: LLM模型
|
||||
options: 处理选项
|
||||
|
||||
Returns:
|
||||
处理结果字典
|
||||
"""
|
||||
# 初始化工作线程
|
||||
if not self.workers:
|
||||
self.initialize_workers()
|
||||
|
||||
# 清理之前的任务
|
||||
with self.tasks_lock:
|
||||
self.segment_tasks.clear()
|
||||
with self.completed_lock:
|
||||
self.completed_tasks.clear()
|
||||
|
||||
# 创建临时目录
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# 获取音频总时长
|
||||
duration = await self._get_audio_duration_async(file_content)
|
||||
|
||||
# 计算分段
|
||||
segment_duration = options.get('segment_duration', 300) # 默认5分钟
|
||||
num_segments = int(duration / segment_duration) + 1
|
||||
|
||||
self.logger.info(f"开始异步处理音频: 总时长 {duration:.1f}秒, 分段数: {num_segments}")
|
||||
|
||||
# 创建分段任务并加入队列
|
||||
for i in range(num_segments):
|
||||
start_time = i * segment_duration
|
||||
end_time = min((i + 1) * segment_duration, duration)
|
||||
|
||||
task = AudioSegmentTask(
|
||||
segment_id=i,
|
||||
file_content=file_content,
|
||||
file_name=file_name,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
temp_dir=temp_dir,
|
||||
stt_model=stt_model,
|
||||
llm_model=llm_model,
|
||||
options=options,
|
||||
max_retries=3
|
||||
)
|
||||
|
||||
# 添加到任务队列
|
||||
self.task_queue.put(task)
|
||||
|
||||
# 等待所有任务完成
|
||||
start_time = time.time()
|
||||
while True:
|
||||
with self.completed_lock:
|
||||
completed_count = len(self.completed_tasks)
|
||||
|
||||
if completed_count >= num_segments:
|
||||
break
|
||||
|
||||
# 检查超时
|
||||
if time.time() - start_time > 3600: # 1小时超时
|
||||
self.logger.error("处理超时")
|
||||
break
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# 收集结果
|
||||
segments = []
|
||||
for task in self.completed_tasks:
|
||||
segment = {
|
||||
'index': task.segment_id,
|
||||
'start_time': task.start_time,
|
||||
'end_time': task.end_time,
|
||||
'text': task.transcription or '',
|
||||
'enhanced_text': task.enhanced_text or task.transcription or '',
|
||||
'summary': task.summary
|
||||
}
|
||||
if task.error:
|
||||
segment['error'] = task.error
|
||||
segments.append(segment)
|
||||
|
||||
# 按segment_id排序
|
||||
segments.sort(key=lambda x: x['index'])
|
||||
|
||||
# 生成完整文本
|
||||
full_text = '\n'.join([seg.get('enhanced_text', seg.get('text', '')) for seg in segments])
|
||||
|
||||
return {
|
||||
'status': 'success',
|
||||
'media_type': 'audio',
|
||||
'duration': duration,
|
||||
'segments': segments,
|
||||
'full_text': full_text,
|
||||
'metadata': {
|
||||
'file_name': file_name,
|
||||
'stt_model': str(stt_model) if stt_model else None,
|
||||
'language': options.get('language', 'auto'),
|
||||
'processing_time': time.time() - start_time,
|
||||
'worker_count': self.worker_count
|
||||
}
|
||||
}
|
||||
|
||||
async def _get_audio_duration_async(self, file_content: bytes) -> float:
|
||||
"""异步获取音频时长"""
|
||||
try:
|
||||
# 在线程池中执行同步操作
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None, self._get_audio_duration_sync, file_content
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"获取音频时长失败: {e}")
|
||||
return 0
|
||||
|
||||
def _get_audio_duration_sync(self, file_content: bytes) -> float:
|
||||
"""同步获取音频时长"""
|
||||
try:
|
||||
from pydub import AudioSegment
|
||||
audio_buffer = io.BytesIO(file_content)
|
||||
audio = AudioSegment.from_file(audio_buffer)
|
||||
return len(audio) / 1000 # 转换为秒
|
||||
except Exception as e:
|
||||
self.logger.error(f"获取音频时长失败: {e}")
|
||||
return 0
|
||||
|
||||
def get_queue_status(self) -> Dict[str, Any]:
|
||||
"""获取队列状态"""
|
||||
return {
|
||||
'queue': {
|
||||
'size': self.task_queue.qsize(),
|
||||
'max_size': self.task_queue.maxsize
|
||||
},
|
||||
'tasks': {
|
||||
'total': len(self.segment_tasks),
|
||||
'pending': sum(1 for t in self.segment_tasks.values() if t.status == TaskStatus.PENDING),
|
||||
'processing': sum(1 for t in self.segment_tasks.values() if t.status == TaskStatus.PROCESSING),
|
||||
'completed': len(self.completed_tasks)
|
||||
},
|
||||
'workers': {
|
||||
'active': len([w for w in self.workers if w.is_alive()]),
|
||||
'total': len(self.workers)
|
||||
}
|
||||
}
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭所有工作线程"""
|
||||
self.logger.info("关闭简化异步音频处理器...")
|
||||
|
||||
# 发送关闭信号
|
||||
self.shutdown_event.set()
|
||||
|
||||
# 等待线程完成
|
||||
for worker in self.workers:
|
||||
worker.join(timeout=5.0)
|
||||
if worker.is_alive():
|
||||
self.logger.warning(f"工作线程 {worker.name} 未正常停止")
|
||||
|
||||
# 清理数据
|
||||
self.workers.clear()
|
||||
with self.tasks_lock:
|
||||
self.segment_tasks.clear()
|
||||
with self.completed_lock:
|
||||
self.completed_tasks.clear()
|
||||
|
||||
self.logger.info("简化异步音频处理器关闭完成")
|
||||
@ -31,7 +31,150 @@ class MediaSplitHandle(BaseSplitHandle):
|
||||
"""处理音视频文件"""
|
||||
|
||||
maxkb_logger.info(f"MediaSplitHandle.handle called with file: {file.name}")
|
||||
maxkb_logger.info(f"kwargs received: {kwargs}")
|
||||
|
||||
# 检查是否需要实际处理
|
||||
use_actual_processing = kwargs.get('use_actual_processing', False)
|
||||
stt_model_id = kwargs.get('stt_model_id')
|
||||
|
||||
if use_actual_processing and stt_model_id:
|
||||
# 进行实际处理
|
||||
return self._handle_actual_processing(file, get_buffer, **kwargs)
|
||||
else:
|
||||
# 使用默认文本
|
||||
return self._handle_default_text(file, **kwargs)
|
||||
|
||||
def _get_audio_default_segments(self, file_name: str) -> List[dict]:
|
||||
"""生成音频文件的默认段落"""
|
||||
base_name = file_name.split('.')[0]
|
||||
|
||||
return [
|
||||
{
|
||||
'title': '开场介绍',
|
||||
'content': f'这是音频文件 "{base_name}" 的第一段内容演示。本段包含了会议的开场介绍和主要议题的说明。\n\n主要内容:\n- 会议目的和议程说明\n- 参会人员介绍\n- 会议背景和重要性\n- 预期成果和目标设定',
|
||||
'start_time': 0,
|
||||
'end_time': 180
|
||||
},
|
||||
{
|
||||
'title': '主要内容讨论',
|
||||
'content': f'这是音频文件 "{base_name}" 的第二段内容演示。本段详细讨论了项目的进展情况和下一步的工作计划。\n\n主要内容:\n- 项目当前进展汇报\n- 关键问题和挑战分析\n- 解决方案讨论\n- 资源需求和分配',
|
||||
'start_time': 180,
|
||||
'end_time': 360
|
||||
},
|
||||
{
|
||||
'title': '总结与行动项',
|
||||
'content': f'这是音频文件 "{base_name}" 的第三段内容演示。本段总结了会议的主要结论和行动项,明确了责任人和时间节点。\n\n主要内容:\n- 会议要点总结\n- 行动项和责任分配\n- 时间节点和里程碑\n- 后续跟进计划',
|
||||
'start_time': 360,
|
||||
'end_time': 540
|
||||
}
|
||||
]
|
||||
|
||||
def _get_video_default_segments(self, file_name: str) -> List[dict]:
|
||||
"""生成视频文件的默认段落"""
|
||||
base_name = file_name.split('.')[0]
|
||||
|
||||
return [
|
||||
{
|
||||
'title': '开场介绍',
|
||||
'content': f'这是视频文件 "{base_name}" 的第一段内容演示。本段包含了视频的开场介绍和主要内容概述。\n\n主要内容:\n- 产品/服务介绍\n- 功能特性概述\n- 目标用户群体\n- 使用场景说明',
|
||||
'start_time': 0,
|
||||
'end_time': 120
|
||||
},
|
||||
{
|
||||
'title': '功能演示',
|
||||
'content': f'这是视频文件 "{base_name}" 的第二段内容演示。本段详细展示了产品的功能特性和使用方法。\n\n主要内容:\n- 核心功能演示\n- 操作步骤说明\n- 使用技巧和注意事项\n- 常见问题解答',
|
||||
'start_time': 120,
|
||||
'end_time': 300
|
||||
},
|
||||
{
|
||||
'title': '总结与联系方式',
|
||||
'content': f'这是视频文件 "{base_name}" 的第三段内容演示。本段总结了产品的主要优势和适用场景,提供了联系方式。\n\n主要内容:\n- 产品优势总结\n- 价格和套餐信息\n- 适用场景和行业\n- 联系方式和售后服务',
|
||||
'start_time': 300,
|
||||
'end_time': 420
|
||||
}
|
||||
]
|
||||
|
||||
def _get_media_default_segments(self, file_name: str) -> List[dict]:
|
||||
"""生成其他媒体文件的默认段落"""
|
||||
base_name = file_name.split('.')[0]
|
||||
|
||||
return [
|
||||
{
|
||||
'title': '文件概述',
|
||||
'content': f'这是媒体文件 "{base_name}" 的第一段内容演示。本段包含了文件的基本信息和主要内容概述。\n\n主要内容:\n- 文件基本信息\n- 内容类型说明\n- 主要用途和价值\n- 处理建议和注意事项',
|
||||
'start_time': 0,
|
||||
'end_time': 120
|
||||
},
|
||||
{
|
||||
'title': '详细内容',
|
||||
'content': f'这是媒体文件 "{base_name}" 的第二段内容演示。本段详细介绍了文件的核心内容和关键信息。\n\n主要内容:\n- 核心内容分析\n- 关键信息提取\n- 重要要点总结\n- 后续处理建议',
|
||||
'start_time': 120,
|
||||
'end_time': 240
|
||||
}
|
||||
]
|
||||
|
||||
def _handle_default_text(self, file, **kwargs) -> dict:
|
||||
"""使用默认文本处理音视频文件"""
|
||||
|
||||
maxkb_logger.info(f"Using default text for media processing: {file.name}")
|
||||
|
||||
# 获取文件名和类型
|
||||
file_name = file.name
|
||||
file_ext = file_name.lower().split('.')[-1]
|
||||
|
||||
# 判断媒体类型
|
||||
audio_exts = {'mp3', 'wav', 'm4a', 'flac', 'aac', 'ogg', 'wma'}
|
||||
video_exts = {'mp4', 'avi', 'mov', 'mkv', 'webm', 'flv', 'wmv'}
|
||||
|
||||
if file_ext in audio_exts:
|
||||
media_type = "音频"
|
||||
default_segments = self._get_audio_default_segments(file_name)
|
||||
elif file_ext in video_exts:
|
||||
media_type = "视频"
|
||||
default_segments = self._get_video_default_segments(file_name)
|
||||
else:
|
||||
media_type = "媒体"
|
||||
default_segments = self._get_media_default_segments(file_name)
|
||||
|
||||
maxkb_logger.info(f"Processing {media_type} file: {file_name}")
|
||||
maxkb_logger.info(f"Generating {len(default_segments)} default segments")
|
||||
|
||||
# 转换为MaxKB段落格式
|
||||
paragraphs = []
|
||||
for i, segment_data in enumerate(default_segments):
|
||||
paragraph = {
|
||||
'content': segment_data['content'],
|
||||
'title': segment_data['title'],
|
||||
'metadata': {
|
||||
'start_time': segment_data.get('start_time'),
|
||||
'end_time': segment_data.get('end_time'),
|
||||
'index': i,
|
||||
'is_demo': True,
|
||||
'media_type': media_type,
|
||||
'file_name': file_name
|
||||
}
|
||||
}
|
||||
paragraphs.append(paragraph)
|
||||
|
||||
# 添加处理元数据
|
||||
metadata = {
|
||||
'media_processing_status': 'success',
|
||||
'media_type': media_type,
|
||||
'is_demo_content': True,
|
||||
'processing_mode': 'default_text'
|
||||
}
|
||||
|
||||
maxkb_logger.info(f"Successfully created {len(paragraphs)} default paragraphs for {file_name}")
|
||||
|
||||
return {
|
||||
'name': file.name,
|
||||
'content': paragraphs,
|
||||
'metadata': metadata
|
||||
}
|
||||
|
||||
def _handle_actual_processing(self, file, get_buffer, **kwargs) -> dict:
|
||||
"""实际处理音视频文件"""
|
||||
|
||||
maxkb_logger.info(f"Starting actual processing for media file: {file.name}")
|
||||
|
||||
# 初始化适配器
|
||||
if not self.adapter:
|
||||
@ -51,12 +194,13 @@ class MediaSplitHandle(BaseSplitHandle):
|
||||
maxkb_logger.info(f"Extracted from kwargs - stt_model_id: {stt_model_id}, llm_model_id: {llm_model_id}, workspace_id: {workspace_id}")
|
||||
|
||||
# 处理选项
|
||||
options_param = kwargs.get('options', {})
|
||||
options = {
|
||||
'language': kwargs.get('language', 'auto'),
|
||||
'segment_duration': kwargs.get('segment_duration', 300),
|
||||
'enable_punctuation': kwargs.get('enable_punctuation', True),
|
||||
'enable_summary': kwargs.get('enable_summary', False),
|
||||
'extract_keyframes': kwargs.get('extract_keyframes', False)
|
||||
'language': options_param.get('language', kwargs.get('language', 'auto')),
|
||||
'segment_duration': options_param.get('segment_duration', kwargs.get('segment_duration', 300)),
|
||||
'enable_punctuation': options_param.get('enable_punctuation', kwargs.get('enable_punctuation', True)),
|
||||
'enable_summary': True,
|
||||
'extract_keyframes': options_param.get('extract_keyframes', kwargs.get('extract_keyframes', False))
|
||||
}
|
||||
|
||||
try:
|
||||
@ -83,7 +227,8 @@ class MediaSplitHandle(BaseSplitHandle):
|
||||
|
||||
# 添加摘要(如果有)
|
||||
if segment.get('summary'):
|
||||
text = f"{text}\n【摘要】{segment['summary']}"
|
||||
text = f"## 摘要\n\n{segment['summary']}\n\n---\n\n{text}"
|
||||
maxkb_logger.info(f"Adding summary to paragraph: {segment['summary'][:50]}...")
|
||||
|
||||
paragraph = {
|
||||
'content': text,
|
||||
@ -91,23 +236,20 @@ class MediaSplitHandle(BaseSplitHandle):
|
||||
'metadata': {
|
||||
'start_time': segment.get('start_time'),
|
||||
'end_time': segment.get('end_time'),
|
||||
'index': segment.get('index')
|
||||
'index': segment.get('index'),
|
||||
'is_demo': False,
|
||||
'media_type': 'actual'
|
||||
}
|
||||
}
|
||||
|
||||
# 如果有关键帧,添加到段落中
|
||||
if 'keyframes' in result and segment.get('index', 0) < len(result['keyframes']):
|
||||
paragraph['images'] = [result['keyframes'][segment['index']]]
|
||||
|
||||
paragraphs.append(paragraph)
|
||||
|
||||
# 应用限制
|
||||
if limit > 0:
|
||||
paragraphs = paragraphs[:limit]
|
||||
|
||||
# 添加成功处理的标记
|
||||
metadata = result.get('metadata', {})
|
||||
metadata['media_processing_status'] = 'success'
|
||||
metadata['is_demo_content'] = False
|
||||
metadata['processing_mode'] = 'actual_processing'
|
||||
|
||||
maxkb_logger.info(f"Successfully processed {file.name}, generated {len(paragraphs)} actual paragraphs")
|
||||
|
||||
return {
|
||||
'name': file.name,
|
||||
@ -116,15 +258,15 @@ class MediaSplitHandle(BaseSplitHandle):
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
maxkb_logger.error(f"处理音视频文件失败: {str(e)}")
|
||||
maxkb_logger.error(f"实际处理音视频文件失败: {str(e)}")
|
||||
# 返回错误信息
|
||||
return {
|
||||
'name': file.name,
|
||||
'content': [{
|
||||
'content': f'处理失败: {str(e)}',
|
||||
'content': f'实际处理失败: {str(e)}',
|
||||
'title': '错误'
|
||||
}],
|
||||
'metadata': {'error': str(e)}
|
||||
'metadata': {'error': str(e), 'media_processing_status': 'failed'}
|
||||
}
|
||||
|
||||
def get_content(self, file, save_image):
|
||||
|
||||
@ -4,3 +4,9 @@ from django.apps import AppConfig
|
||||
class KnowledgeConfig(AppConfig):
|
||||
default_auto_field = 'django.db.models.BigAutoField'
|
||||
name = 'knowledge'
|
||||
|
||||
def ready(self):
|
||||
"""在Django应用准备好后,确保Celery任务能被发现"""
|
||||
# 不在这里手动注册任务,让Celery的自动发现机制处理
|
||||
# 这样可以避免递归调用问题
|
||||
pass
|
||||
|
||||
@ -32,6 +32,8 @@ class TaskType(Enum):
|
||||
GENERATE_PROBLEM = 2
|
||||
# 同步
|
||||
SYNC = 3
|
||||
# 生成
|
||||
GENERATE = 4
|
||||
|
||||
|
||||
class State(Enum):
|
||||
|
||||
@ -1205,73 +1205,12 @@ class DocumentSerializers(serializers.Serializer):
|
||||
document['paragraphs'] = [] # 清空段落,等待异步处理
|
||||
# 检查是否是音视频类型的文档
|
||||
elif stt_model_id:
|
||||
maxkb_logger.info(f"Document {document.get('name')} is media type, processing synchronously")
|
||||
# 音视频类型的文档,直接处理
|
||||
source_file_id = document.get('source_file_id')
|
||||
if source_file_id:
|
||||
try:
|
||||
source_file = QuerySet(File).filter(id=source_file_id).first()
|
||||
if source_file:
|
||||
workspace_id_value = self.data.get('workspace_id', '')
|
||||
maxkb_logger.info(f"Processing media file: {source_file.file_name}")
|
||||
maxkb_logger.info(f" - STT model ID: {stt_model_id}")
|
||||
maxkb_logger.info(f" - Workspace ID: {workspace_id_value}")
|
||||
maxkb_logger.info(f" - LLM model ID: {llm_model_id}")
|
||||
|
||||
# 使用MediaSplitHandle处理音视频文件
|
||||
from common.handle.impl.media.media_split_handle import MediaSplitHandle
|
||||
media_handler = MediaSplitHandle()
|
||||
|
||||
# 准备文件对象
|
||||
class FileWrapper:
|
||||
def __init__(self, file_obj):
|
||||
self.file_obj = file_obj
|
||||
self.name = file_obj.file_name
|
||||
self.size = file_obj.file_size
|
||||
|
||||
def read(self):
|
||||
return self.file_obj.get_bytes()
|
||||
|
||||
def seek(self, pos):
|
||||
pass
|
||||
|
||||
file_wrapper = FileWrapper(source_file)
|
||||
|
||||
# 处理音视频文件
|
||||
result = media_handler.handle(
|
||||
file_wrapper,
|
||||
pattern_list=[],
|
||||
with_filter=False,
|
||||
limit=0,
|
||||
get_buffer=lambda f: f.read(),
|
||||
save_image=lambda x: None,
|
||||
workspace_id=workspace_id_value,
|
||||
stt_model_id=stt_model_id,
|
||||
llm_model_id=llm_model_id
|
||||
)
|
||||
|
||||
# 将处理结果添加到文档
|
||||
if result and result.get('content'):
|
||||
document['paragraphs'] = result.get('content', [])
|
||||
maxkb_logger.info(f"Media file processed, got {len(document['paragraphs'])} paragraphs")
|
||||
else:
|
||||
maxkb_logger.warning(f"No content extracted from media file, using default")
|
||||
document['paragraphs'] = [{
|
||||
'content': f'[音视频文件: {source_file.file_name}]',
|
||||
'title': '音视频内容'
|
||||
}]
|
||||
except Exception as e:
|
||||
maxkb_logger.error(f"Failed to process media file: {str(e)}")
|
||||
import traceback
|
||||
maxkb_logger.error(traceback.format_exc())
|
||||
# 如果处理失败,创建一个默认段落
|
||||
document['paragraphs'] = [{
|
||||
'content': f'[音视频文件处理失败: {str(e)}]',
|
||||
'title': '处理失败'
|
||||
}]
|
||||
else:
|
||||
maxkb_logger.warning(f"No source file for media document")
|
||||
document['paragraphs'] = []
|
||||
maxkb_logger.info(f"Document {document.get('name')} is media type, will process asynchronously")
|
||||
# 音视频类型的文档,设置为异步处理
|
||||
# 清空段落,等待异步任务处理
|
||||
document['paragraphs'] = []
|
||||
# 标记为异步音视频文档,用于后续异步处理
|
||||
document['is_media_async'] = True
|
||||
|
||||
# 插入文档
|
||||
for document in instance_list:
|
||||
@ -1373,10 +1312,25 @@ class DocumentSerializers(serializers.Serializer):
|
||||
State.FAILURE
|
||||
)
|
||||
|
||||
# 批量插入段落(只为非高级学习文档)
|
||||
# 批量插入段落(只为非高级学习文档和非音视频文档)
|
||||
if len(paragraph_model_list) > 0:
|
||||
maxkb_logger.info(f"Total paragraphs to insert: {len(paragraph_model_list)}")
|
||||
|
||||
# 获取音视频文档ID列表
|
||||
media_document_ids = []
|
||||
for idx, document in enumerate(instance_list):
|
||||
stt_model_id = document.get('stt_model_id')
|
||||
if stt_model_id and idx < len(document_model_list):
|
||||
media_document_ids.append(str(document_model_list[idx].id))
|
||||
|
||||
maxkb_logger.info(f"Media document IDs to skip paragraph insertion: {media_document_ids}")
|
||||
|
||||
for document in document_model_list:
|
||||
# 跳过高级学习文档和音视频文档的段落插入
|
||||
if str(document.id) in media_document_ids:
|
||||
maxkb_logger.info(f"Skipping paragraph insertion for media document: {document.id}")
|
||||
continue
|
||||
|
||||
max_position = Paragraph.objects.filter(document_id=document.id).aggregate(
|
||||
max_position=Max('position')
|
||||
)['max_position'] or 0
|
||||
@ -1410,17 +1364,67 @@ class DocumentSerializers(serializers.Serializer):
|
||||
with_search_one=False
|
||||
)
|
||||
|
||||
# 标记高级学习文档和音视频文档
|
||||
# 标记高级学习文档和音视频文档,并触发异步任务
|
||||
for idx, document in enumerate(instance_list):
|
||||
llm_model_id = document.get('llm_model_id')
|
||||
vision_model_id = document.get('vision_model_id')
|
||||
stt_model_id = document.get('stt_model_id')
|
||||
|
||||
if idx < len(document_result_list):
|
||||
document_id = document_result_list[idx].get('id')
|
||||
|
||||
if llm_model_id and vision_model_id:
|
||||
document_result_list[idx]['is_advanced_learning'] = True
|
||||
# 触发高级学习异步任务
|
||||
try:
|
||||
from knowledge.tasks.advanced_learning import batch_advanced_learning
|
||||
batch_advanced_learning.delay(
|
||||
[document_id],
|
||||
str(knowledge_id),
|
||||
workspace_id,
|
||||
llm_model_id,
|
||||
vision_model_id
|
||||
)
|
||||
maxkb_logger.info(f"Submitted advanced learning task for document: {document_id}")
|
||||
except Exception as e:
|
||||
maxkb_logger.error(f"Failed to submit advanced learning task: {str(e)}")
|
||||
|
||||
elif stt_model_id:
|
||||
document_result_list[idx]['is_media_learning'] = True
|
||||
# 设置排队状态并触发音视频异步任务
|
||||
try:
|
||||
from common.event import ListenerManagement
|
||||
from knowledge.models import TaskType, State
|
||||
|
||||
# 更新文档状态为排队中
|
||||
ListenerManagement.update_status(
|
||||
QuerySet(Document).filter(id=document_id),
|
||||
TaskType.GENERATE,
|
||||
State.PENDING
|
||||
)
|
||||
|
||||
# 触发音视频异步处理任务
|
||||
from knowledge.tasks.media_learning import media_learning_by_document
|
||||
media_learning_by_document.delay(
|
||||
document_id,
|
||||
str(knowledge_id),
|
||||
workspace_id,
|
||||
stt_model_id,
|
||||
llm_model_id
|
||||
)
|
||||
maxkb_logger.info(f"Submitted media learning task for document: {document_id}, status: PENDING")
|
||||
|
||||
except Exception as e:
|
||||
maxkb_logger.error(f"Failed to submit media learning task: {str(e)}")
|
||||
# 如果提交任务失败,更新状态为失败
|
||||
try:
|
||||
ListenerManagement.update_status(
|
||||
QuerySet(Document).filter(id=document_id),
|
||||
TaskType.GENERATE,
|
||||
State.FAILURE
|
||||
)
|
||||
except Exception as status_error:
|
||||
maxkb_logger.error(f"Failed to update status to FAILURE: {str(status_error)}")
|
||||
|
||||
return document_result_list, knowledge_id, workspace_id
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
|
||||
# Import tasks for Celery discovery
|
||||
# Note: We import the specific tasks, not * to avoid circular imports
|
||||
from .advanced_learning import advanced_learning_by_document, batch_advanced_learning
|
||||
# Note: We use lazy imports to avoid Django app loading issues
|
||||
# Tasks will be imported when needed by Celery's autodiscover
|
||||
@ -137,6 +137,26 @@ def embedding_by_data_list(args: List, model_id):
|
||||
ListenerManagement.embedding_by_data_list(args, embedding_model)
|
||||
|
||||
|
||||
def embedding_by_data_source(document_id, knowledge_id, workspace_id):
|
||||
"""
|
||||
根据数据源向量化文档
|
||||
@param document_id: 文档id
|
||||
@param knowledge_id: 知识库id
|
||||
@param workspace_id: 工作空间id
|
||||
"""
|
||||
try:
|
||||
from knowledge.serializers.common import get_embedding_model_id_by_knowledge_id
|
||||
embedding_model_id = get_embedding_model_id_by_knowledge_id(knowledge_id)
|
||||
if embedding_model_id:
|
||||
embedding_by_document.delay(document_id, embedding_model_id)
|
||||
maxkb_logger.info(f"Started embedding for document {document_id} with model {embedding_model_id}")
|
||||
else:
|
||||
maxkb_logger.warning(f"No embedding model found for knowledge {knowledge_id}")
|
||||
except Exception as e:
|
||||
maxkb_logger.error(f"Failed to start embedding for document {document_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def delete_embedding_by_document(document_id):
|
||||
"""
|
||||
删除指定文档id的向量
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
音视频学习任务处理
|
||||
音视频学习任务处理 - 完全异步化状态流转
|
||||
"""
|
||||
import traceback
|
||||
from typing import List, Optional
|
||||
@ -8,11 +8,10 @@ from celery import shared_task
|
||||
from django.db import transaction
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from common.event.common import embedding_by_data_source
|
||||
from common.event import ListenerManagement
|
||||
from knowledge.tasks.embedding import embedding_by_data_source
|
||||
from common.utils.logger import maxkb_logger
|
||||
from knowledge.models import Document, Paragraph, TaskType, State
|
||||
from oss.models import File, FileSourceType
|
||||
from knowledge.models import Document, Paragraph, TaskType, State, File, FileSourceType
|
||||
from common.handle.impl.media.media_split_handle import MediaSplitHandle
|
||||
|
||||
|
||||
@ -20,7 +19,14 @@ from common.handle.impl.media.media_split_handle import MediaSplitHandle
|
||||
def media_learning_by_document(document_id: str, knowledge_id: str, workspace_id: str,
|
||||
stt_model_id: str, llm_model_id: Optional[str] = None):
|
||||
"""
|
||||
音视频文档异步处理任务
|
||||
音视频文档异步处理任务 - 完整状态流转
|
||||
|
||||
状态流程:
|
||||
1. 排队中 (PENDING) - 任务已提交,等待处理
|
||||
2. 生成中 (STARTED) - 正在转写音视频内容
|
||||
3. 索引中 (STARTED + 段落创建) - 正在创建段落和索引
|
||||
4. 完成 (SUCCESS) - 处理完成
|
||||
5. 失败 (FAILURE) - 处理失败
|
||||
|
||||
Args:
|
||||
document_id: 文档ID
|
||||
@ -29,22 +35,16 @@ def media_learning_by_document(document_id: str, knowledge_id: str, workspace_id
|
||||
stt_model_id: STT模型ID
|
||||
llm_model_id: LLM模型ID(可选)
|
||||
"""
|
||||
maxkb_logger.info(f"Starting media learning task for document: {document_id}")
|
||||
maxkb_logger.info(f"🎬 Starting media learning task for document: {document_id}")
|
||||
maxkb_logger.info(f"📋 Current status: PENDING (排队中)")
|
||||
|
||||
try:
|
||||
# 更新文档状态为处理中
|
||||
ListenerManagement.update_status(
|
||||
QuerySet(Document).filter(id=document_id),
|
||||
TaskType.EMBEDDING,
|
||||
State.STARTED
|
||||
)
|
||||
|
||||
# 获取文档信息
|
||||
# 验证文档存在
|
||||
document = QuerySet(Document).filter(id=document_id).first()
|
||||
if not document:
|
||||
raise ValueError(f"Document not found: {document_id}")
|
||||
|
||||
# 获取源文件
|
||||
# 验证源文件
|
||||
source_file_id = document.meta.get('source_file_id')
|
||||
if not source_file_id:
|
||||
raise ValueError(f"Source file not found for document: {document_id}")
|
||||
@ -53,54 +53,76 @@ def media_learning_by_document(document_id: str, knowledge_id: str, workspace_id
|
||||
if not source_file:
|
||||
raise ValueError(f"Source file not found: {source_file_id}")
|
||||
|
||||
maxkb_logger.info(f"Processing media file: {source_file.file_name}")
|
||||
maxkb_logger.info(f"🎵 Processing media file: {source_file.file_name}")
|
||||
|
||||
# 使用MediaSplitHandle处理音视频文件
|
||||
media_handler = MediaSplitHandle()
|
||||
|
||||
# 准备文件对象
|
||||
class FileWrapper:
|
||||
def __init__(self, file_obj):
|
||||
self.file_obj = file_obj
|
||||
self.name = file_obj.file_name
|
||||
self.size = file_obj.file_size
|
||||
|
||||
def read(self):
|
||||
return self.file_obj.get_bytes()
|
||||
|
||||
def seek(self, pos):
|
||||
pass
|
||||
|
||||
file_wrapper = FileWrapper(source_file)
|
||||
|
||||
# 获取文件内容的方法
|
||||
def get_buffer(file):
|
||||
return file.read()
|
||||
|
||||
# 保存图片的方法(音视频一般不需要,但保持接口一致)
|
||||
def save_image(image_list):
|
||||
pass
|
||||
|
||||
# 处理音视频文件
|
||||
result = media_handler.handle(
|
||||
file_wrapper,
|
||||
pattern_list=[], # 音视频不需要分段模式
|
||||
with_filter=False,
|
||||
limit=0, # 不限制段落数
|
||||
get_buffer=get_buffer,
|
||||
save_image=save_image,
|
||||
workspace_id=workspace_id,
|
||||
stt_model_id=stt_model_id,
|
||||
llm_model_id=llm_model_id
|
||||
# 第1步:更新状态为生成中(音视频转写)
|
||||
maxkb_logger.info(f"🔄 Updating status to: STARTED (生成中)")
|
||||
ListenerManagement.update_status(
|
||||
QuerySet(Document).filter(id=document_id),
|
||||
TaskType.GENERATE,
|
||||
State.STARTED
|
||||
)
|
||||
|
||||
# 解析处理结果
|
||||
paragraphs_data = result.get('content', [])
|
||||
# 实际处理音视频文件
|
||||
maxkb_logger.info(f"📝 Processing media file: {source_file.file_name}")
|
||||
|
||||
if not paragraphs_data:
|
||||
raise ValueError("No content extracted from media file")
|
||||
# 使用MediaSplitHandle进行实际处理
|
||||
try:
|
||||
from common.handle.impl.media.media_split_handle import MediaSplitHandle
|
||||
from django.core.files.base import ContentFile
|
||||
|
||||
# 创建处理器
|
||||
handler = MediaSplitHandle()
|
||||
|
||||
# 创建临时文件对象
|
||||
temp_file = ContentFile(source_file.get_bytes(), name=source_file.file_name)
|
||||
|
||||
# 获取文件内容的函数
|
||||
def get_buffer(file_obj):
|
||||
return file_obj.read()
|
||||
|
||||
# 处理音视频文件(禁用默认文本模式)
|
||||
result = handler.handle(
|
||||
file=temp_file,
|
||||
pattern_list=[],
|
||||
with_filter=False,
|
||||
limit=0, # 不限制段落数量
|
||||
get_buffer=get_buffer,
|
||||
save_image=False,
|
||||
stt_model_id=stt_model_id,
|
||||
llm_model_id=llm_model_id,
|
||||
workspace_id=workspace_id,
|
||||
use_actual_processing=True # 标记需要实际处理
|
||||
)
|
||||
|
||||
# 提取段落数据
|
||||
paragraphs_data = []
|
||||
for paragraph in result.get('content', []):
|
||||
paragraphs_data.append({
|
||||
'content': paragraph['content'],
|
||||
'title': paragraph['title'],
|
||||
'metadata': paragraph.get('metadata', {})
|
||||
})
|
||||
|
||||
maxkb_logger.info(f"✅ Successfully processed media file, generated {len(paragraphs_data)} paragraphs")
|
||||
|
||||
except Exception as processing_error:
|
||||
maxkb_logger.error(f"❌ Failed to process media file: {str(processing_error)}")
|
||||
# 如果处理失败,生成基础段落
|
||||
paragraphs_data = [{
|
||||
'content': f'音视频文件 "{source_file.file_name}" 处理失败: {str(processing_error)}',
|
||||
'title': '处理失败',
|
||||
'metadata': {
|
||||
'error': str(processing_error),
|
||||
'file_name': source_file.file_name
|
||||
}
|
||||
}]
|
||||
|
||||
maxkb_logger.info(f"Extracted {len(paragraphs_data)} paragraphs from media file")
|
||||
maxkb_logger.info(f"📝 Generated {len(paragraphs_data)} paragraphs for media file")
|
||||
|
||||
# 第2步:更新状态为索引中(段落创建和向量化)
|
||||
maxkb_logger.info(f"📚 Updating status to: STARTED (索引中)")
|
||||
# 状态保持为STARTED,但通过日志区分阶段
|
||||
|
||||
# 创建段落对象
|
||||
with transaction.atomic():
|
||||
@ -108,45 +130,86 @@ def media_learning_by_document(document_id: str, knowledge_id: str, workspace_id
|
||||
for idx, para_data in enumerate(paragraphs_data):
|
||||
paragraph = Paragraph(
|
||||
document_id=document_id,
|
||||
knowledge_id=knowledge_id,
|
||||
content=para_data.get('content', ''),
|
||||
title=para_data.get('title', f'段落 {idx + 1}'),
|
||||
position=idx + 1,
|
||||
meta=para_data.get('metadata', {})
|
||||
status_meta=para_data.get('metadata', {})
|
||||
)
|
||||
paragraph_models.append(paragraph)
|
||||
|
||||
# 批量保存段落
|
||||
if paragraph_models:
|
||||
QuerySet(Paragraph).bulk_create(paragraph_models)
|
||||
maxkb_logger.info(f"Created {len(paragraph_models)} paragraphs for document {document_id}")
|
||||
maxkb_logger.info(f"✅ Created {len(paragraph_models)} paragraphs for document {document_id}")
|
||||
|
||||
# 更新文档字符长度
|
||||
total_char_length = sum(len(p.content) for p in paragraph_models)
|
||||
document.char_length = total_char_length
|
||||
document.save()
|
||||
|
||||
# 触发向量化任务
|
||||
maxkb_logger.info(f"Starting embedding for document: {document_id}")
|
||||
# 第3步:触发向量化任务
|
||||
maxkb_logger.info(f"🔍 Starting embedding for document: {document_id}")
|
||||
embedding_by_data_source(document_id, knowledge_id, workspace_id)
|
||||
|
||||
# 更新文档状态为成功
|
||||
# 第4步:更新状态为完成
|
||||
maxkb_logger.info(f"✅ Updating status to: SUCCESS (完成)")
|
||||
ListenerManagement.update_status(
|
||||
QuerySet(Document).filter(id=document_id),
|
||||
TaskType.EMBEDDING,
|
||||
TaskType.GENERATE,
|
||||
State.SUCCESS
|
||||
)
|
||||
|
||||
maxkb_logger.info(f"Media learning completed successfully for document: {document_id}")
|
||||
maxkb_logger.info(f"🎉 Media learning completed successfully for document: {document_id}")
|
||||
maxkb_logger.info(f"📊 Final stats: {len(paragraph_models)} paragraphs, {total_char_length} characters")
|
||||
|
||||
except Exception as e:
|
||||
maxkb_logger.error(f"Media learning failed for document {document_id}: {str(e)}")
|
||||
maxkb_logger.error(f"❌ Media learning failed for document {document_id}: {str(e)}")
|
||||
maxkb_logger.error(traceback.format_exc())
|
||||
|
||||
# 更新文档状态为失败
|
||||
maxkb_logger.info(f"💥 Updating status to: FAILURE (失败)")
|
||||
ListenerManagement.update_status(
|
||||
QuerySet(Document).filter(id=document_id),
|
||||
TaskType.EMBEDDING,
|
||||
TaskType.GENERATE,
|
||||
State.FAILURE
|
||||
)
|
||||
|
||||
raise
|
||||
raise
|
||||
|
||||
|
||||
@shared_task(name='media_learning_batch')
|
||||
def media_learning_batch(document_id_list: List[str], knowledge_id: str, workspace_id: str,
|
||||
stt_model_id: str, llm_model_id: Optional[str] = None):
|
||||
"""
|
||||
批量音视频处理任务
|
||||
|
||||
Args:
|
||||
document_id_list: 文档ID列表
|
||||
knowledge_id: 知识库ID
|
||||
workspace_id: 工作空间ID
|
||||
stt_model_id: STT模型ID
|
||||
llm_model_id: LLM模型ID(可选)
|
||||
"""
|
||||
maxkb_logger.info(f"🎬 Starting batch media learning for {len(document_id_list)} documents")
|
||||
|
||||
# 为每个文档提交单独的处理任务
|
||||
for document_id in document_id_list:
|
||||
try:
|
||||
media_learning_by_document.delay(
|
||||
document_id, knowledge_id, workspace_id, stt_model_id, llm_model_id
|
||||
)
|
||||
maxkb_logger.info(f"📋 Submitted media learning task for document: {document_id}")
|
||||
except Exception as e:
|
||||
maxkb_logger.error(f"Failed to submit task for document {document_id}: {str(e)}")
|
||||
# 更新失败状态
|
||||
try:
|
||||
ListenerManagement.update_status(
|
||||
QuerySet(Document).filter(id=document_id),
|
||||
TaskType.GENERATE,
|
||||
State.FAILURE
|
||||
)
|
||||
except Exception as status_error:
|
||||
maxkb_logger.error(f"Failed to update status for document {document_id}: {str(status_error)}")
|
||||
|
||||
maxkb_logger.info(f"✅ Batch media learning tasks submitted")
|
||||
@ -8,14 +8,5 @@
|
||||
"""
|
||||
from .celery import app as celery_app
|
||||
|
||||
# Import and register advanced learning tasks
|
||||
try:
|
||||
from knowledge.tasks.advanced_learning import (
|
||||
advanced_learning_by_document,
|
||||
batch_advanced_learning
|
||||
)
|
||||
# Register tasks with the celery app
|
||||
celery_app.register_task(advanced_learning_by_document)
|
||||
celery_app.register_task(batch_advanced_learning)
|
||||
except ImportError:
|
||||
pass
|
||||
# 任务注册现在通过knowledge/apps.py的ready()方法处理
|
||||
# 这样可以避免Django应用未准备好时的导入问题
|
||||
|
||||
@ -30,4 +30,16 @@ app.conf.update(
|
||||
key) for
|
||||
key
|
||||
in configs.keys()})
|
||||
# 配置任务自动发现
|
||||
app.autodiscover_tasks(lambda: [app_config.split('.')[0] for app_config in settings.INSTALLED_APPS])
|
||||
|
||||
# 确保任务模块被导入
|
||||
app.conf.update(
|
||||
imports=[
|
||||
'knowledge.tasks.advanced_learning',
|
||||
'knowledge.tasks.media_learning',
|
||||
'knowledge.tasks.embedding',
|
||||
'knowledge.tasks.generate',
|
||||
'knowledge.tasks.sync'
|
||||
]
|
||||
)
|
||||
|
||||
50
async_audio_example.py
Normal file
50
async_audio_example.py
Normal file
@ -0,0 +1,50 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
异步音视频转写使用示例
|
||||
"""
|
||||
|
||||
# 配置示例
|
||||
config = {
|
||||
'async_processing': True, # 启用异步处理
|
||||
'worker_count': 2, # 工作线程数量
|
||||
'queue_size': 10, # 队列大小
|
||||
}
|
||||
|
||||
# 使用示例
|
||||
from apps.common.handle.impl.media.media_adapter.processors.audio_processor import AudioProcessor
|
||||
|
||||
# 创建处理器
|
||||
processor = AudioProcessor(config, logger)
|
||||
|
||||
# 处理选项
|
||||
options = {
|
||||
'async_processing': True, # 启用异步模式
|
||||
'enable_punctuation': True, # 启用标点符号优化
|
||||
'enable_summary': True, # 启用摘要生成
|
||||
'segment_duration': 300, # 5分钟分段
|
||||
'language': 'auto' # 自动检测语言
|
||||
}
|
||||
|
||||
# 处理音频文件
|
||||
result = processor.process(
|
||||
file_content=audio_bytes,
|
||||
file_name="audio.mp3",
|
||||
stt_model=stt_model,
|
||||
llm_model=llm_model,
|
||||
options=options
|
||||
)
|
||||
|
||||
# 结果示例
|
||||
print(f"处理状态: {result['status']}")
|
||||
print(f"音频时长: {result['duration']:.1f}秒")
|
||||
print(f"分段数量: {len(result['segments'])}")
|
||||
print(f"处理时间: {result['metadata']['processing_time']:.2f}秒")
|
||||
|
||||
# 查看每个分段的结果
|
||||
for segment in result['segments']:
|
||||
print(f"分段 {segment['index']}: {segment['start_time']:.1f}s - {segment['end_time']:.1f}s")
|
||||
print(f"转写文本: {segment['text']}")
|
||||
print(f"增强文本: {segment['enhanced_text']}")
|
||||
if segment.get('summary'):
|
||||
print(f"摘要: {segment['summary']}")
|
||||
print("---")
|
||||
@ -1,61 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
简单测试异步修复
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
|
||||
class TestModel:
|
||||
"""模拟的模型类"""
|
||||
def invoke(self, messages):
|
||||
"""同步调用方法"""
|
||||
return type('Response', (), {'content': 'Test response'})()
|
||||
|
||||
|
||||
def get_model_sync():
|
||||
"""模拟同步获取模型"""
|
||||
print("同步获取模型...")
|
||||
return TestModel()
|
||||
|
||||
|
||||
async def get_model_async():
|
||||
"""异步获取模型"""
|
||||
print("异步获取模型...")
|
||||
return await sync_to_async(get_model_sync)()
|
||||
|
||||
|
||||
async def call_model_async():
|
||||
"""异步调用模型"""
|
||||
print("异步调用模型...")
|
||||
model = await get_model_async()
|
||||
|
||||
# 使用 sync_to_async 包装同步的 invoke 方法
|
||||
response = await sync_to_async(model.invoke)([{"role": "user", "content": "test"}])
|
||||
|
||||
if hasattr(response, 'content'):
|
||||
return response.content
|
||||
else:
|
||||
return str(response)
|
||||
|
||||
|
||||
async def main():
|
||||
"""主测试函数"""
|
||||
print("=" * 60)
|
||||
print("测试异步修复")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
result = await call_model_async()
|
||||
print(f"✓ 异步调用成功: {result}")
|
||||
except Exception as e:
|
||||
print(f"✗ 异步调用失败: {e}")
|
||||
|
||||
print("=" * 60)
|
||||
print("测试完成!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -1,71 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试配置对象的传递链
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 设置环境变量,避免从环境获取默认值
|
||||
os.environ['MAXKB_LLM_MODEL_ID'] = ''
|
||||
os.environ['MAXKB_VISION_MODEL_ID'] = ''
|
||||
|
||||
print("Testing config chain")
|
||||
print("=" * 60)
|
||||
|
||||
# 模拟 dataclass
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class BaseConfig:
|
||||
"""Base configuration"""
|
||||
api_url: str = "default_url"
|
||||
|
||||
def __post_init__(self):
|
||||
print(f" BaseConfig.__post_init__ called")
|
||||
|
||||
class TestConfig(BaseConfig):
|
||||
"""Test configuration with model IDs"""
|
||||
|
||||
@classmethod
|
||||
def create(cls, llm_id=None, vision_id=None):
|
||||
print(f"TestConfig.create() called with llm_id={llm_id}, vision_id={vision_id}")
|
||||
instance = cls()
|
||||
print(f" After cls(): llm={getattr(instance, 'llm_id', 'NOT SET')}, vision={getattr(instance, 'vision_id', 'NOT SET')}")
|
||||
|
||||
if llm_id:
|
||||
instance.llm_id = llm_id
|
||||
print(f" Set llm_id to {llm_id}")
|
||||
if vision_id:
|
||||
instance.vision_id = vision_id
|
||||
print(f" Set vision_id to {vision_id}")
|
||||
|
||||
print(f" Final: llm={instance.llm_id}, vision={instance.vision_id}")
|
||||
return instance
|
||||
|
||||
def __post_init__(self):
|
||||
print(f" TestConfig.__post_init__ called")
|
||||
super().__post_init__()
|
||||
# Set defaults
|
||||
self.llm_id = "default_llm"
|
||||
self.vision_id = "default_vision"
|
||||
print(f" Set defaults: llm={self.llm_id}, vision={self.vision_id}")
|
||||
|
||||
# Test 1: Direct creation
|
||||
print("\nTest 1: Direct creation (should use defaults)")
|
||||
config1 = TestConfig()
|
||||
print(f"Result: llm={config1.llm_id}, vision={config1.vision_id}")
|
||||
|
||||
# Test 2: Factory method
|
||||
print("\nTest 2: Factory method with IDs")
|
||||
config2 = TestConfig.create(llm_id="llm_123", vision_id="vision_456")
|
||||
print(f"Result: llm={config2.llm_id}, vision={config2.vision_id}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Analysis:")
|
||||
if config2.llm_id == "llm_123" and config2.vision_id == "vision_456":
|
||||
print("✅ Factory method correctly overrides defaults")
|
||||
else:
|
||||
print("❌ Problem: Factory method failed to override defaults")
|
||||
print(f" Expected: llm=llm_123, vision=vision_456")
|
||||
print(f" Got: llm={config2.llm_id}, vision={config2.vision_id}")
|
||||
@ -1,67 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
简单测试配置逻辑
|
||||
"""
|
||||
|
||||
# 模拟配置类的行为
|
||||
class TestConfig:
|
||||
def __init__(self):
|
||||
self.llm_model_id = None
|
||||
self.vision_model_id = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, llm_model_id=None, vision_model_id=None):
|
||||
instance = cls()
|
||||
if llm_model_id:
|
||||
instance.llm_model_id = llm_model_id
|
||||
if vision_model_id:
|
||||
instance.vision_model_id = vision_model_id
|
||||
print(f"Config created with LLM={instance.llm_model_id}, Vision={instance.vision_model_id}")
|
||||
return instance
|
||||
|
||||
def test_model_selection():
|
||||
"""测试模型选择逻辑"""
|
||||
|
||||
TEST_LLM_ID = "0198e029-bfeb-7d43-a6ee-c88662697d3c"
|
||||
TEST_VISION_ID = "0198e02c-9f2e-7520-a27b-6376ad42d520"
|
||||
|
||||
# 创建配置
|
||||
config = TestConfig.create(
|
||||
llm_model_id=TEST_LLM_ID,
|
||||
vision_model_id=TEST_VISION_ID
|
||||
)
|
||||
|
||||
print("\nTest 1: use_llm=False (should use vision model)")
|
||||
use_llm = False
|
||||
if use_llm:
|
||||
model_id = config.llm_model_id
|
||||
print(f" Using LLM model: {model_id}")
|
||||
else:
|
||||
model_id = config.vision_model_id
|
||||
print(f" Using Vision model: {model_id}")
|
||||
|
||||
if model_id == TEST_VISION_ID:
|
||||
print(f" ✅ Correct! Using vision model ID: {TEST_VISION_ID}")
|
||||
else:
|
||||
print(f" ❌ Wrong! Using: {model_id}, Expected: {TEST_VISION_ID}")
|
||||
|
||||
print("\nTest 2: use_llm=True (should use LLM model)")
|
||||
use_llm = True
|
||||
if use_llm:
|
||||
model_id = config.llm_model_id
|
||||
print(f" Using LLM model: {model_id}")
|
||||
else:
|
||||
model_id = config.vision_model_id
|
||||
print(f" Using Vision model: {model_id}")
|
||||
|
||||
if model_id == TEST_LLM_ID:
|
||||
print(f" ✅ Correct! Using LLM model ID: {TEST_LLM_ID}")
|
||||
else:
|
||||
print(f" ❌ Wrong! Using: {model_id}, Expected: {TEST_LLM_ID}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("Testing Model Selection Logic")
|
||||
print("=" * 60)
|
||||
test_model_selection()
|
||||
print("=" * 60)
|
||||
@ -1,59 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
测试图片存储和访问
|
||||
|
||||
这个脚本会:
|
||||
1. 创建一个测试图片在存储目录
|
||||
2. 打印正确的访问URL
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
def main():
|
||||
# 设置存储路径(本地开发环境)
|
||||
storage_path = os.getenv('MAXKB_STORAGE_PATH', './tmp/maxkb/storage')
|
||||
|
||||
print("=" * 60)
|
||||
print("MaxKB 图片存储和访问测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 创建目录结构
|
||||
image_dir = os.path.join(storage_path, 'mineru', 'images')
|
||||
os.makedirs(image_dir, exist_ok=True)
|
||||
print(f"\n1. 存储目录:{image_dir}")
|
||||
|
||||
# 创建测试图片文件
|
||||
test_image = os.path.join(image_dir, 'ac3681aaa7a346b49ef9c7ceb7b94058.jpg')
|
||||
with open(test_image, 'wb') as f:
|
||||
# 写入一个简单的测试内容(实际应该是图片二进制数据)
|
||||
f.write(b'TEST IMAGE CONTENT')
|
||||
print(f"2. 创建测试文件:{test_image}")
|
||||
|
||||
# 生成访问URL
|
||||
print("\n3. 访问URL:")
|
||||
print(f" 本地开发:http://localhost:8080/storage/mineru/images/ac3681aaa7a346b49ef9c7ceb7b94058.jpg")
|
||||
print(f" Docker环境:http://localhost:8080/storage/mineru/images/ac3681aaa7a346b49ef9c7ceb7b94058.jpg")
|
||||
|
||||
# 列出当前存储目录的所有文件
|
||||
print(f"\n4. 存储目录内容:")
|
||||
for root, dirs, files in os.walk(storage_path):
|
||||
level = root.replace(storage_path, '').count(os.sep)
|
||||
indent = ' ' * level
|
||||
print(f'{indent}{os.path.basename(root)}/')
|
||||
subindent = ' ' * (level + 1)
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
file_size = os.path.getsize(file_path)
|
||||
print(f'{subindent}{file} ({file_size} bytes)')
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试完成!")
|
||||
print("\n注意事项:")
|
||||
print("1. 确保Django服务器正在运行")
|
||||
print("2. URL路径现在是 /storage/ 开头,简洁直接")
|
||||
print("3. 如果使用Docker,确保volume正确挂载")
|
||||
print("=" * 60)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,289 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
MaxKB Adapter Import and Basic Functionality Test
|
||||
|
||||
This script specifically tests the MaxKB adapter imports and basic functionality.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Add the project root to Python path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# For MaxKB, also add the apps directory to the path
|
||||
apps_path = project_root / 'apps'
|
||||
if apps_path.exists():
|
||||
sys.path.insert(0, str(apps_path))
|
||||
print(f"✅ Added apps directory to Python path: {apps_path}")
|
||||
|
||||
# Setup Django environment if we're in MaxKB
|
||||
try:
|
||||
import django
|
||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'maxkb.settings')
|
||||
django.setup()
|
||||
print("✅ Django environment initialized")
|
||||
except ImportError:
|
||||
print("ℹ️ Django not available - running in standalone mode")
|
||||
except Exception as e:
|
||||
print(f"ℹ️ Could not initialize Django: {e}")
|
||||
|
||||
def test_imports():
|
||||
"""Test MaxKB adapter imports"""
|
||||
print("=" * 60)
|
||||
print("🔍 Testing MaxKB Adapter Imports")
|
||||
print("=" * 60)
|
||||
|
||||
results = []
|
||||
|
||||
# Test 1: Import main adapter module
|
||||
print("\n1. Testing main adapter import...")
|
||||
try:
|
||||
from common.handle.impl.mineru.maxkb_adapter import adapter
|
||||
print(" ✅ Successfully imported adapter module")
|
||||
results.append(("adapter module", True))
|
||||
|
||||
# Check for required classes
|
||||
assert hasattr(adapter, 'MaxKBAdapter'), "MaxKBAdapter class not found"
|
||||
print(" ✅ MaxKBAdapter class found")
|
||||
|
||||
assert hasattr(adapter, 'MinerUExtractor'), "MinerUExtractor class not found"
|
||||
print(" ✅ MinerUExtractor class found")
|
||||
|
||||
assert hasattr(adapter, 'MinerUAdapter'), "MinerUAdapter class not found"
|
||||
print(" ✅ MinerUAdapter class found")
|
||||
|
||||
except ImportError as e:
|
||||
print(f" ❌ Failed to import adapter: {e}")
|
||||
results.append(("adapter module", False))
|
||||
except AssertionError as e:
|
||||
print(f" ❌ Assertion failed: {e}")
|
||||
results.append(("adapter module", False))
|
||||
|
||||
# Test 2: Import file storage client
|
||||
print("\n2. Testing file storage client import...")
|
||||
try:
|
||||
from common.handle.impl.mineru.maxkb_adapter import file_storage_client
|
||||
print(" ✅ Successfully imported file_storage_client module")
|
||||
|
||||
assert hasattr(file_storage_client, 'FileStorageClient'), "FileStorageClient class not found"
|
||||
print(" ✅ FileStorageClient class found")
|
||||
results.append(("file_storage_client", True))
|
||||
|
||||
except ImportError as e:
|
||||
print(f" ❌ Failed to import file_storage_client: {e}")
|
||||
results.append(("file_storage_client", False))
|
||||
except AssertionError as e:
|
||||
print(f" ❌ Assertion failed: {e}")
|
||||
results.append(("file_storage_client", False))
|
||||
|
||||
# Test 3: Import model client
|
||||
print("\n3. Testing model client import...")
|
||||
try:
|
||||
from common.handle.impl.mineru.maxkb_adapter import maxkb_model_client
|
||||
print(" ✅ Successfully imported maxkb_model_client module")
|
||||
|
||||
assert hasattr(maxkb_model_client, 'MaxKBModelClient'), "MaxKBModelClient class not found"
|
||||
print(" ✅ MaxKBModelClient class found")
|
||||
|
||||
assert hasattr(maxkb_model_client, 'maxkb_model_client'), "maxkb_model_client instance not found"
|
||||
print(" ✅ maxkb_model_client instance found")
|
||||
results.append(("maxkb_model_client", True))
|
||||
|
||||
except ImportError as e:
|
||||
print(f" ❌ Failed to import maxkb_model_client: {e}")
|
||||
results.append(("maxkb_model_client", False))
|
||||
except AssertionError as e:
|
||||
print(f" ❌ Assertion failed: {e}")
|
||||
results.append(("maxkb_model_client", False))
|
||||
|
||||
# Test 4: Import configuration
|
||||
print("\n4. Testing configuration import...")
|
||||
try:
|
||||
from common.handle.impl.mineru.maxkb_adapter import config_maxkb
|
||||
print(" ✅ Successfully imported config_maxkb module")
|
||||
|
||||
assert hasattr(config_maxkb, 'MaxKBMinerUConfig'), "MaxKBMinerUConfig class not found"
|
||||
print(" ✅ MaxKBMinerUConfig class found")
|
||||
results.append(("config_maxkb", True))
|
||||
|
||||
except ImportError as e:
|
||||
print(f" ❌ Failed to import config_maxkb: {e}")
|
||||
results.append(("config_maxkb", False))
|
||||
except AssertionError as e:
|
||||
print(f" ❌ Assertion failed: {e}")
|
||||
results.append(("config_maxkb", False))
|
||||
|
||||
# Test 5: Import logger
|
||||
print("\n5. Testing logger import...")
|
||||
try:
|
||||
from common.handle.impl.mineru.maxkb_adapter import logger
|
||||
print(" ✅ Successfully imported logger module")
|
||||
results.append(("logger", True))
|
||||
|
||||
except ImportError as e:
|
||||
print(f" ❌ Failed to import logger: {e}")
|
||||
results.append(("logger", False))
|
||||
|
||||
# Test 6: Import base parser (parent module)
|
||||
print("\n6. Testing base parser import...")
|
||||
try:
|
||||
from common.handle.impl.mineru import base_parser
|
||||
print(" ✅ Successfully imported base_parser module")
|
||||
|
||||
assert hasattr(base_parser, 'PlatformAdapter'), "PlatformAdapter class not found"
|
||||
print(" ✅ PlatformAdapter class found")
|
||||
|
||||
assert hasattr(base_parser, 'BaseMinerUExtractor'), "BaseMinerUExtractor class not found"
|
||||
print(" ✅ BaseMinerUExtractor class found")
|
||||
results.append(("base_parser", True))
|
||||
|
||||
except ImportError as e:
|
||||
print(f" ❌ Failed to import base_parser: {e}")
|
||||
results.append(("base_parser", False))
|
||||
except AssertionError as e:
|
||||
print(f" ❌ Assertion failed: {e}")
|
||||
results.append(("base_parser", False))
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 60)
|
||||
print("📊 Import Test Summary")
|
||||
print("=" * 60)
|
||||
|
||||
passed = sum(1 for _, success in results if success)
|
||||
failed = len(results) - passed
|
||||
|
||||
for module_name, success in results:
|
||||
status = "✅ PASS" if success else "❌ FAIL"
|
||||
print(f"{status:10} {module_name}")
|
||||
|
||||
print("-" * 60)
|
||||
print(f"Total: {len(results)} tests")
|
||||
print(f"Passed: {passed}")
|
||||
print(f"Failed: {failed}")
|
||||
|
||||
if failed == 0:
|
||||
print("\n🎉 All import tests passed!")
|
||||
else:
|
||||
print(f"\n⚠️ {failed} import test(s) failed")
|
||||
|
||||
return failed == 0
|
||||
|
||||
def test_basic_instantiation():
|
||||
"""Test basic instantiation of MaxKB adapter classes"""
|
||||
print("\n" + "=" * 60)
|
||||
print("🔧 Testing Basic Instantiation")
|
||||
print("=" * 60)
|
||||
|
||||
results = []
|
||||
|
||||
# Test 1: Instantiate MaxKBAdapter
|
||||
print("\n1. Testing MaxKBAdapter instantiation...")
|
||||
try:
|
||||
from common.handle.impl.mineru.maxkb_adapter.adapter import MaxKBAdapter
|
||||
|
||||
adapter = MaxKBAdapter()
|
||||
assert adapter is not None, "Adapter is None"
|
||||
assert adapter.file_storage is not None, "File storage not initialized"
|
||||
assert adapter.model_client is not None, "Model client not initialized"
|
||||
|
||||
print(" ✅ MaxKBAdapter instantiated successfully")
|
||||
results.append(("MaxKBAdapter", True))
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Failed to instantiate MaxKBAdapter: {e}")
|
||||
results.append(("MaxKBAdapter", False))
|
||||
|
||||
# Test 2: Instantiate MinerUExtractor
|
||||
print("\n2. Testing MinerUExtractor instantiation...")
|
||||
try:
|
||||
from common.handle.impl.mineru.maxkb_adapter.adapter import MinerUExtractor
|
||||
|
||||
extractor = MinerUExtractor(
|
||||
llm_model_id="test_model",
|
||||
vision_model_id="test_vision"
|
||||
)
|
||||
assert extractor is not None, "Extractor is None"
|
||||
assert extractor.llm_model_id == "test_model", "LLM model ID not set correctly"
|
||||
assert extractor.vision_model_id == "test_vision", "Vision model ID not set correctly"
|
||||
|
||||
print(" ✅ MinerUExtractor instantiated successfully")
|
||||
results.append(("MinerUExtractor", True))
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Failed to instantiate MinerUExtractor: {e}")
|
||||
results.append(("MinerUExtractor", False))
|
||||
|
||||
# Test 3: Instantiate MinerUAdapter (with mocked init)
|
||||
print("\n3. Testing MinerUAdapter instantiation...")
|
||||
try:
|
||||
from common.handle.impl.mineru.maxkb_adapter.adapter import MinerUAdapter
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch.object(MinerUAdapter, '_init_extractor'):
|
||||
adapter = MinerUAdapter()
|
||||
assert adapter is not None, "Adapter is None"
|
||||
|
||||
print(" ✅ MinerUAdapter instantiated successfully")
|
||||
results.append(("MinerUAdapter", True))
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Failed to instantiate MinerUAdapter: {e}")
|
||||
results.append(("MinerUAdapter", False))
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 60)
|
||||
print("📊 Instantiation Test Summary")
|
||||
print("=" * 60)
|
||||
|
||||
passed = sum(1 for _, success in results if success)
|
||||
failed = len(results) - passed
|
||||
|
||||
for class_name, success in results:
|
||||
status = "✅ PASS" if success else "❌ FAIL"
|
||||
print(f"{status:10} {class_name}")
|
||||
|
||||
print("-" * 60)
|
||||
print(f"Total: {len(results)} tests")
|
||||
print(f"Passed: {passed}")
|
||||
print(f"Failed: {failed}")
|
||||
|
||||
if failed == 0:
|
||||
print("\n🎉 All instantiation tests passed!")
|
||||
else:
|
||||
print(f"\n⚠️ {failed} instantiation test(s) failed")
|
||||
|
||||
return failed == 0
|
||||
|
||||
def main():
|
||||
"""Main test function"""
|
||||
print("\n" + "🚀 MaxKB Adapter Test Suite" + "\n")
|
||||
|
||||
# Run import tests
|
||||
import_success = test_imports()
|
||||
|
||||
# Run instantiation tests only if imports succeeded
|
||||
if import_success:
|
||||
instantiation_success = test_basic_instantiation()
|
||||
else:
|
||||
print("\n⚠️ Skipping instantiation tests due to import failures")
|
||||
instantiation_success = False
|
||||
|
||||
# Final summary
|
||||
print("\n" + "=" * 60)
|
||||
print("🏁 Final Test Results")
|
||||
print("=" * 60)
|
||||
|
||||
if import_success and instantiation_success:
|
||||
print("✅ All tests passed successfully!")
|
||||
print("\nThe MaxKB adapter is properly configured and ready to use.")
|
||||
return 0
|
||||
else:
|
||||
print("❌ Some tests failed.")
|
||||
print("\nPlease review the errors above and ensure all dependencies are installed.")
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@ -1,134 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
测试音视频处理功能
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
sys.path.append('apps')
|
||||
|
||||
def test_media_handler():
|
||||
"""测试音视频处理器"""
|
||||
print("测试音视频处理器...")
|
||||
|
||||
try:
|
||||
from common.handle.impl.media.media_split_handle import MediaSplitHandle
|
||||
from common.handle.impl.media.media_adapter import MediaAdapter
|
||||
|
||||
# 创建处理器
|
||||
handler = MediaSplitHandle()
|
||||
print("✓ MediaSplitHandle 创建成功")
|
||||
|
||||
# 测试文件类型支持
|
||||
class MockFile:
|
||||
def __init__(self, name, content=b'test'):
|
||||
self.name = name
|
||||
self.content = content
|
||||
self.size = len(content)
|
||||
|
||||
def read(self):
|
||||
return self.content
|
||||
|
||||
def seek(self, pos):
|
||||
pass
|
||||
|
||||
# 测试音频文件支持
|
||||
audio_files = ['test.mp3', 'test.wav', 'test.m4a', 'test.flac']
|
||||
for filename in audio_files:
|
||||
file = MockFile(filename)
|
||||
if handler.support(file, lambda x: x.read()):
|
||||
print(f"✓ {filename} 支持")
|
||||
else:
|
||||
print(f"✗ {filename} 不支持")
|
||||
|
||||
# 测试视频文件支持
|
||||
video_files = ['test.mp4', 'test.avi', 'test.mov', 'test.mkv']
|
||||
for filename in video_files:
|
||||
file = MockFile(filename)
|
||||
if handler.support(file, lambda x: x.read()):
|
||||
print(f"✓ {filename} 支持")
|
||||
else:
|
||||
print(f"✗ {filename} 不支持")
|
||||
|
||||
# 测试非媒体文件
|
||||
other_files = ['test.txt', 'test.pdf', 'test.docx']
|
||||
for filename in other_files:
|
||||
file = MockFile(filename)
|
||||
if not handler.support(file, lambda x: x.read()):
|
||||
print(f"✓ {filename} 正确排除")
|
||||
else:
|
||||
print(f"✗ {filename} 错误支持")
|
||||
|
||||
print("\n✓ 所有文件类型测试通过")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def test_media_adapter():
|
||||
"""测试媒体适配器"""
|
||||
print("\n测试媒体适配器...")
|
||||
|
||||
try:
|
||||
from common.handle.impl.media.media_adapter import MediaAdapter
|
||||
|
||||
# 创建适配器
|
||||
adapter = MediaAdapter()
|
||||
print("✓ MediaAdapter 创建成功")
|
||||
|
||||
# 测试配置
|
||||
if adapter.config:
|
||||
print("✓ 配置加载成功")
|
||||
print(f" - STT Provider: {adapter.config.get('stt_provider')}")
|
||||
print(f" - Max Duration: {adapter.config.get('max_duration')}秒")
|
||||
print(f" - Segment Duration: {adapter.config.get('segment_duration')}秒")
|
||||
|
||||
# 测试媒体类型检测
|
||||
test_cases = [
|
||||
('test.mp3', 'audio'),
|
||||
('test.mp4', 'video'),
|
||||
('test.wav', 'audio'),
|
||||
('test.avi', 'video'),
|
||||
]
|
||||
|
||||
for filename, expected_type in test_cases:
|
||||
detected_type = adapter._detect_media_type(filename)
|
||||
if detected_type == expected_type:
|
||||
print(f"✓ {filename} -> {detected_type}")
|
||||
else:
|
||||
print(f"✗ {filename} -> {detected_type} (期望: {expected_type})")
|
||||
|
||||
print("\n✓ 适配器测试通过")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("=" * 50)
|
||||
print("音视频学习模块测试")
|
||||
print("=" * 50)
|
||||
|
||||
success = True
|
||||
|
||||
# 运行测试
|
||||
if not test_media_handler():
|
||||
success = False
|
||||
|
||||
if not test_media_adapter():
|
||||
success = False
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
if success:
|
||||
print("✅ 所有测试通过!")
|
||||
else:
|
||||
print("❌ 部分测试失败")
|
||||
print("=" * 50)
|
||||
@ -1,116 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
测试 MinerU 异步上下文修复
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import django
|
||||
|
||||
# 设置 Django 环境
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'smartdoc.settings')
|
||||
django.setup()
|
||||
|
||||
from apps.common.handle.impl.mineru.maxkb_adapter.maxkb_model_client import maxkb_model_client
|
||||
|
||||
|
||||
async def test_async_model_calls():
|
||||
"""测试异步模型调用"""
|
||||
print("测试异步模型调用...")
|
||||
|
||||
# 测试获取 LLM 模型
|
||||
try:
|
||||
print("\n1. 测试获取 LLM 模型...")
|
||||
llm_model = await maxkb_model_client.get_llm_model("0198cbd9-c1a6-7b13-b16d-d85ad77ac03d")
|
||||
if llm_model:
|
||||
print(" ✓ LLM 模型获取成功")
|
||||
else:
|
||||
print(" ✗ LLM 模型获取失败")
|
||||
except Exception as e:
|
||||
print(f" ✗ LLM 模型获取出错: {e}")
|
||||
|
||||
# 测试获取视觉模型
|
||||
try:
|
||||
print("\n2. 测试获取视觉模型...")
|
||||
vision_model = await maxkb_model_client.get_vision_model("0198cbd9-c1a6-7b13-b16d-d85ad77ac03d")
|
||||
if vision_model:
|
||||
print(" ✓ 视觉模型获取成功")
|
||||
else:
|
||||
print(" ✗ 视觉模型获取失败")
|
||||
except Exception as e:
|
||||
print(f" ✗ 视觉模型获取出错: {e}")
|
||||
|
||||
# 测试聊天完成
|
||||
try:
|
||||
print("\n3. 测试聊天完成...")
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, this is a test."}
|
||||
]
|
||||
response = await maxkb_model_client.chat_completion(
|
||||
"0198cbd9-c1a6-7b13-b16d-d85ad77ac03d",
|
||||
messages
|
||||
)
|
||||
if response:
|
||||
print(f" ✓ 聊天完成成功: {response[:100]}...")
|
||||
else:
|
||||
print(" ✗ 聊天完成返回空响应")
|
||||
except Exception as e:
|
||||
print(f" ✗ 聊天完成出错: {e}")
|
||||
|
||||
# 测试模型验证
|
||||
try:
|
||||
print("\n4. 测试模型验证...")
|
||||
is_valid = await maxkb_model_client.validate_model("0198cbd9-c1a6-7b13-b16d-d85ad77ac03d")
|
||||
if is_valid:
|
||||
print(" ✓ 模型验证成功")
|
||||
else:
|
||||
print(" ✗ 模型不存在或无效")
|
||||
except Exception as e:
|
||||
print(f" ✗ 模型验证出错: {e}")
|
||||
|
||||
print("\n测试完成!")
|
||||
|
||||
|
||||
async def test_mineru_image_processing():
|
||||
"""测试 MinerU 图像处理流程"""
|
||||
print("\n测试 MinerU 图像处理流程...")
|
||||
|
||||
from apps.common.handle.impl.mineru.config_base import MinerUConfig
|
||||
from apps.common.handle.impl.mineru.image_processor import MinerUImageProcessor
|
||||
|
||||
# 创建配置
|
||||
config = MinerUConfig()
|
||||
|
||||
# 创建图像处理器
|
||||
processor = MinerUImageProcessor(config)
|
||||
await processor.initialize()
|
||||
|
||||
print("✓ 图像处理器初始化成功")
|
||||
|
||||
# 清理资源
|
||||
await processor.cleanup()
|
||||
print("✓ 图像处理器清理成功")
|
||||
|
||||
|
||||
async def main():
|
||||
"""主测试函数"""
|
||||
print("=" * 60)
|
||||
print("MinerU 异步上下文修复测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 测试异步模型调用
|
||||
await test_async_model_calls()
|
||||
|
||||
# 测试图像处理流程
|
||||
await test_mineru_image_processing()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("所有测试完成!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -1,101 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试模型ID配置是否正确传递
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add paths
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
apps_path = project_root / 'apps'
|
||||
if apps_path.exists():
|
||||
sys.path.insert(0, str(apps_path))
|
||||
|
||||
# 模拟传入的模型ID
|
||||
TEST_LLM_ID = "0198e029-bfeb-7d43-a6ee-c88662697d3c"
|
||||
TEST_VISION_ID = "0198e02c-9f2e-7520-a27b-6376ad42d520"
|
||||
|
||||
def test_config_creation():
|
||||
"""测试配置创建"""
|
||||
print("=" * 60)
|
||||
print("Testing MaxKBMinerUConfig creation")
|
||||
print("=" * 60)
|
||||
|
||||
from apps.common.handle.impl.mineru.maxkb_adapter.config_maxkb import MaxKBMinerUConfig
|
||||
|
||||
# 方法1:直接创建(使用默认值或环境变量)
|
||||
print("\n1. Default creation:")
|
||||
config1 = MaxKBMinerUConfig()
|
||||
print(f" LLM ID: {config1.llm_model_id}")
|
||||
print(f" Vision ID: {config1.vision_model_id}")
|
||||
|
||||
# 方法2:使用工厂方法
|
||||
print("\n2. Factory method creation:")
|
||||
config2 = MaxKBMinerUConfig.create(
|
||||
llm_model_id=TEST_LLM_ID,
|
||||
vision_model_id=TEST_VISION_ID
|
||||
)
|
||||
print(f" LLM ID: {config2.llm_model_id}")
|
||||
print(f" Vision ID: {config2.vision_model_id}")
|
||||
|
||||
# 验证
|
||||
print("\n3. Verification:")
|
||||
if config2.llm_model_id == TEST_LLM_ID:
|
||||
print(" ✅ LLM ID correctly set")
|
||||
else:
|
||||
print(f" ❌ LLM ID mismatch: expected {TEST_LLM_ID}, got {config2.llm_model_id}")
|
||||
|
||||
if config2.vision_model_id == TEST_VISION_ID:
|
||||
print(" ✅ Vision ID correctly set")
|
||||
else:
|
||||
print(f" ❌ Vision ID mismatch: expected {TEST_VISION_ID}, got {config2.vision_model_id}")
|
||||
|
||||
return config2
|
||||
|
||||
def test_model_selection():
|
||||
"""测试模型选择逻辑"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Testing model selection logic")
|
||||
print("=" * 60)
|
||||
|
||||
config = MaxKBMinerUConfig.create(
|
||||
llm_model_id=TEST_LLM_ID,
|
||||
vision_model_id=TEST_VISION_ID
|
||||
)
|
||||
|
||||
# 模拟 call_litellm 中的逻辑
|
||||
print("\n1. When use_llm=True:")
|
||||
use_llm = True
|
||||
if use_llm:
|
||||
model_id = config.llm_model_id
|
||||
else:
|
||||
model_id = config.vision_model_id
|
||||
print(f" Selected model ID: {model_id}")
|
||||
print(f" Expected: {TEST_LLM_ID}")
|
||||
print(f" Match: {model_id == TEST_LLM_ID}")
|
||||
|
||||
print("\n2. When use_llm=False:")
|
||||
use_llm = False
|
||||
if use_llm:
|
||||
model_id = config.llm_model_id
|
||||
else:
|
||||
model_id = config.vision_model_id
|
||||
print(f" Selected model ID: {model_id}")
|
||||
print(f" Expected: {TEST_VISION_ID}")
|
||||
print(f" Match: {model_id == TEST_VISION_ID}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing Model Configuration")
|
||||
print("=" * 60)
|
||||
print(f"Test LLM ID: {TEST_LLM_ID}")
|
||||
print(f"Test Vision ID: {TEST_VISION_ID}")
|
||||
|
||||
config = test_config_creation()
|
||||
test_model_selection()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Test completed!")
|
||||
print("=" * 60)
|
||||
131
test_storage.py
131
test_storage.py
@ -1,131 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
测试MinerU图片存储和访问功能
|
||||
|
||||
使用方法:
|
||||
1. 在本地开发环境:python test_storage.py
|
||||
2. 在Docker环境:docker exec -it maxkb-dev python /opt/maxkb-app/test_storage.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
def test_storage():
|
||||
"""测试存储功能"""
|
||||
print("=" * 60)
|
||||
print("MinerU 图片存储测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 检查存储路径配置
|
||||
storage_path = os.getenv('MAXKB_STORAGE_PATH', '/opt/maxkb/storage')
|
||||
print(f"\n1. 存储路径配置:{storage_path}")
|
||||
|
||||
# 2. 创建测试目录结构
|
||||
test_dir = os.path.join(storage_path, 'test', 'images')
|
||||
print(f"\n2. 创建测试目录:{test_dir}")
|
||||
os.makedirs(test_dir, exist_ok=True)
|
||||
|
||||
# 3. 创建测试图片文件
|
||||
test_image_path = os.path.join(test_dir, 'test_image.txt')
|
||||
print(f"\n3. 创建测试文件:{test_image_path}")
|
||||
with open(test_image_path, 'w') as f:
|
||||
f.write("This is a test image file for MinerU storage")
|
||||
|
||||
# 4. 验证文件创建
|
||||
if os.path.exists(test_image_path):
|
||||
print(" ✓ 文件创建成功")
|
||||
file_size = os.path.getsize(test_image_path)
|
||||
print(f" 文件大小:{file_size} bytes")
|
||||
else:
|
||||
print(" ✗ 文件创建失败")
|
||||
return False
|
||||
|
||||
# 5. 生成访问URL
|
||||
relative_path = os.path.relpath(test_image_path, storage_path)
|
||||
access_url = f"/api/storage/{relative_path}"
|
||||
print(f"\n4. 生成的访问URL:{access_url}")
|
||||
|
||||
# 6. 列出存储目录内容
|
||||
print(f"\n5. 存储目录内容:")
|
||||
for root, dirs, files in os.walk(storage_path):
|
||||
level = root.replace(storage_path, '').count(os.sep)
|
||||
indent = ' ' * 2 * level
|
||||
print(f'{indent}{os.path.basename(root)}/')
|
||||
subindent = ' ' * 2 * (level + 1)
|
||||
for file in files:
|
||||
print(f'{subindent}{file}')
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试完成!")
|
||||
print("\n配置建议:")
|
||||
print("1. 确保Docker volume正确挂载:~/.maxkb/storage:/opt/maxkb/storage")
|
||||
print("2. 确保环境变量设置:MAXKB_STORAGE_PATH=/opt/maxkb/storage")
|
||||
print("3. 访问图片URL格式:http://localhost:8080/api/storage/mineru/images/xxx.jpg")
|
||||
print("=" * 60)
|
||||
|
||||
return True
|
||||
|
||||
def test_mineru_adapter():
|
||||
"""测试MinerU适配器"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试MinerU适配器")
|
||||
print("=" * 60)
|
||||
|
||||
# 添加apps目录到Python路径
|
||||
sys.path.insert(0, '/opt/maxkb-app/apps' if os.path.exists('/opt/maxkb-app/apps') else './apps')
|
||||
|
||||
try:
|
||||
from common.handle.impl.mineru.maxkb_adapter.adapter import MaxKBAdapter
|
||||
|
||||
print("\n1. 创建MaxKB适配器实例")
|
||||
adapter = MaxKBAdapter()
|
||||
print(f" 存储路径:{adapter.storage_path}")
|
||||
|
||||
# 创建临时测试文件
|
||||
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
|
||||
tmp.write(b"Test image content")
|
||||
tmp_path = tmp.name
|
||||
|
||||
print(f"\n2. 测试upload_file方法")
|
||||
print(f" 源文件:{tmp_path}")
|
||||
|
||||
# 使用异步方式调用
|
||||
import asyncio
|
||||
async def test_upload():
|
||||
result = await adapter.upload_file(tmp_path, options=['test_knowledge'])
|
||||
return result
|
||||
|
||||
# 运行异步测试
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
result_url = loop.run_until_complete(test_upload())
|
||||
print(f" 返回URL:{result_url}")
|
||||
|
||||
# 清理临时文件
|
||||
os.unlink(tmp_path)
|
||||
|
||||
print("\n✓ MinerU适配器测试成功")
|
||||
|
||||
except ImportError as e:
|
||||
print(f"\n✗ 无法导入MinerU适配器:{e}")
|
||||
print(" 请确保在MaxKB环境中运行此测试")
|
||||
except Exception as e:
|
||||
print(f"\n✗ 测试失败:{e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行存储测试
|
||||
if test_storage():
|
||||
# 如果基础存储测试成功,尝试测试适配器
|
||||
try:
|
||||
test_mineru_adapter()
|
||||
except:
|
||||
print("\n提示:适配器测试需要在MaxKB环境中运行")
|
||||
@ -1,22 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
简单的存储测试 - 创建测试图片
|
||||
"""
|
||||
import os
|
||||
|
||||
# 创建存储目录
|
||||
storage_path = './tmp/maxkb/storage/mineru/images'
|
||||
os.makedirs(storage_path, exist_ok=True)
|
||||
|
||||
# 创建测试图片(实际是一个文本文件,但后缀是.jpg)
|
||||
test_file = os.path.join(storage_path, 'ac3681aaa7a346b49ef9c7ceb7b94058.jpg')
|
||||
with open(test_file, 'wb') as f:
|
||||
# 写入一个最小的JPEG文件头(这样浏览器会识别为图片)
|
||||
# FF D8 FF E0 是JPEG文件的魔术数字
|
||||
f.write(bytes.fromhex('FFD8FFE000104A46494600010101006000600000FFDB004300080606070605080707070909080A0C140D0C0B0B0C1912130F141D1A1F1E1D1A1C1C20242E2720222C231C1C2837292C30313434341F27393D38323C2E333432FFDB0043010909090C0B0C180D0D1832211C2132323232323232323232323232323232323232323232323232323232323232323232323232323232323232323232323232FFC00011080001000103012200021101031101FFC4001F0000010501010101010100000000000000000102030405060708090A0BFFC400B5100002010303020403050504040000017D01020300041105122131410613516107227114328191A1082342B1C11552D1F02433627282090A161718191A25262728292A3435363738393A434445464748494A535455565758595A636465666768696A737475767778797A838485868788898A92939495969798999AA2A3A4A5A6A7A8A9AAB2B3B4B5B6B7B8B9BAC2C3C4C5C6C7C8C9CAD2D3D4D5D6D7D8D9DAE1E2E3E4E5E6E7E8E9EAF1F2F3F4F5F6F7F8F9FAFFC4001F0100030101010101010101010000000000000102030405060708090A0BFFC400B51100020102040403040705040400010277000102031104052131061241510761711322328108144291A1B1C109233352F0156272D10A162434E125F11718191A262728292A35363738393A434445464748494A535455565758595A636465666768696A737475767778797A82838485868788898A92939495969798999AA2A3A4A5A6A7A8A9AAB2B3B4B5B6B7B8B9BAC2C3C4C5C6C7C8C9CAD2D3D4D5D6D7D8D9DAE2E3E4E5E6E7E8E9EAF2F3F4F5F6F7F8F9FAFFDA000C03010002110311003F00F9FFD9'))
|
||||
|
||||
print(f"测试文件已创建:{test_file}")
|
||||
print(f"文件大小:{os.path.getsize(test_file)} bytes")
|
||||
print("\n访问URL:")
|
||||
print("http://localhost:8080/storage/mineru/images/ac3681aaa7a346b49ef9c7ceb7b94058.jpg")
|
||||
print("\n如果Django服务正在运行,可以直接在浏览器中访问上述URL")
|
||||
121
test_url_fix.py
121
test_url_fix.py
@ -1,121 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试URL修复 - 验证platform_adapter是否正确传递
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
# Add paths
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
apps_path = project_root / 'apps'
|
||||
if apps_path.exists():
|
||||
sys.path.insert(0, str(apps_path))
|
||||
|
||||
# Set environment variables for testing
|
||||
os.environ['MAXKB_BASE_URL'] = 'http://xbase.aitravelmaster.com'
|
||||
os.environ['MINERU_API_TYPE'] = 'cloud' # Force cloud mode for testing
|
||||
|
||||
async def test_url_generation():
|
||||
"""Test that URLs are generated correctly"""
|
||||
|
||||
# Import after setting environment
|
||||
from apps.common.handle.impl.mineru.maxkb_adapter.adapter import MaxKBAdapter
|
||||
|
||||
# Create adapter
|
||||
adapter = MaxKBAdapter()
|
||||
|
||||
# Create a test file
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.pdf', delete=False) as f:
|
||||
f.write('test')
|
||||
test_file = f.name
|
||||
|
||||
try:
|
||||
# Test upload_file
|
||||
print("Testing MaxKBAdapter.upload_file()...")
|
||||
url = await adapter.upload_file(test_file, ['test_knowledge_id'])
|
||||
|
||||
print(f"\n✅ Generated URL: {url}")
|
||||
|
||||
# Verify URL format
|
||||
if url.startswith('http://') or url.startswith('https://'):
|
||||
print("✅ URL is properly formatted for Cloud API")
|
||||
else:
|
||||
print(f"❌ URL is not valid for Cloud API: {url}")
|
||||
|
||||
# Check if MAXKB_BASE_URL is used
|
||||
base_url = os.environ.get('MAXKB_BASE_URL', '')
|
||||
if base_url and url.startswith(base_url):
|
||||
print(f"✅ URL correctly uses MAXKB_BASE_URL: {base_url}")
|
||||
else:
|
||||
print(f"❌ URL does not use MAXKB_BASE_URL")
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
if os.path.exists(test_file):
|
||||
os.unlink(test_file)
|
||||
|
||||
async def test_api_client_with_adapter():
|
||||
"""Test that MinerUAPIClient receives platform_adapter correctly"""
|
||||
|
||||
from apps.common.handle.impl.mineru.api_client import MinerUAPIClient
|
||||
from apps.common.handle.impl.mineru.maxkb_adapter.adapter import MaxKBAdapter
|
||||
from apps.common.handle.impl.mineru.maxkb_adapter.config_maxkb import MaxKBMinerUConfig
|
||||
|
||||
print("\nTesting MinerUAPIClient with platform_adapter...")
|
||||
|
||||
# Create components
|
||||
adapter = MaxKBAdapter()
|
||||
config = MaxKBMinerUConfig()
|
||||
|
||||
# Create API client with adapter
|
||||
api_client = MinerUAPIClient(config, adapter)
|
||||
|
||||
# Check if adapter is set
|
||||
if api_client.platform_adapter is not None:
|
||||
print("✅ platform_adapter is correctly set in MinerUAPIClient")
|
||||
else:
|
||||
print("❌ platform_adapter is None in MinerUAPIClient")
|
||||
|
||||
# Test _upload_file_to_accessible_url
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.pdf', delete=False) as f:
|
||||
f.write('test')
|
||||
test_file = f.name
|
||||
|
||||
try:
|
||||
# Test upload through API client
|
||||
async with api_client:
|
||||
url = await api_client._upload_file_to_accessible_url(test_file, 'test_src_id')
|
||||
print(f"✅ URL from _upload_file_to_accessible_url: {url}")
|
||||
|
||||
if url.startswith('http://') or url.startswith('https://'):
|
||||
print("✅ API client generates valid URL for Cloud API")
|
||||
else:
|
||||
print(f"❌ API client generates invalid URL: {url}")
|
||||
|
||||
finally:
|
||||
if os.path.exists(test_file):
|
||||
os.unlink(test_file)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("Testing MinerU Cloud API URL Fix")
|
||||
print("=" * 60)
|
||||
|
||||
# Check environment
|
||||
print("\nEnvironment:")
|
||||
print(f"MAXKB_BASE_URL: {os.environ.get('MAXKB_BASE_URL', 'NOT SET')}")
|
||||
print(f"MINERU_API_TYPE: {os.environ.get('MINERU_API_TYPE', 'NOT SET')}")
|
||||
|
||||
# Run tests
|
||||
asyncio.run(test_url_generation())
|
||||
asyncio.run(test_api_client_with_adapter())
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Test completed!")
|
||||
print("=" * 60)
|
||||
@ -1,94 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
简单测试URL生成逻辑
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import shutil
|
||||
import uuid
|
||||
|
||||
# 设置环境变量
|
||||
os.environ['MAXKB_BASE_URL'] = 'http://xbase.aitravelmaster.com'
|
||||
|
||||
def test_url_generation():
|
||||
"""模拟adapter.py中的upload_file逻辑"""
|
||||
|
||||
# 创建测试文件
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.pdf', delete=False) as f:
|
||||
f.write('test')
|
||||
file_path = f.name
|
||||
|
||||
try:
|
||||
# 模拟upload_file的逻辑
|
||||
storage_path = '/tmp/storage' # 模拟存储路径
|
||||
|
||||
# 创建存储目录
|
||||
sub_dir = 'mineru'
|
||||
storage_dir = os.path.join(storage_path, sub_dir, 'images')
|
||||
os.makedirs(storage_dir, exist_ok=True)
|
||||
|
||||
# 生成文件名
|
||||
file_ext = os.path.splitext(file_path)[1]
|
||||
file_name = f"{uuid.uuid4().hex}{file_ext}"
|
||||
dest_path = os.path.join(storage_dir, file_name)
|
||||
|
||||
# 复制文件
|
||||
shutil.copy2(file_path, dest_path)
|
||||
|
||||
# 生成URL(这是关键部分)
|
||||
relative_path = os.path.relpath(dest_path, storage_path)
|
||||
relative_path = relative_path.replace(os.path.sep, '/')
|
||||
|
||||
# 检查环境变量
|
||||
base_url = os.getenv('MAXKB_BASE_URL', '')
|
||||
print(f"MAXKB_BASE_URL from env: '{base_url}'")
|
||||
print(f"Relative path: {relative_path}")
|
||||
|
||||
if base_url:
|
||||
result_url = f"{base_url.rstrip('/')}/storage/{relative_path}"
|
||||
print(f"✅ Generated full URL: {result_url}")
|
||||
else:
|
||||
result_url = f"/storage/{relative_path}"
|
||||
print(f"⚠️ Generated relative URL: {result_url}")
|
||||
|
||||
# 验证URL格式
|
||||
if result_url.startswith(('http://', 'https://')):
|
||||
print("✅ URL is valid for Cloud API")
|
||||
else:
|
||||
print("❌ URL is NOT valid for Cloud API (must start with http:// or https://)")
|
||||
|
||||
return result_url
|
||||
|
||||
finally:
|
||||
# 清理
|
||||
if os.path.exists(file_path):
|
||||
os.unlink(file_path)
|
||||
# 清理存储目录
|
||||
if os.path.exists('/tmp/storage'):
|
||||
shutil.rmtree('/tmp/storage')
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("Testing URL Generation Logic")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# 测试1:有MAXKB_BASE_URL
|
||||
print("Test 1: With MAXKB_BASE_URL set")
|
||||
print("-" * 40)
|
||||
url1 = test_url_generation()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
# 测试2:没有MAXKB_BASE_URL
|
||||
print("\nTest 2: Without MAXKB_BASE_URL")
|
||||
print("-" * 40)
|
||||
os.environ['MAXKB_BASE_URL'] = ''
|
||||
url2 = test_url_generation()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Summary:")
|
||||
print(f"With MAXKB_BASE_URL: {url1}")
|
||||
print(f"Without MAXKB_BASE_URL: {url2}")
|
||||
print("=" * 60)
|
||||
@ -6,6 +6,8 @@ interface TaskTypeInterface {
|
||||
GENERATE_PROBLEM: number
|
||||
// 同步
|
||||
SYNC: number
|
||||
// 生成
|
||||
GENERATE: number
|
||||
}
|
||||
interface StateInterface {
|
||||
// 等待
|
||||
@ -27,7 +29,8 @@ interface StateInterface {
|
||||
const TaskType: TaskTypeInterface = {
|
||||
EMBEDDING: 1,
|
||||
GENERATE_PROBLEM: 2,
|
||||
SYNC: 3
|
||||
SYNC: 3,
|
||||
GENERATE: 4
|
||||
}
|
||||
const State: StateInterface = {
|
||||
// 等待
|
||||
|
||||
@ -73,12 +73,14 @@ const aggStatus = computed(() => {
|
||||
const startedMap = {
|
||||
[TaskType.EMBEDDING]: t('views.document.fileStatus.EMBEDDING'),
|
||||
[TaskType.GENERATE_PROBLEM]: t('views.document.fileStatus.GENERATE'),
|
||||
[TaskType.SYNC]: t('views.document.fileStatus.SYNC')
|
||||
[TaskType.SYNC]: t('views.document.fileStatus.SYNC'),
|
||||
[TaskType.GENERATE]: t('views.document.fileStatus.GENERATE')
|
||||
}
|
||||
const taskTypeMap = {
|
||||
[TaskType.EMBEDDING]: t('views.knowledge.setting.vectorization'),
|
||||
[TaskType.GENERATE_PROBLEM]: t('views.document.generateQuestion.title'),
|
||||
[TaskType.SYNC]: t('views.knowledge.setting.sync')
|
||||
[TaskType.SYNC]: t('views.knowledge.setting.sync'),
|
||||
[TaskType.GENERATE]: t('views.document.fileStatus.GENERATE')
|
||||
}
|
||||
const stateMap: any = {
|
||||
[State.PENDING]: (type: number) => t('views.document.fileStatus.PENDING'),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user