realtime接口支持16khz输出

This commit is contained in:
朱潮 2026-03-23 10:37:17 +08:00
parent 2d2e1dbcdf
commit 30dc697071
8 changed files with 88 additions and 17 deletions

View File

@ -83,7 +83,7 @@ class AgentConfig:
safe_dict = self.to_dict().copy() safe_dict = self.to_dict().copy()
if 'api_key' in safe_dict and isinstance(safe_dict['api_key'], str) and safe_dict['api_key'].startswith('sk-'): if 'api_key' in safe_dict and isinstance(safe_dict['api_key'], str) and safe_dict['api_key'].startswith('sk-'):
safe_dict['api_key'] = safe_dict['api_key'][:8] + '***' + safe_dict['api_key'][-6:] safe_dict['api_key'] = safe_dict['api_key'][:8] + '***' + safe_dict['api_key'][-6:]
logger.info(f"config={json.dumps(safe_dict, ensure_ascii=False)}") logger.info(f"config={json.dumps(safe_dict, ensure_ascii=False, default=str)}")
@classmethod @classmethod
async def from_v1_request(cls, request, api_key: str, project_dir: Optional[str] = None, generate_cfg: Optional[Dict] = None, messages: Optional[List] = None): async def from_v1_request(cls, request, api_key: str, project_dir: Optional[str] = None, generate_cfg: Optional[Dict] = None, messages: Optional[List] = None):

View File

@ -288,9 +288,9 @@ async def init_agent(config: AgentConfig):
checkpointer=checkpointer, checkpointer=checkpointer,
shell_env={ shell_env={
k: v for k, v in { k: v for k, v in {
"ASSISTANT_ID": config.bot_id, "ASSISTANT_ID": str(config.bot_id),
"USER_IDENTIFIER": config.user_identifier, "USER_IDENTIFIER": str(config.user_identifier) if config.user_identifier else None,
"TRACE_ID": config.trace_id, "TRACE_ID": str(config.trace_id) if config.trace_id else None,
**(config.shell_env or {}), **(config.shell_env or {}),
}.items() if v is not None }.items() if v is not None
} }

View File

