This commit is contained in:
tongque 2024-05-01 11:51:23 +08:00
commit 93a959b3ef
31 changed files with 479 additions and 167 deletions

View File

@ -54,7 +54,7 @@ class IChatStep(IBaseChatPipelineStep):
message_list = serializers.ListField(required=True, child=MessageField(required=True), message_list = serializers.ListField(required=True, child=MessageField(required=True),
error_messages=ErrMessage.list("对话列表")) error_messages=ErrMessage.list("对话列表"))
# 大语言模型 # 大语言模型
chat_model = ModelField(error_messages=ErrMessage.list("大语言模型")) chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.list("大语言模型"))
# 段落列表 # 段落列表
paragraph_list = serializers.ListField(error_messages=ErrMessage.list("段落列表")) paragraph_list = serializers.ListField(error_messages=ErrMessage.list("段落列表"))
# 对话id # 对话id

View File

@ -59,8 +59,12 @@ def event_content(response,
# 获取token # 获取token
if is_ai_chat: if is_ai_chat:
request_token = chat_model.get_num_tokens_from_messages(message_list) try:
response_token = chat_model.get_num_tokens(all_text) request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(all_text)
except Exception as e:
request_token = 0
response_token = 0
else: else:
request_token = 0 request_token = 0
response_token = 0 response_token = 0
@ -126,6 +130,26 @@ class BaseChatStep(IChatStep):
result.append({'role': 'ai', 'content': answer_text}) result.append({'role': 'ai', 'content': answer_text})
return result return result
@staticmethod
def get_stream_result(message_list: List[BaseMessage],
chat_model: BaseChatModel = None,
paragraph_list=None,
no_references_setting=None):
if paragraph_list is None:
paragraph_list = []
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
for paragraph in paragraph_list if
paragraph.hit_handling_method == 'directly_return']
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
return iter(directly_return_chunk_list), False
elif len(paragraph_list) == 0 and no_references_setting.get(
'status') == 'designated_answer':
return iter([AIMessageChunk(content=no_references_setting.get('value'))]), False
if chat_model is None:
return iter([AIMessageChunk('抱歉,没有配置 AI 模型,无法优化引用分段,请先去应用中设置 AI 模型。')]), False
else:
return chat_model.stream(message_list), True
def execute_stream(self, message_list: List[BaseMessage], def execute_stream(self, message_list: List[BaseMessage],
chat_id, chat_id,
problem_text, problem_text,
@ -136,29 +160,8 @@ class BaseChatStep(IChatStep):
padding_problem_text: str = None, padding_problem_text: str = None,
client_id=None, client_type=None, client_id=None, client_type=None,
no_references_setting=None): no_references_setting=None):
is_ai_chat = False chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
# 调用模型 no_references_setting)
if chat_model is None:
chat_result = iter(
[AIMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list])
else:
if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get(
'status') == 'designated_answer':
chat_result = iter([AIMessageChunk(content=no_references_setting.get('value'))])
else:
if paragraph_list is not None and len(paragraph_list) > 0:
directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content)
for paragraph in paragraph_list if
paragraph.hit_handling_method == 'directly_return']
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
chat_result = iter(directly_return_chunk_list)
else:
chat_result = chat_model.stream(message_list)
is_ai_chat = True
else:
chat_result = chat_model.stream(message_list)
is_ai_chat = True
chat_record_id = uuid.uuid1() chat_record_id = uuid.uuid1()
r = StreamingHttpResponse( r = StreamingHttpResponse(
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list, streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
@ -169,6 +172,27 @@ class BaseChatStep(IChatStep):
r['Cache-Control'] = 'no-cache' r['Cache-Control'] = 'no-cache'
return r return r
@staticmethod
def get_block_result(message_list: List[BaseMessage],
chat_model: BaseChatModel = None,
paragraph_list=None,
no_references_setting=None):
if paragraph_list is None:
paragraph_list = []
directly_return_chunk_list = [AIMessage(content=paragraph.content)
for paragraph in paragraph_list if
paragraph.hit_handling_method == 'directly_return']
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
return directly_return_chunk_list[0], False
elif len(paragraph_list) == 0 and no_references_setting.get(
'status') == 'designated_answer':
return AIMessage(no_references_setting.get('value')), False
if chat_model is None:
return AIMessage('抱歉,没有配置 AI 模型,无法优化引用分段,请先去应用中设置 AI 模型。'), False
else:
return chat_model.invoke(message_list), True
def execute_block(self, message_list: List[BaseMessage], def execute_block(self, message_list: List[BaseMessage],
chat_id, chat_id,
problem_text, problem_text,
@ -178,28 +202,8 @@ class BaseChatStep(IChatStep):
manage: PipelineManage = None, manage: PipelineManage = None,
padding_problem_text: str = None, padding_problem_text: str = None,
client_id=None, client_type=None, no_references_setting=None): client_id=None, client_type=None, no_references_setting=None):
is_ai_chat = False
# 调用模型 # 调用模型
if chat_model is None: chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list, no_references_setting)
chat_result = AIMessage(
content="\n\n".join([paragraph.title + "\n" + paragraph.content for paragraph in paragraph_list]))
else:
if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get(
'status') == 'designated_answer':
chat_result = AIMessage(content=no_references_setting.get('value'))
else:
if paragraph_list is not None and len(paragraph_list) > 0:
directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content)
for paragraph in paragraph_list if
paragraph.hit_handling_method == 'directly_return']
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
chat_result = iter(directly_return_chunk_list)
else:
chat_result = chat_model.invoke(message_list)
is_ai_chat = True
else:
chat_result = chat_model.invoke(message_list)
is_ai_chat = True
chat_record_id = uuid.uuid1() chat_record_id = uuid.uuid1()
if is_ai_chat: if is_ai_chat:
request_token = chat_model.get_num_tokens_from_messages(message_list) request_token = chat_model.get_num_tokens_from_messages(message_list)

View File

@ -28,7 +28,7 @@ class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep):
padding_problem_text: str = None, padding_problem_text: str = None,
no_references_setting=None, no_references_setting=None,
**kwargs) -> List[BaseMessage]: **kwargs) -> List[BaseMessage]:
prompt = prompt if no_references_setting.get('status') == 'designated_answer' else no_references_setting.get( prompt = prompt if (paragraph_list is not None and len(paragraph_list) > 0) else no_references_setting.get(
'value') 'value')
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
start_index = len(history_chat_record) - dialogue_number start_index = len(history_chat_record) - dialogue_number

View File

@ -28,7 +28,7 @@ class IResetProblemStep(IBaseChatPipelineStep):
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True), history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
error_messages=ErrMessage.list("历史对答")) error_messages=ErrMessage.list("历史对答"))
# 大语言模型 # 大语言模型
chat_model = ModelField(error_messages=ErrMessage.base("大语言模型")) chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.base("大语言模型"))
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]: def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
return self.InstanceSerializer return self.InstanceSerializer

View File

@ -22,6 +22,10 @@ prompt = (
class BaseResetProblemStep(IResetProblemStep): class BaseResetProblemStep(IResetProblemStep):
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None, def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
**kwargs) -> str: **kwargs) -> str:
if chat_model is None:
self.context['message_tokens'] = 0
self.context['answer_tokens'] = 0
return problem_text
start_index = len(history_chat_record) - 3 start_index = len(history_chat_record) - 3
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
for index in for index in
@ -35,8 +39,14 @@ class BaseResetProblemStep(IResetProblemStep):
response.content.index('<data>') + 6:response.content.index('</data>')] response.content.index('<data>') + 6:response.content.index('</data>')]
if padding_problem_data is not None and len(padding_problem_data.strip()) > 0: if padding_problem_data is not None and len(padding_problem_data.strip()) > 0:
padding_problem = padding_problem_data padding_problem = padding_problem_data
self.context['message_tokens'] = chat_model.get_num_tokens_from_messages(message_list) try:
self.context['answer_tokens'] = chat_model.get_num_tokens(padding_problem) request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(padding_problem)
except Exception as e:
request_token = 0
response_token = 0
self.context['message_tokens'] = request_token
self.context['answer_tokens'] = response_token
return padding_problem return padding_problem
def get_details(self, manage, **kwargs): def get_details(self, manage, **kwargs):

View File

@ -0,0 +1,23 @@
# Generated by Django 4.1.13 on 2024-04-29 13:33
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('application', '0004_applicationaccesstoken_show_source'),
]
operations = [
migrations.AlterField(
model_name='chat',
name='abstract',
field=models.CharField(max_length=1024, verbose_name='摘要'),
),
migrations.AlterField(
model_name='chatrecord',
name='answer_text',
field=models.CharField(max_length=40960, verbose_name='答案'),
),
]

View File

@ -73,7 +73,7 @@ class ApplicationDatasetMapping(AppModelMixin):
class Chat(AppModelMixin): class Chat(AppModelMixin):
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
application = models.ForeignKey(Application, on_delete=models.CASCADE) application = models.ForeignKey(Application, on_delete=models.CASCADE)
abstract = models.CharField(max_length=256, verbose_name="摘要") abstract = models.CharField(max_length=1024, verbose_name="摘要")
client_id = models.UUIDField(verbose_name="客户端id", default=None, null=True) client_id = models.UUIDField(verbose_name="客户端id", default=None, null=True)
class Meta: class Meta:
@ -96,7 +96,7 @@ class ChatRecord(AppModelMixin):
vote_status = models.CharField(verbose_name='投票', max_length=10, choices=VoteChoices.choices, vote_status = models.CharField(verbose_name='投票', max_length=10, choices=VoteChoices.choices,
default=VoteChoices.UN_VOTE) default=VoteChoices.UN_VOTE)
problem_text = models.CharField(max_length=1024, verbose_name="问题") problem_text = models.CharField(max_length=1024, verbose_name="问题")
answer_text = models.CharField(max_length=4096, verbose_name="答案") answer_text = models.CharField(max_length=40960, verbose_name="答案")
message_tokens = models.IntegerField(verbose_name="请求token数量", default=0) message_tokens = models.IntegerField(verbose_name="请求token数量", default=0)
answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0) answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0)
const = models.IntegerField(verbose_name="总费用", default=0) const = models.IntegerField(verbose_name="总费用", default=0)

View File

@ -47,7 +47,8 @@ chat_cache = cache.caches['chat_cache']
class ModelDatasetAssociation(serializers.Serializer): class ModelDatasetAssociation(serializers.Serializer):
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id")) model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("模型id"))
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True, dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True,
error_messages=ErrMessage.uuid( error_messages=ErrMessage.uuid(
"知识库id")), "知识库id")),
@ -57,8 +58,9 @@ class ModelDatasetAssociation(serializers.Serializer):
super().is_valid(raise_exception=True) super().is_valid(raise_exception=True)
model_id = self.data.get('model_id') model_id = self.data.get('model_id')
user_id = self.data.get('user_id') user_id = self.data.get('user_id')
if not QuerySet(Model).filter(id=model_id).exists(): if model_id is not None and len(model_id) > 0:
raise AppApiException(500, f'模型不存在【{model_id}') if not QuerySet(Model).filter(id=model_id).exists():
raise AppApiException(500, f'模型不存在【{model_id}')
dataset_id_list = list(set(self.data.get('dataset_id_list'))) dataset_id_list = list(set(self.data.get('dataset_id_list')))
exist_dataset_id_list = [str(dataset.id) for dataset in exist_dataset_id_list = [str(dataset.id) for dataset in
QuerySet(DataSet).filter(id__in=dataset_id_list, user_id=user_id)] QuerySet(DataSet).filter(id__in=dataset_id_list, user_id=user_id)]
@ -109,7 +111,8 @@ class ApplicationSerializer(serializers.Serializer):
desc = serializers.CharField(required=False, allow_null=True, allow_blank=True, desc = serializers.CharField(required=False, allow_null=True, allow_blank=True,
max_length=256, min_length=1, max_length=256, min_length=1,
error_messages=ErrMessage.char("应用描述")) error_messages=ErrMessage.char("应用描述"))
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型")) model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("模型"))
multiple_rounds_dialogue = serializers.BooleanField(required=True, error_messages=ErrMessage.char("多轮对话")) multiple_rounds_dialogue = serializers.BooleanField(required=True, error_messages=ErrMessage.char("多轮对话"))
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024, prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024,
error_messages=ErrMessage.char("开场白")) error_messages=ErrMessage.char("开场白"))
@ -254,7 +257,8 @@ class ApplicationSerializer(serializers.Serializer):
error_messages=ErrMessage.char("应用名称")) error_messages=ErrMessage.char("应用名称"))
desc = serializers.CharField(required=False, max_length=256, min_length=1, allow_null=True, allow_blank=True, desc = serializers.CharField(required=False, max_length=256, min_length=1, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("应用描述")) error_messages=ErrMessage.char("应用描述"))
model_id = serializers.CharField(required=False, error_messages=ErrMessage.char("模型")) model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True,
error_messages=ErrMessage.char("模型"))
multiple_rounds_dialogue = serializers.BooleanField(required=False, multiple_rounds_dialogue = serializers.BooleanField(required=False,
error_messages=ErrMessage.boolean("多轮会话")) error_messages=ErrMessage.boolean("多轮会话"))
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024, prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024,
@ -494,22 +498,21 @@ class ApplicationSerializer(serializers.Serializer):
application_id = self.data.get("application_id") application_id = self.data.get("application_id")
application = QuerySet(Application).get(id=application_id) application = QuerySet(Application).get(id=application_id)
if instance.get('model_id') is None or len(instance.get('model_id')) == 0:
model = QuerySet(Model).filter( application.model_id = None
id=instance.get('model_id') if 'model_id' in instance else application.model_id, else:
user_id=application.user_id).first() model = QuerySet(Model).filter(
if model is None: id=instance.get('model_id'),
raise AppApiException(500, "模型不存在") user_id=application.user_id).first()
if model is None:
raise AppApiException(500, "模型不存在")
update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status', update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status',
'dataset_setting', 'model_setting', 'problem_optimization', 'dataset_setting', 'model_setting', 'problem_optimization',
'api_key_is_active', 'icon'] 'api_key_is_active', 'icon']
for update_key in update_keys: for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None: if update_key in instance and instance.get(update_key) is not None:
if update_key == 'multiple_rounds_dialogue': if update_key == 'multiple_rounds_dialogue':
application.__setattr__('dialogue_number', application.__setattr__('dialogue_number', 0 if not instance.get(update_key) else 3)
0 if not instance.get(update_key) else ModelProvideConstants[
model.provider].value.get_dialogue_number())
else: else:
application.__setattr__(update_key, instance.get(update_key)) application.__setattr__(update_key, instance.get(update_key))
application.save() application.save()

View File

@ -27,7 +27,7 @@ from application.models.api_key_model import ApplicationPublicAccessClient, Appl
from common.constants.authentication_type import AuthenticationType from common.constants.authentication_type import AuthenticationType
from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed
from common.util.field_message import ErrMessage from common.util.field_message import ErrMessage
from common.util.rsa_util import decrypt from common.util.rsa_util import rsa_long_decrypt
from common.util.split_model import flat_map from common.util.split_model import flat_map
from dataset.models import Paragraph, Document from dataset.models import Paragraph, Document
from setting.models import Model, Status from setting.models import Model, Status
@ -138,7 +138,7 @@ def get_post_handler(chat_info: ChatInfo):
class ChatMessageSerializer(serializers.Serializer): class ChatMessageSerializer(serializers.Serializer):
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("对话id")) chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("对话id"))
message = serializers.CharField(required=True, error_messages=ErrMessage.char("用户问题")) message = serializers.CharField(required=True, error_messages=ErrMessage.char("用户问题"), max_length=1024)
stream = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否流式回答")) stream = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否流式回答"))
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否重新回答")) re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否重新回答"))
application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id")) application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id"))
@ -167,9 +167,11 @@ class ChatMessageSerializer(serializers.Serializer):
chat_cache.set(chat_id, chat_cache.set(chat_id,
chat_info, timeout=60 * 30) chat_info, timeout=60 * 30)
model = chat_info.application.model model = chat_info.application.model
if model is None:
return chat_info
model = QuerySet(Model).filter(id=model.id).first() model = QuerySet(Model).filter(id=model.id).first()
if model is None: if model is None:
raise AppApiException(500, "模型不存在") return chat_info
if model.status == Status.ERROR: if model.status == Status.ERROR:
raise AppApiException(500, "当前模型不可用") raise AppApiException(500, "当前模型不可用")
if model.status == Status.DOWNLOAD: if model.status == Status.DOWNLOAD:
@ -223,7 +225,7 @@ class ChatMessageSerializer(serializers.Serializer):
# 对话模型 # 对话模型
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads( json.loads(
decrypt(model.credential)), rsa_long_decrypt(model.credential)),
streaming=True) streaming=True)
# 数据集id列表 # 数据集id列表
dataset_id_list = [str(row.dataset_id) for row in dataset_id_list = [str(row.dataset_id) for row in

View File

@ -35,7 +35,7 @@ from common.util.common import post
from common.util.field_message import ErrMessage from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock from common.util.lock import try_lock, un_lock
from common.util.rsa_util import decrypt from common.util.rsa_util import rsa_long_decrypt
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
from dataset.serializers.paragraph_serializers import ParagraphSerializers from dataset.serializers.paragraph_serializers import ParagraphSerializers
from setting.models import Model from setting.models import Model
@ -195,7 +195,8 @@ class ChatSerializers(serializers.Serializer):
if model is not None: if model is not None:
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads( json.loads(
decrypt(model.credential)), rsa_long_decrypt(
model.credential)),
streaming=True) streaming=True)
chat_id = str(uuid.uuid1()) chat_id = str(uuid.uuid1())
@ -213,7 +214,8 @@ class ChatSerializers(serializers.Serializer):
id = serializers.UUIDField(required=False, allow_null=True, id = serializers.UUIDField(required=False, allow_null=True,
error_messages=ErrMessage.uuid("应用id")) error_messages=ErrMessage.uuid("应用id"))
model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id")) model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.uuid("模型id"))
multiple_rounds_dialogue = serializers.BooleanField(required=True, multiple_rounds_dialogue = serializers.BooleanField(required=True,
error_messages=ErrMessage.boolean("多轮会话")) error_messages=ErrMessage.boolean("多轮会话"))
@ -246,14 +248,18 @@ class ChatSerializers(serializers.Serializer):
def open(self): def open(self):
user_id = self.is_valid(raise_exception=True) user_id = self.is_valid(raise_exception=True)
chat_id = str(uuid.uuid1()) chat_id = str(uuid.uuid1())
model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first() model_id = self.data.get('model_id')
if model is None: if model_id is not None and len(model_id) > 0:
raise AppApiException(500, "模型不存在") model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first()
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
rsa_long_decrypt(
model.credential)),
streaming=True)
else:
model = None
chat_model = None
dataset_id_list = self.data.get('dataset_id_list') dataset_id_list = self.data.get('dataset_id_list')
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
decrypt(model.credential)),
streaming=True)
application = Application(id=None, dialogue_number=3, model=model, application = Application(id=None, dialogue_number=3, model=model,
dataset_setting=self.data.get('dataset_setting'), dataset_setting=self.data.get('dataset_setting'),
model_setting=self.data.get('model_setting'), model_setting=self.data.get('model_setting'),

View File

@ -62,18 +62,6 @@ def get_key_pair_by_sql():
return system_setting.meta return system_setting.meta
# def get_key_pair():
# if not os.path.exists("/opt/maxkb/conf/receiver.pem"):
# kv = generate()
# private_file_out = open("/opt/maxkb/conf/private.pem", "wb")
# private_file_out.write(kv.get('value'))
# private_file_out.close()
# receiver_file_out = open("/opt/maxkb/conf/receiver.pem", "wb")
# receiver_file_out.write(kv.get('key'))
# receiver_file_out.close()
# return {'key': open("/opt/maxkb/conf/receiver.pem").read(), 'value': open("/opt/maxkb/conf/private.pem").read()}
def encrypt(msg, public_key: str | None = None): def encrypt(msg, public_key: str | None = None):
""" """
加密 加密
@ -100,3 +88,53 @@ def decrypt(msg, pri_key: str | None = None):
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code)) cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
decrypt_data = cipher.decrypt(base64.b64decode(msg), 0) decrypt_data = cipher.decrypt(base64.b64decode(msg), 0)
return decrypt_data.decode("utf-8") return decrypt_data.decode("utf-8")
def rsa_long_encrypt(message, public_key: str | None = None, length=200):
"""
超长文本加密
:param message: 需要加密的字符串
:param public_key 公钥
:param length: 1024bit的证书用100 2048bit的证书用 200
:return: 加密后的数据
"""
# 读取公钥
if public_key is None:
public_key = get_key_pair().get('key')
cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key,
passphrase=secret_code))
# 处理Plaintext is too long. 分段加密
if len(message) <= length:
# 对编码的数据进行加密并通过base64进行编码
result = base64.b64encode(cipher.encrypt(message.encode('utf-8')))
else:
rsa_text = []
# 对编码后的数据进行切片,原因:加密长度不能过长
for i in range(0, len(message), length):
cont = message[i:i + length]
# 对切片后的数据进行加密并新增到text后面
rsa_text.append(cipher.encrypt(cont.encode('utf-8')))
# 加密完进行拼接
cipher_text = b''.join(rsa_text)
# base64进行编码
result = base64.b64encode(cipher_text)
return result.decode()
def rsa_long_decrypt(message, pri_key: str | None = None, length=256):
"""
超长文本解密默认不加密
:param message: 需要解密的数据
:param pri_key: 秘钥
:param length : 1024bit的证书用1282048bit证书用256位
:return: 解密后的数据
"""
if pri_key is None:
pri_key = get_key_pair().get('value')
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
base64_de = base64.b64decode(message)
res = []
for i in range(0, len(base64_de), length):
res.append(cipher.decrypt(base64_de[i:i + length], 0))
return b"".join(res).decode()

