数据库和向量库中文档 metadata["source"] 改为相对路径,便于向量库迁移 (#2153)

修复:
- 上传知识库文件名称包括子目录时,自动创建子目录
This commit is contained in:
liunux4odoo 2023-11-23 19:54:00 +08:00 committed by GitHub
parent 7a85fe74e9
commit aae4144476
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 28 additions and 20 deletions

View File

@ -27,4 +27,4 @@ if not os.path.exists(LOG_PATH):
BASE_TEMP_DIR = os.path.join(tempfile.gettempdir(), "chatchat") BASE_TEMP_DIR = os.path.join(tempfile.gettempdir(), "chatchat")
if os.path.isdir(BASE_TEMP_DIR): if os.path.isdir(BASE_TEMP_DIR):
shutil.rmtree(BASE_TEMP_DIR) shutil.rmtree(BASE_TEMP_DIR)
os.makedirs(BASE_TEMP_DIR) os.makedirs(BASE_TEMP_DIR, exist_ok=True)

View File

@ -54,6 +54,7 @@ class FilteredCSVLoader(CSVLoader):
raise RuntimeError(f"Error loading {self.file_path}") from e raise RuntimeError(f"Error loading {self.file_path}") from e
return docs return docs
def __read_file(self, csvfile: TextIOWrapper) -> List[Document]: def __read_file(self, csvfile: TextIOWrapper) -> List[Document]:
docs = [] docs = []
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore

View File

@ -37,6 +37,9 @@ def _parse_files_in_thread(
filename = file.filename filename = file.filename
file_path = os.path.join(dir, filename) file_path = os.path.join(dir, filename)
file_content = file.file.read() # 读取上传文件的内容 file_content = file.file.read() # 读取上传文件的内容
if not os.path.isdir(os.path.dirname(file_path)):
os.makedirs(os.path.dirname(file_path))
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(file_content) f.write(file_content)
kb_file = KnowledgeFile(filename=filename, knowledge_base_name="temp") kb_file = KnowledgeFile(filename=filename, knowledge_base_name="temp")
@ -141,9 +144,8 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
) )
source_documents = [] source_documents = []
doc_path = get_temp_dir(knowledge_id)[0]
for inum, doc in enumerate(docs): for inum, doc in enumerate(docs):
filename = Path(doc.metadata["source"]).resolve().relative_to(doc_path) filename = doc.metadata.get("source")
text = f"""出处 [{inum + 1}] [{filename}] \n\n{doc.page_content}\n\n""" text = f"""出处 [{inum + 1}] [{filename}] \n\n{doc.page_content}\n\n"""
source_documents.append(text) source_documents.append(text)

View File

@ -74,9 +74,8 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
) )
source_documents = [] source_documents = []
doc_path = get_doc_path(knowledge_base_name)
for inum, doc in enumerate(docs): for inum, doc in enumerate(docs):
filename = Path(doc.metadata["source"]).resolve().relative_to(doc_path) filename = doc.metadata.get("source")
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename}) parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename})
base_url = request.base_url base_url = request.base_url
url = f"{base_url}knowledge_base/download_doc?" + parameters url = f"{base_url}knowledge_base/download_doc?" + parameters

View File

