语义分割

This commit is contained in:
朱潮 2026-03-22 00:42:57 +08:00
parent f9e9c3c26d
commit 7a547322e3
5 changed files with 236 additions and 55 deletions

36
poetry.lock generated
View File

@ -5302,25 +5302,19 @@ train = ["accelerate (>=0.20.3)", "datasets"]
[[package]]
name = "setuptools"
version = "82.0.1"
description = "Most extensible Python build backend with support for C/C++ extension modules"
version = "70.3.0"
description = "Easily download, build, install, upgrade, and uninstall Python packages"
optional = false
python-versions = ">=3.9"
python-versions = ">=3.8"
groups = ["main"]
markers = "python_version >= \"3.13\""
files = [
{file = "setuptools-82.0.1-py3-none-any.whl", hash = "sha256:a59e362652f08dcd477c78bb6e7bd9d80a7995bc73ce773050228a348ce2e5bb"},
{file = "setuptools-82.0.1.tar.gz", hash = "sha256:7d872682c5d01cfde07da7bccc7b65469d3dca203318515ada1de5eda35efbf9"},
{file = "setuptools-70.3.0-py3-none-any.whl", hash = "sha256:fe384da74336c398e0d956d1cae0669bc02eed936cdb1d49b57de1990dc11ffc"},
{file = "setuptools-70.3.0.tar.gz", hash = "sha256:f171bab1dfbc86b132997f26a119f6056a57950d058587841a0082e8830f9dc5"},
]
[package.extras]
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.13.0) ; sys_platform != \"cygwin\""]
core = ["importlib_metadata (>=6) ; python_version < \"3.10\"", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging (>=24.2)", "tomli (>=2.0.1) ; python_version < \"3.11\"", "wheel (>=0.43.0)"]
cover = ["pytest-cov"]
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
enabler = ["pytest-enabler (>=2.2)"]
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21) ; python_version >= \"3.9\" and sys_platform != \"cygwin\"", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf ; sys_platform != \"cygwin\"", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
type = ["importlib_metadata (>=7.0.2) ; python_version < \"3.10\"", "jaraco.develop (>=7.21) ; sys_platform != \"cygwin\"", "mypy (==1.18.*)", "pytest-mypy"]
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21) ; python_version >= \"3.9\" and sys_platform != \"cygwin\"", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "mypy (==1.10.0)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf ; sys_platform != \"cygwin\"", "pytest-ruff (>=0.3.2) ; sys_platform != \"cygwin\"", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
[[package]]
name = "shellingham"
@ -6347,6 +6341,20 @@ files = [
{file = "wcwidth-0.6.0.tar.gz", hash = "sha256:cdc4e4262d6ef9a1a57e018384cbeb1208d8abbc64176027e2c2455c81313159"},
]
[[package]]
name = "webrtcvad"
version = "2.0.10"
description = "Python interface to the Google WebRTC Voice Activity Detector (VAD)"
optional = false
python-versions = "*"
groups = ["main"]
files = [
{file = "webrtcvad-2.0.10.tar.gz", hash = "sha256:f1bed2fb25b63fb7b1a55d64090c993c9c9167b28485ae0bcdd81cf6ede96aea"},
]
[package.extras]
dev = ["check-manifest", "memory_profiler", "nose", "psutil", "unittest2", "zest.releaser"]
[[package]]
name = "websockets"
version = "15.0.1"
@ -6983,4 +6991,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
[metadata]
lock-version = "2.1"
python-versions = ">=3.12,<3.15"
content-hash = "1461514ed1f9639f41f43ebb28f2a3fcd2d5a5dde954cd509c0ea7bf181e9bb6"
content-hash = "c9c4f80cdbf7d6bce20f65f40b9adce05c5f4a830299de148fcd8482937bddb0"

View File

@ -39,7 +39,9 @@ dependencies = [
"ragflow-sdk (>=0.23.0,<0.24.0)",
"httpx (>=0.28.1,<0.29.0)",
"wsgidav (>=4.3.3,<5.0.0)",
"websockets (>=15.0.0,<16.0.0)"
"websockets (>=15.0.0,<16.0.0)",
"setuptools (<71)",
"webrtcvad (>=2.0.10,<3.0.0)",
]
[tool.poetry.requires-plugins]

View File

