提交更

This commit is contained in:
soon 2023-04-21 08:35:21 +08:00 committed by imClumsyPanda
parent d35eb12078
commit 37ceeae6e2
3 changed files with 33 additions and 16 deletions

View File

@ -117,22 +117,25 @@ class LocalDocQA:
问题:
{question}"""
prompt = PromptTemplate(
template=prompt_template,
input_variables=["context", "question"]
)
self.llm.history = chat_history
vector_store = FAISS.load_local(vs_path, self.embeddings)
knowledge_chain = RetrievalQA.from_llm(
llm=self.llm,
retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),
prompt=prompt
)
knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
input_variables=["page_content"], template="{page_content}"
)
if vs_path is None or vs_path =="":# or (not os.path.exists(vs_path))
result = self.llm.chat(query)
else:
prompt = PromptTemplate(
template=prompt_template,
input_variables=["context", "question"]
)
self.llm.history = chat_history
vector_store = FAISS.load_local(vs_path, self.embeddings)
knowledge_chain = RetrievalQA.from_llm(
llm=self.llm,
retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),
prompt=prompt
)
knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
input_variables=["page_content"], template="{page_content}"
)
knowledge_chain.return_source_documents = True
knowledge_chain.return_source_documents = True
result = knowledge_chain({"query": query})
self.llm.history[-1][0] = query

View File

@ -19,10 +19,11 @@ llm_model_dict = {
"chatglm-6b-int4-qe": "THUDM/chatglm-6b-int4-qe",
"chatglm-6b-int4": "THUDM/chatglm-6b-int4",
"chatglm-6b": "THUDM/chatglm-6b",
"chatyuan": "ClueAI/ChatYuan-large-v2",
}
# LLM model name
LLM_MODEL = "chatglm-6b"
LLM_MODEL = "chatyuan" #"chatglm-6b"
# Use p-tuning-v2 PrefixEncoder
USE_PTUNING_V2 = False

View File

@ -82,6 +82,19 @@ class ChatGLM(LLM):
self.history = self.history+[[None, response]]
return response
def chat(self,
prompt: str) -> str:
response, _ = self.model.chat(
self.tokenizer,
prompt,
history=[],#self.history[-self.history_len:] if self.history_len>0 else
max_length=self.max_token,
temperature=self.temperature,
)
torch_gc()
self.history = self.history+[[None, response]]
return response
def load_model(self,
model_name_or_path: str = "THUDM/chatglm-6b",
llm_device=LLM_DEVICE,