From 1c51d6cafc8af948fc37459fe35a619a9150c712 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Fri, 14 Apr 2023 00:06:45 +0800 Subject: [PATCH] update cli_demo.py --- chains/local_doc_qa.py | 11 ++++++----- cli_demo.py | 3 ++- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index 94be74b..25d411c 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -28,7 +28,8 @@ class LocalDocQA: embedding_device=EMBEDDING_DEVICE, llm_history_len: int = LLM_HISTORY_LEN, llm_model: str = LLM_MODEL, - llm_device=LLM_DEVICE + llm_device=LLM_DEVICE, + top_k=VECTOR_SEARCH_TOP_K, ): self.llm = ChatGLM() self.llm.load_model(model_name_or_path=llm_model_dict[llm_model], @@ -38,6 +39,7 @@ class LocalDocQA: self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], ) self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name, device=embedding_device) + self.top_k = top_k def init_knowledge_vector_store(self, filepath: str): @@ -65,15 +67,14 @@ class LocalDocQA: print(f"{file} 未能成功加载") vector_store = FAISS.from_documents(docs, self.embeddings) - vs_path = f"""./vector_store/{os.path.splitext(file)}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""" + vs_path = f"""./vector_store/{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""" vector_store.save_local(vs_path) return vs_path def get_knowledge_based_answer(self, query, vs_path, - chat_history=[], - top_k=VECTOR_SEARCH_TOP_K): + chat_history=[],): prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。 如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。 @@ -90,7 +91,7 @@ class LocalDocQA: vector_store = FAISS.load_local(vs_path, self.embeddings) knowledge_chain = RetrievalQA.from_llm( llm=self.llm, - retriever=vector_store.as_retriever(search_kwargs={"k": top_k}), + retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}), prompt=prompt ) knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate( diff --git a/cli_demo.py b/cli_demo.py index 093164d..cda072d 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -15,7 +15,8 @@ if __name__ == "__main__": local_doc_qa.init_cfg(llm_model=LLM_MODEL, embedding_model=EMBEDDING_MODEL, embedding_device=EMBEDDING_DEVICE, - llm_history_len=LLM_HISTORY_LEN) + llm_history_len=LLM_HISTORY_LEN, + top_k=VECTOR_SEARCH_TOP_K) vs_path = None while not vs_path: filepath = input("Input your local knowledge file path 请输入本地知识文件路径:")