知识库支持子目录(不包括temp和tmp开头的目录),文件相对路径总长度不可超过255 (#1928)

This commit is contained in:
liunux4odoo 2023-10-31 16:59:40 +08:00 committed by GitHub
parent 65592a45c3
commit d8e15b57ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 8 deletions

View File

@ -9,9 +9,10 @@ from typing import AsyncIterable, List, Optional
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBService, KBServiceFactory
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.knowledge_base.utils import get_doc_path
import json
import os
from pathlib import Path
from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_docs
@ -33,6 +34,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
request: Request = None,
):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
@ -72,10 +74,12 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
)
source_documents = []
doc_path = get_doc_path(knowledge_base_name)
for inum, doc in enumerate(docs):
filename = os.path.split(doc.metadata["source"])[-1]
filename = Path(doc.metadata["source"]).resolve().relative_to(doc_path)
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename})
url = f"/knowledge_base/download_doc?" + parameters
base_url = request.base_url
url = f"{base_url}knowledge_base/download_doc?" + parameters
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
source_documents.append(text)

View File

@ -1,6 +1,5 @@
import os
from configs import (
EMBEDDING_MODEL,
KB_ROOT_PATH,
CHUNK_SIZE,
OVERLAP_SIZE,
@ -18,11 +17,12 @@ from langchain.docstore.document import Document
from langchain.text_splitter import TextSplitter
from pathlib import Path
import json
from server.utils import run_in_thread_pool, embedding_device, get_model_worker_config
from server.utils import run_in_thread_pool, get_model_worker_config
import io
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
import chardet
def validate_kb_name(knowledge_base_id: str) -> bool:
# 检查是否包含预期外的字符或路径攻击关键字
if "../" in knowledge_base_id:
@ -53,8 +53,19 @@ def list_kbs_from_folder():
def list_files_from_folder(kb_name: str):
doc_path = get_doc_path(kb_name)
return [file for file in os.listdir(doc_path)
if os.path.isfile(os.path.join(doc_path, file))]
result = []
for root, _, files in os.walk(doc_path):
tail = os.path.basename(root).lower()
if (tail.startswith("temp")
or tail.startswith("tmp")): # 跳过 temp 或 tmp 开头的文件夹
continue
for file in files:
if file.startswith("~$"): # 跳过 ~$ 开头的文件
continue
path = Path(doc_path) / root / file
result.append(path.resolve().relative_to(doc_path).as_posix())
return result
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],