知识库支持子目录(不包括temp和tmp开头的目录),文件相对路径总长度不可超过255 (#1928)
This commit is contained in:
parent
65592a45c3
commit
d8e15b57ba
|
|
@ -9,9 +9,10 @@ from typing import AsyncIterable, List, Optional
|
||||||
import asyncio
|
import asyncio
|
||||||
from langchain.prompts.chat import ChatPromptTemplate
|
from langchain.prompts.chat import ChatPromptTemplate
|
||||||
from server.chat.utils import History
|
from server.chat.utils import History
|
||||||
from server.knowledge_base.kb_service.base import KBService, KBServiceFactory
|
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||||
|
from server.knowledge_base.utils import get_doc_path
|
||||||
import json
|
import json
|
||||||
import os
|
from pathlib import Path
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
from server.knowledge_base.kb_doc_api import search_docs
|
from server.knowledge_base.kb_doc_api import search_docs
|
||||||
|
|
||||||
|
|
@ -33,6 +34,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||||
|
request: Request = None,
|
||||||
):
|
):
|
||||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
if kb is None:
|
if kb is None:
|
||||||
|
|
@ -72,10 +74,12 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
||||||
)
|
)
|
||||||
|
|
||||||
source_documents = []
|
source_documents = []
|
||||||
|
doc_path = get_doc_path(knowledge_base_name)
|
||||||
for inum, doc in enumerate(docs):
|
for inum, doc in enumerate(docs):
|
||||||
filename = os.path.split(doc.metadata["source"])[-1]
|
filename = Path(doc.metadata["source"]).resolve().relative_to(doc_path)
|
||||||
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename})
|
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename})
|
||||||
url = f"/knowledge_base/download_doc?" + parameters
|
base_url = request.base_url
|
||||||
|
url = f"{base_url}knowledge_base/download_doc?" + parameters
|
||||||
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
||||||
source_documents.append(text)
|
source_documents.append(text)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import os
|
import os
|
||||||
from configs import (
|
from configs import (
|
||||||
EMBEDDING_MODEL,
|
|
||||||
KB_ROOT_PATH,
|
KB_ROOT_PATH,
|
||||||
CHUNK_SIZE,
|
CHUNK_SIZE,
|
||||||
OVERLAP_SIZE,
|
OVERLAP_SIZE,
|
||||||
|
|
@ -18,11 +17,12 @@ from langchain.docstore.document import Document
|
||||||
from langchain.text_splitter import TextSplitter
|
from langchain.text_splitter import TextSplitter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import json
|
import json
|
||||||
from server.utils import run_in_thread_pool, embedding_device, get_model_worker_config
|
from server.utils import run_in_thread_pool, get_model_worker_config
|
||||||
import io
|
import io
|
||||||
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
|
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
|
||||||
import chardet
|
import chardet
|
||||||
|
|
||||||
|
|
||||||
def validate_kb_name(knowledge_base_id: str) -> bool:
|
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||||
# 检查是否包含预期外的字符或路径攻击关键字
|
# 检查是否包含预期外的字符或路径攻击关键字
|
||||||
if "../" in knowledge_base_id:
|
if "../" in knowledge_base_id:
|
||||||
|
|
@ -53,8 +53,19 @@ def list_kbs_from_folder():
|
||||||
|
|
||||||
def list_files_from_folder(kb_name: str):
|
def list_files_from_folder(kb_name: str):
|
||||||
doc_path = get_doc_path(kb_name)
|
doc_path = get_doc_path(kb_name)
|
||||||
return [file for file in os.listdir(doc_path)
|
result = []
|
||||||
if os.path.isfile(os.path.join(doc_path, file))]
|
for root, _, files in os.walk(doc_path):
|
||||||
|
tail = os.path.basename(root).lower()
|
||||||
|
if (tail.startswith("temp")
|
||||||
|
or tail.startswith("tmp")): # 跳过 temp 或 tmp 开头的文件夹
|
||||||
|
continue
|
||||||
|
for file in files:
|
||||||
|
if file.startswith("~$"): # 跳过 ~$ 开头的文件
|
||||||
|
continue
|
||||||
|
path = Path(doc_path) / root / file
|
||||||
|
result.append(path.resolve().relative_to(doc_path).as_posix())
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
|
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue