refactor: enhance text-to-speech processing by splitting content into chunks and merging audio segments
This commit is contained in:
parent
06e759a320
commit
16088975fa
@ -6,9 +6,11 @@ from django.core.files.uploadedfile import InMemoryUploadedFile
|
|||||||
|
|
||||||
from application.flow.i_step_node import NodeResult
|
from application.flow.i_step_node import NodeResult
|
||||||
from application.flow.step_node.text_to_speech_step_node.i_text_to_speech_node import ITextToSpeechNode
|
from application.flow.step_node.text_to_speech_step_node.i_text_to_speech_node import ITextToSpeechNode
|
||||||
|
from common.utils.common import _remove_empty_lines
|
||||||
from knowledge.models import FileSourceType
|
from knowledge.models import FileSourceType
|
||||||
from models_provider.tools import get_model_instance_by_model_workspace_id
|
from models_provider.tools import get_model_instance_by_model_workspace_id
|
||||||
from oss.serializers.file import FileSerializer
|
from oss.serializers.file import FileSerializer
|
||||||
|
from pydub import AudioSegment
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_uploaded_file(file_bytes, file_name="generated_audio.mp3"):
|
def bytes_to_uploaded_file(file_bytes, file_name="generated_audio.mp3"):
|
||||||
@ -41,32 +43,72 @@ class BaseTextToSpeechNode(ITextToSpeechNode):
|
|||||||
|
|
||||||
def execute(self, tts_model_id, chat_id,
|
def execute(self, tts_model_id, chat_id,
|
||||||
content, model_params_setting=None,
|
content, model_params_setting=None,
|
||||||
**kwargs) -> NodeResult:
|
max_length=1024, **kwargs) -> NodeResult:
|
||||||
self.context['content'] = content
|
# 分割文本为合理片段
|
||||||
workspace_id = self.workflow_manage.get_body().get('workspace_id')
|
content = _remove_empty_lines(content)
|
||||||
model = get_model_instance_by_model_workspace_id(tts_model_id, workspace_id,
|
content_chunks = [content[i:i + max_length]
|
||||||
**model_params_setting)
|
for i in range(0, len(content), max_length)]
|
||||||
audio_byte = model.text_to_speech(content)
|
|
||||||
# 需要把这个音频文件存储到数据库中
|
# 生成并收集所有音频片段
|
||||||
file_name = 'generated_audio.mp3'
|
audio_segments = []
|
||||||
file = bytes_to_uploaded_file(audio_byte, file_name)
|
temp_files = []
|
||||||
|
|
||||||
|
for i, chunk in enumerate(content_chunks):
|
||||||
|
self.context['content'] = chunk
|
||||||
|
workspace_id = self.workflow_manage.get_body().get('workspace_id')
|
||||||
|
model = get_model_instance_by_model_workspace_id(
|
||||||
|
tts_model_id, workspace_id, **model_params_setting)
|
||||||
|
|
||||||
|
audio_byte = model.text_to_speech(chunk)
|
||||||
|
|
||||||
|
# 保存为临时音频文件用于合并
|
||||||
|
temp_file = io.BytesIO(audio_byte)
|
||||||
|
audio_segment = AudioSegment.from_file(temp_file)
|
||||||
|
audio_segments.append(audio_segment)
|
||||||
|
temp_files.append(temp_file)
|
||||||
|
|
||||||
|
# 合并所有音频片段
|
||||||
|
combined_audio = AudioSegment.empty()
|
||||||
|
for segment in audio_segments:
|
||||||
|
combined_audio += segment
|
||||||
|
|
||||||
|
# 将合并后的音频转为字节流
|
||||||
|
output_buffer = io.BytesIO()
|
||||||
|
combined_audio.export(output_buffer, format="mp3")
|
||||||
|
combined_bytes = output_buffer.getvalue()
|
||||||
|
|
||||||
|
# 存储合并后的音频文件
|
||||||
|
file_name = 'combined_audio.mp3'
|
||||||
|
file = bytes_to_uploaded_file(combined_bytes, file_name)
|
||||||
|
|
||||||
application = self.workflow_manage.work_flow_post_handler.chat_info.application
|
application = self.workflow_manage.work_flow_post_handler.chat_info.application
|
||||||
meta = {
|
meta = {
|
||||||
'debug': False if application.id else True,
|
'debug': False if application.id else True,
|
||||||
'chat_id': chat_id,
|
'chat_id': chat_id,
|
||||||
'application_id': str(application.id) if application.id else None,
|
'application_id': str(application.id) if application.id else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
file_url = FileSerializer(data={
|
file_url = FileSerializer(data={
|
||||||
'file': file,
|
'file': file,
|
||||||
'meta': meta,
|
'meta': meta,
|
||||||
'source_id': meta['application_id'],
|
'source_id': meta['application_id'],
|
||||||
'source_type': FileSourceType.APPLICATION.value
|
'source_type': FileSourceType.APPLICATION.value
|
||||||
}).upload()
|
}).upload()
|
||||||
# 拼接一个audio标签的src属性
|
|
||||||
audio_label = f'<audio src="{file_url}" controls style = "width: 300px; height: 43px"></audio>'
|
# 生成音频标签
|
||||||
|
audio_label = f'<audio src="{file_url}" controls style="width: 300px; height: 43px"></audio>'
|
||||||
file_id = file_url.split('/')[-1]
|
file_id = file_url.split('/')[-1]
|
||||||
audio_list = [{'file_id': file_id, 'file_name': file_name, 'url': file_url}]
|
audio_list = [{'file_id': file_id, 'file_name': file_name, 'url': file_url}]
|
||||||
return NodeResult({'answer': audio_label, 'result': audio_list}, {})
|
|
||||||
|
# 关闭所有临时文件
|
||||||
|
for temp_file in temp_files:
|
||||||
|
temp_file.close()
|
||||||
|
output_buffer.close()
|
||||||
|
|
||||||
|
return NodeResult({
|
||||||
|
'answer': audio_label,
|
||||||
|
'result': audio_list
|
||||||
|
}, {})
|
||||||
|
|
||||||
def get_details(self, index: int, **kwargs):
|
def get_details(self, index: int, **kwargs):
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -14,6 +14,8 @@ import gzip
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import ssl
|
import ssl
|
||||||
|
|
||||||
|
import requests
|
||||||
import uuid_utils.compat as uuid
|
import uuid_utils.compat as uuid
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user