perf: Optimize word segmentation retrieval (#2767)
This commit is contained in:
parent
6fde8ec80f
commit
2991f0b640
@ -12,9 +12,6 @@ from typing import List
|
|||||||
|
|
||||||
import jieba
|
import jieba
|
||||||
import jieba.posseg
|
import jieba.posseg
|
||||||
from jieba import analyse
|
|
||||||
|
|
||||||
from common.util.split_model import group_by
|
|
||||||
|
|
||||||
jieba_word_list_cache = [chr(item) for item in range(38, 84)]
|
jieba_word_list_cache = [chr(item) for item in range(38, 84)]
|
||||||
|
|
||||||
@ -80,37 +77,12 @@ def get_key_by_word_dict(key, word_dict):
|
|||||||
|
|
||||||
|
|
||||||
def to_ts_vector(text: str):
|
def to_ts_vector(text: str):
|
||||||
# 获取不分词的数据
|
|
||||||
word_list = get_word_list(text)
|
|
||||||
# 获取关键词关系
|
|
||||||
word_dict = to_word_dict(word_list, text)
|
|
||||||
# 替换字符串
|
|
||||||
text = replace_word(word_dict, text)
|
|
||||||
# 分词
|
# 分词
|
||||||
filter_word = jieba.analyse.extract_tags(text, topK=100)
|
result = jieba.lcut(text)
|
||||||
result = jieba.lcut(text, HMM=True, use_paddle=True)
|
return " ".join(result)
|
||||||
# 过滤标点符号
|
|
||||||
result = [item for item in result if filter_word.__contains__(item) and len(item) < 10]
|
|
||||||
result_ = [{'word': get_key_by_word_dict(result[index], word_dict), 'index': index} for index in
|
|
||||||
range(len(result))]
|
|
||||||
result_group = group_by(result_, lambda r: r['word'])
|
|
||||||
return " ".join(
|
|
||||||
[f"{key.lower()}:{','.join([str(item['index'] + 1) for item in result_group[key]][:20])}" for key in
|
|
||||||
result_group if
|
|
||||||
not remove_chars.__contains__(key) and len(key.strip()) >= 0])
|
|
||||||
|
|
||||||
|
|
||||||
def to_query(text: str):
|
def to_query(text: str):
|
||||||
# 获取不分词的数据
|
extract_tags = jieba.lcut(text)
|
||||||
word_list = get_word_list(text)
|
result = " ".join(extract_tags)
|
||||||
# 获取关键词关系
|
|
||||||
word_dict = to_word_dict(word_list, text)
|
|
||||||
# 替换字符串
|
|
||||||
text = replace_word(word_dict, text)
|
|
||||||
extract_tags = analyse.extract_tags(text, topK=5, withWeight=True, allowPOS=('ns', 'n', 'vn', 'v', 'eng'))
|
|
||||||
result = " ".join([get_key_by_word_dict(word, word_dict) for word, score in extract_tags if
|
|
||||||
not remove_chars.__contains__(word)])
|
|
||||||
# 删除词库
|
|
||||||
for word in word_list:
|
|
||||||
jieba.del_word(word)
|
|
||||||
return result
|
return result
|
||||||
|
|||||||
@ -12,7 +12,9 @@ import uuid
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from django.db.models import QuerySet
|
import jieba
|
||||||
|
from django.contrib.postgres.search import SearchVector
|
||||||
|
from django.db.models import QuerySet, Value
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
|
|
||||||
from common.db.search import generate_sql_by_query_dict
|
from common.db.search import generate_sql_by_query_dict
|
||||||
@ -68,7 +70,8 @@ class PGVector(BaseVectorStore):
|
|||||||
source_id=text_list[index].get('source_id'),
|
source_id=text_list[index].get('source_id'),
|
||||||
source_type=text_list[index].get('source_type'),
|
source_type=text_list[index].get('source_type'),
|
||||||
embedding=embeddings[index],
|
embedding=embeddings[index],
|
||||||
search_vector=to_ts_vector(text_list[index]['text'])) for index in
|
search_vector=SearchVector(Value(to_ts_vector(text_list[index]['text'])))) for
|
||||||
|
index in
|
||||||
range(0, len(texts))]
|
range(0, len(texts))]
|
||||||
if not is_the_task_interrupted():
|
if not is_the_task_interrupted():
|
||||||
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
|
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user