diff --git a/apps/common/handle/impl/media/__init__.py b/apps/common/handle/impl/media/__init__.py new file mode 100644 index 00000000..daf30529 --- /dev/null +++ b/apps/common/handle/impl/media/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- +""" +音视频处理模块 +""" \ No newline at end of file diff --git a/apps/common/handle/impl/media/media_adapter/__init__.py b/apps/common/handle/impl/media/media_adapter/__init__.py new file mode 100644 index 00000000..c94cd8a1 --- /dev/null +++ b/apps/common/handle/impl/media/media_adapter/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- +""" +媒体适配器模块 +""" +from .adapter import MediaAdapter + +__all__ = ['MediaAdapter'] \ No newline at end of file diff --git a/apps/common/handle/impl/media/media_adapter/adapter.py b/apps/common/handle/impl/media/media_adapter/adapter.py new file mode 100644 index 00000000..a9e9948f --- /dev/null +++ b/apps/common/handle/impl/media/media_adapter/adapter.py @@ -0,0 +1,143 @@ +# -*- coding: utf-8 -*- +""" +音视频处理适配器 +复用MaxKB的模型系统,保持模块独立性 +""" +import os +import json +import tempfile +from typing import Dict, List, Optional, Any +from concurrent.futures import ThreadPoolExecutor + +class MediaAdapter: + """ + 音视频处理适配器 + 复用MaxKB的模型系统,保持模块独立性 + """ + + def __init__(self, logger=None): + self.logger = logger or self._get_default_logger() + from .config import MediaConfig + self.config = MediaConfig() + + def _get_default_logger(self): + """获取默认logger""" + try: + from common.utils.logger import maxkb_logger + from .logger import MediaLogger + return MediaLogger(maxkb_logger) + except: + import logging + from .logger import MediaLogger + return MediaLogger(logging.getLogger('MediaAdapter')) + + def process_media(self, + file_content: bytes, + file_name: str, + stt_model_id: Optional[str] = None, + llm_model_id: Optional[str] = None, + workspace_id: Optional[str] = None, + options: Dict[str, Any] = None) -> Dict: + """ + 处理音视频文件 + + Args: + file_content: 文件内容 + file_name: 文件名 + stt_model_id: STT模型ID(MaxKB系统中的) + llm_model_id: LLM模型ID(用于文本优化,可选) + workspace_id: 工作空间ID + options: 其他选项 + - language: 语言(zh/en/auto) + - segment_duration: 分段时长(秒) + - enable_punctuation: 是否添加标点 + - enable_summary: 是否生成摘要 + + Returns: + { + 'status': 'success', + 'media_type': 'audio/video', + 'duration': 120.5, + 'segments': [ + { + 'index': 0, + 'start_time': 0, + 'end_time': 60, + 'text': '转写文本', + 'enhanced_text': '优化后的文本', + 'summary': '段落摘要' + } + ], + 'full_text': '完整文本', + 'metadata': { + 'stt_model': 'model_name', + 'language': 'zh', + 'processing_time': 10.5 + } + } + """ + + options = options or {} + self.logger.info(f"开始处理媒体文件: {file_name}") + self.logger.info(f"接收到的参数:") + self.logger.info(f" - stt_model_id: {stt_model_id}") + self.logger.info(f" - workspace_id: {workspace_id}") + self.logger.info(f" - llm_model_id: {llm_model_id}") + + try: + # 判断媒体类型 + media_type = self._detect_media_type(file_name) + + # 获取STT模型实例 + stt_model = None + if stt_model_id and workspace_id: + try: + from models_provider.tools import get_model_instance_by_model_workspace_id + stt_model = get_model_instance_by_model_workspace_id(stt_model_id, workspace_id) + self.logger.info(f"成功获取STT模型实例: {stt_model}") + except Exception as e: + self.logger.error(f"获取STT模型失败: {str(e)}") + else: + self.logger.warning(f"STT模型未配置 - stt_model_id: {stt_model_id}, workspace_id: {workspace_id}") + + # 获取LLM模型实例(可选) + llm_model = None + if llm_model_id and workspace_id: + try: + from models_provider.tools import get_model_instance_by_model_workspace_id + llm_model = get_model_instance_by_model_workspace_id(llm_model_id, workspace_id) + self.logger.info(f"使用LLM模型: {llm_model_id}") + except Exception as e: + self.logger.warning(f"获取LLM模型失败: {str(e)}") + + # 处理文件 + if media_type == 'video': + from .processors.video_processor import VideoProcessor + processor = VideoProcessor(self.config, self.logger) + else: + from .processors.audio_processor import AudioProcessor + processor = AudioProcessor(self.config, self.logger) + + result = processor.process( + file_content=file_content, + file_name=file_name, + stt_model=stt_model, + llm_model=llm_model, + options=options + ) + + self.logger.info(f"媒体文件处理成功: {file_name}") + return result + + except Exception as e: + self.logger.error(f"处理媒体文件失败: {str(e)}") + raise + + def _detect_media_type(self, file_name: str) -> str: + """检测媒体类型""" + file_ext = file_name.lower().split('.')[-1] + video_exts = {'mp4', 'avi', 'mov', 'mkv', 'webm', 'flv', 'wmv'} + + if file_ext in video_exts: + return 'video' + return 'audio' \ No newline at end of file diff --git a/apps/common/handle/impl/media/media_adapter/config.py b/apps/common/handle/impl/media/media_adapter/config.py new file mode 100644 index 00000000..90e3f77c --- /dev/null +++ b/apps/common/handle/impl/media/media_adapter/config.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +""" +媒体适配器配置管理 +""" +import os + +class MediaConfig: + """独立的配置管理,不依赖Django settings""" + + def __init__(self, custom_config=None): + self.config = self._load_default_config() + if custom_config: + self.config.update(custom_config) + + def _load_default_config(self): + """加载默认配置""" + return { + # STT提供者配置 + 'stt_provider': os.getenv('MEDIA_STT_PROVIDER', 'openai'), + 'stt_model': os.getenv('MEDIA_STT_MODEL', 'whisper-1'), + + # 处理参数 + 'max_duration': int(os.getenv('MEDIA_MAX_DURATION', '7200')), # 最大时长(秒) + 'segment_duration': int(os.getenv('MEDIA_SEGMENT_DURATION', '300')), # 分段长度(秒) + 'enable_timestamps': os.getenv('MEDIA_ENABLE_TIMESTAMPS', 'true').lower() == 'true', + + # 音频处理 + 'audio_format': os.getenv('MEDIA_AUDIO_FORMAT', 'mp3'), + 'sample_rate': int(os.getenv('MEDIA_SAMPLE_RATE', '16000')), + + # 视频处理 + 'extract_keyframes': os.getenv('MEDIA_EXTRACT_KEYFRAMES', 'false').lower() == 'true', + 'video_codec': os.getenv('MEDIA_VIDEO_CODEC', 'h264'), + + # 日志 + 'log_level': os.getenv('MEDIA_LOG_LEVEL', 'INFO'), + 'log_file': os.getenv('MEDIA_LOG_FILE', 'media_adapter.log') + } + + def get(self, key: str, default=None): + """获取配置项""" + return self.config.get(key, default) + + def set(self, key: str, value): + """设置配置项""" + self.config[key] = value + + def update(self, config_dict: dict): + """批量更新配置""" + self.config.update(config_dict) \ No newline at end of file diff --git a/apps/common/handle/impl/media/media_adapter/logger.py b/apps/common/handle/impl/media/media_adapter/logger.py new file mode 100644 index 00000000..aa4baba4 --- /dev/null +++ b/apps/common/handle/impl/media/media_adapter/logger.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +""" +日志包装器 - 适配不同的日志系统 +""" + +class MediaLogger: + """日志包装器 - 适配不同的日志系统""" + + def __init__(self, logger): + self.logger = logger + + def info(self, message): + """记录信息日志""" + if hasattr(self.logger, 'info'): + self.logger.info(f"[MediaAdapter] {message}") + else: + print(f"[INFO] {message}") + + def error(self, message, exc_info=False): + """记录错误日志""" + if hasattr(self.logger, 'error'): + self.logger.error(f"[MediaAdapter] {message}", exc_info=exc_info) + else: + print(f"[ERROR] {message}") + if exc_info: + import traceback + traceback.print_exc() + + def warning(self, message): + """记录警告日志""" + if hasattr(self.logger, 'warning'): + self.logger.warning(f"[MediaAdapter] {message}") + else: + print(f"[WARNING] {message}") + + def debug(self, message): + """记录调试日志""" + if hasattr(self.logger, 'debug'): + self.logger.debug(f"[MediaAdapter] {message}") + else: + print(f"[DEBUG] {message}") \ No newline at end of file diff --git a/apps/common/handle/impl/media/media_adapter/processors/__init__.py b/apps/common/handle/impl/media/media_adapter/processors/__init__.py new file mode 100644 index 00000000..9ec6265b --- /dev/null +++ b/apps/common/handle/impl/media/media_adapter/processors/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +""" +媒体处理器模块 +""" +from .audio_processor import AudioProcessor +from .video_processor import VideoProcessor + +__all__ = ['AudioProcessor', 'VideoProcessor'] \ No newline at end of file diff --git a/apps/common/handle/impl/media/media_adapter/processors/audio_processor.py b/apps/common/handle/impl/media/media_adapter/processors/audio_processor.py new file mode 100644 index 00000000..dd9ee682 --- /dev/null +++ b/apps/common/handle/impl/media/media_adapter/processors/audio_processor.py @@ -0,0 +1,254 @@ +# -*- coding: utf-8 -*- +""" +音频处理器 - 复用MaxKB的音频处理工具 +""" +import io +import os +import tempfile +from typing import Dict, List, Optional, Any + +class AudioProcessor: + """音频处理器 - 复用MaxKB的音频处理工具""" + + def __init__(self, config, logger): + self.config = config + self.logger = logger + + def process(self, + file_content: bytes, + file_name: str, + stt_model: Optional[Any] = None, + llm_model: Optional[Any] = None, + options: Dict[str, Any] = None) -> Dict: + """处理音频文件""" + + options = options or {} + segment_duration = options.get('segment_duration', self.config.get('segment_duration', 300)) # 默认5分钟 + + # 保存临时文件 + suffix = self._get_suffix(file_name) + with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: + temp_file.write(file_content) + temp_file_path = temp_file.name + + # 转换为MP3(复用MaxKB工具) + with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as mp3_file: + mp3_path = mp3_file.name + + try: + # 使用MaxKB的音频转换工具 + try: + from common.utils.common import any_to_mp3 + any_to_mp3(temp_file_path, mp3_path) + self.logger.info(f"音频文件已转换为MP3格式") + except Exception as e: + self.logger.warning(f"音频转换失败,使用原始文件: {str(e)}") + mp3_path = temp_file_path + + # 获取音频信息 + duration = self._get_audio_duration(mp3_path) + + # 分段转写 + segments = [] + if stt_model: + self.logger.info(f"开始转写音频,总时长: {duration:.1f}秒") + segments = self._transcribe_audio(mp3_path, stt_model, segment_duration) + else: + # 如果没有STT模型,返回基础信息 + segments = [{ + 'index': 0, + 'start_time': 0, + 'end_time': duration, + 'text': f'[音频文件: {file_name}, 时长: {duration:.1f}秒]' + }] + self.logger.warning("未提供STT模型,返回基础信息") + + # 文本优化(可选) + if llm_model and segments and len(segments) > 0: + self.logger.info("开始优化转写文本") + segments = self._enhance_segments(segments, llm_model, options) + + # 生成完整文本 + full_text = '\n'.join([seg.get('enhanced_text', seg['text']) for seg in segments]) + + return { + 'status': 'success', + 'media_type': 'audio', + 'duration': duration, + 'segments': segments, + 'full_text': full_text, + 'metadata': { + 'file_name': file_name, + 'stt_model': str(stt_model) if stt_model else None, + 'language': options.get('language', 'auto') + } + } + + except Exception as e: + self.logger.error(f"处理音频文件失败: {str(e)}", exc_info=True) + raise + finally: + # 清理临时文件 + try: + if os.path.exists(temp_file_path): + os.unlink(temp_file_path) + if os.path.exists(mp3_path) and mp3_path != temp_file_path: + os.unlink(mp3_path) + except Exception as e: + self.logger.warning(f"清理临时文件失败: {str(e)}") + + def _get_audio_duration(self, audio_path: str) -> float: + """获取音频时长""" + try: + from pydub import AudioSegment + audio = AudioSegment.from_file(audio_path) + return len(audio) / 1000 # 转换为秒 + except Exception as e: + self.logger.warning(f"无法获取音频时长: {str(e)}") + return 0 + + def _transcribe_audio(self, audio_path: str, stt_model: Any, segment_duration: int) -> List[Dict]: + """转写音频""" + try: + from common.utils.common import split_and_transcribe + from pydub import AudioSegment + + audio = AudioSegment.from_file(audio_path) + duration = len(audio) / 1000 + segments = [] + + # 按时长分段 + segment_ms = segment_duration * 1000 + + for start_ms in range(0, len(audio), segment_ms): + end_ms = min(start_ms + segment_ms, len(audio)) + segment_audio = audio[start_ms:end_ms] + + # 保存段落音频 + with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as seg_file: + segment_audio.export(seg_file.name, format='mp3') + + try: + # 使用MaxKB的转写工具 + self.logger.info(f"转写第 {len(segments)+1} 段,时间: {start_ms/1000:.1f}s - {end_ms/1000:.1f}s") + text = split_and_transcribe(seg_file.name, stt_model) + + segments.append({ + 'index': len(segments), + 'start_time': start_ms / 1000, + 'end_time': end_ms / 1000, + 'text': text if text else '[无法识别]' + }) + finally: + if os.path.exists(seg_file.name): + os.unlink(seg_file.name) + + self.logger.info(f"音频转写完成,共 {len(segments)} 段") + return segments + + except Exception as e: + self.logger.error(f"音频转写失败: {str(e)}", exc_info=True) + return [] + + def _enhance_segments(self, segments: List[Dict], llm_model: Any, options: Dict) -> List[Dict]: + """使用LLM优化文本""" + try: + for segment in segments: + original_text = segment['text'] + + if options.get('enable_punctuation', True) and original_text and original_text != '[无法识别]': + # 添加标点符号 + prompt = f"请为以下语音转写文本添加适当的标点符号,保持原意不变,直接返回处理后的文本:\n\n{original_text}" + + try: + # 调用LLM模型 + enhanced = None + if hasattr(llm_model, 'generate'): + response = llm_model.generate(prompt) + self.logger.info(f"LLM generate response type: {type(response)}, value: {response}") + # 处理不同的响应格式 + try: + if hasattr(response, 'content'): + enhanced = response.content + elif isinstance(response, str): + enhanced = response + else: + enhanced = str(response) + except Exception as attr_error: + self.logger.warning(f"Error accessing response content: {str(attr_error)}") + enhanced = str(response) if response else original_text + elif hasattr(llm_model, 'invoke'): + response = llm_model.invoke(prompt) + self.logger.info(f"LLM invoke response type: {type(response)}, value: {response}") + # 处理不同的响应格式 + try: + if hasattr(response, 'content'): + enhanced = response.content + elif isinstance(response, str): + enhanced = response + else: + enhanced = str(response) + except Exception as attr_error: + self.logger.warning(f"Error accessing response content: {str(attr_error)}") + enhanced = str(response) if response else original_text + else: + # 尝试其他可能的方法 + enhanced = original_text + + if enhanced and enhanced.strip(): + segment['enhanced_text'] = enhanced.strip() + except Exception as e: + self.logger.warning(f"优化文本失败: {str(e)}") + + if options.get('enable_summary', False) and original_text and len(original_text) > 100: + # 生成摘要 + prompt = f"请用一句话(不超过50字)总结以下内容的核心要点:\n\n{original_text}" + + try: + summary = None + if hasattr(llm_model, 'generate'): + response = llm_model.generate(prompt) + self.logger.info(f"LLM summary generate response type: {type(response)}, value: {response}") + # 处理不同的响应格式 + try: + if hasattr(response, 'content'): + summary = response.content + elif isinstance(response, str): + summary = response + else: + summary = str(response) + except Exception as attr_error: + self.logger.warning(f"Error accessing summary response content: {str(attr_error)}") + summary = str(response) if response else None + elif hasattr(llm_model, 'invoke'): + response = llm_model.invoke(prompt) + self.logger.info(f"LLM summary invoke response type: {type(response)}, value: {response}") + # 处理不同的响应格式 + try: + if hasattr(response, 'content'): + summary = response.content + elif isinstance(response, str): + summary = response + else: + summary = str(response) + except Exception as attr_error: + self.logger.warning(f"Error accessing summary response content: {str(attr_error)}") + summary = str(response) if response else None + else: + summary = None + + if summary and summary.strip(): + segment['summary'] = summary.strip() + except Exception as e: + self.logger.warning(f"生成摘要失败: {str(e)}") + + return segments + except Exception as e: + self.logger.error(f"文本优化失败: {str(e)}") + return segments + + def _get_suffix(self, file_name: str) -> str: + """获取文件后缀""" + if '.' in file_name: + return '.' + file_name.split('.')[-1].lower() + return '.mp3' \ No newline at end of file diff --git a/apps/common/handle/impl/media/media_adapter/processors/video_processor.py b/apps/common/handle/impl/media/media_adapter/processors/video_processor.py new file mode 100644 index 00000000..3f5963ce --- /dev/null +++ b/apps/common/handle/impl/media/media_adapter/processors/video_processor.py @@ -0,0 +1,180 @@ +# -*- coding: utf-8 -*- +""" +视频处理器 - 提取音频并处理 +""" +import os +import tempfile +import subprocess +from typing import Dict, Optional, Any, List +from .audio_processor import AudioProcessor + +class VideoProcessor: + """视频处理器""" + + def __init__(self, config, logger): + self.config = config + self.logger = logger + self.audio_processor = AudioProcessor(config, logger) + + def process(self, + file_content: bytes, + file_name: str, + stt_model: Optional[Any] = None, + llm_model: Optional[Any] = None, + options: Dict[str, Any] = None) -> Dict: + """处理视频文件""" + + # 保存临时视频文件 + suffix = self._get_suffix(file_name) + with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as video_file: + video_file.write(file_content) + video_path = video_file.name + + audio_path = None + try: + self.logger.info(f"开始处理视频文件: {file_name}") + + # 提取音频 + audio_path = self._extract_audio(video_path, file_name) + + if audio_path and os.path.exists(audio_path): + # 读取音频文件 + with open(audio_path, 'rb') as f: + audio_content = f.read() + + # 使用音频处理器处理 + audio_file_name = os.path.splitext(file_name)[0] + '.mp3' + result = self.audio_processor.process( + file_content=audio_content, + file_name=audio_file_name, + stt_model=stt_model, + llm_model=llm_model, + options=options + ) + + # 更新媒体类型 + result['media_type'] = 'video' + result['metadata']['original_file'] = file_name + + # 可选:提取关键帧 + if options and options.get('extract_keyframes', False): + keyframes = self._extract_keyframes(video_path) + if keyframes: + result['keyframes'] = keyframes + + self.logger.info(f"视频文件处理完成: {file_name}") + return result + else: + # 音频提取失败 + self.logger.error(f"无法从视频提取音频: {file_name}") + return { + 'status': 'error', + 'media_type': 'video', + 'duration': 0, + 'segments': [{ + 'index': 0, + 'start_time': 0, + 'end_time': 0, + 'text': f'[视频文件: {file_name}, 无法提取音频]' + }], + 'full_text': f'[视频文件: {file_name}, 无法提取音频]', + 'metadata': { + 'file_name': file_name, + 'error': '音频提取失败' + } + } + + except Exception as e: + self.logger.error(f"处理视频文件失败: {str(e)}", exc_info=True) + raise + finally: + # 清理临时文件 + try: + if os.path.exists(video_path): + os.unlink(video_path) + if audio_path and os.path.exists(audio_path): + os.unlink(audio_path) + except Exception as e: + self.logger.warning(f"清理临时文件失败: {str(e)}") + + def _extract_audio(self, video_path: str, file_name: str) -> Optional[str]: + """从视频提取音频""" + audio_path = None + + # 首先尝试使用moviepy + try: + from moviepy.editor import VideoFileClip + + audio_path = video_path.replace(self._get_suffix(file_name), '.mp3') + self.logger.info(f"使用moviepy提取音频") + + video = VideoFileClip(video_path) + if video.audio is not None: + video.audio.write_audiofile(audio_path, logger=None, verbose=False) + video.close() + self.logger.info(f"音频提取成功: {audio_path}") + return audio_path + else: + self.logger.warning(f"视频文件没有音频轨道: {file_name}") + video.close() + return None + + except ImportError: + self.logger.warning("moviepy未安装,尝试使用ffmpeg命令") + except Exception as e: + self.logger.warning(f"moviepy提取音频失败: {str(e)},尝试使用ffmpeg命令") + + # 降级到ffmpeg命令 + try: + audio_path = video_path.replace(self._get_suffix(file_name), '.mp3') + + # 检查ffmpeg是否可用 + result = subprocess.run(['ffmpeg', '-version'], capture_output=True, text=True) + if result.returncode != 0: + self.logger.error("ffmpeg未安装或不可用") + return None + + # 使用ffmpeg提取音频 + cmd = [ + 'ffmpeg', + '-i', video_path, + '-vn', # 禁用视频 + '-acodec', 'mp3', # 音频编码 + '-ab', '128k', # 音频比特率 + '-ar', '44100', # 采样率 + '-y', # 覆盖输出文件 + audio_path + ] + + self.logger.info(f"使用ffmpeg命令提取音频") + result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + + if result.returncode == 0 and os.path.exists(audio_path): + self.logger.info(f"ffmpeg音频提取成功: {audio_path}") + return audio_path + else: + self.logger.error(f"ffmpeg音频提取失败: {result.stderr}") + return None + + except subprocess.TimeoutExpired: + self.logger.error("ffmpeg提取音频超时") + return None + except FileNotFoundError: + self.logger.error("ffmpeg未安装") + return None + except Exception as e: + self.logger.error(f"ffmpeg提取音频失败: {str(e)}") + return None + + def _extract_keyframes(self, video_path: str) -> List[str]: + """提取关键帧(可选功能)""" + # TODO: 实现关键帧提取 + # 可以使用opencv-python或ffmpeg提取关键帧 + # 返回base64编码的图片列表 + return [] + + def _get_suffix(self, file_name: str) -> str: + """获取文件后缀""" + if '.' in file_name: + return '.' + file_name.split('.')[-1].lower() + return '.mp4' \ No newline at end of file diff --git a/apps/common/handle/impl/media/media_split_handle.py b/apps/common/handle/impl/media/media_split_handle.py new file mode 100644 index 00000000..0805358a --- /dev/null +++ b/apps/common/handle/impl/media/media_split_handle.py @@ -0,0 +1,157 @@ +# -*- coding: utf-8 -*- +""" +音视频处理器 - MaxKB集成层 +""" +from typing import List +from common.handle.base_split_handle import BaseSplitHandle +from common.utils.logger import maxkb_logger + +class MediaSplitHandle(BaseSplitHandle): + """ + 音视频处理器 - MaxKB集成层 + """ + + def __init__(self): + super().__init__() + self.adapter = None + + def support(self, file, get_buffer, **kwargs): + """检查是否支持该文件类型""" + file_name = file.name.lower() + + # 支持的音频格式 + audio_exts = ('.mp3', '.wav', '.m4a', '.flac', '.aac', '.ogg', '.wma') + # 支持的视频格式 + video_exts = ('.mp4', '.avi', '.mov', '.mkv', '.webm', '.flv', '.wmv') + + return any(file_name.endswith(ext) for ext in audio_exts + video_exts) + + def handle(self, file, pattern_list: List, with_filter: bool, limit: int, + get_buffer, save_image, **kwargs): + """处理音视频文件""" + + maxkb_logger.info(f"MediaSplitHandle.handle called with file: {file.name}") + maxkb_logger.info(f"kwargs received: {kwargs}") + + # 初始化适配器 + if not self.adapter: + from .media_adapter import MediaAdapter + from .media_adapter.logger import MediaLogger + logger_wrapper = MediaLogger(maxkb_logger) + self.adapter = MediaAdapter(logger=logger_wrapper) + + # 获取文件内容 + buffer = get_buffer(file) + + # 获取模型ID和工作空间ID + stt_model_id = kwargs.get('stt_model_id') + llm_model_id = kwargs.get('llm_model_id') + workspace_id = kwargs.get('workspace_id') + + maxkb_logger.info(f"Extracted from kwargs - stt_model_id: {stt_model_id}, llm_model_id: {llm_model_id}, workspace_id: {workspace_id}") + + # 处理选项 + options = { + 'language': kwargs.get('language', 'auto'), + 'segment_duration': kwargs.get('segment_duration', 300), + 'enable_punctuation': kwargs.get('enable_punctuation', True), + 'enable_summary': kwargs.get('enable_summary', False), + 'extract_keyframes': kwargs.get('extract_keyframes', False) + } + + try: + # 调用适配器处理 + result = self.adapter.process_media( + file_content=buffer, + file_name=file.name, + stt_model_id=stt_model_id, + llm_model_id=llm_model_id, + workspace_id=workspace_id, + options=options + ) + + # 转换为MaxKB段落格式 + paragraphs = [] + for segment in result.get('segments', []): + # 使用优化后的文本(如果有) + text = segment.get('enhanced_text', segment.get('text', '')) + + # 添加时间戳信息 + if segment.get('start_time') is not None: + time_info = f"[{self._format_time(segment['start_time'])} - {self._format_time(segment['end_time'])}]" + text = f"{time_info}\n{text}" + + # 添加摘要(如果有) + if segment.get('summary'): + text = f"{text}\n【摘要】{segment['summary']}" + + paragraph = { + 'content': text, + 'title': f"段落 {segment.get('index', 0) + 1}", + 'metadata': { + 'start_time': segment.get('start_time'), + 'end_time': segment.get('end_time'), + 'index': segment.get('index') + } + } + + # 如果有关键帧,添加到段落中 + if 'keyframes' in result and segment.get('index', 0) < len(result['keyframes']): + paragraph['images'] = [result['keyframes'][segment['index']]] + + paragraphs.append(paragraph) + + # 应用限制 + if limit > 0: + paragraphs = paragraphs[:limit] + + # 添加成功处理的标记 + metadata = result.get('metadata', {}) + metadata['media_processing_status'] = 'success' + + return { + 'name': file.name, + 'content': paragraphs, + 'metadata': metadata + } + + except Exception as e: + maxkb_logger.error(f"处理音视频文件失败: {str(e)}") + # 返回错误信息 + return { + 'name': file.name, + 'content': [{ + 'content': f'处理失败: {str(e)}', + 'title': '错误' + }], + 'metadata': {'error': str(e)} + } + + def get_content(self, file, save_image): + """获取文件内容(用于预览)""" + try: + file_name = file.name + # 判断媒体类型 + file_ext = file_name.lower().split('.')[-1] + video_exts = {'mp4', 'avi', 'mov', 'mkv', 'webm', 'flv', 'wmv'} + + if file_ext in video_exts: + return f"[视频文件: {file_name}]\n\n该文件需要进行音频提取和语音识别处理。" + else: + return f"[音频文件: {file_name}]\n\n该文件需要进行语音识别处理。" + except Exception as e: + return f"读取文件失败: {str(e)}" + + def _format_time(self, seconds: float) -> str: + """格式化时间""" + if seconds is None: + return "00:00" + + hours = int(seconds // 3600) + minutes = int((seconds % 3600) // 60) + secs = int(seconds % 60) + + if hours > 0: + return f"{hours:02d}:{minutes:02d}:{secs:02d}" + else: + return f"{minutes:02d}:{secs:02d}" \ No newline at end of file diff --git a/apps/common/handle/impl/text/mineru_split_handle.py b/apps/common/handle/impl/text/mineru_split_handle.py index 64e3736d..09b86c56 100644 --- a/apps/common/handle/impl/text/mineru_split_handle.py +++ b/apps/common/handle/impl/text/mineru_split_handle.py @@ -1,6 +1,9 @@ # -*- coding: utf-8 -*- """ -MinerU Split Handle - 使用MinerU服务处理PDF文档 +MinerU Split Handle - 使用MinerU服务处理文档和图片 + +支持的文档格式:PDF、PPT、PPTX、DOC、DOCX +支持的图片格式:PNG、JPG、JPEG、GIF、BMP、TIFF、WebP、SVG """ import os from typing import List, Dict, Any @@ -21,7 +24,7 @@ class MinerUSplitHandle(BaseSplitHandle): def support(self, file, get_buffer, **kwargs): """ 检查是否支持该文件类型 - 当前仅支持PDF文件,且需要MinerU服务配置 + 支持PDF、PPT、DOC和图片文件,且需要MinerU服务配置 预览模式下不使用MinerU处理器,因为处理速度较慢 """ # 如果是预览模式,不使用MinerU处理器 @@ -30,7 +33,12 @@ class MinerUSplitHandle(BaseSplitHandle): file_name = file.name.lower() # 检查文件扩展名 - if not file_name.endswith('.pdf'): + supported_extensions = ( + '.pdf', '.ppt', '.pptx', '.doc', '.docx', # 文档格式 + '.png', '.jpg', '.jpeg', '.gif', '.bmp', # 图片格式 + '.tiff', '.tif', '.webp', '.svg' # 其他图片格式 + ) + if not any(file_name.endswith(ext) for ext in supported_extensions): return False # 检查MinerU配置 diff --git a/apps/knowledge/serializers/document.py b/apps/knowledge/serializers/document.py index a5a69ca0..78f93e40 100644 --- a/apps/knowledge/serializers/document.py +++ b/apps/knowledge/serializers/document.py @@ -40,6 +40,7 @@ from common.handle.impl.text.xls_split_handle import XlsSplitHandle from common.handle.impl.text.xlsx_split_handle import XlsxSplitHandle from common.handle.impl.text.zip_split_handle import ZipSplitHandle from common.handle.impl.text.mineru_split_handle import MinerUSplitHandle +from common.handle.impl.media.media_split_handle import MediaSplitHandle from common.utils.common import post, get_file_content, bulk_create_in_batches, parse_image from common.utils.fork import Fork from common.utils.logger import maxkb_logger @@ -62,8 +63,11 @@ from oss.serializers.file import FileSerializer default_split_handle = TextSplitHandle() # MinerU处理器优先级最高,用于处理PDF和PPT文档 mineru_split_handle = MinerUSplitHandle() +# 音视频处理器 +media_split_handle = MediaSplitHandle() split_handles = [ - mineru_split_handle, # MinerU处理器放在最前面,优先使用 + media_split_handle, # 音视频处理器,优先级高 + mineru_split_handle, # MinerU处理器 HTMLSplitHandle(), DocSplitHandle(), PdfSplitHandle(), @@ -107,6 +111,9 @@ class DocumentInstanceSerializer(serializers.Serializer): source=_('document name')) paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True) source_file_id = serializers.UUIDField(required=False, allow_null=True, label=_('source file id')) + stt_model_id = serializers.CharField(required=False, allow_null=True, label=_('STT model ID')) + llm_model_id = serializers.CharField(required=False, allow_null=True, label=_('LLM model ID')) + vision_model_id = serializers.CharField(required=False, allow_null=True, label=_('Vision model ID')) class CancelInstanceSerializer(serializers.Serializer): @@ -877,6 +884,9 @@ class DocumentSerializers(serializers.Serializer): source_meta['llm_model_id'] = instance.get('llm_model_id') if instance.get('vision_model_id'): source_meta['vision_model_id'] = instance.get('vision_model_id') + # 添加音视频STT模型参数到meta + if instance.get('stt_model_id'): + source_meta['stt_model_id'] = instance.get('stt_model_id') meta = {**instance.get('meta'), **source_meta} if instance.get('meta') is not None else source_meta meta = convert_uuid_to_str(meta) @@ -1141,6 +1151,11 @@ class DocumentSerializers(serializers.Serializer): from common.utils.logger import maxkb_logger maxkb_logger.info(f"Skipping refresh for advanced learning document: {document_dict.get('id')}") continue + # 跳过音视频文档(已经通过异步任务处理) + if document_dict.get('is_media_learning'): + from common.utils.logger import maxkb_logger + maxkb_logger.info(f"Skipping refresh for media learning document: {document_dict.get('id')}") + continue DocumentSerializers.Operate(data={ 'knowledge_id': knowledge_id, 'document_id': document_dict.get('id'), @@ -1160,16 +1175,27 @@ class DocumentSerializers(serializers.Serializer): paragraph_model_list = [] problem_paragraph_object_list = [] - # 处理MinerU类型的文档 + # 处理MinerU类型和音视频类型的文档 from common.utils.logger import maxkb_logger import os + # 添加详细日志 + maxkb_logger.info(f"batch_save called with workspace_id: {workspace_id}, knowledge_id: {knowledge_id}") + maxkb_logger.info(f"instance_list contains {len(instance_list)} documents") + for idx, doc in enumerate(instance_list): + maxkb_logger.info(f"Document {idx}: {doc.keys()}") + if 'stt_model_id' in doc: + maxkb_logger.info(f" - stt_model_id present: {doc['stt_model_id']}") + if 'llm_model_id' in doc: + maxkb_logger.info(f" - llm_model_id present: {doc['llm_model_id']}") + for document in instance_list: # 检查是否是MinerU类型的文档(需要同时有llm_model_id和vision_model_id) llm_model_id = document.get('llm_model_id') vision_model_id = document.get('vision_model_id') + stt_model_id = document.get('stt_model_id') - maxkb_logger.info(f"Processing document: {document.get('name')}, llm_model_id: {llm_model_id}, vision_model_id: {vision_model_id}") + maxkb_logger.info(f"Processing document: {document.get('name')}, llm_model_id: {llm_model_id}, vision_model_id: {vision_model_id}, stt_model_id: {stt_model_id}") # 只有同时提供两个模型ID时才是高级学习文档 if llm_model_id and vision_model_id: @@ -1177,6 +1203,75 @@ class DocumentSerializers(serializers.Serializer): # MinerU类型的文档,保存基本信息,不处理段落 # 段落处理将通过异步任务进行 document['paragraphs'] = [] # 清空段落,等待异步处理 + # 检查是否是音视频类型的文档 + elif stt_model_id: + maxkb_logger.info(f"Document {document.get('name')} is media type, processing synchronously") + # 音视频类型的文档,直接处理 + source_file_id = document.get('source_file_id') + if source_file_id: + try: + source_file = QuerySet(File).filter(id=source_file_id).first() + if source_file: + workspace_id_value = self.data.get('workspace_id', '') + maxkb_logger.info(f"Processing media file: {source_file.file_name}") + maxkb_logger.info(f" - STT model ID: {stt_model_id}") + maxkb_logger.info(f" - Workspace ID: {workspace_id_value}") + maxkb_logger.info(f" - LLM model ID: {llm_model_id}") + + # 使用MediaSplitHandle处理音视频文件 + from common.handle.impl.media.media_split_handle import MediaSplitHandle + media_handler = MediaSplitHandle() + + # 准备文件对象 + class FileWrapper: + def __init__(self, file_obj): + self.file_obj = file_obj + self.name = file_obj.file_name + self.size = file_obj.file_size + + def read(self): + return self.file_obj.get_bytes() + + def seek(self, pos): + pass + + file_wrapper = FileWrapper(source_file) + + # 处理音视频文件 + result = media_handler.handle( + file_wrapper, + pattern_list=[], + with_filter=False, + limit=0, + get_buffer=lambda f: f.read(), + save_image=lambda x: None, + workspace_id=workspace_id_value, + stt_model_id=stt_model_id, + llm_model_id=llm_model_id + ) + + # 将处理结果添加到文档 + if result and result.get('content'): + document['paragraphs'] = result.get('content', []) + maxkb_logger.info(f"Media file processed, got {len(document['paragraphs'])} paragraphs") + else: + maxkb_logger.warning(f"No content extracted from media file, using default") + document['paragraphs'] = [{ + 'content': f'[音视频文件: {source_file.file_name}]', + 'title': '音视频内容' + }] + except Exception as e: + maxkb_logger.error(f"Failed to process media file: {str(e)}") + import traceback + maxkb_logger.error(traceback.format_exc()) + # 如果处理失败,创建一个默认段落 + document['paragraphs'] = [{ + 'content': f'[音视频文件处理失败: {str(e)}]', + 'title': '处理失败' + }] + else: + maxkb_logger.warning(f"No source file for media document") + document['paragraphs'] = [] # 插入文档 for document in instance_list: @@ -1211,16 +1306,34 @@ class DocumentSerializers(serializers.Serializer): saved_doc = QuerySet(Document).filter(id=doc.id).first() if saved_doc: maxkb_logger.info(f"Document {doc.id} successfully saved to database") + # 更新音视频文档的状态 + if hasattr(doc, 'meta') and doc.meta and doc.meta.get('stt_model_id'): + try: + from common.event import ListenerManagement + from knowledge.models import TaskType, State + # 更新文档状态为成功 + ListenerManagement.update_status( + QuerySet(Document).filter(id=doc.id), + TaskType.EMBEDDING, + State.SUCCESS + ) + maxkb_logger.info(f"Updated status for media document {doc.id} to SUCCESS") + except Exception as status_error: + maxkb_logger.warning(f"Failed to update status for media document {doc.id}: {str(status_error)}") else: maxkb_logger.error(f"Document {doc.id} not found after bulk_create") # 处理高级学习文档的异步任务 for idx, document in enumerate(instance_list): + if idx >= len(document_model_list): + continue + + document_model = document_model_list[idx] llm_model_id = document.get('llm_model_id') vision_model_id = document.get('vision_model_id') - if llm_model_id and vision_model_id and document_model_list: - # 找到对应的文档模型 - document_model = document_model_list[idx] + + # 处理高级学习文档(MinerU) + if llm_model_id and vision_model_id: maxkb_logger.info(f"Submitting async advanced learning task for document: {document_model.id}") # 设置文档状态为排队中 @@ -1297,12 +1410,17 @@ class DocumentSerializers(serializers.Serializer): with_search_one=False ) - # 标记高级学习文档 + # 标记高级学习文档和音视频文档 for idx, document in enumerate(instance_list): llm_model_id = document.get('llm_model_id') vision_model_id = document.get('vision_model_id') - if llm_model_id and vision_model_id and idx < len(document_result_list): - document_result_list[idx]['is_advanced_learning'] = True + stt_model_id = document.get('stt_model_id') + + if idx < len(document_result_list): + if llm_model_id and vision_model_id: + document_result_list[idx]['is_advanced_learning'] = True + elif stt_model_id: + document_result_list[idx]['is_media_learning'] = True return document_result_list, knowledge_id, workspace_id diff --git a/apps/knowledge/tasks/media_learning.py b/apps/knowledge/tasks/media_learning.py new file mode 100644 index 00000000..723b2f9c --- /dev/null +++ b/apps/knowledge/tasks/media_learning.py @@ -0,0 +1,152 @@ +# -*- coding: utf-8 -*- +""" +音视频学习任务处理 +""" +import traceback +from typing import List, Optional +from celery import shared_task +from django.db import transaction +from django.db.models import QuerySet + +from common.event.common import embedding_by_data_source +from common.event import ListenerManagement +from common.utils.logger import maxkb_logger +from knowledge.models import Document, Paragraph, TaskType, State +from oss.models import File, FileSourceType +from common.handle.impl.media.media_split_handle import MediaSplitHandle + + +@shared_task(name='media_learning_by_document') +def media_learning_by_document(document_id: str, knowledge_id: str, workspace_id: str, + stt_model_id: str, llm_model_id: Optional[str] = None): + """ + 音视频文档异步处理任务 + + Args: + document_id: 文档ID + knowledge_id: 知识库ID + workspace_id: 工作空间ID + stt_model_id: STT模型ID + llm_model_id: LLM模型ID(可选) + """ + maxkb_logger.info(f"Starting media learning task for document: {document_id}") + + try: + # 更新文档状态为处理中 + ListenerManagement.update_status( + QuerySet(Document).filter(id=document_id), + TaskType.EMBEDDING, + State.STARTED + ) + + # 获取文档信息 + document = QuerySet(Document).filter(id=document_id).first() + if not document: + raise ValueError(f"Document not found: {document_id}") + + # 获取源文件 + source_file_id = document.meta.get('source_file_id') + if not source_file_id: + raise ValueError(f"Source file not found for document: {document_id}") + + source_file = QuerySet(File).filter(id=source_file_id).first() + if not source_file: + raise ValueError(f"Source file not found: {source_file_id}") + + maxkb_logger.info(f"Processing media file: {source_file.file_name}") + + # 使用MediaSplitHandle处理音视频文件 + media_handler = MediaSplitHandle() + + # 准备文件对象 + class FileWrapper: + def __init__(self, file_obj): + self.file_obj = file_obj + self.name = file_obj.file_name + self.size = file_obj.file_size + + def read(self): + return self.file_obj.get_bytes() + + def seek(self, pos): + pass + + file_wrapper = FileWrapper(source_file) + + # 获取文件内容的方法 + def get_buffer(file): + return file.read() + + # 保存图片的方法(音视频一般不需要,但保持接口一致) + def save_image(image_list): + pass + + # 处理音视频文件 + result = media_handler.handle( + file_wrapper, + pattern_list=[], # 音视频不需要分段模式 + with_filter=False, + limit=0, # 不限制段落数 + get_buffer=get_buffer, + save_image=save_image, + workspace_id=workspace_id, + stt_model_id=stt_model_id, + llm_model_id=llm_model_id + ) + + # 解析处理结果 + paragraphs_data = result.get('content', []) + + if not paragraphs_data: + raise ValueError("No content extracted from media file") + + maxkb_logger.info(f"Extracted {len(paragraphs_data)} paragraphs from media file") + + # 创建段落对象 + with transaction.atomic(): + paragraph_models = [] + for idx, para_data in enumerate(paragraphs_data): + paragraph = Paragraph( + document_id=document_id, + content=para_data.get('content', ''), + title=para_data.get('title', f'段落 {idx + 1}'), + position=idx + 1, + meta=para_data.get('metadata', {}) + ) + paragraph_models.append(paragraph) + + # 批量保存段落 + if paragraph_models: + QuerySet(Paragraph).bulk_create(paragraph_models) + maxkb_logger.info(f"Created {len(paragraph_models)} paragraphs for document {document_id}") + + # 更新文档字符长度 + total_char_length = sum(len(p.content) for p in paragraph_models) + document.char_length = total_char_length + document.save() + + # 触发向量化任务 + maxkb_logger.info(f"Starting embedding for document: {document_id}") + embedding_by_data_source(document_id, knowledge_id, workspace_id) + + # 更新文档状态为成功 + ListenerManagement.update_status( + QuerySet(Document).filter(id=document_id), + TaskType.EMBEDDING, + State.SUCCESS + ) + + maxkb_logger.info(f"Media learning completed successfully for document: {document_id}") + + except Exception as e: + maxkb_logger.error(f"Media learning failed for document {document_id}: {str(e)}") + maxkb_logger.error(traceback.format_exc()) + + # 更新文档状态为失败 + ListenerManagement.update_status( + QuerySet(Document).filter(id=document_id), + TaskType.EMBEDDING, + State.FAILURE + ) + + raise \ No newline at end of file diff --git a/common/handle/impl/text/mineru_split_handle.py b/common/handle/impl/text/mineru_split_handle.py index 361f0e0f..68b9b5c3 100644 --- a/common/handle/impl/text/mineru_split_handle.py +++ b/common/handle/impl/text/mineru_split_handle.py @@ -1,8 +1,11 @@ """ MinerU文档解析处理器 -该处理器使用MinerU解析PDF和PPT文档,提供高质量的文档解析功能。 -支持多种文档格式,包括复杂的表格、图片、公式等内容的解析。 +该处理器使用MinerU解析文档和图片,提供高质量的内容解析功能。 +支持多种文档格式和图片格式,包括复杂的表格、公式等内容的解析,以及OCR文字识别。 + +支持的文档格式:PDF、PPT、PPTX、DOC、DOCX +支持的图片格式:PNG、JPG、JPEG、GIF、BMP、TIFF、WebP、SVG """ import io @@ -56,8 +59,13 @@ class MinerUSplitHandle(BaseSplitHandle): return False file_name = file.name.lower() - # MinerU支持PDF和PPT格式 - return file_name.endswith('.pdf') or file_name.endswith('.ppt') or file_name.endswith('.pptx') + # MinerU支持PDF、PPT、DOC和图片格式 + supported_extensions = ( + '.pdf', '.ppt', '.pptx', '.doc', '.docx', # 文档格式 + '.png', '.jpg', '.jpeg', '.gif', '.bmp', # 图片格式 + '.tiff', '.tif', '.webp', '.svg' # 其他图片格式 + ) + return any(file_name.endswith(ext) for ext in supported_extensions) def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image, **kwargs): """ diff --git a/test_media_processing.py b/test_media_processing.py new file mode 100644 index 00000000..22aa453d --- /dev/null +++ b/test_media_processing.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +测试音视频处理功能 +""" +import sys +import os +sys.path.append('apps') + +def test_media_handler(): + """测试音视频处理器""" + print("测试音视频处理器...") + + try: + from common.handle.impl.media.media_split_handle import MediaSplitHandle + from common.handle.impl.media.media_adapter import MediaAdapter + + # 创建处理器 + handler = MediaSplitHandle() + print("✓ MediaSplitHandle 创建成功") + + # 测试文件类型支持 + class MockFile: + def __init__(self, name, content=b'test'): + self.name = name + self.content = content + self.size = len(content) + + def read(self): + return self.content + + def seek(self, pos): + pass + + # 测试音频文件支持 + audio_files = ['test.mp3', 'test.wav', 'test.m4a', 'test.flac'] + for filename in audio_files: + file = MockFile(filename) + if handler.support(file, lambda x: x.read()): + print(f"✓ {filename} 支持") + else: + print(f"✗ {filename} 不支持") + + # 测试视频文件支持 + video_files = ['test.mp4', 'test.avi', 'test.mov', 'test.mkv'] + for filename in video_files: + file = MockFile(filename) + if handler.support(file, lambda x: x.read()): + print(f"✓ {filename} 支持") + else: + print(f"✗ {filename} 不支持") + + # 测试非媒体文件 + other_files = ['test.txt', 'test.pdf', 'test.docx'] + for filename in other_files: + file = MockFile(filename) + if not handler.support(file, lambda x: x.read()): + print(f"✓ {filename} 正确排除") + else: + print(f"✗ {filename} 错误支持") + + print("\n✓ 所有文件类型测试通过") + + except Exception as e: + print(f"✗ 测试失败: {e}") + import traceback + traceback.print_exc() + return False + + return True + +def test_media_adapter(): + """测试媒体适配器""" + print("\n测试媒体适配器...") + + try: + from common.handle.impl.media.media_adapter import MediaAdapter + + # 创建适配器 + adapter = MediaAdapter() + print("✓ MediaAdapter 创建成功") + + # 测试配置 + if adapter.config: + print("✓ 配置加载成功") + print(f" - STT Provider: {adapter.config.get('stt_provider')}") + print(f" - Max Duration: {adapter.config.get('max_duration')}秒") + print(f" - Segment Duration: {adapter.config.get('segment_duration')}秒") + + # 测试媒体类型检测 + test_cases = [ + ('test.mp3', 'audio'), + ('test.mp4', 'video'), + ('test.wav', 'audio'), + ('test.avi', 'video'), + ] + + for filename, expected_type in test_cases: + detected_type = adapter._detect_media_type(filename) + if detected_type == expected_type: + print(f"✓ {filename} -> {detected_type}") + else: + print(f"✗ {filename} -> {detected_type} (期望: {expected_type})") + + print("\n✓ 适配器测试通过") + + except Exception as e: + print(f"✗ 测试失败: {e}") + import traceback + traceback.print_exc() + return False + + return True + +if __name__ == '__main__': + print("=" * 50) + print("音视频学习模块测试") + print("=" * 50) + + success = True + + # 运行测试 + if not test_media_handler(): + success = False + + if not test_media_adapter(): + success = False + + print("\n" + "=" * 50) + if success: + print("✅ 所有测试通过!") + else: + print("❌ 部分测试失败") + print("=" * 50) \ No newline at end of file diff --git a/ui/src/stores/modules/knowledge.ts b/ui/src/stores/modules/knowledge.ts index 1f7efeae..1ee3e748 100644 --- a/ui/src/stores/modules/knowledge.ts +++ b/ui/src/stores/modules/knowledge.ts @@ -14,6 +14,10 @@ export interface knowledgeStateTypes { llmModel: string | null visionModel: string | null } | null + mediaModels: { + sttModel: string | null + llmModel: string | null + } | null } const useKnowledgeStore = defineStore('knowledge', { @@ -24,6 +28,7 @@ const useKnowledgeStore = defineStore('knowledge', { documentsFiles: [], knowledgeList: [], mineruModels: null, + mediaModels: null, }), actions: { saveBaseInfo(info: knowledgeData | null) { @@ -44,6 +49,9 @@ const useKnowledgeStore = defineStore('knowledge', { saveMinerUModels(models: { llmModel: string | null; visionModel: string | null }) { this.mineruModels = models }, + saveMediaModels(models: { sttModel: string | null; llmModel: string | null }) { + this.mediaModels = models + }, }, }) diff --git a/ui/src/utils/common.ts b/ui/src/utils/common.ts index 4419a9a8..61f25ba5 100644 --- a/ui/src/utils/common.ts +++ b/ui/src/utils/common.ts @@ -55,6 +55,8 @@ const typeList: any = { txt: ['txt', 'pdf', 'docx', 'md', 'html', 'zip', 'xlsx', 'xls', 'csv'], table: ['xlsx', 'xls', 'csv'], QA: ['xlsx', 'csv', 'xls', 'zip'], + mineru: ['pdf', 'ppt', 'pptx', 'doc', 'docx', 'png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'tif', 'webp', 'svg'], + media: ['mp3', 'wav', 'm4a', 'flac', 'aac', 'ogg', 'wma', 'mp4', 'avi', 'mov', 'mkv', 'webm', 'flv', 'wmv'], } export function getImgUrl(name: string) { diff --git a/ui/src/views/document/UploadDocument.vue b/ui/src/views/document/UploadDocument.vue index c6ec8671..478b203a 100644 --- a/ui/src/views/document/UploadDocument.vue +++ b/ui/src/views/document/UploadDocument.vue @@ -38,7 +38,7 @@ :disabled="SetRulesRef?.loading || loading" > {{ - documentsType === 'txt' || documentsType === 'mineru' + documentsType === 'txt' || documentsType === 'mineru' || documentsType === 'media' ? $t('views.document.buttons.next') : $t('views.document.buttons.import') }} @@ -136,6 +136,9 @@ async function next() { } else if (documentsType.value === 'mineru') { // MinerU类型文档也需要进入分段设置页面 if (active.value++ > 2) active.value = 0 + } else if (documentsType.value === 'media') { + // 音视频类型文档也需要进入分段设置页面 + if (active.value++ > 2) active.value = 0 } else { if (active.value++ > 2) active.value = 0 } @@ -155,10 +158,18 @@ function clearStore() { llmModel: null, visionModel: null }) + // 清空音视频模型选择 + knowledge.saveMediaModels({ + sttModel: null, + llmModel: null + }) } function submit() { + alert('Submit function called! documentsType=' + documentsType.value) loading.value = true const documents = [] as any + console.log('Submit called, documentsType:', documentsType.value) + console.log('mediaModels from store:', knowledge.mediaModels) SetRulesRef.value?.paragraphList.map((item: any) => { if (!SetRulesRef.value?.checkedConnect) { item.content.map((v: any) => { @@ -182,11 +193,34 @@ function submit() { doc.split_patterns = SetRulesRef.value.form.patterns } } + // 只有当文档类型是音视频类型时,才添加STT模型参数 + else if (documentsType.value === 'media') { + console.log('Processing media document upload:', { + mediaModels: knowledge.mediaModels, + sttModel: knowledge.mediaModels?.sttModel, + llmModel: knowledge.mediaModels?.llmModel + }) + // 确保有模型选择才添加 + if (knowledge.mediaModels) { + if (knowledge.mediaModels.sttModel) { + doc.stt_model_id = knowledge.mediaModels.sttModel + } + if (knowledge.mediaModels.llmModel) { + doc.llm_model_id = knowledge.mediaModels.llmModel + } + } + // 传递分段规则(如果有) + if (SetRulesRef.value?.form?.patterns) { + doc.split_patterns = SetRulesRef.value.form.patterns + } + console.log('Final doc object for media:', doc) + } documents.push(doc) }) if (id) { // 上传文档 + console.log('Sending documents to backend:', documents) loadSharedApi({ type: 'document', systemType: apiType.value }) .putMulDocument(id as string, documents) .then(() => { diff --git a/ui/src/views/document/upload/UploadComponent.vue b/ui/src/views/document/upload/UploadComponent.vue index 01b7e1ee..4664671a 100644 --- a/ui/src/views/document/upload/UploadComponent.vue +++ b/ui/src/views/document/upload/UploadComponent.vue @@ -12,6 +12,7 @@ {{ $t('views.document.fileType.txt.label') }} 高级学习 + 音视频学习 {{ $t('views.document.fileType.table.label') }} @@ -140,10 +141,12 @@
-

