diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index c977db1..cffb3d3 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -9,9 +9,10 @@ from typing import AsyncIterable, List, Optional import asyncio from langchain.prompts.chat import ChatPromptTemplate from server.chat.utils import History -from server.knowledge_base.kb_service.base import KBService, KBServiceFactory +from server.knowledge_base.kb_service.base import KBServiceFactory +from server.knowledge_base.utils import get_doc_path import json -import os +from pathlib import Path from urllib.parse import urlencode from server.knowledge_base.kb_doc_api import search_docs @@ -33,6 +34,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), + request: Request = None, ): kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: @@ -72,10 +74,12 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", ) source_documents = [] + doc_path = get_doc_path(knowledge_base_name) for inum, doc in enumerate(docs): - filename = os.path.split(doc.metadata["source"])[-1] + filename = Path(doc.metadata["source"]).resolve().relative_to(doc_path) parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename}) - url = f"/knowledge_base/download_doc?" + parameters + base_url = request.base_url + url = f"{base_url}knowledge_base/download_doc?" + parameters text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n""" source_documents.append(text) diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 2b4cde9..688db55 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -1,6 +1,5 @@ import os from configs import ( - EMBEDDING_MODEL, KB_ROOT_PATH, CHUNK_SIZE, OVERLAP_SIZE, @@ -18,11 +17,12 @@ from langchain.docstore.document import Document from langchain.text_splitter import TextSplitter from pathlib import Path import json -from server.utils import run_in_thread_pool, embedding_device, get_model_worker_config +from server.utils import run_in_thread_pool, get_model_worker_config import io from typing import List, Union, Callable, Dict, Optional, Tuple, Generator import chardet + def validate_kb_name(knowledge_base_id: str) -> bool: # 检查是否包含预期外的字符或路径攻击关键字 if "../" in knowledge_base_id: @@ -53,8 +53,19 @@ def list_kbs_from_folder(): def list_files_from_folder(kb_name: str): doc_path = get_doc_path(kb_name) - return [file for file in os.listdir(doc_path) - if os.path.isfile(os.path.join(doc_path, file))] + result = [] + for root, _, files in os.walk(doc_path): + tail = os.path.basename(root).lower() + if (tail.startswith("temp") + or tail.startswith("tmp")): # 跳过 temp 或 tmp 开头的文件夹 + continue + for file in files: + if file.startswith("~$"): # 跳过 ~$ 开头的文件 + continue + path = Path(doc_path) / root / file + result.append(path.resolve().relative_to(doc_path).as_posix()) + + return result LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],