fix: 修复知识库替换同步未对本地知识库进行覆盖
This commit is contained in:
parent
f19a4d9bd2
commit
dbaafee224
@ -38,7 +38,7 @@ from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type,
|
|||||||
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \
|
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \
|
||||||
get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id
|
get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id
|
||||||
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
|
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
|
||||||
from dataset.task import sync_web_dataset
|
from dataset.task import sync_web_dataset, sync_replace_web_dataset
|
||||||
from embedding.models import SearchMode
|
from embedding.models import SearchMode
|
||||||
from embedding.task import embedding_by_dataset, delete_embedding_by_dataset
|
from embedding.task import embedding_by_dataset, delete_embedding_by_dataset
|
||||||
from setting.models import AuthOperate
|
from setting.models import AuthOperate
|
||||||
@ -602,7 +602,9 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
document_name = child_link.tag.text if child_link.tag is not None and len(
|
document_name = child_link.tag.text if child_link.tag is not None and len(
|
||||||
child_link.tag.text.strip()) > 0 else child_link.url
|
child_link.tag.text.strip()) > 0 else child_link.url
|
||||||
paragraphs = get_split_model('web.md').parse(response.content)
|
paragraphs = get_split_model('web.md').parse(response.content)
|
||||||
first = QuerySet(Document).filter(meta__source_url=child_link.url, dataset=dataset).first()
|
print(child_link.url.strip())
|
||||||
|
first = QuerySet(Document).filter(meta__source_url=child_link.url.strip(),
|
||||||
|
dataset=dataset).first()
|
||||||
if first is not None:
|
if first is not None:
|
||||||
# 如果存在,使用文档同步
|
# 如果存在,使用文档同步
|
||||||
DocumentSerializers.Sync(data={'document_id': first.id}).sync()
|
DocumentSerializers.Sync(data={'document_id': first.id}).sync()
|
||||||
@ -610,7 +612,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
# 插入
|
# 插入
|
||||||
DocumentSerializers.Create(data={'dataset_id': dataset.id}).save(
|
DocumentSerializers.Create(data={'dataset_id': dataset.id}).save(
|
||||||
{'name': document_name, 'paragraphs': paragraphs,
|
{'name': document_name, 'paragraphs': paragraphs,
|
||||||
'meta': {'source_url': child_link.url, 'selector': dataset.meta.get('selector')},
|
'meta': {'source_url': child_link.url.strip(),
|
||||||
|
'selector': dataset.meta.get('selector')},
|
||||||
'type': Type.web}, with_valid=True)
|
'type': Type.web}, with_valid=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||||
@ -624,7 +627,7 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
"""
|
"""
|
||||||
url = dataset.meta.get('source_url')
|
url = dataset.meta.get('source_url')
|
||||||
selector = dataset.meta.get('selector') if 'selector' in dataset.meta else None
|
selector = dataset.meta.get('selector') if 'selector' in dataset.meta else None
|
||||||
sync_web_dataset.delay(str(dataset.id), url, selector)
|
sync_replace_web_dataset.delay(str(dataset.id), url, selector)
|
||||||
|
|
||||||
def complete_sync(self, dataset):
|
def complete_sync(self, dataset):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from typing import List
|
|||||||
from celery_once import QueueOnce
|
from celery_once import QueueOnce
|
||||||
|
|
||||||
from common.util.fork import ForkManage, Fork
|
from common.util.fork import ForkManage, Fork
|
||||||
from dataset.task.tools import get_save_handler, get_sync_web_document_handler
|
from dataset.task.tools import get_save_handler, get_sync_web_document_handler, get_sync_handler
|
||||||
|
|
||||||
from ops import celery_app
|
from ops import celery_app
|
||||||
|
|
||||||
@ -34,6 +34,18 @@ def sync_web_dataset(dataset_id: str, url: str, selector: str):
|
|||||||
max_kb_error.error(f'同步web知识库:{dataset_id}出现错误{str(e)}{traceback.format_exc()}')
|
max_kb_error.error(f'同步web知识库:{dataset_id}出现错误{str(e)}{traceback.format_exc()}')
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(base=QueueOnce, once={'keys': ['dataset_id']}, name='celery:sync_web_dataset')
|
||||||
|
def sync_replace_web_dataset(dataset_id: str, url: str, selector: str):
|
||||||
|
try:
|
||||||
|
max_kb.info(f"开始--->开始同步web知识库:{dataset_id}")
|
||||||
|
ForkManage(url, selector.split(" ") if selector is not None else []).fork(2, set(),
|
||||||
|
get_sync_handler(dataset_id
|
||||||
|
))
|
||||||
|
max_kb.info(f"结束--->结束同步web知识库:{dataset_id}")
|
||||||
|
except Exception as e:
|
||||||
|
max_kb_error.error(f'同步web知识库:{dataset_id}出现错误{str(e)}{traceback.format_exc()}')
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name='celery:sync_web_document')
|
@celery_app.task(name='celery:sync_web_document')
|
||||||
def sync_web_document(dataset_id, source_url_list: List[str], selector: str):
|
def sync_web_document(dataset_id, source_url_list: List[str], selector: str):
|
||||||
handler = get_sync_web_document_handler(dataset_id)
|
handler = get_sync_web_document_handler(dataset_id)
|
||||||
|
|||||||
@ -11,9 +11,11 @@ import logging
|
|||||||
import re
|
import re
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
from django.db.models import QuerySet
|
||||||
|
|
||||||
from common.util.fork import ChildLink, Fork
|
from common.util.fork import ChildLink, Fork
|
||||||
from common.util.split_model import get_split_model
|
from common.util.split_model import get_split_model
|
||||||
from dataset.models import Type, Document, Status
|
from dataset.models import Type, Document, DataSet, Status
|
||||||
|
|
||||||
max_kb_error = logging.getLogger("max_kb_error")
|
max_kb_error = logging.getLogger("max_kb_error")
|
||||||
max_kb = logging.getLogger("max_kb")
|
max_kb = logging.getLogger("max_kb")
|
||||||
@ -38,6 +40,34 @@ def get_save_handler(dataset_id, selector):
|
|||||||
return handler
|
return handler
|
||||||
|
|
||||||
|
|
||||||
|
def get_sync_handler(dataset_id):
|
||||||
|
from dataset.serializers.document_serializers import DocumentSerializers
|
||||||
|
dataset = QuerySet(DataSet).filter(id=dataset_id).first()
|
||||||
|
|
||||||
|
def handler(child_link: ChildLink, response: Fork.Response):
|
||||||
|
if response.status == 200:
|
||||||
|
try:
|
||||||
|
|
||||||
|
document_name = child_link.tag.text if child_link.tag is not None and len(
|
||||||
|
child_link.tag.text.strip()) > 0 else child_link.url
|
||||||
|
paragraphs = get_split_model('web.md').parse(response.content)
|
||||||
|
first = QuerySet(Document).filter(meta__source_url=child_link.url.strip(),
|
||||||
|
dataset=dataset).first()
|
||||||
|
if first is not None:
|
||||||
|
# 如果存在,使用文档同步
|
||||||
|
DocumentSerializers.Sync(data={'document_id': first.id}).sync()
|
||||||
|
else:
|
||||||
|
# 插入
|
||||||
|
DocumentSerializers.Create(data={'dataset_id': dataset.id}).save(
|
||||||
|
{'name': document_name, 'paragraphs': paragraphs,
|
||||||
|
'meta': {'source_url': child_link.url.strip(), 'selector': dataset.meta.get('selector')},
|
||||||
|
'type': Type.web}, with_valid=True)
|
||||||
|
except Exception as e:
|
||||||
|
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||||
|
|
||||||
|
return handler
|
||||||
|
|
||||||
|
|
||||||
def get_sync_web_document_handler(dataset_id):
|
def get_sync_web_document_handler(dataset_id):
|
||||||
from dataset.serializers.document_serializers import DocumentSerializers
|
from dataset.serializers.document_serializers import DocumentSerializers
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user