update local_doc_qa.py
This commit is contained in:
parent
25ecc7f820
commit
7d4560e599
7
api.py
7
api.py
|
|
@ -54,10 +54,11 @@ async def upload_file(UserFile: UploadFile=File(...)):
|
|||
# print(UserFile.filename)
|
||||
with open(filepath, 'wb') as f:
|
||||
f.write(content)
|
||||
vs_path = local_doc_qa.init_knowledge_vector_store(filepath)
|
||||
vs_path, files = local_doc_qa.init_knowledge_vector_store(filepath)
|
||||
response = {
|
||||
'msg': 'seccessful',
|
||||
'status': 1
|
||||
'msg': 'seccess' if len(files)>0 else 'fail',
|
||||
'status': 1 if len(files)>0 else 0,
|
||||
'loaded_files': files
|
||||
}
|
||||
|
||||
except Exception as err:
|
||||
|
|
|
|||
|
|
@ -53,7 +53,9 @@ class LocalDocQA:
|
|||
self.top_k = top_k
|
||||
|
||||
def init_knowledge_vector_store(self,
|
||||
filepath: str or List[str]):
|
||||
filepath: str or List[str],
|
||||
vs_path: str or os.PathLike = None):
|
||||
loaded_files = []
|
||||
if isinstance(filepath, str):
|
||||
if not os.path.exists(filepath):
|
||||
print("路径不存在")
|
||||
|
|
@ -63,6 +65,7 @@ class LocalDocQA:
|
|||
try:
|
||||
docs = load_file(filepath)
|
||||
print(f"{file} 已成功加载")
|
||||
loaded_files.append(filepath)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(f"{file} 未能成功加载")
|
||||
|
|
@ -74,6 +77,7 @@ class LocalDocQA:
|
|||
try:
|
||||
docs += load_file(fullfilepath)
|
||||
print(f"{file} 已成功加载")
|
||||
loaded_files.append(fullfilepath)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(f"{file} 未能成功加载")
|
||||
|
|
@ -83,14 +87,21 @@ class LocalDocQA:
|
|||
try:
|
||||
docs += load_file(file)
|
||||
print(f"{file} 已成功加载")
|
||||
loaded_files.append(file)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(f"{file} 未能成功加载")
|
||||
|
||||
vector_store = FAISS.from_documents(docs, self.embeddings)
|
||||
vs_path = f"""./vector_store/{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
|
||||
if vs_path and os.path.isdir(vs_path):
|
||||
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
||||
vector_store.add_documents(docs)
|
||||
else:
|
||||
if not vs_path:
|
||||
vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
|
||||
vector_store = FAISS.from_documents(docs, self.embeddings)
|
||||
|
||||
vector_store.save_local(vs_path)
|
||||
return vs_path if len(docs) > 0 else None
|
||||
return vs_path if len(docs) > 0 else None, loaded_files
|
||||
|
||||
def get_knowledge_based_answer(self,
|
||||
query,
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ if __name__ == "__main__":
|
|||
vs_path = None
|
||||
while not vs_path:
|
||||
filepath = input("Input your local knowledge file path 请输入本地知识文件路径:")
|
||||
vs_path = local_doc_qa.init_knowledge_vector_store(filepath)
|
||||
vs_path, _ = local_doc_qa.init_knowledge_vector_store(filepath)
|
||||
history = []
|
||||
while True:
|
||||
query = input("Input your question 请输入问题:")
|
||||
|
|
|
|||
Loading…
Reference in New Issue