提交更
This commit is contained in:
parent
d35eb12078
commit
37ceeae6e2
|
|
@ -117,22 +117,25 @@ class LocalDocQA:
|
||||||
|
|
||||||
问题:
|
问题:
|
||||||
{question}"""
|
{question}"""
|
||||||
prompt = PromptTemplate(
|
if vs_path is None or vs_path =="":# or (not os.path.exists(vs_path))
|
||||||
template=prompt_template,
|
result = self.llm.chat(query)
|
||||||
input_variables=["context", "question"]
|
else:
|
||||||
)
|
prompt = PromptTemplate(
|
||||||
self.llm.history = chat_history
|
template=prompt_template,
|
||||||
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
input_variables=["context", "question"]
|
||||||
knowledge_chain = RetrievalQA.from_llm(
|
)
|
||||||
llm=self.llm,
|
self.llm.history = chat_history
|
||||||
retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),
|
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
||||||
prompt=prompt
|
knowledge_chain = RetrievalQA.from_llm(
|
||||||
)
|
llm=self.llm,
|
||||||
knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
|
retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),
|
||||||
input_variables=["page_content"], template="{page_content}"
|
prompt=prompt
|
||||||
)
|
)
|
||||||
|
knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
|
||||||
|
input_variables=["page_content"], template="{page_content}"
|
||||||
|
)
|
||||||
|
|
||||||
knowledge_chain.return_source_documents = True
|
knowledge_chain.return_source_documents = True
|
||||||
|
|
||||||
result = knowledge_chain({"query": query})
|
result = knowledge_chain({"query": query})
|
||||||
self.llm.history[-1][0] = query
|
self.llm.history[-1][0] = query
|
||||||
|
|
|
||||||
|
|
@ -19,10 +19,11 @@ llm_model_dict = {
|
||||||
"chatglm-6b-int4-qe": "THUDM/chatglm-6b-int4-qe",
|
"chatglm-6b-int4-qe": "THUDM/chatglm-6b-int4-qe",
|
||||||
"chatglm-6b-int4": "THUDM/chatglm-6b-int4",
|
"chatglm-6b-int4": "THUDM/chatglm-6b-int4",
|
||||||
"chatglm-6b": "THUDM/chatglm-6b",
|
"chatglm-6b": "THUDM/chatglm-6b",
|
||||||
|
"chatyuan": "ClueAI/ChatYuan-large-v2",
|
||||||
}
|
}
|
||||||
|
|
||||||
# LLM model name
|
# LLM model name
|
||||||
LLM_MODEL = "chatglm-6b"
|
LLM_MODEL = "chatyuan" #"chatglm-6b"
|
||||||
|
|
||||||
# Use p-tuning-v2 PrefixEncoder
|
# Use p-tuning-v2 PrefixEncoder
|
||||||
USE_PTUNING_V2 = False
|
USE_PTUNING_V2 = False
|
||||||
|
|
|
||||||
|
|
@ -82,6 +82,19 @@ class ChatGLM(LLM):
|
||||||
self.history = self.history+[[None, response]]
|
self.history = self.history+[[None, response]]
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
def chat(self,
|
||||||
|
prompt: str) -> str:
|
||||||
|
response, _ = self.model.chat(
|
||||||
|
self.tokenizer,
|
||||||
|
prompt,
|
||||||
|
history=[],#self.history[-self.history_len:] if self.history_len>0 else
|
||||||
|
max_length=self.max_token,
|
||||||
|
temperature=self.temperature,
|
||||||
|
)
|
||||||
|
torch_gc()
|
||||||
|
self.history = self.history+[[None, response]]
|
||||||
|
return response
|
||||||
|
|
||||||
def load_model(self,
|
def load_model(self,
|
||||||
model_name_or_path: str = "THUDM/chatglm-6b",
|
model_name_or_path: str = "THUDM/chatglm-6b",
|
||||||
llm_device=LLM_DEVICE,
|
llm_device=LLM_DEVICE,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue