Compare commits
2 Commits
ba65c44755
...
43a77b3015
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
43a77b3015 | ||
|
|
3ee80a637e |
@ -19,14 +19,17 @@ class _StreamTagFilter:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
SKIP_TAGS = {"TOOL_CALL", "TOOL_RESPONSE", "THINK", "SOURCE", "REFERENCE"}
|
SKIP_TAGS = {"TOOL_CALL", "TOOL_RESPONSE", "THINK", "SOURCE", "REFERENCE"}
|
||||||
|
KNOWN_TAGS = {"ANSWER", "TOOL_CALL", "TOOL_RESPONSE", "THINK", "SOURCE", "REFERENCE", "PREAMBLE", "SUMMARY"}
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.state = "idle" # idle, answer, skip
|
self.state = "idle" # idle, answer, skip
|
||||||
self.found_any_tag = False
|
self.found_any_tag = False
|
||||||
self._pending = "" # buffer for partial tag like "[TOO..."
|
self._pending = "" # buffer for partial tag like "[TOO..."
|
||||||
|
self.answer_ended = False # True when ANSWER block ends (e.g. hit [TOOL_CALL])
|
||||||
|
|
||||||
def feed(self, chunk: str) -> str:
|
def feed(self, chunk: str) -> str:
|
||||||
"""Feed a chunk, return text that should be passed to TTS."""
|
"""Feed a chunk, return text that should be passed to TTS."""
|
||||||
|
self.answer_ended = False
|
||||||
self._pending += chunk
|
self._pending += chunk
|
||||||
output = []
|
output = []
|
||||||
|
|
||||||
@ -50,11 +53,18 @@ class _StreamTagFilter:
|
|||||||
|
|
||||||
tag_name = self._pending[bracket_pos + 1:close_pos]
|
tag_name = self._pending[bracket_pos + 1:close_pos]
|
||||||
self._pending = self._pending[close_pos + 1:]
|
self._pending = self._pending[close_pos + 1:]
|
||||||
self.found_any_tag = True
|
|
||||||
|
|
||||||
|
if tag_name not in self.KNOWN_TAGS:
|
||||||
|
if self.state == "answer" or not self.found_any_tag:
|
||||||
|
output.append(f"[{tag_name}]")
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.found_any_tag = True
|
||||||
if tag_name == "ANSWER":
|
if tag_name == "ANSWER":
|
||||||
self.state = "answer"
|
self.state = "answer"
|
||||||
else:
|
else:
|
||||||
|
if self.state == "answer":
|
||||||
|
self.answer_ended = True
|
||||||
self.state = "skip"
|
self.state = "skip"
|
||||||
|
|
||||||
elif self.state == "skip":
|
elif self.state == "skip":
|
||||||
@ -70,6 +80,9 @@ class _StreamTagFilter:
|
|||||||
tag_name = self._pending[bracket_pos + 1:close_pos]
|
tag_name = self._pending[bracket_pos + 1:close_pos]
|
||||||
self._pending = self._pending[close_pos + 1:]
|
self._pending = self._pending[close_pos + 1:]
|
||||||
|
|
||||||
|
if tag_name not in self.KNOWN_TAGS:
|
||||||
|
continue
|
||||||
|
|
||||||
if tag_name == "ANSWER":
|
if tag_name == "ANSWER":
|
||||||
self.state = "answer"
|
self.state = "answer"
|
||||||
else:
|
else:
|
||||||
@ -115,6 +128,11 @@ class VoiceSession:
|
|||||||
self._current_asr_text = ""
|
self._current_asr_text = ""
|
||||||
# When True, discard TTS audio from SERVER_ACK (comfort speech period)
|
# When True, discard TTS audio from SERVER_ACK (comfort speech period)
|
||||||
self._is_sending_chat_tts_text = False
|
self._is_sending_chat_tts_text = False
|
||||||
|
# Set to True when event 350 fires for chat_tts_text, indicating the TTS segment is done
|
||||||
|
# and next TTS send must use start=True to begin a new session
|
||||||
|
self._tts_segment_done = False
|
||||||
|
# Signaled when event 359 fires (TTS fully completed), used to wait before starting new TTS
|
||||||
|
self._tts_complete_event: asyncio.Event = asyncio.Event()
|
||||||
self._receive_task: Optional[asyncio.Task] = None
|
self._receive_task: Optional[asyncio.Task] = None
|
||||||
self._agent_task: Optional[asyncio.Task] = None
|
self._agent_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
@ -252,10 +270,14 @@ class VoiceSession:
|
|||||||
self._is_sending_chat_tts_text = False
|
self._is_sending_chat_tts_text = False
|
||||||
logger.info(f"[Voice] Comfort/RAG TTS done, resuming audio forwarding")
|
logger.info(f"[Voice] Comfort/RAG TTS done, resuming audio forwarding")
|
||||||
|
|
||||||
|
# Mark TTS segment as done so next send uses start=True
|
||||||
|
if tts_type == "chat_tts_text":
|
||||||
|
self._tts_segment_done = True
|
||||||
|
|
||||||
elif event == 359:
|
elif event == 359:
|
||||||
# TTS fully completed (all segments done)
|
# TTS fully completed (all segments done)
|
||||||
logger.info(f"[Voice] TTS fully completed")
|
logger.info(f"[Voice] TTS fully completed")
|
||||||
# await self._emit_status("idle")
|
self._tts_complete_event.set()
|
||||||
|
|
||||||
elif event in (152, 153):
|
elif event in (152, 153):
|
||||||
logger.info(f"[Voice] Session finished event: {event}")
|
logger.info(f"[Voice] Session finished event: {event}")
|
||||||
@ -314,6 +336,35 @@ class VoiceSession:
|
|||||||
passthrough = tag_filter.feed(chunk)
|
passthrough = tag_filter.feed(chunk)
|
||||||
|
|
||||||
if not passthrough:
|
if not passthrough:
|
||||||
|
# ANSWER block ended (e.g. hit [TOOL_CALL]), flush sentence_buf immediately
|
||||||
|
if tag_filter.answer_ended and sentence_buf:
|
||||||
|
flush = sentence_buf.strip()
|
||||||
|
sentence_buf = ""
|
||||||
|
if flush:
|
||||||
|
flush = self._clean_markdown(flush)
|
||||||
|
if flush:
|
||||||
|
if tts_started and self._tts_segment_done:
|
||||||
|
logger.info(f"[Voice] TTS segment done, closing session and waiting for delivery (answer ended)")
|
||||||
|
await self.realtime_client.chat_tts_text(content="", start=False, end=True)
|
||||||
|
self._tts_complete_event.clear()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._tts_complete_event.wait(), timeout=10)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(f"[Voice] Timeout waiting for TTS complete, proceeding anyway")
|
||||||
|
tts_started = False
|
||||||
|
self._tts_segment_done = False
|
||||||
|
logger.info(f"[Voice] TTS delivery done, starting new session (answer ended)")
|
||||||
|
|
||||||
|
logger.info(f"[Voice] Sending TTS sentence (answer ended): '{flush[:80]}'")
|
||||||
|
await self.realtime_client.chat_tts_text(
|
||||||
|
content=flush,
|
||||||
|
start=not tts_started,
|
||||||
|
end=False,
|
||||||
|
)
|
||||||
|
if not tts_started:
|
||||||
|
await self._emit_status("speaking")
|
||||||
|
tts_started = True
|
||||||
|
self._tts_segment_done = False
|
||||||
continue
|
continue
|
||||||
|
|
||||||
sentence_buf += passthrough
|
sentence_buf += passthrough
|
||||||
@ -331,26 +382,59 @@ class VoiceSession:
|
|||||||
if sentence:
|
if sentence:
|
||||||
sentence = self._clean_markdown(sentence)
|
sentence = self._clean_markdown(sentence)
|
||||||
if sentence:
|
if sentence:
|
||||||
|
# If previous TTS segment completed (e.g. gap during tool call),
|
||||||
|
# close old session, wait for TTS delivery to finish, then restart
|
||||||
|
if tts_started and self._tts_segment_done:
|
||||||
|
logger.info(f"[Voice] TTS segment done, closing session and waiting for delivery")
|
||||||
|
await self.realtime_client.chat_tts_text(content="", start=False, end=True)
|
||||||
|
self._tts_complete_event.clear()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._tts_complete_event.wait(), timeout=10)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(f"[Voice] Timeout waiting for TTS complete, proceeding anyway")
|
||||||
|
tts_started = False
|
||||||
|
self._tts_segment_done = False
|
||||||
|
logger.info(f"[Voice] TTS delivery done, starting new session")
|
||||||
|
|
||||||
logger.info(f"[Voice] Sending TTS sentence: '{sentence[:80]}'")
|
logger.info(f"[Voice] Sending TTS sentence: '{sentence[:80]}'")
|
||||||
await self.realtime_client.chat_tts_text(
|
await self.realtime_client.chat_tts_text(
|
||||||
content=sentence,
|
content=sentence,
|
||||||
start=not tts_started,
|
start=not tts_started,
|
||||||
end=False,
|
end=False,
|
||||||
)
|
)
|
||||||
|
if not tts_started:
|
||||||
|
await self._emit_status("speaking")
|
||||||
tts_started = True
|
tts_started = True
|
||||||
|
self._tts_segment_done = False
|
||||||
|
|
||||||
# Handle remaining text in buffer (last sentence without ending punctuation)
|
# Handle remaining text in buffer (last sentence without ending punctuation)
|
||||||
remaining = sentence_buf.strip()
|
remaining = sentence_buf.strip()
|
||||||
if remaining:
|
if remaining:
|
||||||
remaining = self._clean_markdown(remaining)
|
remaining = self._clean_markdown(remaining)
|
||||||
if remaining:
|
if remaining:
|
||||||
|
# If previous TTS segment completed, close and wait before restart
|
||||||
|
if tts_started and self._tts_segment_done:
|
||||||
|
logger.info(f"[Voice] TTS segment done, closing session and waiting for delivery (remaining)")
|
||||||
|
await self.realtime_client.chat_tts_text(content="", start=False, end=True)
|
||||||
|
self._tts_complete_event.clear()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._tts_complete_event.wait(), timeout=10)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(f"[Voice] Timeout waiting for TTS complete, proceeding anyway")
|
||||||
|
tts_started = False
|
||||||
|
self._tts_segment_done = False
|
||||||
|
logger.info(f"[Voice] TTS delivery done, starting new session for remaining")
|
||||||
|
|
||||||
logger.info(f"[Voice] Sending TTS remaining: '{remaining[:80]}'")
|
logger.info(f"[Voice] Sending TTS remaining: '{remaining[:80]}'")
|
||||||
await self.realtime_client.chat_tts_text(
|
await self.realtime_client.chat_tts_text(
|
||||||
content=remaining,
|
content=remaining,
|
||||||
start=not tts_started,
|
start=not tts_started,
|
||||||
end=False,
|
end=False,
|
||||||
)
|
)
|
||||||
|
if not tts_started:
|
||||||
|
await self._emit_status("speaking")
|
||||||
tts_started = True
|
tts_started = True
|
||||||
|
self._tts_segment_done = False
|
||||||
|
|
||||||
# Send TTS end signal
|
# Send TTS end signal
|
||||||
if tts_started:
|
if tts_started:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user