diff --git a/api.py b/api.py index 4c93492..ae26c51 100644 --- a/api.py +++ b/api.py @@ -54,10 +54,11 @@ async def upload_file(UserFile: UploadFile=File(...)): # print(UserFile.filename) with open(filepath, 'wb') as f: f.write(content) - vs_path = local_doc_qa.init_knowledge_vector_store(filepath) + vs_path, files = local_doc_qa.init_knowledge_vector_store(filepath) response = { - 'msg': 'seccessful', - 'status': 1 + 'msg': 'seccess' if len(files)>0 else 'fail', + 'status': 1 if len(files)>0 else 0, + 'loaded_files': files } except Exception as err: diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index f9b7207..56aea4a 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -53,7 +53,9 @@ class LocalDocQA: self.top_k = top_k def init_knowledge_vector_store(self, - filepath: str or List[str]): + filepath: str or List[str], + vs_path: str or os.PathLike = None): + loaded_files = [] if isinstance(filepath, str): if not os.path.exists(filepath): print("路径不存在") @@ -63,6 +65,7 @@ class LocalDocQA: try: docs = load_file(filepath) print(f"{file} 已成功加载") + loaded_files.append(filepath) except Exception as e: print(e) print(f"{file} 未能成功加载") @@ -74,6 +77,7 @@ class LocalDocQA: try: docs += load_file(fullfilepath) print(f"{file} 已成功加载") + loaded_files.append(fullfilepath) except Exception as e: print(e) print(f"{file} 未能成功加载") @@ -83,14 +87,21 @@ class LocalDocQA: try: docs += load_file(file) print(f"{file} 已成功加载") + loaded_files.append(file) except Exception as e: print(e) print(f"{file} 未能成功加载") - vector_store = FAISS.from_documents(docs, self.embeddings) - vs_path = f"""./vector_store/{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""" + if vs_path and os.path.isdir(vs_path): + vector_store = FAISS.load_local(vs_path, self.embeddings) + vector_store.add_documents(docs) + else: + if not vs_path: + vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""" + vector_store = FAISS.from_documents(docs, self.embeddings) + vector_store.save_local(vs_path) - return vs_path if len(docs) > 0 else None + return vs_path if len(docs) > 0 else None, loaded_files def get_knowledge_based_answer(self, query, diff --git a/cli_demo.py b/cli_demo.py index d0bff0c..b594380 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -24,7 +24,7 @@ if __name__ == "__main__": vs_path = None while not vs_path: filepath = input("Input your local knowledge file path 请输入本地知识文件路径:") - vs_path = local_doc_qa.init_knowledge_vector_store(filepath) + vs_path, _ = local_doc_qa.init_knowledge_vector_store(filepath) history = [] while True: query = input("Input your question 请输入问题:")