update cli_demo.py
This commit is contained in:
parent
5bd664829e
commit
1c51d6cafc
|
|
@ -28,7 +28,8 @@ class LocalDocQA:
|
|||
embedding_device=EMBEDDING_DEVICE,
|
||||
llm_history_len: int = LLM_HISTORY_LEN,
|
||||
llm_model: str = LLM_MODEL,
|
||||
llm_device=LLM_DEVICE
|
||||
llm_device=LLM_DEVICE,
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
):
|
||||
self.llm = ChatGLM()
|
||||
self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
|
||||
|
|
@ -38,6 +39,7 @@ class LocalDocQA:
|
|||
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], )
|
||||
self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
|
||||
device=embedding_device)
|
||||
self.top_k = top_k
|
||||
|
||||
def init_knowledge_vector_store(self,
|
||||
filepath: str):
|
||||
|
|
@ -65,15 +67,14 @@ class LocalDocQA:
|
|||
print(f"{file} 未能成功加载")
|
||||
|
||||
vector_store = FAISS.from_documents(docs, self.embeddings)
|
||||
vs_path = f"""./vector_store/{os.path.splitext(file)}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
|
||||
vs_path = f"""./vector_store/{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
|
||||
vector_store.save_local(vs_path)
|
||||
return vs_path
|
||||
|
||||
def get_knowledge_based_answer(self,
|
||||
query,
|
||||
vs_path,
|
||||
chat_history=[],
|
||||
top_k=VECTOR_SEARCH_TOP_K):
|
||||
chat_history=[],):
|
||||
prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。
|
||||
如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
|
||||
|
||||
|
|
@ -90,7 +91,7 @@ class LocalDocQA:
|
|||
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
||||
knowledge_chain = RetrievalQA.from_llm(
|
||||
llm=self.llm,
|
||||
retriever=vector_store.as_retriever(search_kwargs={"k": top_k}),
|
||||
retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),
|
||||
prompt=prompt
|
||||
)
|
||||
knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ if __name__ == "__main__":
|
|||
local_doc_qa.init_cfg(llm_model=LLM_MODEL,
|
||||
embedding_model=EMBEDDING_MODEL,
|
||||
embedding_device=EMBEDDING_DEVICE,
|
||||
llm_history_len=LLM_HISTORY_LEN)
|
||||
llm_history_len=LLM_HISTORY_LEN,
|
||||
top_k=VECTOR_SEARCH_TOP_K)
|
||||
vs_path = None
|
||||
while not vs_path:
|
||||
filepath = input("Input your local knowledge file path 请输入本地知识文件路径:")
|
||||
|
|
|
|||
Loading…
Reference in New Issue