View File

@ -164,6 +164,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
elif target_dataset.type == Type.base.value and dataset.type == Type.web.value: elif target_dataset.type == Type.base.value and dataset.type == Type.web.value:
document_list.update(dataset_id=target_dataset_id, type=Type.base, document_list.update(dataset_id=target_dataset_id, type=Type.base,
meta={}) meta={})
else:
document_list.update(dataset_id=target_dataset_id)
paragraph_list.update(dataset_id=target_dataset_id) paragraph_list.update(dataset_id=target_dataset_id)
ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs( ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs(
[problem_paragraph_mapping.id for problem_paragraph_mapping in problem_paragraph_mapping_list], [problem_paragraph_mapping.id for problem_paragraph_mapping in problem_paragraph_mapping_list],
@ -713,6 +715,19 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
ListenerManagement.delete_embedding_by_document_list_signal.send(document_id_list) ListenerManagement.delete_embedding_by_document_list_signal.send(document_id_list)
return True return True
def batch_edit_hit_handling(self, instance: Dict, with_valid=True):
if with_valid:
BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
hit_handling_method = instance.get('hit_handling_method')
if hit_handling_method is None:
raise AppApiException(500, '命中处理方式必填')
if hit_handling_method != 'optimization' and hit_handling_method != 'directly_return':
raise AppApiException(500, '命中处理方式必须为directly_return|optimization')
self.is_valid(raise_exception=True)
document_id_list = instance.get("id_list")
hit_handling_method = instance.get('hit_handling_method')
QuerySet(Document).filter(id__in=document_id_list).update(hit_handling_method=hit_handling_method)
class FileBufferHandle: class FileBufferHandle:
buffer = None buffer = None

View File

@ -0,0 +1,27 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file document_api.py
@date2024/4/28 13:56
@desc:
"""
from drf_yasg import openapi
from common.mixins.api_mixin import ApiMixin
class DocumentApi(ApiMixin):
class BatchEditHitHandlingApi(ApiMixin):
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
properties={
'id_list': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING),
title="主键id列表",
description="主键id列表"),
'hit_handling_method': openapi.Schema(type=openapi.TYPE_STRING, title="命中处理方式",
description="directly_return|optimization")
}
)

View File

@ -14,6 +14,7 @@ urlpatterns = [
path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document'), path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document'),
path('dataset/<str:dataset_id>/document/web', views.WebDocument.as_view()), path('dataset/<str:dataset_id>/document/web', views.WebDocument.as_view()),
path('dataset/<str:dataset_id>/document/_bach', views.Document.Batch.as_view()), path('dataset/<str:dataset_id>/document/_bach', views.Document.Batch.as_view()),
path('dataset/<str:dataset_id>/document/batch_hit_handling', views.Document.BatchEditHitHandling.as_view()),
path('dataset/<str:dataset_id>/document/<int:current_page>/<int:page_size>', views.Document.Page.as_view()), path('dataset/<str:dataset_id>/document/<int:current_page>/<int:page_size>', views.Document.Page.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>', views.Document.Operate.as_view(), path('dataset/<str:dataset_id>/document/<str:document_id>', views.Document.Operate.as_view(),
name="document_operate"), name="document_operate"),

View File

@ -19,6 +19,7 @@ from common.response import result
from common.util.common import query_params_to_single_dict from common.util.common import query_params_to_single_dict
from dataset.serializers.common_serializers import BatchSerializer from dataset.serializers.common_serializers import BatchSerializer
from dataset.serializers.document_serializers import DocumentSerializers, DocumentWebInstanceSerializer from dataset.serializers.document_serializers import DocumentSerializers, DocumentWebInstanceSerializer
from dataset.swagger_api.document_api import DocumentApi
class WebDocument(APIView): class WebDocument(APIView):
@ -71,6 +72,24 @@ class Document(APIView):
d.is_valid(raise_exception=True) d.is_valid(raise_exception=True)
return result.success(d.list()) return result.success(d.list())
class BatchEditHitHandling(APIView):
authentication_classes = [TokenAuth]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="批量修改文档命中处理方式",
operation_id="批量修改文档命中处理方式",
request_body=
DocumentApi.BatchEditHitHandlingApi.get_request_body_api(),
manual_parameters=DocumentSerializers.Create.get_request_params_api(),
responses=result.get_default_response(),
tags=["知识库/文档"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def put(self, request: Request, dataset_id: str):
return result.success(
DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_edit_hit_handling(request.data))
class Batch(APIView): class Batch(APIView):
authentication_classes = [TokenAuth] authentication_classes = [TokenAuth]

View File

@ -18,27 +18,30 @@ def update_embedding_search_vector(embedding, paragraph_list):
def save_keywords(apps, schema_editor): def save_keywords(apps, schema_editor):
document = apps.get_model("dataset", "Document") try:
embedding = apps.get_model("embedding", "Embedding") document = apps.get_model("dataset", "Document")
paragraph = apps.get_model('dataset', 'Paragraph') embedding = apps.get_model("embedding", "Embedding")
db_alias = schema_editor.connection.alias paragraph = apps.get_model('dataset', 'Paragraph')
document_list = document.objects.using(db_alias).all() db_alias = schema_editor.connection.alias
for document in document_list: document_list = document.objects.using(db_alias).all()
document.status = Status.embedding for document in document_list:
document.save() document.status = Status.embedding
paragraph_list = paragraph.objects.using(db_alias).filter(document=document).all() document.save()
embedding_list = embedding.objects.using(db_alias).filter(document=document).values('id', 'search_vector', paragraph_list = paragraph.objects.using(db_alias).filter(document=document).all()
'paragraph') embedding_list = embedding.objects.using(db_alias).filter(document=document).values('id', 'search_vector',
embedding_update_list = [update_embedding_search_vector(embedding, paragraph_list) for embedding 'paragraph')
in embedding_list] embedding_update_list = [update_embedding_search_vector(embedding, paragraph_list) for embedding
child_array = sub_array(embedding_update_list, 50) in embedding_list]
for c in child_array: child_array = sub_array(embedding_update_list, 50)
try: for c in child_array:
embedding.objects.using(db_alias).bulk_update(c, ['search_vector']) try:
except Exception as e: embedding.objects.using(db_alias).bulk_update(c, ['search_vector'])
print(e) except Exception as e:
document.status = Status.success print(e)
document.save() document.status = Status.success
document.save()
except Exception as e:
print(e)
class Migration(migrations.Migration): class Migration(migrations.Migration):

View File

@ -0,0 +1,18 @@
# Generated by Django 4.1.13 on 2024-04-28 18:06
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('setting', '0003_model_meta_model_status'),
]
operations = [
migrations.AlterField(
model_name='model',
name='credential',
field=models.CharField(max_length=102400, verbose_name='模型认证信息'),
),
]

View File

@ -42,7 +42,7 @@ class Model(AppModelMixin):
provider = models.CharField(max_length=128, verbose_name='供应商') provider = models.CharField(max_length=128, verbose_name='供应商')
credential = models.CharField(max_length=5120, verbose_name="模型认证信息") credential = models.CharField(max_length=102400, verbose_name="模型认证信息")
meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict) meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict)

View File

@ -18,7 +18,7 @@ from rest_framework import serializers
from application.models import Application from application.models import Application
from common.exception.app_exception import AppApiException from common.exception.app_exception import AppApiException
from common.util.field_message import ErrMessage from common.util.field_message import ErrMessage
from common.util.rsa_util import encrypt, decrypt from common.util.rsa_util import rsa_long_decrypt, rsa_long_encrypt
from setting.models.model_management import Model, Status from setting.models.model_management import Model, Status
from setting.models_provider.base_model_provider import ValidCode, DownModelChunkStatus from setting.models_provider.base_model_provider import ValidCode, DownModelChunkStatus
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
@ -118,7 +118,7 @@ class ModelSerializer(serializers.Serializer):
model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type, model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type,
model_name) model_name)
source_model_credential = json.loads(decrypt(model.credential)) source_model_credential = json.loads(rsa_long_decrypt(model.credential))
source_encryption_model_credential = model_credential.encryption_dict(source_model_credential) source_encryption_model_credential = model_credential.encryption_dict(source_model_credential)
if credential is not None: if credential is not None:
for k in source_encryption_model_credential.keys(): for k in source_encryption_model_credential.keys():
@ -170,7 +170,7 @@ class ModelSerializer(serializers.Serializer):
model_name = self.data.get('model_name') model_name = self.data.get('model_name')
model_credential_str = json.dumps(credential) model_credential_str = json.dumps(credential)
model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name, model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name,
credential=encrypt(model_credential_str), credential=rsa_long_encrypt(model_credential_str),
provider=provider, model_type=model_type, model_name=model_name) provider=provider, model_type=model_type, model_name=model_name)
model.save() model.save()
if status == Status.DOWNLOAD: if status == Status.DOWNLOAD:
@ -180,7 +180,7 @@ class ModelSerializer(serializers.Serializer):
@staticmethod @staticmethod
def model_to_dict(model: Model): def model_to_dict(model: Model):
credential = json.loads(decrypt(model.credential)) credential = json.loads(rsa_long_decrypt(model.credential))
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type, return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
'model_name': model.model_name, 'model_name': model.model_name,
'status': model.status, 'status': model.status,
@ -252,7 +252,7 @@ class ModelSerializer(serializers.Serializer):
if update_key in instance and instance.get(update_key) is not None: if update_key in instance and instance.get(update_key) is not None:
if update_key == 'credential': if update_key == 'credential':
model_credential_str = json.dumps(credential) model_credential_str = json.dumps(credential)
model.__setattr__(update_key, encrypt(model_credential_str)) model.__setattr__(update_key, rsa_long_encrypt(model_credential_str))
else: else:
model.__setattr__(update_key, instance.get(update_key)) model.__setattr__(update_key, instance.get(update_key))
model.save() model.save()

View File

@ -73,7 +73,7 @@ const postDataset: (data: datasetData, loading?: Ref<boolean>) => Promise<Result
data, data,
loading loading
) => { ) => {
return post(`${prefix}`, data, undefined, loading) return post(`${prefix}`, data, undefined, loading, 1000 * 60 * 5)
} }
/** /**

View File

@ -10,7 +10,7 @@ const prefix = '/dataset'
* @param file:file,limit:number,patterns:array,with_filter:boolean * @param file:file,limit:number,patterns:array,with_filter:boolean
*/ */
const postSplitDocument: (data: any) => Promise<Result<any>> = (data) => { const postSplitDocument: (data: any) => Promise<Result<any>> = (data) => {
return post(`${prefix}/document/split`, data) return post(`${prefix}/document/split`, data, undefined, undefined, 1000 * 60 * 60)
} }
/** /**
@ -80,7 +80,7 @@ const postDocument: (
data: any, data: any,
loading?: Ref<boolean> loading?: Ref<boolean>
) => Promise<Result<any>> = (dataset_id, data, loading) => { ) => Promise<Result<any>> = (dataset_id, data, loading) => {
return post(`${prefix}/${dataset_id}/document/_bach`, data, {}, loading) return post(`${prefix}/${dataset_id}/document/_bach`, data, {}, loading, 1000 * 60 * 5)
} }
/** /**
@ -206,6 +206,20 @@ const putMigrateMulDocument: (
) )
} }
/**
*
* @param dataset_id id
* @param data {id_list:[],hit_handling_method:'directly_return|optimization'}
* @param loading
* @returns
*/
const batchEditHitHandling: (
dataset_id: string,
data: any,
loading?: Ref<boolean>
) => Promise<Result<boolean>> = (dataset_id, data, loading) => {
return put(`${prefix}/${dataset_id}/document/batch_hit_handling`, data, undefined, loading)
}
export default { export default {
postSplitDocument, postSplitDocument,
getDocument, getDocument,
@ -219,5 +233,6 @@ export default {
putDocumentRefresh, putDocumentRefresh,
delMulSyncDocument, delMulSyncDocument,
postWebDocument, postWebDocument,
putMigrateMulDocument putMigrateMulDocument,
batchEditHitHandling
} }

View File

@ -56,7 +56,7 @@ export class ChatRecordManage {
this.chat.answer_text = this.chat.answer_text =
this.chat.answer_text + this.chat.buffer.splice(0, this.chat.buffer.length - 20).join('') this.chat.answer_text + this.chat.buffer.splice(0, this.chat.buffer.length - 20).join('')
} else if (this.is_close) { } else if (this.is_close) {
this.chat.answer_text = this.chat.answer_text + this.chat.buffer.join('') this.chat.answer_text = this.chat.answer_text + this.chat.buffer.splice(0).join('')
this.chat.write_ed = true this.chat.write_ed = true
this.write_ed = true this.write_ed = true
if (this.loading) { if (this.loading) {

View File

@ -5,6 +5,7 @@
v-model="dialogVisible" v-model="dialogVisible"
destroy-on-close destroy-on-close
append-to-body append-to-body
align-center
> >
<div class="paragraph-source-height"> <div class="paragraph-source-height">
<el-scrollbar> <el-scrollbar>
@ -63,7 +64,7 @@
</el-dialog> </el-dialog>
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { ref, watch, nextTick } from 'vue' import { ref, watch, onBeforeUnmount } from 'vue'
import { cloneDeep } from 'lodash' import { cloneDeep } from 'lodash'
import { arraySort } from '@/utils/utils' import { arraySort } from '@/utils/utils'
const emit = defineEmits(['refresh']) const emit = defineEmits(['refresh'])
@ -86,12 +87,15 @@ const open = (data: any, id?: string) => {
detail.value.paragraph_list = arraySort(detail.value.paragraph_list, 'similarity', true) detail.value.paragraph_list = arraySort(detail.value.paragraph_list, 'similarity', true)
dialogVisible.value = true dialogVisible.value = true
} }
onBeforeUnmount(() => {
dialogVisible.value = false
})
defineExpose({ open }) defineExpose({ open })
</script> </script>
<style lang="scss"> <style lang="scss">
.paragraph-source { .paragraph-source {
padding: 0; padding: 0;
.el-dialog__header { .el-dialog__header {
padding: 24px 24px 0 24px; padding: 24px 24px 0 24px;
} }
@ -102,4 +106,9 @@ defineExpose({ open })
height: calc(100vh - 260px); height: calc(100vh - 260px);
} }
} }
@media only screen and (max-width: 768px) {
.paragraph-source {
width: 90% !important;
}
}
</style> </style>

View File

@ -224,7 +224,7 @@ const chartOpenId = ref('')
const chatList = ref<any[]>([]) const chatList = ref<any[]>([])
const isDisabledChart = computed( const isDisabledChart = computed(
() => !(inputValue.value.trim() && (props.appId || (props.data?.name && props.data?.model_id))) () => !(inputValue.value.trim() && (props.appId || props.data?.name))
) )
const isMdArray = (val: string) => val.match(/^-\s.*/m) const isMdArray = (val: string) => val.match(/^-\s.*/m)
const prologueList = computed(() => { const prologueList = computed(() => {
@ -274,12 +274,7 @@ function openParagraph(row: any, id?: string) {
} }
function quickProblemHandle(val: string) { function quickProblemHandle(val: string) {
if (!props.log && !loading.value && props.data?.name && props.data?.model_id) { if (!loading.value && props.data?.name) {
// inputValue.value = val
// nextTick(() => {
// quickInputRef.value?.focus()
// })
handleDebounceClick(val) handleDebounceClick(val)
} }
} }
@ -509,16 +504,14 @@ function regenerationChart(item: chatType) {
} }
function getSourceDetail(row: any) { function getSourceDetail(row: any) {
logApi logApi.getRecordDetail(id || props.appId, row.chat_id, row.record_id, loading).then((res) => {
.getRecordDetail(id || props.appId, row.chat_id, row.record_id, loading) const exclude_keys = ['answer_text', 'id']
.then((res) => { Object.keys(res.data).forEach((key) => {
const exclude_keys = ['answer_text', 'id'] if (!exclude_keys.includes(key)) {
Object.keys(res.data).forEach((key) => { row[key] = res.data[key]
if (!exclude_keys.includes(key)) { }
row[key] = res.data[key]
}
})
}) })
})
return true return true
} }

View File

@ -119,9 +119,15 @@ const promise: (
export const get: ( export const get: (
url: string, url: string,
params?: unknown, params?: unknown,
loading?: NProgress | Ref<boolean> loading?: NProgress | Ref<boolean>,
) => Promise<Result<any>> = (url: string, params: unknown, loading?: NProgress | Ref<boolean>) => { timeout?: number
return promise(request({ url: url, method: 'get', params }), loading) ) => Promise<Result<any>> = (
url: string,
params: unknown,
loading?: NProgress | Ref<boolean>,
timeout?: number
) => {
return promise(request({ url: url, method: 'get', params, timeout: timeout }), loading)
} }
/** /**
@ -136,9 +142,10 @@ export const post: (
url: string, url: string,
data?: unknown, data?: unknown,
params?: unknown, params?: unknown,
loading?: NProgress | Ref<boolean> loading?: NProgress | Ref<boolean>,
) => Promise<Result<any> | any> = (url, data, params, loading) => { timeout?: number
return promise(request({ url: url, method: 'post', data, params }), loading) ) => Promise<Result<any> | any> = (url, data, params, loading, timeout) => {
return promise(request({ url: url, method: 'post', data, params, timeout }), loading)
} }
/**| /**|
@ -153,9 +160,10 @@ export const put: (
url: string, url: string,
data?: unknown, data?: unknown,
params?: unknown, params?: unknown,
loading?: NProgress | Ref<boolean> loading?: NProgress | Ref<boolean>,
) => Promise<Result<any>> = (url, data, params, loading) => { timeout?: number
return promise(request({ url: url, method: 'put', data, params }), loading) ) => Promise<Result<any>> = (url, data, params, loading, timeout) => {
return promise(request({ url: url, method: 'put', data, params, timeout }), loading)
} }
/** /**
@ -169,9 +177,10 @@ export const del: (
url: string, url: string,
params?: unknown, params?: unknown,
data?: unknown, data?: unknown,
loading?: NProgress | Ref<boolean> loading?: NProgress | Ref<boolean>,
) => Promise<Result<any>> = (url, params, data, loading) => { timeout?: number
return promise(request({ url: url, method: 'delete', params, data }), loading) ) => Promise<Result<any>> = (url, params, data, loading, timeout) => {
return promise(request({ url: url, method: 'delete', params, data, timeout }), loading)
} }
/** /**

View File

@ -35,7 +35,7 @@
accept="image/*" accept="image/*"
:on-change="onChange" :on-change="onChange"
> >
<el-button icon="Upload">上传</el-button> <el-button icon="Upload" :disabled="radioType !== 'custom'">上传</el-button>
</el-upload> </el-upload>
</div> </div>
<div class="el-upload__tip info mt-16"> <div class="el-upload__tip info mt-16">

View File

@ -48,7 +48,7 @@
<el-form-item label="AI 模型" prop="model_id"> <el-form-item label="AI 模型" prop="model_id">
<template #label> <template #label>
<div class="flex-between"> <div class="flex-between">
<span>AI 模型 <span class="danger">*</span></span> <span>AI 模型 </span>
</div> </div>
</template> </template>
<el-select <el-select
@ -56,6 +56,7 @@
placeholder="请选择 AI 模型" placeholder="请选择 AI 模型"
class="w-full" class="w-full"
popper-class="select-model" popper-class="select-model"
:clearable="true"
> >
<el-option-group <el-option-group
v-for="(value, label) in modelOptions" v-for="(value, label) in modelOptions"
@ -338,7 +339,7 @@ const rules = reactive<FormRules<ApplicationFormType>>({
name: [{ required: true, message: '请输入应用名称', trigger: 'blur' }], name: [{ required: true, message: '请输入应用名称', trigger: 'blur' }],
model_id: [ model_id: [
{ {
required: true, required: false,
message: '请选择模型', message: '请选择模型',
trigger: 'change' trigger: 'change'
} }

View File

@ -251,7 +251,7 @@ defineExpose({ open })
padding: 0 !important; padding: 0 !important;
} }
.dialog-max-height { .dialog-max-height {
height: calc(100vh - 180px); height: 550px;
} }
.custom-slider { .custom-slider {
.el-input-number.is-without-controls .el-input__wrapper { .el-input-number.is-without-controls .el-input__wrapper {

View File

@ -0,0 +1,96 @@
<template>
<el-dialog
title="设置"
v-model="dialogVisible"
:close-on-click-modal="false"
:close-on-press-escape="false"
:destroy-on-close="true"
width="400"
>
<el-form
label-position="top"
ref="webFormRef"
:rules="rules"
:model="form"
require-asterisk-position="right"
>
<el-form-item>
<template #label>
<div class="flex align-center">
<span class="mr-4">命中处理方式</span>
<el-tooltip
effect="dark"
content="用户提问时,命中文档下的分段时按照设置的方式进行处理。"
placement="right"
>
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
</template>
<el-radio-group v-model="form.hit_handling_method">
<template v-for="(value, key) of hitHandlingMethod" :key="key">
<el-radio :value="key">{{ value }}</el-radio>
</template>
</el-radio-group>
</el-form-item>
</el-form>
<template #footer>
<span class="dialog-footer">
<el-button @click.prevent="dialogVisible = false"> 取消 </el-button>
<el-button type="primary" @click="submit(webFormRef)" :loading="loading"> 确定 </el-button>
</span>
</template>
</el-dialog>
</template>
<script setup lang="ts">
import { ref, reactive, watch } from 'vue'
import { useRoute } from 'vue-router'
import type { FormInstance, FormRules } from 'element-plus'
import documentApi from '@/api/document'
import { MsgSuccess } from '@/utils/message'
import { hitHandlingMethod } from '../utils'
const route = useRoute()
const {
params: { id }
} = route as any
const emit = defineEmits(['refresh'])
const webFormRef = ref()
const loading = ref<boolean>(false)
const documentList = ref<Array<string>>([])
const form = ref<any>({
hit_handling_method: 'optimization'
})
const rules = reactive({
source_url: [{ required: true, message: '请输入文档地址', trigger: 'blur' }]
})
const dialogVisible = ref<boolean>(false)
const open = (list: Array<string>) => {
documentList.value = list
dialogVisible.value = true
}
const submit = async (formEl: FormInstance | undefined) => {
if (!formEl) return
await formEl.validate((valid, fields) => {
if (valid) {
const obj = {
hit_handling_method: form.value.hit_handling_method,
id_list: documentList.value
}
documentApi.batchEditHitHandling(id, obj, loading).then((res: any) => {
MsgSuccess('设置成功')
emit('refresh')
dialogVisible.value = false
})
}
})
}
defineExpose({ open })
</script>
<style lang="scss" scoped></style>

View File

@ -21,10 +21,13 @@
>同步文档</el-button >同步文档</el-button
> >
<el-button @click="openDatasetDialog()" :disabled="multipleSelection.length === 0" <el-button @click="openDatasetDialog()" :disabled="multipleSelection.length === 0"
>批量迁移</el-button >迁移</el-button
>
<el-button @click="openBatchEditDocument" :disabled="multipleSelection.length === 0"
>设置</el-button
> >
<el-button @click="deleteMulDocument" :disabled="multipleSelection.length === 0" <el-button @click="deleteMulDocument" :disabled="multipleSelection.length === 0"
>批量删除</el-button >删除</el-button
> >
</div> </div>
@ -212,6 +215,10 @@
</div> </div>
<ImportDocumentDialog ref="ImportDocumentDialogRef" :title="title" @refresh="refresh" /> <ImportDocumentDialog ref="ImportDocumentDialogRef" :title="title" @refresh="refresh" />
<SyncWebDialog ref="SyncWebDialogRef" @refresh="refresh" /> <SyncWebDialog ref="SyncWebDialogRef" @refresh="refresh" />
<BatchEditDocumentDialog
ref="batchEditDocumentDialogRef"
@refresh="refresh"
></BatchEditDocumentDialog>
<!-- 选择知识库 --> <!-- 选择知识库 -->
<SelectDatasetDialog ref="SelectDatasetDialogRef" @refresh="refresh" /> <SelectDatasetDialog ref="SelectDatasetDialogRef" @refresh="refresh" />
</div> </div>
@ -225,6 +232,7 @@ import documentApi from '@/api/document'
import ImportDocumentDialog from './component/ImportDocumentDialog.vue' import ImportDocumentDialog from './component/ImportDocumentDialog.vue'
import SyncWebDialog from '@/views/dataset/component/SyncWebDialog.vue' import SyncWebDialog from '@/views/dataset/component/SyncWebDialog.vue'
import SelectDatasetDialog from './component/SelectDatasetDialog.vue' import SelectDatasetDialog from './component/SelectDatasetDialog.vue'
import BatchEditDocumentDialog from './component/BatchEditDocumentDialog.vue'
import { numberFormat } from '@/utils/utils' import { numberFormat } from '@/utils/utils'
import { datetimeFormat } from '@/utils/time' import { datetimeFormat } from '@/utils/time'
import { hitHandlingMethod } from './utils' import { hitHandlingMethod } from './utils'
@ -257,7 +265,7 @@ onBeforeRouteLeave((to: any, from: any) => {
}) })
const beforePagination = computed(() => common.paginationConfig[storeKey]) const beforePagination = computed(() => common.paginationConfig[storeKey])
const beforeSearch = computed(() => common.search[storeKey]) const beforeSearch = computed(() => common.search[storeKey])
const batchEditDocumentDialogRef = ref<InstanceType<typeof BatchEditDocumentDialog>>()
const SyncWebDialogRef = ref() const SyncWebDialogRef = ref()
const loading = ref(false) const loading = ref(false)
let interval: any let interval: any
@ -317,6 +325,13 @@ const handleSelectionChange = (val: any[]) => {
multipleSelection.value = val multipleSelection.value = val
} }
function openBatchEditDocument() {
const arr: string[] = multipleSelection.value.map((v) => v.id)
if (batchEditDocumentDialogRef) {
batchEditDocumentDialogRef?.value?.open(arr)
}
}
/** /**
* 初始化轮询 * 初始化轮询
*/ */
@ -356,9 +371,9 @@ function refreshDocument(row: any) {
.catch(() => {}) .catch(() => {})
} }
} else { } else {
// documentApi.putDocumentRefresh(row.dataset_id, row.id).then((res) => { documentApi.putDocumentRefresh(row.dataset_id, row.id).then((res) => {
// getList() getList()
// }) })
} }
} }

View File

@ -1,7 +1,7 @@
<template> <template>
<el-drawer v-model="visible" size="60%" @close="closeHandle" class="chat-record-drawer"> <el-drawer v-model="visible" size="60%" @close="closeHandle" class="chat-record-drawer">
<template #header> <template #header>
<h4>{{ currentAbstract }}</h4> <h4 class="single-line">{{ currentAbstract }}</h4>
</template> </template>
<div <div
v-loading="paginationConfig.current_page === 1 && loading" v-loading="paginationConfig.current_page === 1 && loading"
@ -120,6 +120,11 @@ defineExpose({
}) })
</script> </script>
<style lang="scss"> <style lang="scss">
.single-line {
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
}
.chat-record-drawer { .chat-record-drawer {
.el-drawer__body { .el-drawer__body {
background: var(--app-layout-bg-color); background: var(--app-layout-bg-color);