数据库和向量库中文档 metadata["source"] 改为相对路径,便于向量库迁移 (#2153)

修复:
- 上传知识库文件名称包括子目录时,自动创建子目录
This commit is contained in:
liunux4odoo 2023-11-23 19:54:00 +08:00 committed by GitHub
parent 7a85fe74e9
commit aae4144476
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 28 additions and 20 deletions

View File

@ -27,4 +27,4 @@ if not os.path.exists(LOG_PATH):
BASE_TEMP_DIR = os.path.join(tempfile.gettempdir(), "chatchat")
if os.path.isdir(BASE_TEMP_DIR):
shutil.rmtree(BASE_TEMP_DIR)
os.makedirs(BASE_TEMP_DIR)
os.makedirs(BASE_TEMP_DIR, exist_ok=True)

View File

@ -54,6 +54,7 @@ class FilteredCSVLoader(CSVLoader):
raise RuntimeError(f"Error loading {self.file_path}") from e
return docs
def __read_file(self, csvfile: TextIOWrapper) -> List[Document]:
docs = []
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
@ -77,4 +78,4 @@ class FilteredCSVLoader(CSVLoader):
else:
raise ValueError(f"Column '{self.columns_to_read[0]}' not found in CSV file.")
return docs
return docs

View File

@ -37,6 +37,9 @@ def _parse_files_in_thread(
filename = file.filename
file_path = os.path.join(dir, filename)
file_content = file.file.read() # 读取上传文件的内容
if not os.path.isdir(os.path.dirname(file_path)):
os.makedirs(os.path.dirname(file_path))
with open(file_path, "wb") as f:
f.write(file_content)
kb_file = KnowledgeFile(filename=filename, knowledge_base_name="temp")
@ -141,9 +144,8 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
)
source_documents = []
doc_path = get_temp_dir(knowledge_id)[0]
for inum, doc in enumerate(docs):
filename = Path(doc.metadata["source"]).resolve().relative_to(doc_path)
filename = doc.metadata.get("source")
text = f"""出处 [{inum + 1}] [{filename}] \n\n{doc.page_content}\n\n"""
source_documents.append(text)

View File

@ -74,9 +74,8 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
)
source_documents = []
doc_path = get_doc_path(knowledge_base_name)
for inum, doc in enumerate(docs):
filename = Path(doc.metadata["source"]).resolve().relative_to(doc_path)
filename = doc.metadata.get("source")
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename})
base_url = request.base_url
url = f"{base_url}knowledge_base/download_doc?" + parameters

View File

@ -81,6 +81,8 @@ def _save_files_in_thread(files: List[UploadFile],
logger.warn(file_status)
return dict(code=404, msg=file_status, data=data)
if not os.path.isdir(os.path.dirname(file_path)):
os.makedirs(os.path.dirname(file_path))
with open(file_path, "wb") as f:
f.write(file_content)
return dict(code=200, msg=f"成功上传文件 {filename}", data=data)

View File

@ -2,7 +2,7 @@ import operator
from abc import ABC, abstractmethod
import os
from pathlib import Path
import numpy as np
from langchain.embeddings.base import Embeddings
from langchain.docstore.document import Document
@ -111,12 +111,20 @@ class KBService(ABC):
if docs:
custom_docs = True
for doc in docs:
doc.metadata.setdefault("source", kb_file.filepath)
doc.metadata.setdefault("source", kb_file.filename)
else:
docs = kb_file.file2text()
custom_docs = False
if docs:
# 将 metadata["source"] 改为相对路径
for doc in docs:
try:
source = doc.metadata.get("source", "")
rel_path = Path(source).relative_to(self.doc_path)
doc.metadata["source"] = str(rel_path.as_posix().strip("/"))
except Exception as e:
print(f"cannot convert absolute path ({source}) to relative path. error is : {e}")
self.delete_doc(kb_file)
doc_infos = self.do_add_doc(docs, **kwargs)
status = add_file_to_db(kb_file,

View File

@ -83,7 +83,7 @@ class FaissKBService(KBService):
kb_file: KnowledgeFile,
**kwargs):
with self.load_vector_store().acquire() as vs:
ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source") == kb_file.filepath]
ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source") == kb_file.filename]
if len(ids) > 0:
vs.delete(ids)
if not kwargs.get("not_refresh_vs_cache"):

View File

@ -298,10 +298,6 @@ class KnowledgeFile:
chunk_overlap=chunk_overlap)
if self.text_splitter_name == "MarkdownHeaderTextSplitter":
docs = text_splitter.split_text(docs[0].page_content)
for doc in docs:
# 如果文档有元数据
if doc.metadata:
doc.metadata["source"] = os.path.basename(self.filepath)
else:
docs = text_splitter.split_documents(docs)

View File

@ -17,9 +17,9 @@ api_base_url = api_address()
kb = "kb_for_api_test"
test_files = {
"FAQ.MD": str(root_path / "docs" / "FAQ.MD"),
"README.MD": str(root_path / "README.MD"),
"test.txt": get_file_path("samples", "test.txt"),
"wiki/Home.MD": get_file_path("samples", "wiki/Home.md"),
"wiki/开发环境部署.MD": get_file_path("samples", "wiki/开发环境部署.md"),
"test_files/test.txt": get_file_path("samples", "test_files/test.txt"),
}
print("\n\n直接url访问\n")

View File

@ -46,14 +46,14 @@ def test_recreate_vs():
assert len(docs) > 0
pprint(docs[0])
for doc in docs:
assert doc.metadata["source"] == path
assert doc.metadata["source"] == name
# list docs base on metadata
docs = kb.list_docs(metadata={"source": path})
docs = kb.list_docs(metadata={"source": name})
assert len(docs) > 0
for doc in docs:
assert doc.metadata["source"] == path
assert doc.metadata["source"] == name
def test_increament():
@ -74,7 +74,7 @@ def test_increament():
pprint(docs[0])
for doc in docs:
assert doc.metadata["source"] == os.path.join(doc_path, f)
assert doc.metadata["source"] == f
def test_prune_db():