update local_doc_qa.py
This commit is contained in:
parent
e8b2ddea51
commit
c613c41d0e
|
|
@ -10,6 +10,7 @@ from langchain.docstore.document import Document
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from utils import torch_gc
|
from utils import torch_gc
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from pypinyin import lazy_pinyin
|
||||||
|
|
||||||
|
|
||||||
DEVICE_ = EMBEDDING_DEVICE
|
DEVICE_ = EMBEDDING_DEVICE
|
||||||
|
|
@ -76,14 +77,14 @@ def similarity_search_with_score_by_vector(
|
||||||
doc = self.docstore.search(_id)
|
doc = self.docstore.search(_id)
|
||||||
id_set.add(i)
|
id_set.add(i)
|
||||||
docs_len = len(doc.page_content)
|
docs_len = len(doc.page_content)
|
||||||
for k in range(1, max(i, store_len-i)):
|
for k in range(1, max(i, store_len - i)):
|
||||||
break_flag = False
|
break_flag = False
|
||||||
for l in [i + k, i - k]:
|
for l in [i + k, i - k]:
|
||||||
if 0 <= l < len(self.index_to_docstore_id):
|
if 0 <= l < len(self.index_to_docstore_id):
|
||||||
_id0 = self.index_to_docstore_id[l]
|
_id0 = self.index_to_docstore_id[l]
|
||||||
doc0 = self.docstore.search(_id0)
|
doc0 = self.docstore.search(_id0)
|
||||||
if docs_len + len(doc0.page_content) > self.chunk_size:
|
if docs_len + len(doc0.page_content) > self.chunk_size:
|
||||||
break_flag=True
|
break_flag = True
|
||||||
break
|
break
|
||||||
elif doc0.metadata["source"] == doc.metadata["source"]:
|
elif doc0.metadata["source"] == doc.metadata["source"]:
|
||||||
docs_len += len(doc0.page_content)
|
docs_len += len(doc0.page_content)
|
||||||
|
|
@ -166,7 +167,7 @@ class LocalDocQA:
|
||||||
if len(failed_files) > 0:
|
if len(failed_files) > 0:
|
||||||
print("以下文件未能成功加载:")
|
print("以下文件未能成功加载:")
|
||||||
for file in failed_files:
|
for file in failed_files:
|
||||||
print(file,end="\n")
|
print(file, end="\n")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
docs = []
|
docs = []
|
||||||
|
|
@ -187,7 +188,7 @@ class LocalDocQA:
|
||||||
else:
|
else:
|
||||||
if not vs_path:
|
if not vs_path:
|
||||||
vs_path = os.path.join(VS_ROOT_PATH,
|
vs_path = os.path.join(VS_ROOT_PATH,
|
||||||
f"""{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""")
|
f"""{"".join(lazy_pinyin(os.path.splitext(file)[0]))}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""")
|
||||||
vector_store = FAISS.from_documents(docs, self.embeddings)
|
vector_store = FAISS.from_documents(docs, self.embeddings)
|
||||||
torch_gc()
|
torch_gc()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,4 +13,5 @@ gradio==3.28.3
|
||||||
fastapi
|
fastapi
|
||||||
uvicorn
|
uvicorn
|
||||||
peft
|
peft
|
||||||
|
pypinyin
|
||||||
#detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2
|
#detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue