diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index da83da7..02e0958 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -10,6 +10,7 @@ from langchain.docstore.document import Document import numpy as np from utils import torch_gc from tqdm import tqdm +from pypinyin import lazy_pinyin DEVICE_ = EMBEDDING_DEVICE @@ -76,14 +77,14 @@ def similarity_search_with_score_by_vector( doc = self.docstore.search(_id) id_set.add(i) docs_len = len(doc.page_content) - for k in range(1, max(i, store_len-i)): + for k in range(1, max(i, store_len - i)): break_flag = False for l in [i + k, i - k]: if 0 <= l < len(self.index_to_docstore_id): _id0 = self.index_to_docstore_id[l] doc0 = self.docstore.search(_id0) if docs_len + len(doc0.page_content) > self.chunk_size: - break_flag=True + break_flag = True break elif doc0.metadata["source"] == doc.metadata["source"]: docs_len += len(doc0.page_content) @@ -166,7 +167,7 @@ class LocalDocQA: if len(failed_files) > 0: print("以下文件未能成功加载:") for file in failed_files: - print(file,end="\n") + print(file, end="\n") else: docs = [] @@ -187,7 +188,7 @@ class LocalDocQA: else: if not vs_path: vs_path = os.path.join(VS_ROOT_PATH, - f"""{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""") + f"""{"".join(lazy_pinyin(os.path.splitext(file)[0]))}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""") vector_store = FAISS.from_documents(docs, self.embeddings) torch_gc() diff --git a/requirements.txt b/requirements.txt index 56ccfd8..1efce0f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,5 @@ gradio==3.28.3 fastapi uvicorn peft +pypinyin #detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2