From d8e15b57bad9287ed0c1ac31340baec9d016ab07 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Tue, 31 Oct 2023 16:59:40 +0800 Subject: [PATCH] =?UTF-8?q?=E7=9F=A5=E8=AF=86=E5=BA=93=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=AD=90=E7=9B=AE=E5=BD=95=EF=BC=88=E4=B8=8D=E5=8C=85=E6=8B=AC?= =?UTF-8?q?temp=E5=92=8Ctmp=E5=BC=80=E5=A4=B4=E7=9A=84=E7=9B=AE=E5=BD=95?= =?UTF-8?q?=EF=BC=89=EF=BC=8C=E6=96=87=E4=BB=B6=E7=9B=B8=E5=AF=B9=E8=B7=AF?= =?UTF-8?q?=E5=BE=84=E6=80=BB=E9=95=BF=E5=BA=A6=E4=B8=8D=E5=8F=AF=E8=B6=85?= =?UTF-8?q?=E8=BF=87255=20(#1928)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/chat/knowledge_base_chat.py | 12 ++++++++---- server/knowledge_base/utils.py | 19 +++++++++++++++---- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index c977db1..cffb3d3 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -9,9 +9,10 @@ from typing import AsyncIterable, List, Optional import asyncio from langchain.prompts.chat import ChatPromptTemplate from server.chat.utils import History -from server.knowledge_base.kb_service.base import KBService, KBServiceFactory +from server.knowledge_base.kb_service.base import KBServiceFactory +from server.knowledge_base.utils import get_doc_path import json -import os +from pathlib import Path from urllib.parse import urlencode from server.knowledge_base.kb_doc_api import search_docs @@ -33,6 +34,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), + request: Request = None, ): kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: @@ -72,10 +74,12 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", ) source_documents = [] + doc_path = get_doc_path(knowledge_base_name) for inum, doc in enumerate(docs): - filename = os.path.split(doc.metadata["source"])[-1] + filename = Path(doc.metadata["source"]).resolve().relative_to(doc_path) parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename}) - url = f"/knowledge_base/download_doc?" + parameters + base_url = request.base_url + url = f"{base_url}knowledge_base/download_doc?" + parameters text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n""" source_documents.append(text) diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 2b4cde9..688db55 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -1,6 +1,5 @@ import os from configs import ( - EMBEDDING_MODEL, KB_ROOT_PATH, CHUNK_SIZE, OVERLAP_SIZE, @@ -18,11 +17,12 @@ from langchain.docstore.document import Document from langchain.text_splitter import TextSplitter from pathlib import Path import json -from server.utils import run_in_thread_pool, embedding_device, get_model_worker_config +from server.utils import run_in_thread_pool, get_model_worker_config import io from typing import List, Union, Callable, Dict, Optional, Tuple, Generator import chardet + def validate_kb_name(knowledge_base_id: str) -> bool: # 检查是否包含预期外的字符或路径攻击关键字 if "../" in knowledge_base_id: @@ -53,8 +53,19 @@ def list_kbs_from_folder(): def list_files_from_folder(kb_name: str): doc_path = get_doc_path(kb_name) - return [file for file in os.listdir(doc_path) - if os.path.isfile(os.path.join(doc_path, file))] + result = [] + for root, _, files in os.walk(doc_path): + tail = os.path.basename(root).lower() + if (tail.startswith("temp") + or tail.startswith("tmp")): # 跳过 temp 或 tmp 开头的文件夹 + continue + for file in files: + if file.startswith("~$"): # 跳过 ~$ 开头的文件 + continue + path = Path(doc_path) / root / file + result.append(path.resolve().relative_to(doc_path).as_posix()) + + return result LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],