feat: 细分段落chunk增加召回命中率 (#841)
This commit is contained in:
parent
203c3e5cde
commit
53434f9d24
18
apps/common/chunk/__init__.py
Normal file
18
apps/common/chunk/__init__.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: __init__.py
|
||||||
|
@date:2024/7/23 17:03
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from common.chunk.impl.mark_chunk_handle import MarkChunkHandle
|
||||||
|
|
||||||
|
handles = [MarkChunkHandle()]
|
||||||
|
|
||||||
|
|
||||||
|
def text_to_chunk(text: str):
|
||||||
|
chunk_list = [text]
|
||||||
|
for handle in handles:
|
||||||
|
chunk_list = handle.handle(chunk_list)
|
||||||
|
return chunk_list
|
||||||
16
apps/common/chunk/i_chunk_handle.py
Normal file
16
apps/common/chunk/i_chunk_handle.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: i_chunk_handle.py
|
||||||
|
@date:2024/7/23 16:51
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
class IChunkHandle(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def handle(self, chunk_list: List[str]):
|
||||||
|
pass
|
||||||
24
apps/common/chunk/impl/mark_chunk_handle.py
Normal file
24
apps/common/chunk/impl/mark_chunk_handle.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: mark_chunk_handle.py
|
||||||
|
@date:2024/7/23 16:52
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from common.chunk.i_chunk_handle import IChunkHandle
|
||||||
|
|
||||||
|
split_chunk_pattern = "!|。|\n|;|;"
|
||||||
|
|
||||||
|
|
||||||
|
class MarkChunkHandle(IChunkHandle):
|
||||||
|
def handle(self, chunk_list: List[str]):
|
||||||
|
result = []
|
||||||
|
for chunk in chunk_list:
|
||||||
|
base_chunk = re.split(split_chunk_pattern, chunk)
|
||||||
|
base_chunk = [chunk.strip() for chunk in base_chunk if len(chunk.strip()) > 0]
|
||||||
|
result = [*result, *base_chunk]
|
||||||
|
return result
|
||||||
@ -19,9 +19,7 @@ SELECT
|
|||||||
paragraph."id" AS paragraph_id,
|
paragraph."id" AS paragraph_id,
|
||||||
paragraph.dataset_id AS dataset_id,
|
paragraph.dataset_id AS dataset_id,
|
||||||
1 AS source_type,
|
1 AS source_type,
|
||||||
concat_ws('
|
concat_ws(E'\n',paragraph.title,paragraph."content") AS "text",
|
||||||
',concat_ws('
|
|
||||||
',paragraph.title,paragraph."content"),paragraph.title) AS "text",
|
|
||||||
paragraph.is_active AS is_active
|
paragraph.is_active AS is_active
|
||||||
FROM
|
FROM
|
||||||
paragraph paragraph
|
paragraph paragraph
|
||||||
|
|||||||
@ -0,0 +1,17 @@
|
|||||||
|
# Generated by Django 4.2.14 on 2024-07-23 18:14
|
||||||
|
|
||||||
|
from django.db import migrations
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('embedding', '0002_embedding_search_vector'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterUniqueTogether(
|
||||||
|
name='embedding',
|
||||||
|
unique_together=set(),
|
||||||
|
),
|
||||||
|
]
|
||||||
@ -50,4 +50,3 @@ class Embedding(models.Model):
|
|||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
db_table = "embedding"
|
db_table = "embedding"
|
||||||
unique_together = ['source_id', 'source_type']
|
|
||||||
|
|||||||
@ -8,16 +8,31 @@
|
|||||||
"""
|
"""
|
||||||
import threading
|
import threading
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from functools import reduce
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
|
|
||||||
|
from common.chunk import text_to_chunk
|
||||||
from common.util.common import sub_array
|
from common.util.common import sub_array
|
||||||
from embedding.models import SourceType, SearchMode
|
from embedding.models import SourceType, SearchMode
|
||||||
|
|
||||||
lock = threading.Lock()
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_data(data: Dict):
|
||||||
|
if str(data.get('source_type')) == SourceType.PARAGRAPH.value:
|
||||||
|
text = data.get('text')
|
||||||
|
chunk_list = text_to_chunk(text)
|
||||||
|
return [{**data, 'text': chunk} for chunk in chunk_list]
|
||||||
|
return [data]
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_data_list(data_list: List[Dict]):
|
||||||
|
result = [chunk_data(data) for data in data_list]
|
||||||
|
return reduce(lambda x, y: [*x, *y], result, [])
|
||||||
|
|
||||||
|
|
||||||
class BaseVectorStore(ABC):
|
class BaseVectorStore(ABC):
|
||||||
vector_exists = False
|
vector_exists = False
|
||||||
|
|
||||||
@ -64,7 +79,12 @@ class BaseVectorStore(ABC):
|
|||||||
:return: bool
|
:return: bool
|
||||||
"""
|
"""
|
||||||
self.save_pre_handler()
|
self.save_pre_handler()
|
||||||
self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding)
|
data = {'document_id': document_id, 'paragraph_id': paragraph_id, 'dataset_id': dataset_id,
|
||||||
|
'is_active': is_active, 'source_id': source_id, 'source_type': source_type, 'text': text}
|
||||||
|
chunk_list = chunk_data(data)
|
||||||
|
result = sub_array(chunk_list)
|
||||||
|
for child_array in result:
|
||||||
|
self._batch_save(child_array, embedding)
|
||||||
|
|
||||||
def batch_save(self, data_list: List[Dict], embedding: Embeddings):
|
def batch_save(self, data_list: List[Dict], embedding: Embeddings):
|
||||||
# 获取锁
|
# 获取锁
|
||||||
@ -77,7 +97,8 @@ class BaseVectorStore(ABC):
|
|||||||
:return: bool
|
:return: bool
|
||||||
"""
|
"""
|
||||||
self.save_pre_handler()
|
self.save_pre_handler()
|
||||||
result = sub_array(data_list)
|
chunk_list = chunk_data_list(data_list)
|
||||||
|
result = sub_array(chunk_list)
|
||||||
for child_array in result:
|
for child_array in result:
|
||||||
self._batch_save(child_array, embedding)
|
self._batch_save(child_array, embedding)
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
18
apps/setting/migrations/0006_alter_model_status.py
Normal file
18
apps/setting/migrations/0006_alter_model_status.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# Generated by Django 4.2.14 on 2024-07-23 18:14
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('setting', '0005_model_permission_type'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name='model',
|
||||||
|
name='status',
|
||||||
|
field=models.CharField(choices=[('SUCCESS', '成功'), ('ERROR', '失败'), ('DOWNLOAD', '下载中'), ('PAUSE_DOWNLOAD', '暂停下载')], default='SUCCESS', max_length=20, verbose_name='设置类型'),
|
||||||
|
),
|
||||||
|
]
|
||||||
Loading…
Reference in New Issue
Block a user