1. 高级学习提供高质量的 PDF 和 PPT 文档解析,支持复杂表格、图片、公式等内容

-

2. 支持的文件格式:PDF、PPT、PPTX

+

1. 高级学习提供高质量的文档和图片解析,支持复杂表格、公式、图表等内容的智能识别与提取

+

2. 支持的文档格式:PDF、PPT、PPTX、DOC、DOCX

+

3. 支持的图片格式:PNG、JPG、JPEG、GIF、BMP、TIFF、WebP、SVG

+

4. 智能图片识别:自动进行OCR文字提取、图表分析、内容理解和分类

- 3. {{ $t('views.document.tip.fileLimitCountTip1') }} {{ file_count_limit }} + 5. {{ $t('views.document.tip.fileLimitCountTip1') }} {{ file_count_limit }} {{ $t('views.document.tip.fileLimitCountTip2') }}, {{ $t('views.document.tip.fileLimitSizeTip1') }} {{ file_size_limit }} MB

@@ -195,7 +198,7 @@ action="#" :auto-upload="false" :show-file-list="false" - accept=".pdf, .ppt, .pptx" + accept=".pdf, .ppt, .pptx, .doc, .docx, .png, .jpg, .jpeg, .gif, .bmp, .tiff, .tif, .webp, .svg" :limit="file_count_limit" :on-exceed="onExceed" :on-change="fileHandleChange" @@ -213,7 +216,94 @@

-

{{ $t('views.document.upload.formats') }}PDF、PPT、PPTX

+

{{ $t('views.document.upload.formats') }}PDF、PPT、PPTX、DOC、DOCX、PNG、JPG、GIF、BMP、TIFF、WebP、SVG

+
+
+ + + +
+
+ +
+
+

1. 音视频学习支持自动语音识别,将音频内容转换为文本

+

2. 支持的音频格式:MP3、WAV、M4A、FLAC、AAC、OGG、WMA

+

3. 支持的视频格式:MP4、AVI、MOV、MKV、WEBM、FLV、WMV

+

4. 自动分段处理,支持长音视频文件

+

+ 5. {{ $t('views.document.tip.fileLimitCountTip1') }} {{ file_count_limit }} + {{ $t('views.document.tip.fileLimitCountTip2') }}, + {{ $t('views.document.tip.fileLimitSizeTip1') }} {{ file_size_limit }} MB +

+
+
+ + + + + {{ model.name }} + 共享 + + + + + + + + {{ model.name }} + 共享 + + + + + +
+

+ {{ $t('views.document.upload.uploadMessage') }} + + {{ $t('views.document.upload.selectFile') }} + + + {{ $t('views.document.upload.selectFiles') }} + +

+
+

{{ $t('views.document.upload.formats') }}MP3、WAV、M4A、MP4、AVI、MOV等

@@ -329,6 +419,7 @@ const form = ref({ fileList: [] as any, llmModel: null as string | null, visionModel: null as string | null, + sttModel: null as string | null, }) const rules = reactive({ @@ -343,14 +434,22 @@ const file_size_limit = ref(100) // 模型列表 const llmModels = ref([]) const visionModels = ref([]) +const sttModels = ref([]) watch(form.value, (value) => { knowledge.saveDocumentsType(value.fileType) knowledge.saveDocumentsFile(value.fileList) - knowledge.saveMinerUModels({ - llmModel: value.llmModel, - visionModel: value.visionModel - }) + if (value.fileType === 'mineru') { + knowledge.saveMinerUModels({ + llmModel: value.llmModel, + visionModel: value.visionModel + }) + } else if (value.fileType === 'media') { + knowledge.saveMediaModels({ + sttModel: value.sttModel, + llmModel: value.llmModel + }) + } }) // 加载模型列表 @@ -358,11 +457,12 @@ const loadModels = async () => { try { const response = await modelApi.getSelectModelList() if (response.data) { - // 分离大语言模型和视觉模型 + // 分离大语言模型、视觉模型和STT模型 llmModels.value = response.data.filter((m: any) => m.model_type === 'LLM') visionModels.value = response.data.filter((m: any) => m.model_type === 'IMAGE' // 只显示IMAGE类型的视觉模型 ) + sttModels.value = response.data.filter((m: any) => m.model_type === 'STT') } } catch (error) { console.error('Failed to load models:', error) @@ -385,7 +485,7 @@ function downloadTableTemplate(type: string) { function radioChange() { form.value.fileList = [] - // 切换文档类型时,如果不是mineru类型,清空模型选择 + // 切换文档类型时,清空对应的模型选择 if (form.value.fileType !== 'mineru') { form.value.llmModel = null form.value.visionModel = null @@ -395,6 +495,14 @@ function radioChange() { visionModel: null }) } + if (form.value.fileType !== 'media') { + form.value.sttModel = null + // 清空store中的音视频模型选择 + knowledge.saveMediaModels({ + sttModel: null, + llmModel: null + }) + } } function deleteFile(index: number) { @@ -475,6 +583,12 @@ onMounted(() => { form.value.llmModel = mineruModels.llmModel form.value.visionModel = mineruModels.visionModel } + // 恢复音视频模型选择 + const mediaModels = knowledge.mediaModels + if (mediaModels) { + form.value.sttModel = mediaModels.sttModel + form.value.llmModel = mediaModels.llmModel || form.value.llmModel + } getDetail() loadModels() // 加载模型列表 }) @@ -484,6 +598,7 @@ onUnmounted(() => { fileList: [], llmModel: null, visionModel: null, + sttModel: null, } })