@ -93,11 +93,13 @@ async def voice_realtime(websocket: WebSocket):
continue continue
voice_mode = msg.get("voice_mode") or VOICE_DEFAULT_MODE voice_mode = msg.get("voice_mode") or VOICE_DEFAULT_MODE
client_sample_rate = msg.get("sample_rate", 24000)
session_kwargs = dict( session_kwargs = dict(
bot_id=bot_id, bot_id=bot_id,
session_id=msg.get("session_id"), session_id=msg.get("session_id"),
user_identifier=msg.get("user_identifier"), user_identifier=msg.get("user_identifier"),
sample_rate=client_sample_rate,
on_audio=on_audio, on_audio=on_audio,
on_asr_text=on_asr_text, on_asr_text=on_asr_text,
on_agent_result=on_agent_result, on_agent_result=on_agent_result,

View File

@ -37,6 +37,22 @@ class StreamingTTSClient:
self._speaker = speaker or VOLCENGINE_DEFAULT_SPEAKER self._speaker = speaker or VOLCENGINE_DEFAULT_SPEAKER
async def synthesize(self, text: str): async def synthesize(self, text: str):
"""
Synthesize text to audio via SSE streaming.
Yields 24kHz float32 PCM audio chunks.
"""
async for chunk in self._synthesize_internal(text, raw_int16=False):
yield chunk
async def synthesize_raw(self, text: str):
"""
Synthesize text to audio via SSE streaming.
Yields 24kHz int16 PCM audio chunks (no float32 conversion).
"""
async for chunk in self._synthesize_internal(text, raw_int16=True):
yield chunk
async def _synthesize_internal(self, text: str, raw_int16: bool = False):
""" """
Synthesize text to audio via SSE streaming. Synthesize text to audio via SSE streaming.
Yields 24kHz float32 PCM audio chunks. Yields 24kHz float32 PCM audio chunks.
@ -97,7 +113,7 @@ class StreamingTTSClient:
if line == "": if line == "":
# Blank line = end of one SSE event # Blank line = end of one SSE event
if current_data: if current_data:
async for audio in self._process_sse_data(current_data): async for audio in self._process_sse_data(current_data, raw_int16=raw_int16):
chunk_count += 1 chunk_count += 1
yield audio yield audio
current_event = "" current_event = ""
@ -118,7 +134,7 @@ class StreamingTTSClient:
# Handle remaining data without trailing blank line # Handle remaining data without trailing blank line
if current_data: if current_data:
async for audio in self._process_sse_data(current_data): async for audio in self._process_sse_data(current_data, raw_int16=raw_int16):
chunk_count += 1 chunk_count += 1
yield audio yield audio
@ -127,7 +143,7 @@ class StreamingTTSClient:
except Exception as e: except Exception as e:
logger.error(f"[TTS] Error: {e}", exc_info=True) logger.error(f"[TTS] Error: {e}", exc_info=True)
async def _process_sse_data(self, data_str: str): async def _process_sse_data(self, data_str: str, raw_int16: bool = False):
"""Parse SSE data field and yield audio chunks if present.""" """Parse SSE data field and yield audio chunks if present."""
data_str = data_str.rstrip("\n") data_str = data_str.rstrip("\n")
if not data_str: if not data_str:
@ -143,9 +159,13 @@ class StreamingTTSClient:
if code == 0 and data.get("data"): if code == 0 and data.get("data"):
# Audio data chunk # Audio data chunk
pcm_raw = base64.b64decode(data["data"]) pcm_raw = base64.b64decode(data["data"])
pcm_f32 = convert_pcm_s16_to_f32(pcm_raw) if raw_int16:
if pcm_f32: if pcm_raw:
yield pcm_f32 yield pcm_raw
else:
pcm_f32 = convert_pcm_s16_to_f32(pcm_raw)
if pcm_f32:
yield pcm_f32
elif code == 20000000: elif code == 20000000:
# End of stream # End of stream

View File

@ -27,6 +27,7 @@ class VoiceLiteSession:
bot_id: str, bot_id: str,
session_id: Optional[str] = None, session_id: Optional[str] = None,
user_identifier: Optional[str] = None, user_identifier: Optional[str] = None,
sample_rate: int = 24000,
on_audio: Optional[Callable[[bytes], Awaitable[None]]] = None, on_audio: Optional[Callable[[bytes], Awaitable[None]]] = None,
on_asr_text: Optional[Callable[[str], Awaitable[None]]] = None, on_asr_text: Optional[Callable[[str], Awaitable[None]]] = None,
on_agent_result: Optional[Callable[[str], Awaitable[None]]] = None, on_agent_result: Optional[Callable[[str], Awaitable[None]]] = None,
@ -38,6 +39,7 @@ class VoiceLiteSession:
self.bot_id = bot_id self.bot_id = bot_id
self.session_id = session_id or str(uuid.uuid4()) self.session_id = session_id or str(uuid.uuid4())
self.user_identifier = user_identifier or "" self.user_identifier = user_identifier or ""
self._client_sample_rate = sample_rate
self._bot_config: dict = {} self._bot_config: dict = {}
self._speaker: str = "" self._speaker: str = ""
@ -110,7 +112,7 @@ class VoiceLiteSession:
await self._asr_client.close() await self._asr_client.close()
# VAD configuration # VAD configuration
VAD_SILENCE_DURATION = 1.5 # Seconds of silence before sending finish VAD_SILENCE_DURATION = 3.0 # Seconds of silence before sending finish
VAD_PRE_BUFFER_SIZE = 5 # Number of audio chunks to buffer before VAD triggers VAD_PRE_BUFFER_SIZE = 5 # Number of audio chunks to buffer before VAD triggers
VAD_SOURCE_RATE = 24000 # Input audio sample rate VAD_SOURCE_RATE = 24000 # Input audio sample rate
VAD_TARGET_RATE = 16000 # webrtcvad supported sample rate VAD_TARGET_RATE = 16000 # webrtcvad supported sample rate
@ -139,6 +141,43 @@ class VoiceLiteSession:
resampled.append(samples[src_idx]) resampled.append(samples[src_idx])
return struct.pack(f'<{len(resampled)}h', *resampled) return struct.pack(f'<{len(resampled)}h', *resampled)
@staticmethod
def _resample_16k_to_24k(pcm_data: bytes) -> bytes:
"""Upsample 16-bit PCM from 16kHz to 24kHz (ratio 2:3).
For every 2 input samples, produces 3 output samples using linear interpolation.
"""
n_samples = len(pcm_data) // 2
if n_samples == 0:
return b''
samples = struct.unpack(f'<{n_samples}h', pcm_data[:n_samples * 2])
out_len = (n_samples * 3) // 2
resampled = []
for i in range(out_len):
src_pos = (i * 2) / 3
src_idx = int(src_pos)
frac = src_pos - src_idx
if src_idx + 1 < n_samples:
val = int(samples[src_idx] * (1 - frac) + samples[src_idx + 1] * frac)
elif src_idx < n_samples:
val = samples[src_idx]
else:
break
resampled.append(max(-32768, min(32767, val)))
return struct.pack(f'<{len(resampled)}h', *resampled)
def _resample_input(self, audio_data: bytes) -> bytes:
"""Resample incoming audio to 24kHz if needed."""
if self._client_sample_rate == 16000:
return self._resample_16k_to_24k(audio_data)
return audio_data
def _resample_output(self, audio_data: bytes) -> bytes:
"""Resample outgoing audio from 24kHz to client sample rate if needed."""
if self._client_sample_rate == 16000:
return self._resample_24k_to_16k(audio_data)
return audio_data
def _webrtcvad_detect(self, pcm_data: bytes) -> bool: def _webrtcvad_detect(self, pcm_data: bytes) -> bool:
"""Run webrtcvad on audio data. Returns True if voice is detected in any frame.""" """Run webrtcvad on audio data. Returns True if voice is detected in any frame."""
resampled = self._resample_24k_to_16k(pcm_data) resampled = self._resample_24k_to_16k(pcm_data)
@ -164,6 +203,9 @@ class VoiceLiteSession:
if not self._running: if not self._running:
return return
# Resample to 24kHz if client sends lower sample rate
audio_data = self._resample_input(audio_data)
self._audio_chunk_count += 1 self._audio_chunk_count += 1
has_voice = self._webrtcvad_detect(audio_data) has_voice = self._webrtcvad_detect(audio_data)
now = asyncio.get_event_loop().time() now = asyncio.get_event_loop().time()
@ -435,9 +477,15 @@ class VoiceLiteSession:
async def _send_tts(self, tts_client: StreamingTTSClient, sentence: str) -> None: async def _send_tts(self, tts_client: StreamingTTSClient, sentence: str) -> None:
"""Synthesize a sentence and emit audio chunks.""" """Synthesize a sentence and emit audio chunks."""
logger.info(f"[VoiceLite] TTS sentence: '{sentence[:80]}'") logger.info(f"[VoiceLite] TTS sentence: '{sentence[:80]}'")
async for audio_chunk in tts_client.synthesize(sentence): if self._client_sample_rate != 24000:
if self._on_audio: # Client needs non-24kHz: use raw int16 pipeline to allow resampling
await self._on_audio(audio_chunk) async for audio_chunk in tts_client.synthesize_raw(sentence):
if self._on_audio:
await self._on_audio(self._resample_output(audio_chunk))
else:
async for audio_chunk in tts_client.synthesize(sentence):
if self._on_audio:
await self._on_audio(audio_chunk)
async def _emit_status(self, status: str) -> None: async def _emit_status(self, status: str) -> None:
if self._on_status: if self._on_status:

View File

@ -17,6 +17,7 @@ class VoiceSession:
bot_id: str, bot_id: str,
session_id: Optional[str] = None, session_id: Optional[str] = None,
user_identifier: Optional[str] = None, user_identifier: Optional[str] = None,
sample_rate: int = 24000,
on_audio: Optional[Callable[[bytes], Awaitable[None]]] = None, on_audio: Optional[Callable[[bytes], Awaitable[None]]] = None,
on_asr_text: Optional[Callable[[str], Awaitable[None]]] = None, on_asr_text: Optional[Callable[[str], Awaitable[None]]] = None,
on_agent_result: Optional[Callable[[str], Awaitable[None]]] = None, on_agent_result: Optional[Callable[[str], Awaitable[None]]] = None,

View File

@ -16,7 +16,7 @@ import urllib.parse
def get_config(): def get_config():
"""获取配置下面的MASTERKEY和ASSISTANT_ID是从环境变量自动获取的不需要用户提供""" """获取配置下面的MASTERKEY和ASSISTANT_ID是从环境变量自动获取的不需要用户提供"""
masterkey = os.environ.get("MASTERKEY", "master") masterkey = os.environ.get("MASTERKEY", "master")
bot_id = os.environ.get("ASSISTANT_ID", "") bot_id = str(os.environ.get("ASSISTANT_ID", ""))
if not masterkey: if not masterkey:
print("ERROR: MASTERKEY environment variable is required") print("ERROR: MASTERKEY environment variable is required")
sys.exit(1) sys.exit(1)

View File

@ -110,7 +110,7 @@ RAGFLOW_MAX_CONCURRENT_UPLOADS = int(os.getenv("RAGFLOW_MAX_CONCURRENT_UPLOADS",
# ============================================================ # ============================================================
# New API 基础 URL支付后端 # New API 基础 URL支付后端
NEW_API_BASE_URL = os.getenv("NEW_API_BASE_URL", "http://116.62.16.218:3000") NEW_API_BASE_URL = os.getenv("NEW_API_BASE_URL", "http://100.77.70.35:3001")
# New API 请求超时(秒) # New API 请求超时(秒)
NEW_API_TIMEOUT = int(os.getenv("NEW_API_TIMEOUT", "30")) NEW_API_TIMEOUT = int(os.getenv("NEW_API_TIMEOUT", "30"))
@ -133,7 +133,7 @@ VOLCENGINE_TTS_SAMPLE_RATE = int(os.getenv("VOLCENGINE_TTS_SAMPLE_RATE", "24000"
# ============================================================ # ============================================================
VOICE_DEFAULT_MODE = os.getenv("VOICE_DEFAULT_MODE", "lite") # "realtime" | "lite" VOICE_DEFAULT_MODE = os.getenv("VOICE_DEFAULT_MODE", "lite") # "realtime" | "lite"
# Silence timeout (seconds) - ASR considers user done speaking after this # Silence timeout (seconds) - ASR considers user done speaking after this
VOICE_LITE_SILENCE_TIMEOUT = float(os.getenv("VOICE_LITE_SILENCE_TIMEOUT", "1.5")) VOICE_LITE_SILENCE_TIMEOUT = float(os.getenv("VOICE_LITE_SILENCE_TIMEOUT", "3.0"))
# ============================================================ # ============================================================
# Single Agent Mode Configuration # Single Agent Mode Configuration