@ -81,6 +81,8 @@ def _save_files_in_thread(files: List[UploadFile],
logger.warn(file_status) logger.warn(file_status)
return dict(code=404, msg=file_status, data=data) return dict(code=404, msg=file_status, data=data)
if not os.path.isdir(os.path.dirname(file_path)):
os.makedirs(os.path.dirname(file_path))
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(file_content) f.write(file_content)
return dict(code=200, msg=f"成功上传文件 {filename}", data=data) return dict(code=200, msg=f"成功上传文件 {filename}", data=data)

View File

@ -2,7 +2,7 @@ import operator
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import os import os
from pathlib import Path
import numpy as np import numpy as np
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.docstore.document import Document from langchain.docstore.document import Document
@ -111,12 +111,20 @@ class KBService(ABC):
if docs: if docs:
custom_docs = True custom_docs = True
for doc in docs: for doc in docs:
doc.metadata.setdefault("source", kb_file.filepath) doc.metadata.setdefault("source", kb_file.filename)
else: else:
docs = kb_file.file2text() docs = kb_file.file2text()
custom_docs = False custom_docs = False
if docs: if docs:
# 将 metadata["source"] 改为相对路径
for doc in docs:
try:
source = doc.metadata.get("source", "")
rel_path = Path(source).relative_to(self.doc_path)
doc.metadata["source"] = str(rel_path.as_posix().strip("/"))
except Exception as e:
print(f"cannot convert absolute path ({source}) to relative path. error is : {e}")
self.delete_doc(kb_file) self.delete_doc(kb_file)
doc_infos = self.do_add_doc(docs, **kwargs) doc_infos = self.do_add_doc(docs, **kwargs)
status = add_file_to_db(kb_file, status = add_file_to_db(kb_file,

View File

@ -83,7 +83,7 @@ class FaissKBService(KBService):
kb_file: KnowledgeFile, kb_file: KnowledgeFile,
**kwargs): **kwargs):
with self.load_vector_store().acquire() as vs: with self.load_vector_store().acquire() as vs:
ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source") == kb_file.filepath] ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source") == kb_file.filename]
if len(ids) > 0: if len(ids) > 0:
vs.delete(ids) vs.delete(ids)
if not kwargs.get("not_refresh_vs_cache"): if not kwargs.get("not_refresh_vs_cache"):

View File

@ -298,10 +298,6 @@ class KnowledgeFile:
chunk_overlap=chunk_overlap) chunk_overlap=chunk_overlap)
if self.text_splitter_name == "MarkdownHeaderTextSplitter": if self.text_splitter_name == "MarkdownHeaderTextSplitter":
docs = text_splitter.split_text(docs[0].page_content) docs = text_splitter.split_text(docs[0].page_content)
for doc in docs:
# 如果文档有元数据
if doc.metadata:
doc.metadata["source"] = os.path.basename(self.filepath)
else: else:
docs = text_splitter.split_documents(docs) docs = text_splitter.split_documents(docs)

View File

@ -17,9 +17,9 @@ api_base_url = api_address()
kb = "kb_for_api_test" kb = "kb_for_api_test"
test_files = { test_files = {
"FAQ.MD": str(root_path / "docs" / "FAQ.MD"), "wiki/Home.MD": get_file_path("samples", "wiki/Home.md"),
"README.MD": str(root_path / "README.MD"), "wiki/开发环境部署.MD": get_file_path("samples", "wiki/开发环境部署.md"),
"test.txt": get_file_path("samples", "test.txt"), "test_files/test.txt": get_file_path("samples", "test_files/test.txt"),
} }
print("\n\n直接url访问\n") print("\n\n直接url访问\n")

View File

@ -46,14 +46,14 @@ def test_recreate_vs():
assert len(docs) > 0 assert len(docs) > 0
pprint(docs[0]) pprint(docs[0])
for doc in docs: for doc in docs:
assert doc.metadata["source"] == path assert doc.metadata["source"] == name
# list docs base on metadata # list docs base on metadata
docs = kb.list_docs(metadata={"source": path}) docs = kb.list_docs(metadata={"source": name})
assert len(docs) > 0 assert len(docs) > 0
for doc in docs: for doc in docs:
assert doc.metadata["source"] == path assert doc.metadata["source"] == name
def test_increament(): def test_increament():
@ -74,7 +74,7 @@ def test_increament():
pprint(docs[0]) pprint(docs[0])
for doc in docs: for doc in docs:
assert doc.metadata["source"] == os.path.join(doc_path, f) assert doc.metadata["source"] == f
def test_prune_db(): def test_prune_db():