realtime接口支持16khz输出
This commit is contained in:
parent
2d2e1dbcdf
commit
30dc697071
@ -83,7 +83,7 @@ class AgentConfig:
|
||||
safe_dict = self.to_dict().copy()
|
||||
if 'api_key' in safe_dict and isinstance(safe_dict['api_key'], str) and safe_dict['api_key'].startswith('sk-'):
|
||||
safe_dict['api_key'] = safe_dict['api_key'][:8] + '***' + safe_dict['api_key'][-6:]
|
||||
logger.info(f"config={json.dumps(safe_dict, ensure_ascii=False)}")
|
||||
logger.info(f"config={json.dumps(safe_dict, ensure_ascii=False, default=str)}")
|
||||
|
||||
@classmethod
|
||||
async def from_v1_request(cls, request, api_key: str, project_dir: Optional[str] = None, generate_cfg: Optional[Dict] = None, messages: Optional[List] = None):
|
||||
|
||||
@ -288,9 +288,9 @@ async def init_agent(config: AgentConfig):
|
||||
checkpointer=checkpointer,
|
||||
shell_env={
|
||||
k: v for k, v in {
|
||||
"ASSISTANT_ID": config.bot_id,
|
||||
"USER_IDENTIFIER": config.user_identifier,
|
||||
"TRACE_ID": config.trace_id,
|
||||
"ASSISTANT_ID": str(config.bot_id),
|
||||
"USER_IDENTIFIER": str(config.user_identifier) if config.user_identifier else None,
|
||||
"TRACE_ID": str(config.trace_id) if config.trace_id else None,
|
||||
**(config.shell_env or {}),
|
||||
}.items() if v is not None
|
||||
}
|
||||
|
||||
@ -93,11 +93,13 @@ async def voice_realtime(websocket: WebSocket):
|
||||
continue
|
||||
|
||||
voice_mode = msg.get("voice_mode") or VOICE_DEFAULT_MODE
|
||||
client_sample_rate = msg.get("sample_rate", 24000)
|
||||
|
||||
session_kwargs = dict(
|
||||
bot_id=bot_id,
|
||||
session_id=msg.get("session_id"),
|
||||
user_identifier=msg.get("user_identifier"),
|
||||
sample_rate=client_sample_rate,
|
||||
on_audio=on_audio,
|
||||
on_asr_text=on_asr_text,
|
||||
on_agent_result=on_agent_result,
|
||||
|
||||
@ -37,6 +37,22 @@ class StreamingTTSClient:
|
||||
self._speaker = speaker or VOLCENGINE_DEFAULT_SPEAKER
|
||||
|
||||
async def synthesize(self, text: str):
|
||||
"""
|
||||
Synthesize text to audio via SSE streaming.
|
||||
Yields 24kHz float32 PCM audio chunks.
|
||||
"""
|
||||
async for chunk in self._synthesize_internal(text, raw_int16=False):
|
||||
yield chunk
|
||||
|
||||
async def synthesize_raw(self, text: str):
|
||||
"""
|
||||
Synthesize text to audio via SSE streaming.
|
||||
Yields 24kHz int16 PCM audio chunks (no float32 conversion).
|
||||
"""
|
||||
async for chunk in self._synthesize_internal(text, raw_int16=True):
|
||||
yield chunk
|
||||
|
||||
async def _synthesize_internal(self, text: str, raw_int16: bool = False):
|
||||
"""
|
||||
Synthesize text to audio via SSE streaming.
|
||||
Yields 24kHz float32 PCM audio chunks.
|
||||
@ -97,7 +113,7 @@ class StreamingTTSClient:
|
||||
if line == "":
|
||||
# Blank line = end of one SSE event
|
||||
if current_data:
|
||||
async for audio in self._process_sse_data(current_data):
|
||||
async for audio in self._process_sse_data(current_data, raw_int16=raw_int16):
|
||||
chunk_count += 1
|
||||
yield audio
|
||||
current_event = ""
|
||||
@ -118,7 +134,7 @@ class StreamingTTSClient:
|
||||
|
||||
# Handle remaining data without trailing blank line
|
||||
if current_data:
|
||||
async for audio in self._process_sse_data(current_data):
|
||||
async for audio in self._process_sse_data(current_data, raw_int16=raw_int16):
|
||||
chunk_count += 1
|
||||
yield audio
|
||||
|
||||
@ -127,7 +143,7 @@ class StreamingTTSClient:
|
||||
except Exception as e:
|
||||
logger.error(f"[TTS] Error: {e}", exc_info=True)
|
||||
|
||||
async def _process_sse_data(self, data_str: str):
|
||||
async def _process_sse_data(self, data_str: str, raw_int16: bool = False):
|
||||
"""Parse SSE data field and yield audio chunks if present."""
|
||||
data_str = data_str.rstrip("\n")
|
||||
if not data_str:
|
||||
@ -143,6 +159,10 @@ class StreamingTTSClient:
|
||||
if code == 0 and data.get("data"):
|
||||
# Audio data chunk
|
||||
pcm_raw = base64.b64decode(data["data"])
|
||||
if raw_int16:
|
||||
if pcm_raw:
|
||||
yield pcm_raw
|
||||
else:
|
||||
pcm_f32 = convert_pcm_s16_to_f32(pcm_raw)
|
||||
if pcm_f32:
|
||||
yield pcm_f32
|
||||
|
||||
@ -27,6 +27,7 @@ class VoiceLiteSession:
|
||||
bot_id: str,
|
||||
session_id: Optional[str] = None,
|
||||
user_identifier: Optional[str] = None,
|
||||
sample_rate: int = 24000,
|
||||
on_audio: Optional[Callable[[bytes], Awaitable[None]]] = None,
|
||||
on_asr_text: Optional[Callable[[str], Awaitable[None]]] = None,
|
||||
on_agent_result: Optional[Callable[[str], Awaitable[None]]] = None,
|
||||
@ -38,6 +39,7 @@ class VoiceLiteSession:
|
||||
self.bot_id = bot_id
|
||||
self.session_id = session_id or str(uuid.uuid4())
|
||||
self.user_identifier = user_identifier or ""
|
||||
self._client_sample_rate = sample_rate
|
||||
|
||||
self._bot_config: dict = {}
|
||||
self._speaker: str = ""
|
||||
@ -110,7 +112,7 @@ class VoiceLiteSession:
|
||||
await self._asr_client.close()
|
||||
|
||||
# VAD configuration
|
||||
VAD_SILENCE_DURATION = 1.5 # Seconds of silence before sending finish
|
||||
VAD_SILENCE_DURATION = 3.0 # Seconds of silence before sending finish
|
||||
VAD_PRE_BUFFER_SIZE = 5 # Number of audio chunks to buffer before VAD triggers
|
||||
VAD_SOURCE_RATE = 24000 # Input audio sample rate
|
||||
VAD_TARGET_RATE = 16000 # webrtcvad supported sample rate
|
||||
@ -139,6 +141,43 @@ class VoiceLiteSession:
|
||||
resampled.append(samples[src_idx])
|
||||
return struct.pack(f'<{len(resampled)}h', *resampled)
|
||||
|
||||
@staticmethod
|
||||
def _resample_16k_to_24k(pcm_data: bytes) -> bytes:
|
||||
"""Upsample 16-bit PCM from 16kHz to 24kHz (ratio 2:3).
|
||||
|
||||
For every 2 input samples, produces 3 output samples using linear interpolation.
|
||||
"""
|
||||
n_samples = len(pcm_data) // 2
|
||||
if n_samples == 0:
|
||||
return b''
|
||||
samples = struct.unpack(f'<{n_samples}h', pcm_data[:n_samples * 2])
|
||||
out_len = (n_samples * 3) // 2
|
||||
resampled = []
|
||||
for i in range(out_len):
|
||||
src_pos = (i * 2) / 3
|
||||
src_idx = int(src_pos)
|
||||
frac = src_pos - src_idx
|
||||
if src_idx + 1 < n_samples:
|
||||
val = int(samples[src_idx] * (1 - frac) + samples[src_idx + 1] * frac)
|
||||
elif src_idx < n_samples:
|
||||
val = samples[src_idx]
|
||||
else:
|
||||
break
|
||||
resampled.append(max(-32768, min(32767, val)))
|
||||
return struct.pack(f'<{len(resampled)}h', *resampled)
|
||||
|
||||
def _resample_input(self, audio_data: bytes) -> bytes:
|
||||
"""Resample incoming audio to 24kHz if needed."""
|
||||
if self._client_sample_rate == 16000:
|
||||
return self._resample_16k_to_24k(audio_data)
|
||||
return audio_data
|
||||
|
||||
def _resample_output(self, audio_data: bytes) -> bytes:
|
||||
"""Resample outgoing audio from 24kHz to client sample rate if needed."""
|
||||
if self._client_sample_rate == 16000:
|
||||
return self._resample_24k_to_16k(audio_data)
|
||||
return audio_data
|
||||
|
||||
def _webrtcvad_detect(self, pcm_data: bytes) -> bool:
|
||||
"""Run webrtcvad on audio data. Returns True if voice is detected in any frame."""
|
||||
resampled = self._resample_24k_to_16k(pcm_data)
|
||||
@ -164,6 +203,9 @@ class VoiceLiteSession:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
# Resample to 24kHz if client sends lower sample rate
|
||||
audio_data = self._resample_input(audio_data)
|
||||
|
||||
self._audio_chunk_count += 1
|
||||
has_voice = self._webrtcvad_detect(audio_data)
|
||||
now = asyncio.get_event_loop().time()
|
||||
@ -435,6 +477,12 @@ class VoiceLiteSession:
|
||||
async def _send_tts(self, tts_client: StreamingTTSClient, sentence: str) -> None:
|
||||
"""Synthesize a sentence and emit audio chunks."""
|
||||
logger.info(f"[VoiceLite] TTS sentence: '{sentence[:80]}'")
|
||||
if self._client_sample_rate != 24000:
|
||||
# Client needs non-24kHz: use raw int16 pipeline to allow resampling
|
||||
async for audio_chunk in tts_client.synthesize_raw(sentence):
|
||||
if self._on_audio:
|
||||
await self._on_audio(self._resample_output(audio_chunk))
|
||||
else:
|
||||
async for audio_chunk in tts_client.synthesize(sentence):
|
||||
if self._on_audio:
|
||||
await self._on_audio(audio_chunk)
|
||||
|
||||
@ -17,6 +17,7 @@ class VoiceSession:
|
||||
bot_id: str,
|
||||
session_id: Optional[str] = None,
|
||||
user_identifier: Optional[str] = None,
|
||||
sample_rate: int = 24000,
|
||||
on_audio: Optional[Callable[[bytes], Awaitable[None]]] = None,
|
||||
on_asr_text: Optional[Callable[[str], Awaitable[None]]] = None,
|
||||
on_agent_result: Optional[Callable[[str], Awaitable[None]]] = None,
|
||||
|
||||
@ -16,7 +16,7 @@ import urllib.parse
|
||||
def get_config():
|
||||
"""获取配置,下面的MASTERKEY和ASSISTANT_ID是从环境变量自动获取的,不需要用户提供"""
|
||||
masterkey = os.environ.get("MASTERKEY", "master")
|
||||
bot_id = os.environ.get("ASSISTANT_ID", "")
|
||||
bot_id = str(os.environ.get("ASSISTANT_ID", ""))
|
||||
if not masterkey:
|
||||
print("ERROR: MASTERKEY environment variable is required")
|
||||
sys.exit(1)
|
||||
|
||||
@ -110,7 +110,7 @@ RAGFLOW_MAX_CONCURRENT_UPLOADS = int(os.getenv("RAGFLOW_MAX_CONCURRENT_UPLOADS",
|
||||
# ============================================================
|
||||
|
||||
# New API 基础 URL(支付后端)
|
||||
NEW_API_BASE_URL = os.getenv("NEW_API_BASE_URL", "http://116.62.16.218:3000")
|
||||
NEW_API_BASE_URL = os.getenv("NEW_API_BASE_URL", "http://100.77.70.35:3001")
|
||||
|
||||
# New API 请求超时(秒)
|
||||
NEW_API_TIMEOUT = int(os.getenv("NEW_API_TIMEOUT", "30"))
|
||||
@ -133,7 +133,7 @@ VOLCENGINE_TTS_SAMPLE_RATE = int(os.getenv("VOLCENGINE_TTS_SAMPLE_RATE", "24000"
|
||||
# ============================================================
|
||||
VOICE_DEFAULT_MODE = os.getenv("VOICE_DEFAULT_MODE", "lite") # "realtime" | "lite"
|
||||
# Silence timeout (seconds) - ASR considers user done speaking after this
|
||||
VOICE_LITE_SILENCE_TIMEOUT = float(os.getenv("VOICE_LITE_SILENCE_TIMEOUT", "1.5"))
|
||||
VOICE_LITE_SILENCE_TIMEOUT = float(os.getenv("VOICE_LITE_SILENCE_TIMEOUT", "3.0"))
|
||||
|
||||
# ============================================================
|
||||
# Single Agent Mode Configuration
|
||||
|
||||
Loading…
Reference in New Issue
Block a user