maxkb/apps/knowledge/tasks/media_learning.py
朱潮 dd0360fb6f
Some checks failed
sync2gitee / repo-sync (push) Has been cancelled
Typos Check / Spell Check with Typos (push) Has been cancelled
modify file status
2025-08-29 09:29:02 +08:00

152 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
音视频学习任务处理
"""
import traceback
from typing import List, Optional
from celery import shared_task
from django.db import transaction
from django.db.models import QuerySet
from common.event.common import embedding_by_data_source
from common.event import ListenerManagement
from common.utils.logger import maxkb_logger
from knowledge.models import Document, Paragraph, TaskType, State
from oss.models import File, FileSourceType
from common.handle.impl.media.media_split_handle import MediaSplitHandle
@shared_task(name='media_learning_by_document')
def media_learning_by_document(document_id: str, knowledge_id: str, workspace_id: str,
stt_model_id: str, llm_model_id: Optional[str] = None):
"""
音视频文档异步处理任务
Args:
document_id: 文档ID
knowledge_id: 知识库ID
workspace_id: 工作空间ID
stt_model_id: STT模型ID
llm_model_id: LLM模型ID可选
"""
maxkb_logger.info(f"Starting media learning task for document: {document_id}")
try:
# 更新文档状态为处理中
ListenerManagement.update_status(
QuerySet(Document).filter(id=document_id),
TaskType.EMBEDDING,
State.STARTED
)
# 获取文档信息
document = QuerySet(Document).filter(id=document_id).first()
if not document:
raise ValueError(f"Document not found: {document_id}")
# 获取源文件
source_file_id = document.meta.get('source_file_id')
if not source_file_id:
raise ValueError(f"Source file not found for document: {document_id}")
source_file = QuerySet(File).filter(id=source_file_id).first()
if not source_file:
raise ValueError(f"Source file not found: {source_file_id}")
maxkb_logger.info(f"Processing media file: {source_file.file_name}")
# 使用MediaSplitHandle处理音视频文件
media_handler = MediaSplitHandle()
# 准备文件对象
class FileWrapper:
def __init__(self, file_obj):
self.file_obj = file_obj
self.name = file_obj.file_name
self.size = file_obj.file_size
def read(self):
return self.file_obj.get_bytes()
def seek(self, pos):
pass
file_wrapper = FileWrapper(source_file)
# 获取文件内容的方法
def get_buffer(file):
return file.read()
# 保存图片的方法(音视频一般不需要,但保持接口一致)
def save_image(image_list):
pass
# 处理音视频文件
result = media_handler.handle(
file_wrapper,
pattern_list=[], # 音视频不需要分段模式
with_filter=False,
limit=0, # 不限制段落数
get_buffer=get_buffer,
save_image=save_image,
workspace_id=workspace_id,
stt_model_id=stt_model_id,
llm_model_id=llm_model_id
)
# 解析处理结果
paragraphs_data = result.get('content', [])
if not paragraphs_data:
raise ValueError("No content extracted from media file")
maxkb_logger.info(f"Extracted {len(paragraphs_data)} paragraphs from media file")
# 创建段落对象
with transaction.atomic():
paragraph_models = []
for idx, para_data in enumerate(paragraphs_data):
paragraph = Paragraph(
document_id=document_id,
content=para_data.get('content', ''),
title=para_data.get('title', f'段落 {idx + 1}'),
position=idx + 1,
meta=para_data.get('metadata', {})
)
paragraph_models.append(paragraph)
# 批量保存段落
if paragraph_models:
QuerySet(Paragraph).bulk_create(paragraph_models)
maxkb_logger.info(f"Created {len(paragraph_models)} paragraphs for document {document_id}")
# 更新文档字符长度
total_char_length = sum(len(p.content) for p in paragraph_models)
document.char_length = total_char_length
document.save()
# 触发向量化任务
maxkb_logger.info(f"Starting embedding for document: {document_id}")
embedding_by_data_source(document_id, knowledge_id, workspace_id)
# 更新文档状态为成功
ListenerManagement.update_status(
QuerySet(Document).filter(id=document_id),
TaskType.EMBEDDING,
State.SUCCESS
)
maxkb_logger.info(f"Media learning completed successfully for document: {document_id}")
except Exception as e:
maxkb_logger.error(f"Media learning failed for document {document_id}: {str(e)}")
maxkb_logger.error(traceback.format_exc())
# 更新文档状态为失败
ListenerManagement.update_status(
QuerySet(Document).filter(id=document_id),
TaskType.EMBEDDING,
State.FAILURE
)
raise