@ -165,7 +165,7 @@ safetensors==0.7.0 ; python_version >= "3.12" and python_version < "3.15"
scikit-learn==1.8.0 ; python_version >= "3.12" and python_version < "3.15"
scipy==1.17.1 ; python_version >= "3.12" and python_version < "3.15"
sentence-transformers==3.4.1 ; python_version >= "3.12" and python_version < "3.15"
setuptools==82.0.1 ; python_version >= "3.13" and python_version < "3.15"
setuptools==70.3.0 ; python_version >= "3.12" and python_version < "3.15"
shellingham==1.5.4 ; python_version >= "3.12" and python_version < "3.15"
six==1.17.0 ; python_version >= "3.12" and python_version < "3.15"
sniffio==1.3.1 ; python_version >= "3.12" and python_version < "3.15"
@ -203,6 +203,7 @@ uvloop==0.22.1 ; python_version >= "3.12" and python_version < "3.15"
watchfiles==1.1.1 ; python_version >= "3.12" and python_version < "3.15"
wcmatch==10.1 ; python_version >= "3.12" and python_version < "3.15"
wcwidth==0.6.0 ; python_version >= "3.12" and python_version < "3.15"
webrtcvad==2.0.10 ; python_version >= "3.12" and python_version < "3.15"
websockets==15.0.1 ; python_version >= "3.12" and python_version < "3.15"
wrapt==1.17.3 ; python_version >= "3.12" and python_version < "3.15"
wsgidav==4.3.3 ; python_version >= "3.12" and python_version < "3.15"

View File

