feat: 优化应用对话

This commit is contained in:
shaohuzhang1 2024-01-18 18:36:24 +08:00
parent af7c28868d
commit 04d3ec0524
3 changed files with 129 additions and 78 deletions

View File

@ -41,8 +41,6 @@ def event_content(response,
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': chunk.content, 'is_end': False}) + "\n\n" 'content': chunk.content, 'is_end': False}) + "\n\n"
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': '', 'is_end': True}) + "\n\n"
# 获取token # 获取token
request_token = chat_model.get_num_tokens_from_messages(message_list) request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(all_text) response_token = chat_model.get_num_tokens(all_text)
@ -56,6 +54,8 @@ def event_content(response,
manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
all_text, manage, step, padding_problem_text) all_text, manage, step, padding_problem_text)
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': '', 'is_end': True}) + "\n\n"
except Exception as e: except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,

View File

@ -176,12 +176,22 @@ class ChatRecordSerializer(serializers.Serializer):
chat_record_id = serializers.UUIDField(required=True) chat_record_id = serializers.UUIDField(required=True)
def get_chat_record(self):
chat_record_id = self.data.get('chat_record_id')
chat_id = self.data.get('chat_id')
chat_info: ChatInfo = chat_cache.get(chat_id)
chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if
chat_record.id == uuid.UUID(chat_record_id)]
if chat_record_list is not None and len(chat_record_list):
return chat_record_list[-1]
return QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first()
def one(self, with_valid=True): def one(self, with_valid=True):
if with_valid: if with_valid:
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
chat_record_id = self.data.get('chat_record_id') chat_record = self.get_chat_record()
chat_id = self.data.get('chat_id') if chat_record is None:
chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first() raise AppApiException(500, "对话不存在")
dataset_list = [] dataset_list = []
paragraph_list = [] paragraph_list = []
if len(chat_record.paragraph_id_list) > 0: if len(chat_record.paragraph_id_list) > 0:
@ -200,7 +210,7 @@ class ChatRecordSerializer(serializers.Serializer):
return { return {
**ChatRecordSerializerModel(chat_record).data, **ChatRecordSerializerModel(chat_record).data,
'padding_problem_text': chat_record.details.get( 'padding_problem_text': chat_record.details.get('problem_padding').get(
'padding_problem_text') if 'problem_padding' in chat_record.details else None, 'padding_problem_text') if 'problem_padding' in chat_record.details else None,
'dataset_list': dataset_list, 'dataset_list': dataset_list,
'paragraph_list': paragraph_list} 'paragraph_list': paragraph_list}

View File

@ -160,7 +160,7 @@
</div> </div>
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { ref, nextTick, computed, watch } from 'vue' import { ref, nextTick, computed, watch, reactive } from 'vue'
import { useRoute } from 'vue-router' import { useRoute } from 'vue-router'
import LogOperationButton from './LogOperationButton.vue' import LogOperationButton from './LogOperationButton.vue'
import OperationButton from './OperationButton.vue' import OperationButton from './OperationButton.vue'
@ -172,6 +172,7 @@ import { randomId } from '@/utils/utils'
import useStore from '@/stores' import useStore from '@/stores'
import MdRenderer from '@/components/markdown-renderer/MdRenderer.vue' import MdRenderer from '@/components/markdown-renderer/MdRenderer.vue'
import { MdPreview } from 'md-editor-v3' import { MdPreview } from 'md-editor-v3'
import { MsgError } from '@/utils/message'
defineOptions({ name: 'AiChat' }) defineOptions({ name: 'AiChat' })
const route = useRoute() const route = useRoute()
const { const {
@ -288,44 +289,25 @@ function getChartOpenId() {
}) })
} }
} }
/**
function chatMessage() { * 获取一个递归函数,处理流式数据
loading.value = true * @param chat 每一条对话记录
if (!chartOpenId.value) { * @param reader 流数据
getChartOpenId() * @param stream 是否是流式数据
} else { */
const problem_text = inputValue.value const getWrite = (chat: any, reader: any, stream: boolean) => {
const id = randomId()
chatList.value.push({
id: id,
problem_text: problem_text,
answer_text: '',
buffer: [],
write_ed: false,
is_stop: false,
record_id: '',
vote_status: '-1'
})
inputValue.value = ''
nextTick(() => {
scrollDiv.value.setScrollTop(Number.MAX_SAFE_INTEGER)
})
applicationApi.postChatMessage(chartOpenId.value, problem_text).then((response) => {
const row = chatList.value.find((item) => item.id === id)
if (row) {
ChatManagement.addChatRecord(row, 50, loading)
ChatManagement.write(id)
const reader = response.body.getReader()
let tempResult = '' let tempResult = ''
/*eslint no-constant-condition: ["error", { "checkLoops": false }]*/ /**
const write = ({ done, value }: { done: boolean; value: any }) => { *
* @param done 是否结束
* @param value
*/
const write_stream = ({ done, value }: { done: boolean; value: any }) => {
try { try {
if (done) { if (done) {
ChatManagement.close(id) ChatManagement.close(chat.id)
return return
} }
const decoder = new TextDecoder('utf-8') const decoder = new TextDecoder('utf-8')
let str = decoder.decode(value, { stream: true }) let str = decoder.decode(value, { stream: true })
// start chunk chunkdata:{xxx}\n\n data:{ -> xxx}\n\n fetchchunkdata: \n\n // start chunk chunkdata:{xxx}\n\n data:{ -> xxx}\n\n fetchchunkdata: \n\n
@ -334,7 +316,7 @@ function chatMessage() {
str = tempResult str = tempResult
tempResult = '' tempResult = ''
} else { } else {
return reader.read().then(write) return reader.read().then(write_stream)
} }
// end // end
if (str && str.startsWith('data:')) { if (str && str.startsWith('data:')) {
@ -342,10 +324,10 @@ function chatMessage() {
if (split) { if (split) {
for (const index in split) { for (const index in split) {
const chunk = JSON?.parse(split[index].replace('data:', '')) const chunk = JSON?.parse(split[index].replace('data:', ''))
row.record_id = chunk.id chat.record_id = chunk.id
const content = chunk?.content const content = chunk?.content
if (content) { if (content) {
ChatManagement.append(id, content) ChatManagement.append(chat.id, content)
} }
if (chunk.is_end) { if (chunk.is_end) {
// //
@ -355,24 +337,82 @@ function chatMessage() {
} }
} }
} catch (e) { } catch (e) {
console.log(e) return Promise.reject(e)
// console
} }
return reader.read().then(write) return reader.read().then(write_stream)
} }
reader /**
.read() * 处理 json 响应
.then(write) * @param param0
.then((ok: any) => { */
getSourceDetail(row) const write_json = ({ done, value }: { done: boolean; value: any }) => {
if (done) {
const result_block = JSON.parse(tempResult)
if (result_block.code === 500) {
return Promise.reject(result_block.message)
} else {
if (result_block.content) {
ChatManagement.append(chat.id, result_block.content)
}
}
ChatManagement.close(chat.id)
return
}
if (value) {
const decoder = new TextDecoder('utf-8')
tempResult += decoder.decode(value)
}
return reader.read().then(write_json)
}
return stream ? write_stream : write_json
}
function chatMessage() {
loading.value = true
if (!chartOpenId.value) {
getChartOpenId()
} else {
const problem_text = inputValue.value
const chat = reactive({
id: randomId(),
problem_text: problem_text,
answer_text: '',
buffer: [],
write_ed: false,
is_stop: false,
record_id: '',
vote_status: '-1'
}) })
.finally((ok: any) => { chatList.value.push(chat)
ChatManagement.close(id) inputValue.value = ''
nextTick(() => {
//
scrollDiv.value.setScrollTop(Number.MAX_SAFE_INTEGER)
})
//
applicationApi
.postChatMessage(chartOpenId.value, problem_text)
.then((response) => {
ChatManagement.addChatRecord(chat, 50, loading)
ChatManagement.write(chat.id)
const reader = response.body.getReader()
//
const write = getWrite(
chat,
reader,
response.headers.get('Content-Type') !== 'application/json'
)
return reader.read().then(write)
})
.then(() => {
return getSourceDetail(chat)
})
.finally(() => {
ChatManagement.close(chat.id)
}) })
.catch((e: any) => { .catch((e: any) => {
ChatManagement.close(id) MsgError(e)
}) ChatManagement.close(chat.id)
}
}) })
} }
} }
@ -382,12 +422,13 @@ function regenerationChart(item: chatType) {
chatMessage() chatMessage()
} }
function getSourceDetail(row: chatType) { function getSourceDetail(row: any) {
logApi.getRecordDetail(id, row.id, row.record_id, loading).then((res) => { logApi.getRecordDetail(id, chartOpenId.value, row.record_id, loading).then((res) => {
const obj = { row, ...res.data } Object.keys(res.data).forEach((key) => {
const index = chatList.value.findIndex((v) => v.id === row.id) row[key] = res.data[key]
chatList.value.splice(index, 1, obj)
}) })
})
return true
} }
/** /**