152 lines
5.4 KiB
Python
152 lines
5.4 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
音视频学习任务处理
|
||
"""
|
||
import traceback
|
||
from typing import List, Optional
|
||
from celery import shared_task
|
||
from django.db import transaction
|
||
from django.db.models import QuerySet
|
||
|
||
from common.event.common import embedding_by_data_source
|
||
from common.event import ListenerManagement
|
||
from common.utils.logger import maxkb_logger
|
||
from knowledge.models import Document, Paragraph, TaskType, State
|
||
from oss.models import File, FileSourceType
|
||
from common.handle.impl.media.media_split_handle import MediaSplitHandle
|
||
|
||
|
||
@shared_task(name='media_learning_by_document')
|
||
def media_learning_by_document(document_id: str, knowledge_id: str, workspace_id: str,
|
||
stt_model_id: str, llm_model_id: Optional[str] = None):
|
||
"""
|
||
音视频文档异步处理任务
|
||
|
||
Args:
|
||
document_id: 文档ID
|
||
knowledge_id: 知识库ID
|
||
workspace_id: 工作空间ID
|
||
stt_model_id: STT模型ID
|
||
llm_model_id: LLM模型ID(可选)
|
||
"""
|
||
maxkb_logger.info(f"Starting media learning task for document: {document_id}")
|
||
|
||
try:
|
||
# 更新文档状态为处理中
|
||
ListenerManagement.update_status(
|
||
QuerySet(Document).filter(id=document_id),
|
||
TaskType.EMBEDDING,
|
||
State.STARTED
|
||
)
|
||
|
||
# 获取文档信息
|
||
document = QuerySet(Document).filter(id=document_id).first()
|
||
if not document:
|
||
raise ValueError(f"Document not found: {document_id}")
|
||
|
||
# 获取源文件
|
||
source_file_id = document.meta.get('source_file_id')
|
||
if not source_file_id:
|
||
raise ValueError(f"Source file not found for document: {document_id}")
|
||
|
||
source_file = QuerySet(File).filter(id=source_file_id).first()
|
||
if not source_file:
|
||
raise ValueError(f"Source file not found: {source_file_id}")
|
||
|
||
maxkb_logger.info(f"Processing media file: {source_file.file_name}")
|
||
|
||
# 使用MediaSplitHandle处理音视频文件
|
||
media_handler = MediaSplitHandle()
|
||
|
||
# 准备文件对象
|
||
class FileWrapper:
|
||
def __init__(self, file_obj):
|
||
self.file_obj = file_obj
|
||
self.name = file_obj.file_name
|
||
self.size = file_obj.file_size
|
||
|
||
def read(self):
|
||
return self.file_obj.get_bytes()
|
||
|
||
def seek(self, pos):
|
||
pass
|
||
|
||
file_wrapper = FileWrapper(source_file)
|
||
|
||
# 获取文件内容的方法
|
||
def get_buffer(file):
|
||
return file.read()
|
||
|
||
# 保存图片的方法(音视频一般不需要,但保持接口一致)
|
||
def save_image(image_list):
|
||
pass
|
||
|
||
# 处理音视频文件
|
||
result = media_handler.handle(
|
||
file_wrapper,
|
||
pattern_list=[], # 音视频不需要分段模式
|
||
with_filter=False,
|
||
limit=0, # 不限制段落数
|
||
get_buffer=get_buffer,
|
||
save_image=save_image,
|
||
workspace_id=workspace_id,
|
||
stt_model_id=stt_model_id,
|
||
llm_model_id=llm_model_id
|
||
)
|
||
|
||
# 解析处理结果
|
||
paragraphs_data = result.get('content', [])
|
||
|
||
if not paragraphs_data:
|
||
raise ValueError("No content extracted from media file")
|
||
|
||
maxkb_logger.info(f"Extracted {len(paragraphs_data)} paragraphs from media file")
|
||
|
||
# 创建段落对象
|
||
with transaction.atomic():
|
||
paragraph_models = []
|
||
for idx, para_data in enumerate(paragraphs_data):
|
||
paragraph = Paragraph(
|
||
document_id=document_id,
|
||
content=para_data.get('content', ''),
|
||
title=para_data.get('title', f'段落 {idx + 1}'),
|
||
position=idx + 1,
|
||
meta=para_data.get('metadata', {})
|
||
)
|
||
paragraph_models.append(paragraph)
|
||
|
||
# 批量保存段落
|
||
if paragraph_models:
|
||
QuerySet(Paragraph).bulk_create(paragraph_models)
|
||
maxkb_logger.info(f"Created {len(paragraph_models)} paragraphs for document {document_id}")
|
||
|
||
# 更新文档字符长度
|
||
total_char_length = sum(len(p.content) for p in paragraph_models)
|
||
document.char_length = total_char_length
|
||
document.save()
|
||
|
||
# 触发向量化任务
|
||
maxkb_logger.info(f"Starting embedding for document: {document_id}")
|
||
embedding_by_data_source(document_id, knowledge_id, workspace_id)
|
||
|
||
# 更新文档状态为成功
|
||
ListenerManagement.update_status(
|
||
QuerySet(Document).filter(id=document_id),
|
||
TaskType.EMBEDDING,
|
||
State.SUCCESS
|
||
)
|
||
|
||
maxkb_logger.info(f"Media learning completed successfully for document: {document_id}")
|
||
|
||
except Exception as e:
|
||
maxkb_logger.error(f"Media learning failed for document {document_id}: {str(e)}")
|
||
maxkb_logger.error(traceback.format_exc())
|
||
|
||
# 更新文档状态为失败
|
||
ListenerManagement.update_status(
|
||
QuerySet(Document).filter(id=document_id),
|
||
TaskType.EMBEDDING,
|
||
State.FAILURE
|
||
)
|
||
|
||
raise |