perf: Optimize word segmentation retrieval (#2767)

This commit is contained in:
shaohuzhang1 2025-04-01 19:11:16 +08:00 committed by GitHub
parent 6fde8ec80f
commit 2991f0b640
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 34 deletions

View File

@ -12,9 +12,6 @@ from typing import List
import jieba import jieba
import jieba.posseg import jieba.posseg
from jieba import analyse
from common.util.split_model import group_by
jieba_word_list_cache = [chr(item) for item in range(38, 84)] jieba_word_list_cache = [chr(item) for item in range(38, 84)]
@ -80,37 +77,12 @@ def get_key_by_word_dict(key, word_dict):
def to_ts_vector(text: str): def to_ts_vector(text: str):
# 获取不分词的数据
word_list = get_word_list(text)
# 获取关键词关系
word_dict = to_word_dict(word_list, text)
# 替换字符串
text = replace_word(word_dict, text)
# 分词 # 分词
filter_word = jieba.analyse.extract_tags(text, topK=100) result = jieba.lcut(text)
result = jieba.lcut(text, HMM=True, use_paddle=True) return " ".join(result)
# 过滤标点符号
result = [item for item in result if filter_word.__contains__(item) and len(item) < 10]
result_ = [{'word': get_key_by_word_dict(result[index], word_dict), 'index': index} for index in
range(len(result))]
result_group = group_by(result_, lambda r: r['word'])
return " ".join(
[f"{key.lower()}:{','.join([str(item['index'] + 1) for item in result_group[key]][:20])}" for key in
result_group if
not remove_chars.__contains__(key) and len(key.strip()) >= 0])
def to_query(text: str): def to_query(text: str):
# 获取不分词的数据 extract_tags = jieba.lcut(text)
word_list = get_word_list(text) result = " ".join(extract_tags)
# 获取关键词关系
word_dict = to_word_dict(word_list, text)
# 替换字符串
text = replace_word(word_dict, text)
extract_tags = analyse.extract_tags(text, topK=5, withWeight=True, allowPOS=('ns', 'n', 'vn', 'v', 'eng'))
result = " ".join([get_key_by_word_dict(word, word_dict) for word, score in extract_tags if
not remove_chars.__contains__(word)])
# 删除词库
for word in word_list:
jieba.del_word(word)
return result return result

View File

@ -12,7 +12,9 @@ import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List from typing import Dict, List
from django.db.models import QuerySet import jieba
from django.contrib.postgres.search import SearchVector
from django.db.models import QuerySet, Value
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from common.db.search import generate_sql_by_query_dict from common.db.search import generate_sql_by_query_dict
@ -68,7 +70,8 @@ class PGVector(BaseVectorStore):
source_id=text_list[index].get('source_id'), source_id=text_list[index].get('source_id'),
source_type=text_list[index].get('source_type'), source_type=text_list[index].get('source_type'),
embedding=embeddings[index], embedding=embeddings[index],
search_vector=to_ts_vector(text_list[index]['text'])) for index in search_vector=SearchVector(Value(to_ts_vector(text_list[index]['text'])))) for
index in
range(0, len(texts))] range(0, len(texts))]
if not is_the_task_interrupted(): if not is_the_task_interrupted():
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None