数据库和向量库中文档 metadata["source"] 改为相对路径,便于向量库迁移 (#2153)
修复: - 上传知识库文件名称包括子目录时,自动创建子目录
This commit is contained in:
parent
7a85fe74e9
commit
aae4144476
|
|
@ -27,4 +27,4 @@ if not os.path.exists(LOG_PATH):
|
|||
BASE_TEMP_DIR = os.path.join(tempfile.gettempdir(), "chatchat")
|
||||
if os.path.isdir(BASE_TEMP_DIR):
|
||||
shutil.rmtree(BASE_TEMP_DIR)
|
||||
os.makedirs(BASE_TEMP_DIR)
|
||||
os.makedirs(BASE_TEMP_DIR, exist_ok=True)
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ class FilteredCSVLoader(CSVLoader):
|
|||
raise RuntimeError(f"Error loading {self.file_path}") from e
|
||||
|
||||
return docs
|
||||
|
||||
def __read_file(self, csvfile: TextIOWrapper) -> List[Document]:
|
||||
docs = []
|
||||
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
|
||||
|
|
@ -77,4 +78,4 @@ class FilteredCSVLoader(CSVLoader):
|
|||
else:
|
||||
raise ValueError(f"Column '{self.columns_to_read[0]}' not found in CSV file.")
|
||||
|
||||
return docs
|
||||
return docs
|
||||
|
|
|
|||
|
|
@ -37,6 +37,9 @@ def _parse_files_in_thread(
|
|||
filename = file.filename
|
||||
file_path = os.path.join(dir, filename)
|
||||
file_content = file.file.read() # 读取上传文件的内容
|
||||
|
||||
if not os.path.isdir(os.path.dirname(file_path)):
|
||||
os.makedirs(os.path.dirname(file_path))
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
kb_file = KnowledgeFile(filename=filename, knowledge_base_name="temp")
|
||||
|
|
@ -141,9 +144,8 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
|
|||
)
|
||||
|
||||
source_documents = []
|
||||
doc_path = get_temp_dir(knowledge_id)[0]
|
||||
for inum, doc in enumerate(docs):
|
||||
filename = Path(doc.metadata["source"]).resolve().relative_to(doc_path)
|
||||
filename = doc.metadata.get("source")
|
||||
text = f"""出处 [{inum + 1}] [{filename}] \n\n{doc.page_content}\n\n"""
|
||||
source_documents.append(text)
|
||||
|
||||
|
|
|
|||
|
|
@ -74,9 +74,8 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
)
|
||||
|
||||
source_documents = []
|
||||
doc_path = get_doc_path(knowledge_base_name)
|
||||
for inum, doc in enumerate(docs):
|
||||
filename = Path(doc.metadata["source"]).resolve().relative_to(doc_path)
|
||||
filename = doc.metadata.get("source")
|
||||
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename})
|
||||
base_url = request.base_url
|
||||
url = f"{base_url}knowledge_base/download_doc?" + parameters
|
||||
|
|
|
|||
|
|
@ -81,6 +81,8 @@ def _save_files_in_thread(files: List[UploadFile],
|
|||
logger.warn(file_status)
|
||||
return dict(code=404, msg=file_status, data=data)
|
||||
|
||||
if not os.path.isdir(os.path.dirname(file_path)):
|
||||
os.makedirs(os.path.dirname(file_path))
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
return dict(code=200, msg=f"成功上传文件 {filename}", data=data)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import operator
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
import os
|
||||
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.docstore.document import Document
|
||||
|
|
@ -111,12 +111,20 @@ class KBService(ABC):
|
|||
if docs:
|
||||
custom_docs = True
|
||||
for doc in docs:
|
||||
doc.metadata.setdefault("source", kb_file.filepath)
|
||||
doc.metadata.setdefault("source", kb_file.filename)
|
||||
else:
|
||||
docs = kb_file.file2text()
|
||||
custom_docs = False
|
||||
|
||||
if docs:
|
||||
# 将 metadata["source"] 改为相对路径
|
||||
for doc in docs:
|
||||
try:
|
||||
source = doc.metadata.get("source", "")
|
||||
rel_path = Path(source).relative_to(self.doc_path)
|
||||
doc.metadata["source"] = str(rel_path.as_posix().strip("/"))
|
||||
except Exception as e:
|
||||
print(f"cannot convert absolute path ({source}) to relative path. error is : {e}")
|
||||
self.delete_doc(kb_file)
|
||||
doc_infos = self.do_add_doc(docs, **kwargs)
|
||||
status = add_file_to_db(kb_file,
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class FaissKBService(KBService):
|
|||
kb_file: KnowledgeFile,
|
||||
**kwargs):
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source") == kb_file.filepath]
|
||||
ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source") == kb_file.filename]
|
||||
if len(ids) > 0:
|
||||
vs.delete(ids)
|
||||
if not kwargs.get("not_refresh_vs_cache"):
|
||||
|
|
|
|||
|
|
@ -298,10 +298,6 @@ class KnowledgeFile:
|
|||
chunk_overlap=chunk_overlap)
|
||||
if self.text_splitter_name == "MarkdownHeaderTextSplitter":
|
||||
docs = text_splitter.split_text(docs[0].page_content)
|
||||
for doc in docs:
|
||||
# 如果文档有元数据
|
||||
if doc.metadata:
|
||||
doc.metadata["source"] = os.path.basename(self.filepath)
|
||||
else:
|
||||
docs = text_splitter.split_documents(docs)
|
||||
|
||||
|
|
|
|||
|
|
@ -17,9 +17,9 @@ api_base_url = api_address()
|
|||
|
||||
kb = "kb_for_api_test"
|
||||
test_files = {
|
||||
"FAQ.MD": str(root_path / "docs" / "FAQ.MD"),
|
||||
"README.MD": str(root_path / "README.MD"),
|
||||
"test.txt": get_file_path("samples", "test.txt"),
|
||||
"wiki/Home.MD": get_file_path("samples", "wiki/Home.md"),
|
||||
"wiki/开发环境部署.MD": get_file_path("samples", "wiki/开发环境部署.md"),
|
||||
"test_files/test.txt": get_file_path("samples", "test_files/test.txt"),
|
||||
}
|
||||
|
||||
print("\n\n直接url访问\n")
|
||||
|
|
|
|||
|
|
@ -46,14 +46,14 @@ def test_recreate_vs():
|
|||
assert len(docs) > 0
|
||||
pprint(docs[0])
|
||||
for doc in docs:
|
||||
assert doc.metadata["source"] == path
|
||||
assert doc.metadata["source"] == name
|
||||
|
||||
# list docs base on metadata
|
||||
docs = kb.list_docs(metadata={"source": path})
|
||||
docs = kb.list_docs(metadata={"source": name})
|
||||
assert len(docs) > 0
|
||||
|
||||
for doc in docs:
|
||||
assert doc.metadata["source"] == path
|
||||
assert doc.metadata["source"] == name
|
||||
|
||||
|
||||
def test_increament():
|
||||
|
|
@ -74,7 +74,7 @@ def test_increament():
|
|||
pprint(docs[0])
|
||||
|
||||
for doc in docs:
|
||||
assert doc.metadata["source"] == os.path.join(doc_path, f)
|
||||
assert doc.metadata["source"] == f
|
||||
|
||||
|
||||
def test_prune_db():
|
||||
|
|
|
|||
Loading…
Reference in New Issue