知识库支持子目录(不包括temp和tmp开头的目录),文件相对路径总长度不可超过255 (#1928)

This commit is contained in:
liunux4odoo 2023-10-31 16:59:40 +08:00 committed by GitHub
parent 65592a45c3
commit d8e15b57ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 8 deletions

View File

@ -9,9 +9,10 @@ from typing import AsyncIterable, List, Optional
import asyncio import asyncio
from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts.chat import ChatPromptTemplate
from server.chat.utils import History from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBService, KBServiceFactory from server.knowledge_base.kb_service.base import KBServiceFactory
from server.knowledge_base.utils import get_doc_path
import json import json
import os from pathlib import Path
from urllib.parse import urlencode from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_docs from server.knowledge_base.kb_doc_api import search_docs
@ -33,6 +34,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
request: Request = None,
): ):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name) kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None: if kb is None:
@ -72,10 +74,12 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
) )
source_documents = [] source_documents = []
doc_path = get_doc_path(knowledge_base_name)
for inum, doc in enumerate(docs): for inum, doc in enumerate(docs):
filename = os.path.split(doc.metadata["source"])[-1] filename = Path(doc.metadata["source"]).resolve().relative_to(doc_path)
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename}) parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename})
url = f"/knowledge_base/download_doc?" + parameters base_url = request.base_url
url = f"{base_url}knowledge_base/download_doc?" + parameters
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n""" text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
source_documents.append(text) source_documents.append(text)

View File

@ -1,6 +1,5 @@
import os import os
from configs import ( from configs import (
EMBEDDING_MODEL,
KB_ROOT_PATH, KB_ROOT_PATH,
CHUNK_SIZE, CHUNK_SIZE,
OVERLAP_SIZE, OVERLAP_SIZE,
@ -18,11 +17,12 @@ from langchain.docstore.document import Document
from langchain.text_splitter import TextSplitter from langchain.text_splitter import TextSplitter
from pathlib import Path from pathlib import Path
import json import json
from server.utils import run_in_thread_pool, embedding_device, get_model_worker_config from server.utils import run_in_thread_pool, get_model_worker_config
import io import io
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
import chardet import chardet
def validate_kb_name(knowledge_base_id: str) -> bool: def validate_kb_name(knowledge_base_id: str) -> bool:
# 检查是否包含预期外的字符或路径攻击关键字 # 检查是否包含预期外的字符或路径攻击关键字
if "../" in knowledge_base_id: if "../" in knowledge_base_id:
@ -53,8 +53,19 @@ def list_kbs_from_folder():
def list_files_from_folder(kb_name: str): def list_files_from_folder(kb_name: str):
doc_path = get_doc_path(kb_name) doc_path = get_doc_path(kb_name)
return [file for file in os.listdir(doc_path) result = []
if os.path.isfile(os.path.join(doc_path, file))] for root, _, files in os.walk(doc_path):
tail = os.path.basename(root).lower()
if (tail.startswith("temp")
or tail.startswith("tmp")): # 跳过 temp 或 tmp 开头的文件夹
continue
for file in files:
if file.startswith("~$"): # 跳过 ~$ 开头的文件
continue
path = Path(doc_path) / root / file
result.append(path.resolve().relative_to(doc_path).as_posix())
return result
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'], LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],