From aae4144476ac2414c23b7bc01958cf20b35f3d13 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Thu, 23 Nov 2023 19:54:00 +0800 Subject: [PATCH] =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E5=92=8C=E5=90=91?= =?UTF-8?q?=E9=87=8F=E5=BA=93=E4=B8=AD=E6=96=87=E6=A1=A3=20metadata["sourc?= =?UTF-8?q?e"]=20=E6=94=B9=E4=B8=BA=E7=9B=B8=E5=AF=B9=E8=B7=AF=E5=BE=84?= =?UTF-8?q?=EF=BC=8C=E4=BE=BF=E4=BA=8E=E5=90=91=E9=87=8F=E5=BA=93=E8=BF=81?= =?UTF-8?q?=E7=A7=BB=20(#2153)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复: - 上传知识库文件名称包括子目录时,自动创建子目录 --- configs/basic_config.py.example | 2 +- document_loaders/FilteredCSVloader.py | 3 ++- server/chat/file_chat.py | 6 ++++-- server/chat/knowledge_base_chat.py | 3 +-- server/knowledge_base/kb_doc_api.py | 2 ++ server/knowledge_base/kb_service/base.py | 12 ++++++++++-- server/knowledge_base/kb_service/faiss_kb_service.py | 2 +- server/knowledge_base/utils.py | 4 ---- tests/api/test_kb_api.py | 6 +++--- tests/test_migrate.py | 8 ++++---- 10 files changed, 28 insertions(+), 20 deletions(-) diff --git a/configs/basic_config.py.example b/configs/basic_config.py.example index c3eab3c..a22fb97 100644 --- a/configs/basic_config.py.example +++ b/configs/basic_config.py.example @@ -27,4 +27,4 @@ if not os.path.exists(LOG_PATH): BASE_TEMP_DIR = os.path.join(tempfile.gettempdir(), "chatchat") if os.path.isdir(BASE_TEMP_DIR): shutil.rmtree(BASE_TEMP_DIR) -os.makedirs(BASE_TEMP_DIR) +os.makedirs(BASE_TEMP_DIR, exist_ok=True) diff --git a/document_loaders/FilteredCSVloader.py b/document_loaders/FilteredCSVloader.py index 0f8148d..d9ca508 100644 --- a/document_loaders/FilteredCSVloader.py +++ b/document_loaders/FilteredCSVloader.py @@ -54,6 +54,7 @@ class FilteredCSVLoader(CSVLoader): raise RuntimeError(f"Error loading {self.file_path}") from e return docs + def __read_file(self, csvfile: TextIOWrapper) -> List[Document]: docs = [] csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore @@ -77,4 +78,4 @@ class FilteredCSVLoader(CSVLoader): else: raise ValueError(f"Column '{self.columns_to_read[0]}' not found in CSV file.") - return docs \ No newline at end of file + return docs diff --git a/server/chat/file_chat.py b/server/chat/file_chat.py index efe89cf..ea3475a 100644 --- a/server/chat/file_chat.py +++ b/server/chat/file_chat.py @@ -37,6 +37,9 @@ def _parse_files_in_thread( filename = file.filename file_path = os.path.join(dir, filename) file_content = file.file.read() # 读取上传文件的内容 + + if not os.path.isdir(os.path.dirname(file_path)): + os.makedirs(os.path.dirname(file_path)) with open(file_path, "wb") as f: f.write(file_content) kb_file = KnowledgeFile(filename=filename, knowledge_base_name="temp") @@ -141,9 +144,8 @@ async def file_chat(query: str = Body(..., description="用户输入", examples= ) source_documents = [] - doc_path = get_temp_dir(knowledge_id)[0] for inum, doc in enumerate(docs): - filename = Path(doc.metadata["source"]).resolve().relative_to(doc_path) + filename = doc.metadata.get("source") text = f"""出处 [{inum + 1}] [{filename}] \n\n{doc.page_content}\n\n""" source_documents.append(text) diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index b607b1a..0ea99a6 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -74,9 +74,8 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", ) source_documents = [] - doc_path = get_doc_path(knowledge_base_name) for inum, doc in enumerate(docs): - filename = Path(doc.metadata["source"]).resolve().relative_to(doc_path) + filename = doc.metadata.get("source") parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename}) base_url = request.base_url url = f"{base_url}knowledge_base/download_doc?" + parameters diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index c086d00..df574cd 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -81,6 +81,8 @@ def _save_files_in_thread(files: List[UploadFile], logger.warn(file_status) return dict(code=404, msg=file_status, data=data) + if not os.path.isdir(os.path.dirname(file_path)): + os.makedirs(os.path.dirname(file_path)) with open(file_path, "wb") as f: f.write(file_content) return dict(code=200, msg=f"成功上传文件 {filename}", data=data) diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index fa079f6..0af8de5 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -2,7 +2,7 @@ import operator from abc import ABC, abstractmethod import os - +from pathlib import Path import numpy as np from langchain.embeddings.base import Embeddings from langchain.docstore.document import Document @@ -111,12 +111,20 @@ class KBService(ABC): if docs: custom_docs = True for doc in docs: - doc.metadata.setdefault("source", kb_file.filepath) + doc.metadata.setdefault("source", kb_file.filename) else: docs = kb_file.file2text() custom_docs = False if docs: + # 将 metadata["source"] 改为相对路径 + for doc in docs: + try: + source = doc.metadata.get("source", "") + rel_path = Path(source).relative_to(self.doc_path) + doc.metadata["source"] = str(rel_path.as_posix().strip("/")) + except Exception as e: + print(f"cannot convert absolute path ({source}) to relative path. error is : {e}") self.delete_doc(kb_file) doc_infos = self.do_add_doc(docs, **kwargs) status = add_file_to_db(kb_file, diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index a444f03..231c0e3 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -83,7 +83,7 @@ class FaissKBService(KBService): kb_file: KnowledgeFile, **kwargs): with self.load_vector_store().acquire() as vs: - ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source") == kb_file.filepath] + ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source") == kb_file.filename] if len(ids) > 0: vs.delete(ids) if not kwargs.get("not_refresh_vs_cache"): diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 28863c3..f3781a2 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -298,10 +298,6 @@ class KnowledgeFile: chunk_overlap=chunk_overlap) if self.text_splitter_name == "MarkdownHeaderTextSplitter": docs = text_splitter.split_text(docs[0].page_content) - for doc in docs: - # 如果文档有元数据 - if doc.metadata: - doc.metadata["source"] = os.path.basename(self.filepath) else: docs = text_splitter.split_documents(docs) diff --git a/tests/api/test_kb_api.py b/tests/api/test_kb_api.py index 82ef3a1..c404d22 100644 --- a/tests/api/test_kb_api.py +++ b/tests/api/test_kb_api.py @@ -17,9 +17,9 @@ api_base_url = api_address() kb = "kb_for_api_test" test_files = { - "FAQ.MD": str(root_path / "docs" / "FAQ.MD"), - "README.MD": str(root_path / "README.MD"), - "test.txt": get_file_path("samples", "test.txt"), + "wiki/Home.MD": get_file_path("samples", "wiki/Home.md"), + "wiki/开发环境部署.MD": get_file_path("samples", "wiki/开发环境部署.md"), + "test_files/test.txt": get_file_path("samples", "test_files/test.txt"), } print("\n\n直接url访问\n") diff --git a/tests/test_migrate.py b/tests/test_migrate.py index 7195dd3..b794b02 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -46,14 +46,14 @@ def test_recreate_vs(): assert len(docs) > 0 pprint(docs[0]) for doc in docs: - assert doc.metadata["source"] == path + assert doc.metadata["source"] == name # list docs base on metadata - docs = kb.list_docs(metadata={"source": path}) + docs = kb.list_docs(metadata={"source": name}) assert len(docs) > 0 for doc in docs: - assert doc.metadata["source"] == path + assert doc.metadata["source"] == name def test_increament(): @@ -74,7 +74,7 @@ def test_increament(): pprint(docs[0]) for doc in docs: - assert doc.metadata["source"] == os.path.join(doc_path, f) + assert doc.metadata["source"] == f def test_prune_db():