@ -4,13 +4,15 @@ import struct
import uuid
from typing import Optional, Callable, Awaitable
import webrtcvad
from services.streaming_asr_client import StreamingASRClient
from services.streaming_tts_client import StreamingTTSClient
from services.voice_utils import (
StreamTagFilter,
clean_markdown,
stream_v3_agent,
SENTENCE_END_RE,
TTSSentenceSplitter,
)
from utils.settings import VOICE_LITE_SILENCE_TIMEOUT
@ -63,7 +65,8 @@ class VoiceLiteSession:
self._last_asr_emit_time: float = 0
self._utterance_lock = asyncio.Lock()
# VAD (Voice Activity Detection) state
# VAD (Voice Activity Detection) via webrtcvad
self._vad = webrtcvad.Vad(2) # aggressiveness 0-3 (2 = balanced)
self._vad_speaking = False # Whether user is currently speaking
self._vad_silence_start: float = 0 # When silence started
self._vad_finish_task: Optional[asyncio.Task] = None
@ -105,23 +108,52 @@ class VoiceLiteSession:
await self._asr_client.close()
# VAD configuration
VAD_ENERGY_THRESHOLD = 500 # RMS energy threshold for voice detection
VAD_SILENCE_DURATION = 1.5 # Seconds of silence before sending finish
VAD_PRE_BUFFER_SIZE = 5 # Number of audio chunks to buffer before VAD triggers
VAD_SOURCE_RATE = 24000 # Input audio sample rate
VAD_TARGET_RATE = 16000 # webrtcvad supported sample rate
VAD_FRAME_DURATION_MS = 30 # Frame duration for webrtcvad (10, 20, or 30 ms)
_audio_chunk_count = 0
@staticmethod
def _calc_rms(pcm_data: bytes) -> float:
"""Calculate RMS energy of 16-bit PCM audio."""
if len(pcm_data) < 2:
return 0.0
def _resample_24k_to_16k(pcm_data: bytes) -> bytes:
"""Downsample 16-bit PCM from 24kHz to 16kHz (ratio 3:2).
Takes every 2 out of 3 samples (simple decimation).
"""
n_samples = len(pcm_data) // 2
if n_samples == 0:
return b''
samples = struct.unpack(f'<{n_samples}h', pcm_data[:n_samples * 2])
if not samples:
return 0.0
sum_sq = sum(s * s for s in samples)
return (sum_sq / n_samples) ** 0.5
# Pick samples at indices 0, 1.5, 3, 4.5, ... -> floor(i * 3/2) for output index i
out_len = (n_samples * 2) // 3
resampled = []
for i in range(out_len):
src_idx = (i * 3) // 2
if src_idx < n_samples:
resampled.append(samples[src_idx])
return struct.pack(f'<{len(resampled)}h', *resampled)
def _webrtcvad_detect(self, pcm_data: bytes) -> bool:
"""Run webrtcvad on audio data. Returns True if voice is detected in any frame."""
resampled = self._resample_24k_to_16k(pcm_data)
frame_size = (self.VAD_TARGET_RATE * self.VAD_FRAME_DURATION_MS // 1000) * 2 # bytes per frame
if len(resampled) < frame_size:
return False
# Check frames; return True if any frame has voice
voice_frames = 0
total_frames = 0
for offset in range(0, len(resampled) - frame_size + 1, frame_size):
frame = resampled[offset:offset + frame_size]
total_frames += 1
try:
if self._vad.is_speech(frame, self.VAD_TARGET_RATE):
voice_frames += 1
except Exception:
pass
# Consider voice detected if at least one frame has speech
return voice_frames > 0
async def handle_audio(self, audio_data: bytes) -> None:
"""Forward user audio to ASR with VAD gating. Lazy-connect on speech start."""
@ -129,8 +161,7 @@ class VoiceLiteSession:
return
self._audio_chunk_count += 1
rms = self._calc_rms(audio_data)
has_voice = rms > self.VAD_ENERGY_THRESHOLD
has_voice = self._webrtcvad_detect(audio_data)
now = asyncio.get_event_loop().time()
if has_voice:
@ -142,7 +173,7 @@ class VoiceLiteSession:
if not self._vad_speaking:
# Speech just started — connect ASR
self._vad_speaking = True
logger.info(f"[VoiceLite] VAD: speech started (rms={rms:.0f}), connecting ASR...")
logger.info(f"[VoiceLite] VAD: speech started (webrtcvad), connecting ASR...")
try:
await self._connect_asr()
# Send buffered pre-speech audio
@ -320,8 +351,8 @@ class VoiceLiteSession:
await self._emit_status("thinking")
accumulated_text = []
sentence_buf = ""
tag_filter = StreamTagFilter()
splitter = TTSSentenceSplitter()
tts_client = StreamingTTSClient(speaker=self._speaker)
speaking = False
@ -340,26 +371,20 @@ class VoiceLiteSession:
passthrough = tag_filter.feed(chunk)
if not passthrough:
if tag_filter.answer_ended and sentence_buf:
flush = clean_markdown(sentence_buf.strip())
sentence_buf = ""
if flush:
if not speaking:
await self._emit_status("speaking")
speaking = True
await self._send_tts(tts_client, flush)
if tag_filter.answer_ended:
for sentence in splitter.flush():
sentence = clean_markdown(sentence)
if sentence:
if not speaking:
await self._emit_status("speaking")
speaking = True
await self._send_tts(tts_client, sentence)
continue
sentence_buf += passthrough
while True:
match = SENTENCE_END_RE.search(sentence_buf)
if not match:
break
end_pos = match.end()
sentence = clean_markdown(sentence_buf[:end_pos].strip())
sentence_buf = sentence_buf[end_pos:]
# Feed raw passthrough to splitter (preserve newlines for splitting),
# apply clean_markdown on output sentences
for sentence in splitter.feed(passthrough):
sentence = clean_markdown(sentence)
if sentence:
if not speaking:
await self._emit_status("speaking")
@ -367,12 +392,13 @@ class VoiceLiteSession:
await self._send_tts(tts_client, sentence)
# Handle remaining text
remaining = clean_markdown(sentence_buf.strip())
if remaining:
if not speaking:
await self._emit_status("speaking")
speaking = True
await self._send_tts(tts_client, remaining)
for sentence in splitter.flush():
sentence = clean_markdown(sentence)
if sentence:
if not speaking:
await self._emit_status("speaking")
speaking = True
await self._send_tts(tts_client, sentence)
# Log full agent result (not sent to frontend, already streamed)
full_result = "".join(accumulated_text)

View File

@ -7,6 +7,150 @@ logger = logging.getLogger('app')
SENTENCE_END_RE = re.compile(r'[。!?;\n.!?;]')
# Emoji pattern: matches Unicode emoji without touching CJK characters
_EMOJI_RE = re.compile(
"["
"\U0001F600-\U0001F64F" # emoticons
"\U0001F300-\U0001F5FF" # symbols & pictographs
"\U0001F680-\U0001F6FF" # transport & map
"\U0001F1E0-\U0001F1FF" # flags
"\U0001F900-\U0001F9FF" # supplemental symbols
"\U0001FA00-\U0001FA6F" # chess symbols
"\U0001FA70-\U0001FAFF" # symbols extended-A
"\U00002702-\U000027B0" # dingbats
"\U00002600-\U000026FF" # misc symbols
"\U0000FE00-\U0000FE0F" # variation selectors
"\U0000200D" # zero width joiner
"\U000024C2" # Ⓜ enclosed letter
"\U00002B50\U00002B55" # star, circle
"\U000023CF\U000023E9-\U000023F3\U000023F8-\U000023FA" # media controls
"\U0001F170-\U0001F251" # enclosed alphanumeric supplement
"]+",
flags=re.UNICODE,
)
# Strong sentence-ending punctuation (excluding \n which is handled separately)
_STRONG_PUNCT_RE = re.compile(r'[。!?;.!?;~]')
# Soft punctuation (usable as split points when buffer is getting long)
_SOFT_PUNCT_RE = re.compile(r'[,:、)) \t]')
class TTSSentenceSplitter:
"""
Intelligent sentence splitter for TTS streaming.
Rules (in priority order):
1. Split on newlines unconditionally (LLM paragraph boundaries)
2. Split on strong punctuation ( etc.) only if accumulated >= MIN_LENGTH
3. If buffer reaches SOFT_THRESHOLD, also split on soft punctuation (etc.)
4. If buffer reaches MAX_LENGTH, force split at best available position
- Strip emoji from output (TTS cannot pronounce them)
- On flush(), return any remaining text regardless of length
"""
MIN_LENGTH = 10 # Don't send sentences shorter than this
SOFT_THRESHOLD = 30 # Start considering soft punctuation splits
MAX_LENGTH = 80 # Force split even without punctuation
def __init__(self):
self._buf = ""
def _clean_for_tts(self, text: str) -> str:
"""Remove emoji and collapse whitespace."""
text = _EMOJI_RE.sub("", text)
text = re.sub(r'[ \t]+', ' ', text)
return text.strip()
def feed(self, chunk: str) -> list[str]:
"""Feed a text chunk, return list of ready sentences (may be empty)."""
self._buf += chunk
results = []
while self._buf:
buf_len = len(self._buf)
# 0. Newline split — highest priority
nl_pos = self._buf.find('\n')
if nl_pos >= 0:
before = self._buf[:nl_pos]
rest = self._buf[nl_pos:].lstrip('\n')
cleaned = self._clean_for_tts(before)
if len(cleaned) >= self.MIN_LENGTH:
# Long enough, emit as a sentence
self._buf = rest
results.append(cleaned)
continue
elif not rest:
# No more text after newline, keep buffer and wait
break
else:
# Too short — merge with next paragraph
self._buf = before + rest
continue
# 1. Try strong punctuation split — scan for the best split point
best_end = -1
for match in _STRONG_PUNCT_RE.finditer(self._buf):
end_pos = match.end()
candidate = self._buf[:end_pos]
if len(candidate.strip()) >= self.MIN_LENGTH:
best_end = end_pos
break # Take the first valid (long enough) split
# Short segment before this punct — skip and keep scanning
if best_end > 0:
sentence = self._clean_for_tts(self._buf[:best_end])
self._buf = self._buf[best_end:]
if sentence:
results.append(sentence)
continue
# 2. Buffer getting long: try soft punctuation split
if buf_len >= self.SOFT_THRESHOLD:
best_soft = -1
for m in _SOFT_PUNCT_RE.finditer(self._buf):
pos = m.end()
if pos >= self.MIN_LENGTH:
best_soft = pos
if pos >= self.SOFT_THRESHOLD:
break
if best_soft >= self.MIN_LENGTH:
sentence = self._clean_for_tts(self._buf[:best_soft])
self._buf = self._buf[best_soft:]
if sentence:
results.append(sentence)
continue
# 3. Buffer too long: force split at MAX_LENGTH
if buf_len >= self.MAX_LENGTH:
split_at = self.MAX_LENGTH
search_region = self._buf[self.MIN_LENGTH:self.MAX_LENGTH]
last_space = max(search_region.rfind(' '), search_region.rfind(''),
search_region.rfind(','), search_region.rfind(''))
if last_space >= 0:
split_at = self.MIN_LENGTH + last_space + 1
sentence = self._clean_for_tts(self._buf[:split_at])
self._buf = self._buf[split_at:]
if sentence:
results.append(sentence)
continue
# Not enough text yet, wait for more
break
return results
def flush(self) -> list[str]:
"""Flush remaining buffer. Call at end of stream."""
results = []
if self._buf.strip():
sentence = self._clean_for_tts(self._buf)
if sentence:
results.append(sentence)
self._buf = ""
return results
class StreamTagFilter:
"""