diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index 6cedca7..1b968cc 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -117,26 +117,23 @@ class LocalDocQA: 问题: {question}""" - if vs_path is None or vs_path =="":# or (not os.path.exists(vs_path)) - result = self.llm.chat(query) - else: - prompt = PromptTemplate( - template=prompt_template, - input_variables=["context", "question"] - ) - self.llm.history = chat_history - vector_store = FAISS.load_local(vs_path, self.embeddings) - knowledge_chain = RetrievalQA.from_llm( - llm=self.llm, - retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}), - prompt=prompt - ) - knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate( - input_variables=["page_content"], template="{page_content}" - ) - - knowledge_chain.return_source_documents = True + prompt = PromptTemplate( + template=prompt_template, + input_variables=["context", "question"] + ) + self.llm.history = chat_history + vector_store = FAISS.load_local(vs_path, self.embeddings) + knowledge_chain = RetrievalQA.from_llm( + llm=self.llm, + retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}), + prompt=prompt + ) + knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate( + input_variables=["page_content"], template="{page_content}" + ) + knowledge_chain.return_source_documents = True result = knowledge_chain({"query": query}) + self.llm.history[-1][0] = query return result, self.llm.history