知识库支持子目录(不包括temp和tmp开头的目录),文件相对路径总长度不可超过255 (#1928)
This commit is contained in:
parent
65592a45c3
commit
d8e15b57ba
|
|
@ -9,9 +9,10 @@ from typing import AsyncIterable, List, Optional
|
|||
import asyncio
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from server.chat.utils import History
|
||||
from server.knowledge_base.kb_service.base import KBService, KBServiceFactory
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.knowledge_base.utils import get_doc_path
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlencode
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
|
||||
|
|
@ -33,6 +34,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
request: Request = None,
|
||||
):
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
|
|
@ -72,10 +74,12 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
)
|
||||
|
||||
source_documents = []
|
||||
doc_path = get_doc_path(knowledge_base_name)
|
||||
for inum, doc in enumerate(docs):
|
||||
filename = os.path.split(doc.metadata["source"])[-1]
|
||||
filename = Path(doc.metadata["source"]).resolve().relative_to(doc_path)
|
||||
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename})
|
||||
url = f"/knowledge_base/download_doc?" + parameters
|
||||
base_url = request.base_url
|
||||
url = f"{base_url}knowledge_base/download_doc?" + parameters
|
||||
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
||||
source_documents.append(text)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import os
|
||||
from configs import (
|
||||
EMBEDDING_MODEL,
|
||||
KB_ROOT_PATH,
|
||||
CHUNK_SIZE,
|
||||
OVERLAP_SIZE,
|
||||
|
|
@ -18,11 +17,12 @@ from langchain.docstore.document import Document
|
|||
from langchain.text_splitter import TextSplitter
|
||||
from pathlib import Path
|
||||
import json
|
||||
from server.utils import run_in_thread_pool, embedding_device, get_model_worker_config
|
||||
from server.utils import run_in_thread_pool, get_model_worker_config
|
||||
import io
|
||||
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
|
||||
import chardet
|
||||
|
||||
|
||||
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||
# 检查是否包含预期外的字符或路径攻击关键字
|
||||
if "../" in knowledge_base_id:
|
||||
|
|
@ -53,8 +53,19 @@ def list_kbs_from_folder():
|
|||
|
||||
def list_files_from_folder(kb_name: str):
|
||||
doc_path = get_doc_path(kb_name)
|
||||
return [file for file in os.listdir(doc_path)
|
||||
if os.path.isfile(os.path.join(doc_path, file))]
|
||||
result = []
|
||||
for root, _, files in os.walk(doc_path):
|
||||
tail = os.path.basename(root).lower()
|
||||
if (tail.startswith("temp")
|
||||
or tail.startswith("tmp")): # 跳过 temp 或 tmp 开头的文件夹
|
||||
continue
|
||||
for file in files:
|
||||
if file.startswith("~$"): # 跳过 ~$ 开头的文件
|
||||
continue
|
||||
path = Path(doc_path) / root / file
|
||||
result.append(path.resolve().relative_to(doc_path).as_posix())
|
||||
|
||||
return result
|
||||
|
||||
|
||||
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
|
||||
|
|
|
|||
Loading…
Reference in New Issue