diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index a15be3d..6cedca7 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -117,22 +117,25 @@ class LocalDocQA: 问题: {question}""" - prompt = PromptTemplate( - template=prompt_template, - input_variables=["context", "question"] - ) - self.llm.history = chat_history - vector_store = FAISS.load_local(vs_path, self.embeddings) - knowledge_chain = RetrievalQA.from_llm( - llm=self.llm, - retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}), - prompt=prompt - ) - knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate( - input_variables=["page_content"], template="{page_content}" - ) + if vs_path is None or vs_path =="":# or (not os.path.exists(vs_path)) + result = self.llm.chat(query) + else: + prompt = PromptTemplate( + template=prompt_template, + input_variables=["context", "question"] + ) + self.llm.history = chat_history + vector_store = FAISS.load_local(vs_path, self.embeddings) + knowledge_chain = RetrievalQA.from_llm( + llm=self.llm, + retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}), + prompt=prompt + ) + knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate( + input_variables=["page_content"], template="{page_content}" + ) - knowledge_chain.return_source_documents = True + knowledge_chain.return_source_documents = True result = knowledge_chain({"query": query}) self.llm.history[-1][0] = query diff --git a/configs/model_config.py b/configs/model_config.py index 79baa2e..c7e724d 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -19,10 +19,11 @@ llm_model_dict = { "chatglm-6b-int4-qe": "THUDM/chatglm-6b-int4-qe", "chatglm-6b-int4": "THUDM/chatglm-6b-int4", "chatglm-6b": "THUDM/chatglm-6b", + "chatyuan": "ClueAI/ChatYuan-large-v2", } # LLM model name -LLM_MODEL = "chatglm-6b" +LLM_MODEL = "chatyuan" #"chatglm-6b" # Use p-tuning-v2 PrefixEncoder USE_PTUNING_V2 = False diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py index c3d1a21..20e98a5 100644 --- a/models/chatglm_llm.py +++ b/models/chatglm_llm.py @@ -82,6 +82,19 @@ class ChatGLM(LLM): self.history = self.history+[[None, response]] return response + def chat(self, + prompt: str) -> str: + response, _ = self.model.chat( + self.tokenizer, + prompt, + history=[],#self.history[-self.history_len:] if self.history_len>0 else + max_length=self.max_token, + temperature=self.temperature, + ) + torch_gc() + self.history = self.history+[[None, response]] + return response + def load_model(self, model_name_or_path: str = "THUDM/chatglm-6b", llm_device=LLM_DEVICE,