feat: 支持向量模型
This commit is contained in:
parent
e600b91de2
commit
75b9b17e2e
@ -39,6 +39,7 @@ from common.util.file_util import get_file_content
|
|||||||
from common.util.lock import try_lock, un_lock
|
from common.util.lock import try_lock, un_lock
|
||||||
from common.util.rsa_util import rsa_long_decrypt
|
from common.util.rsa_util import rsa_long_decrypt
|
||||||
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
|
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
|
||||||
|
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
|
||||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers
|
from dataset.serializers.paragraph_serializers import ParagraphSerializers
|
||||||
from setting.models import Model
|
from setting.models import Model
|
||||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||||
@ -533,9 +534,10 @@ class ChatRecordSerializer(serializers.Serializer):
|
|||||||
raise AppApiException(500, "文档id不正确")
|
raise AppApiException(500, "文档id不正确")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def post_embedding_paragraph(chat_record, paragraph_id):
|
def post_embedding_paragraph(chat_record, paragraph_id, dataset_id):
|
||||||
|
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||||
# 发送向量化事件
|
# 发送向量化事件
|
||||||
ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id)
|
ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id, embedding_model=model)
|
||||||
return chat_record
|
return chat_record
|
||||||
|
|
||||||
@post(post_function=post_embedding_paragraph)
|
@post(post_function=post_embedding_paragraph)
|
||||||
@ -573,7 +575,7 @@ class ChatRecordSerializer(serializers.Serializer):
|
|||||||
chat_record.improve_paragraph_id_list.append(paragraph.id)
|
chat_record.improve_paragraph_id_list.append(paragraph.id)
|
||||||
# 添加标注
|
# 添加标注
|
||||||
chat_record.save()
|
chat_record.save()
|
||||||
return ChatRecordSerializerModel(chat_record).data, paragraph.id
|
return ChatRecordSerializerModel(chat_record).data, paragraph.id, dataset_id
|
||||||
|
|
||||||
class Operate(serializers.Serializer):
|
class Operate(serializers.Serializer):
|
||||||
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
|
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
|
||||||
|
|||||||
@ -14,14 +14,14 @@ embedding_thread_pool = ThreadPoolExecutor(3)
|
|||||||
|
|
||||||
|
|
||||||
def poxy(poxy_function):
|
def poxy(poxy_function):
|
||||||
def inner(args):
|
def inner(args, **keywords):
|
||||||
work_thread_pool.submit(poxy_function, args)
|
work_thread_pool.submit(poxy_function, args, **keywords)
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
|
||||||
def embedding_poxy(poxy_function):
|
def embedding_poxy(poxy_function):
|
||||||
def inner(args):
|
def inner(args, **keywords):
|
||||||
embedding_thread_pool.submit(poxy_function, args)
|
embedding_thread_pool.submit(poxy_function, args, **keywords)
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|||||||
@ -85,8 +85,8 @@ class ListenerManagement:
|
|||||||
delete_embedding_by_dataset_id_list_signal = signal("delete_embedding_by_dataset_id_list")
|
delete_embedding_by_dataset_id_list_signal = signal("delete_embedding_by_dataset_id_list")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def embedding_by_problem(args):
|
def embedding_by_problem(args, embedding_model: Embeddings):
|
||||||
VectorStore.get_embedding_vector().save(**args)
|
VectorStore.get_embedding_vector().save(**args, embedding=embedding_model)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@embedding_poxy
|
@embedding_poxy
|
||||||
@ -165,7 +165,7 @@ class ListenerManagement:
|
|||||||
document_list = QuerySet(Document).filter(dataset_id=dataset_id)
|
document_list = QuerySet(Document).filter(dataset_id=dataset_id)
|
||||||
max_kb.info(f"数据集文档:{[d.name for d in document_list]}")
|
max_kb.info(f"数据集文档:{[d.name for d in document_list]}")
|
||||||
for document in document_list:
|
for document in document_list:
|
||||||
ListenerManagement.embedding_by_document(document.id, embedding_model)
|
ListenerManagement.embedding_by_document(document.id, embedding_model=embedding_model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
max_kb_error.error(f'向量化数据集:{dataset_id}出现错误{str(e)}{traceback.format_exc()}')
|
max_kb_error.error(f'向量化数据集:{dataset_id}出现错误{str(e)}{traceback.format_exc()}')
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
@ -145,5 +145,5 @@ def get_embedding_model_by_dataset_id_list(dataset_id_list: List):
|
|||||||
|
|
||||||
|
|
||||||
def get_embedding_model_by_dataset_id(dataset_id: str):
|
def get_embedding_model_by_dataset_id(dataset_id: str):
|
||||||
dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id)
|
dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first()
|
||||||
return EmbeddingModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode))
|
return EmbeddingModelManage.get_model(dataset_id, lambda _id: get_model(dataset.embedding_mode))
|
||||||
|
|||||||
@ -745,7 +745,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
def re_embedding(self, with_valid=True):
|
def re_embedding(self, with_valid=True):
|
||||||
if with_valid:
|
if with_valid:
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
ListenerManagement.embedding_by_dataset_signal.send(self.data.get('id'))
|
model = get_embedding_model_by_dataset_id(self.data.get('id'))
|
||||||
|
ListenerManagement.embedding_by_dataset_signal.send(self.data.get('id'), embedding_model=model)
|
||||||
|
|
||||||
def list_application(self, with_valid=True):
|
def list_application(self, with_valid=True):
|
||||||
if with_valid:
|
if with_valid:
|
||||||
|
|||||||
@ -41,7 +41,8 @@ from common.util.file_util import get_file_content
|
|||||||
from common.util.fork import Fork
|
from common.util.fork import Fork
|
||||||
from common.util.split_model import get_split_model
|
from common.util.split_model import get_split_model
|
||||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image
|
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image
|
||||||
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage
|
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \
|
||||||
|
get_embedding_model_by_dataset_id
|
||||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
|
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
@ -392,7 +393,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
problem_paragraph_mapping_list) > 0 else None
|
problem_paragraph_mapping_list) > 0 else None
|
||||||
# 向量化
|
# 向量化
|
||||||
if with_embedding:
|
if with_embedding:
|
||||||
ListenerManagement.embedding_by_document_signal.send(document_id)
|
model = get_embedding_model_by_dataset_id(dataset_id=document.dataset_id)
|
||||||
|
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
|
||||||
else:
|
else:
|
||||||
document.status = Status.error
|
document.status = Status.error
|
||||||
document.save()
|
document.save()
|
||||||
@ -405,6 +407,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
class Operate(ApiMixin, serializers.Serializer):
|
class Operate(ApiMixin, serializers.Serializer):
|
||||||
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
|
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
|
||||||
"文档id"))
|
"文档id"))
|
||||||
|
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("数据集id"))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_request_params_api():
|
def get_request_params_api():
|
||||||
@ -530,7 +533,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
if with_valid:
|
if with_valid:
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
document_id = self.data.get("document_id")
|
document_id = self.data.get("document_id")
|
||||||
ListenerManagement.embedding_by_document_signal.send(document_id)
|
model = get_embedding_model_by_dataset_id(dataset_id=self.data.get('dataset_id'))
|
||||||
|
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
|
||||||
|
|
||||||
@transaction.atomic
|
@transaction.atomic
|
||||||
def delete(self):
|
def delete(self):
|
||||||
@ -599,8 +603,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def post_embedding(result, document_id):
|
def post_embedding(result, document_id, dataset_id):
|
||||||
ListenerManagement.embedding_by_document_signal.send(document_id)
|
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||||
|
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -646,7 +651,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
document_id = str(document_model.id)
|
document_id = str(document_model.id)
|
||||||
return DocumentSerializers.Operate(
|
return DocumentSerializers.Operate(
|
||||||
data={'dataset_id': dataset_id, 'document_id': document_id}).one(
|
data={'dataset_id': dataset_id, 'document_id': document_id}).one(
|
||||||
with_valid=True), document_id
|
with_valid=True), document_id, dataset_id
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_sync_handler(dataset_id):
|
def get_sync_handler(dataset_id):
|
||||||
@ -803,9 +808,10 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
return openapi.Schema(type=openapi.TYPE_ARRAY, items=DocumentSerializers.Create.get_request_body_api())
|
return openapi.Schema(type=openapi.TYPE_ARRAY, items=DocumentSerializers.Create.get_request_body_api())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def post_embedding(document_list):
|
def post_embedding(document_list, dataset_id):
|
||||||
for document_dict in document_list:
|
for document_dict in document_list:
|
||||||
ListenerManagement.embedding_by_document_signal.send(document_dict.get('id'))
|
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||||
|
ListenerManagement.embedding_by_document_signal.send(document_dict.get('id'), embedding_model=model)
|
||||||
return document_list
|
return document_list
|
||||||
|
|
||||||
@post(post_function=post_embedding)
|
@post(post_function=post_embedding)
|
||||||
@ -846,7 +852,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
return [],
|
return [],
|
||||||
query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]})
|
query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]})
|
||||||
return native_search(query_set, select_string=get_file_content(
|
return native_search(query_set, select_string=get_file_content(
|
||||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=False),
|
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')),
|
||||||
|
with_search_one=False), dataset_id
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _batch_sync(document_id_list: List[str]):
|
def _batch_sync(document_id_list: List[str]):
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from common.util.common import post
|
|||||||
from common.util.field_message import ErrMessage
|
from common.util.field_message import ErrMessage
|
||||||
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping
|
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping
|
||||||
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \
|
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \
|
||||||
ProblemParagraphManage
|
ProblemParagraphManage, get_embedding_model_by_dataset_id
|
||||||
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
|
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
|
||||||
from embedding.models import SourceType
|
from embedding.models import SourceType
|
||||||
|
|
||||||
@ -132,6 +132,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
paragraph_id=self.data.get('paragraph_id'),
|
paragraph_id=self.data.get('paragraph_id'),
|
||||||
dataset_id=self.data.get('dataset_id'))
|
dataset_id=self.data.get('dataset_id'))
|
||||||
problem_paragraph_mapping.save()
|
problem_paragraph_mapping.save()
|
||||||
|
model = get_embedding_model_by_dataset_id(self.data.get('dataset_id'))
|
||||||
if with_embedding:
|
if with_embedding:
|
||||||
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
|
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
|
||||||
'is_active': True,
|
'is_active': True,
|
||||||
@ -140,7 +141,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
'document_id': self.data.get('document_id'),
|
'document_id': self.data.get('document_id'),
|
||||||
'paragraph_id': self.data.get('paragraph_id'),
|
'paragraph_id': self.data.get('paragraph_id'),
|
||||||
'dataset_id': self.data.get('dataset_id'),
|
'dataset_id': self.data.get('dataset_id'),
|
||||||
})
|
}, embedding_model=model)
|
||||||
|
|
||||||
return ProblemSerializers.Operate(
|
return ProblemSerializers.Operate(
|
||||||
data={'dataset_id': self.data.get('dataset_id'),
|
data={'dataset_id': self.data.get('dataset_id'),
|
||||||
@ -227,6 +228,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
problem_id=problem.id)
|
problem_id=problem.id)
|
||||||
problem_paragraph_mapping.save()
|
problem_paragraph_mapping.save()
|
||||||
if with_embedding:
|
if with_embedding:
|
||||||
|
model = get_embedding_model_by_dataset_id(self.data.get('dataset_id'))
|
||||||
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
|
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
|
||||||
'is_active': True,
|
'is_active': True,
|
||||||
'source_type': SourceType.PROBLEM,
|
'source_type': SourceType.PROBLEM,
|
||||||
@ -234,7 +236,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
'document_id': self.data.get('document_id'),
|
'document_id': self.data.get('document_id'),
|
||||||
'paragraph_id': self.data.get('paragraph_id'),
|
'paragraph_id': self.data.get('paragraph_id'),
|
||||||
'dataset_id': self.data.get('dataset_id'),
|
'dataset_id': self.data.get('dataset_id'),
|
||||||
})
|
}, embedding_model=model)
|
||||||
|
|
||||||
def un_association(self, with_valid=True):
|
def un_association(self, with_valid=True):
|
||||||
if with_valid:
|
if with_valid:
|
||||||
@ -454,13 +456,14 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
raise AppApiException(500, "段落id不存在")
|
raise AppApiException(500, "段落id不存在")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def post_embedding(paragraph, instance):
|
def post_embedding(paragraph, instance, dataset_id):
|
||||||
if 'is_active' in instance and instance.get('is_active') is not None:
|
if 'is_active' in instance and instance.get('is_active') is not None:
|
||||||
s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get(
|
s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get(
|
||||||
'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal)
|
'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal)
|
||||||
s.send(paragraph.get('id'))
|
s.send(paragraph.get('id'))
|
||||||
else:
|
else:
|
||||||
ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id'))
|
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||||
|
ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id'), embedding_model=model)
|
||||||
return paragraph
|
return paragraph
|
||||||
|
|
||||||
@post(post_embedding)
|
@post(post_embedding)
|
||||||
@ -508,7 +511,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
|
|
||||||
_paragraph.save()
|
_paragraph.save()
|
||||||
update_document_char_length(self.data.get('document_id'))
|
update_document_char_length(self.data.get('document_id'))
|
||||||
return self.one(), instance
|
return self.one(), instance, self.data.get('dataset_id')
|
||||||
|
|
||||||
def get_problem_list(self):
|
def get_problem_list(self):
|
||||||
ProblemParagraphMapping(ProblemParagraphMapping)
|
ProblemParagraphMapping(ProblemParagraphMapping)
|
||||||
@ -582,7 +585,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
# 修改长度
|
# 修改长度
|
||||||
update_document_char_length(document_id)
|
update_document_char_length(document_id)
|
||||||
if with_embedding:
|
if with_embedding:
|
||||||
ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id))
|
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||||
|
ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id), embedding_model=model)
|
||||||
return ParagraphSerializers.Operate(
|
return ParagraphSerializers.Operate(
|
||||||
data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one(
|
data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one(
|
||||||
with_valid=True)
|
with_valid=True)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user