This commit is contained in:
parent
f45855c34b
commit
83cd69e5b7
@ -243,13 +243,16 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
language = get_language()
|
language = get_language()
|
||||||
if self.data.get('type') == 'csv':
|
if self.data.get('type') == 'csv':
|
||||||
file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', f'csv_template_{to_locale(language)}.csv'), "rb")
|
file = open(
|
||||||
|
os.path.join(PROJECT_DIR, "apps", "dataset", 'template', f'csv_template_{to_locale(language)}.csv'),
|
||||||
|
"rb")
|
||||||
content = file.read()
|
content = file.read()
|
||||||
file.close()
|
file.close()
|
||||||
return HttpResponse(content, status=200, headers={'Content-Type': 'text/cxv',
|
return HttpResponse(content, status=200, headers={'Content-Type': 'text/cxv',
|
||||||
'Content-Disposition': 'attachment; filename="csv_template.csv"'})
|
'Content-Disposition': 'attachment; filename="csv_template.csv"'})
|
||||||
elif self.data.get('type') == 'excel':
|
elif self.data.get('type') == 'excel':
|
||||||
file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', f'excel_template_{to_locale(language)}.xlsx'), "rb")
|
file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template',
|
||||||
|
f'excel_template_{to_locale(language)}.xlsx'), "rb")
|
||||||
content = file.read()
|
content = file.read()
|
||||||
file.close()
|
file.close()
|
||||||
return HttpResponse(content, status=200, headers={'Content-Type': 'application/vnd.ms-excel',
|
return HttpResponse(content, status=200, headers={'Content-Type': 'application/vnd.ms-excel',
|
||||||
@ -261,7 +264,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
language = get_language()
|
language = get_language()
|
||||||
if self.data.get('type') == 'csv':
|
if self.data.get('type') == 'csv':
|
||||||
file = open(
|
file = open(
|
||||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'template', f'table_template_{to_locale(language)}.csv'),
|
os.path.join(PROJECT_DIR, "apps", "dataset", 'template',
|
||||||
|
f'table_template_{to_locale(language)}.csv'),
|
||||||
"rb")
|
"rb")
|
||||||
content = file.read()
|
content = file.read()
|
||||||
file.close()
|
file.close()
|
||||||
@ -1180,7 +1184,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
if not QuerySet(Document).filter(id=document_id).exists():
|
if not QuerySet(Document).filter(id=document_id).exists():
|
||||||
raise AppApiException(500, _('document id not exist'))
|
raise AppApiException(500, _('document id not exist'))
|
||||||
|
|
||||||
def generate_related(self, model_id, prompt, with_valid=True):
|
def generate_related(self, model_id, prompt, state_list=None, with_valid=True):
|
||||||
if with_valid:
|
if with_valid:
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
document_id = self.data.get('document_id')
|
document_id = self.data.get('document_id')
|
||||||
@ -1192,7 +1196,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
State.PENDING)
|
State.PENDING)
|
||||||
ListenerManagement.get_aggregation_document_status(document_id)()
|
ListenerManagement.get_aggregation_document_status(document_id)()
|
||||||
try:
|
try:
|
||||||
generate_related_by_document_id.delay(document_id, model_id, prompt)
|
generate_related_by_document_id.delay(document_id, model_id, prompt, state_list)
|
||||||
except AlreadyQueued as e:
|
except AlreadyQueued as e:
|
||||||
raise AppApiException(500, _('The task is being executed, please do not send it again.'))
|
raise AppApiException(500, _('The task is being executed, please do not send it again.'))
|
||||||
|
|
||||||
@ -1205,17 +1209,23 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
document_id_list = instance.get("document_id_list")
|
document_id_list = instance.get("document_id_list")
|
||||||
model_id = instance.get("model_id")
|
model_id = instance.get("model_id")
|
||||||
prompt = instance.get("prompt")
|
prompt = instance.get("prompt")
|
||||||
|
state_list = instance.get('state_list')
|
||||||
ListenerManagement.update_status(QuerySet(Document).filter(id__in=document_id_list),
|
ListenerManagement.update_status(QuerySet(Document).filter(id__in=document_id_list),
|
||||||
TaskType.GENERATE_PROBLEM,
|
TaskType.GENERATE_PROBLEM,
|
||||||
State.PENDING)
|
State.PENDING)
|
||||||
ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id__in=document_id_list),
|
ListenerManagement.update_status(QuerySet(Paragraph).annotate(
|
||||||
TaskType.GENERATE_PROBLEM,
|
reversed_status=Reverse('status'),
|
||||||
|
task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value,
|
||||||
|
1),
|
||||||
|
).filter(task_type_status__in=state_list, document_id__in=document_id_list)
|
||||||
|
.values('id'),
|
||||||
|
TaskType.EMBEDDING,
|
||||||
State.PENDING)
|
State.PENDING)
|
||||||
ListenerManagement.get_aggregation_document_status_by_query_set(
|
ListenerManagement.get_aggregation_document_status_by_query_set(
|
||||||
QuerySet(Document).filter(id__in=document_id_list))()
|
QuerySet(Document).filter(id__in=document_id_list))()
|
||||||
try:
|
try:
|
||||||
for document_id in document_id_list:
|
for document_id in document_id_list:
|
||||||
generate_related_by_document_id.delay(document_id, model_id, prompt)
|
generate_related_by_document_id.delay(document_id, model_id, prompt, state_list)
|
||||||
except AlreadyQueued as e:
|
except AlreadyQueued as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -3,11 +3,12 @@ import traceback
|
|||||||
|
|
||||||
from celery_once import QueueOnce
|
from celery_once import QueueOnce
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
|
from django.db.models.functions import Reverse, Substr
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
from common.config.embedding_config import ModelManage
|
from common.config.embedding_config import ModelManage
|
||||||
from common.event import ListenerManagement
|
from common.event import ListenerManagement
|
||||||
from common.util.page_utils import page
|
from common.util.page_utils import page, page_desc
|
||||||
from dataset.models import Paragraph, Document, Status, TaskType, State
|
from dataset.models import Paragraph, Document, Status, TaskType, State
|
||||||
from dataset.task.tools import save_problem
|
from dataset.task.tools import save_problem
|
||||||
from ops import celery_app
|
from ops import celery_app
|
||||||
@ -64,7 +65,11 @@ def get_is_the_task_interrupted(document_id):
|
|||||||
|
|
||||||
@celery_app.task(base=QueueOnce, once={'keys': ['document_id']},
|
@celery_app.task(base=QueueOnce, once={'keys': ['document_id']},
|
||||||
name='celery:generate_related_by_document')
|
name='celery:generate_related_by_document')
|
||||||
def generate_related_by_document_id(document_id, model_id, prompt):
|
def generate_related_by_document_id(document_id, model_id, prompt, state_list=None):
|
||||||
|
if state_list is None:
|
||||||
|
state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value,
|
||||||
|
State.REVOKE.value,
|
||||||
|
State.REVOKED.value, State.IGNORED.value]
|
||||||
try:
|
try:
|
||||||
is_the_task_interrupted = get_is_the_task_interrupted(document_id)
|
is_the_task_interrupted = get_is_the_task_interrupted(document_id)
|
||||||
if is_the_task_interrupted():
|
if is_the_task_interrupted():
|
||||||
@ -78,7 +83,12 @@ def generate_related_by_document_id(document_id, model_id, prompt):
|
|||||||
generate_problem = get_generate_problem(llm_model, prompt,
|
generate_problem = get_generate_problem(llm_model, prompt,
|
||||||
ListenerManagement.get_aggregation_document_status(
|
ListenerManagement.get_aggregation_document_status(
|
||||||
document_id), is_the_task_interrupted)
|
document_id), is_the_task_interrupted)
|
||||||
page(QuerySet(Paragraph).filter(document_id=document_id), 10, generate_problem, is_the_task_interrupted)
|
query_set = QuerySet(Paragraph).annotate(
|
||||||
|
reversed_status=Reverse('status'),
|
||||||
|
task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value,
|
||||||
|
1),
|
||||||
|
).filter(task_type_status__in=state_list, document_id=document_id)
|
||||||
|
page_desc(query_set, 10, generate_problem, is_the_task_interrupted)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
max_kb_error.error(f'根据文档生成问题:{document_id}出现错误{str(e)}{traceback.format_exc()}')
|
max_kb_error.error(f'根据文档生成问题:{document_id}出现错误{str(e)}{traceback.format_exc()}')
|
||||||
max_kb_error.error(_('Generate issue based on document: {document_id} error {error}{traceback}').format(
|
max_kb_error.error(_('Generate issue based on document: {document_id} error {error}{traceback}').format(
|
||||||
|
|||||||
@ -1,5 +1,8 @@
|
|||||||
from typing import Dict
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||||
|
|
||||||
@ -18,3 +21,15 @@ class XinferenceImage(MaxKBBaseModel, BaseChatOpenAI):
|
|||||||
stream_usage=True,
|
stream_usage=True,
|
||||||
**optional_params,
|
**optional_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
if self.usage_metadata is None or self.usage_metadata == {}:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||||
|
return self.usage_metadata.get('input_tokens', 0)
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
if self.usage_metadata is None or self.usage_metadata == {}:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return len(tokenizer.encode(text))
|
||||||
|
return self.get_last_generation_info().get('output_tokens', 0)
|
||||||
|
|||||||
@ -1,8 +1,11 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Dict, List
|
||||||
from urllib.parse import urlparse, ParseResult
|
from urllib.parse import urlparse, ParseResult
|
||||||
|
|
||||||
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||||
|
|
||||||
@ -33,3 +36,15 @@ class XinferenceChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||||||
openai_api_key=model_credential.get('api_key'),
|
openai_api_key=model_credential.get('api_key'),
|
||||||
**optional_params
|
**optional_params
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
if self.usage_metadata is None or self.usage_metadata == {}:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||||
|
return self.usage_metadata.get('input_tokens', 0)
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
if self.usage_metadata is None or self.usage_metadata == {}:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return len(tokenizer.encode(text))
|
||||||
|
return self.get_last_generation_info().get('output_tokens', 0)
|
||||||
|
|||||||
@ -48,6 +48,16 @@
|
|||||||
type="textarea"
|
type="textarea"
|
||||||
/>
|
/>
|
||||||
</el-form-item>
|
</el-form-item>
|
||||||
|
<el-form-item :label="$t('views.problem.relateParagraph.selectParagraph')" prop="state">
|
||||||
|
<el-radio-group v-model="state" class="radio-block">
|
||||||
|
<el-radio value="error" size="large" class="mb-16">{{
|
||||||
|
$t('views.document.form.selectVectorization.error')
|
||||||
|
}}</el-radio>
|
||||||
|
<el-radio value="all" size="large">{{
|
||||||
|
$t('views.document.form.selectVectorization.all')
|
||||||
|
}}</el-radio>
|
||||||
|
</el-radio-group>
|
||||||
|
</el-form-item>
|
||||||
</el-form>
|
</el-form>
|
||||||
</div>
|
</div>
|
||||||
<template #footer>
|
<template #footer>
|
||||||
@ -87,7 +97,11 @@ const dialogVisible = ref<boolean>(false)
|
|||||||
const modelOptions = ref<any>(null)
|
const modelOptions = ref<any>(null)
|
||||||
const idList = ref<string[]>([])
|
const idList = ref<string[]>([])
|
||||||
const apiType = ref('') // 文档document或段落paragraph
|
const apiType = ref('') // 文档document或段落paragraph
|
||||||
|
const state = ref<'all' | 'error'>('error')
|
||||||
|
const stateMap = {
|
||||||
|
all: ['0', '1', '2', '3', '4', '5', 'n'],
|
||||||
|
error: ['0', '1', '3', '4', '5', 'n']
|
||||||
|
}
|
||||||
const FormRef = ref()
|
const FormRef = ref()
|
||||||
const userId = user.userInfo?.id as string
|
const userId = user.userInfo?.id as string
|
||||||
const form = ref(prompt.get(userId))
|
const form = ref(prompt.get(userId))
|
||||||
@ -131,14 +145,22 @@ const submitHandle = async (formEl: FormInstance) => {
|
|||||||
// 保存提示词
|
// 保存提示词
|
||||||
prompt.save(user.userInfo?.id as string, form.value)
|
prompt.save(user.userInfo?.id as string, form.value)
|
||||||
if (apiType.value === 'paragraph') {
|
if (apiType.value === 'paragraph') {
|
||||||
const data = { ...form.value, paragraph_id_list: idList.value }
|
const data = {
|
||||||
|
...form.value,
|
||||||
|
paragraph_id_list: idList.value,
|
||||||
|
state_list: stateMap[state.value]
|
||||||
|
}
|
||||||
paragraphApi.batchGenerateRelated(id, documentId, data, loading).then(() => {
|
paragraphApi.batchGenerateRelated(id, documentId, data, loading).then(() => {
|
||||||
MsgSuccess(t('views.document.generateQuestion.successMessage'))
|
MsgSuccess(t('views.document.generateQuestion.successMessage'))
|
||||||
emit('refresh')
|
emit('refresh')
|
||||||
dialogVisible.value = false
|
dialogVisible.value = false
|
||||||
})
|
})
|
||||||
} else if (apiType.value === 'document') {
|
} else if (apiType.value === 'document') {
|
||||||
const data = { ...form.value, document_id_list: idList.value }
|
const data = {
|
||||||
|
...form.value,
|
||||||
|
document_id_list: idList.value,
|
||||||
|
state_list: stateMap[state.value]
|
||||||
|
}
|
||||||
documentApi.batchGenerateRelated(id, data, loading).then(() => {
|
documentApi.batchGenerateRelated(id, data, loading).then(() => {
|
||||||
MsgSuccess(t('views.document.generateQuestion.successMessage'))
|
MsgSuccess(t('views.document.generateQuestion.successMessage'))
|
||||||
emit('refresh')
|
emit('refresh')
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user