add streaming option in configs/model_config.py
This commit is contained in:
parent
2ebcd1369e
commit
0e8cc0d16c
|
|
@ -116,10 +116,12 @@ class LocalDocQA:
|
||||||
llm_history_len: int = LLM_HISTORY_LEN,
|
llm_history_len: int = LLM_HISTORY_LEN,
|
||||||
llm_model: str = LLM_MODEL,
|
llm_model: str = LLM_MODEL,
|
||||||
llm_device=LLM_DEVICE,
|
llm_device=LLM_DEVICE,
|
||||||
|
streaming=STREAMING,
|
||||||
top_k=VECTOR_SEARCH_TOP_K,
|
top_k=VECTOR_SEARCH_TOP_K,
|
||||||
use_ptuning_v2: bool = USE_PTUNING_V2
|
use_ptuning_v2: bool = USE_PTUNING_V2
|
||||||
):
|
):
|
||||||
self.llm = ChatGLM()
|
self.llm = ChatGLM()
|
||||||
|
self.llm.streaming = streaming
|
||||||
self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
|
self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
|
||||||
llm_device=llm_device,
|
llm_device=llm_device,
|
||||||
use_ptuning_v2=use_ptuning_v2)
|
use_ptuning_v2=use_ptuning_v2)
|
||||||
|
|
@ -186,9 +188,7 @@ class LocalDocQA:
|
||||||
def get_knowledge_based_answer(self,
|
def get_knowledge_based_answer(self,
|
||||||
query,
|
query,
|
||||||
vs_path,
|
vs_path,
|
||||||
chat_history=[],
|
chat_history=[]):
|
||||||
streaming=True):
|
|
||||||
self.llm.streaming = streaming
|
|
||||||
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
||||||
FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector
|
FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector
|
||||||
vector_store.chunk_size=self.chunk_size
|
vector_store.chunk_size=self.chunk_size
|
||||||
|
|
@ -197,7 +197,7 @@ class LocalDocQA:
|
||||||
related_docs = get_docs_with_score(related_docs_with_score)
|
related_docs = get_docs_with_score(related_docs_with_score)
|
||||||
prompt = generate_prompt(related_docs, query)
|
prompt = generate_prompt(related_docs, query)
|
||||||
|
|
||||||
if streaming:
|
if self.llm.streaming:
|
||||||
for result, history in self.llm._call(prompt=prompt,
|
for result, history in self.llm._call(prompt=prompt,
|
||||||
history=chat_history):
|
history=chat_history):
|
||||||
history[-1][0] = query
|
history[-1][0] = query
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,9 @@ llm_model_dict = {
|
||||||
# LLM model name
|
# LLM model name
|
||||||
LLM_MODEL = "chatglm-6b"
|
LLM_MODEL = "chatglm-6b"
|
||||||
|
|
||||||
|
# LLM streaming reponse
|
||||||
|
STREAMING = True
|
||||||
|
|
||||||
# Use p-tuning-v2 PrefixEncoder
|
# Use p-tuning-v2 PrefixEncoder
|
||||||
USE_PTUNING_V2 = False
|
USE_PTUNING_V2 = False
|
||||||
|
|
||||||
|
|
|
||||||
31
webui.py
31
webui.py
|
|
@ -30,8 +30,8 @@ local_doc_qa = LocalDocQA()
|
||||||
|
|
||||||
|
|
||||||
def get_answer(query, vs_path, history, mode):
|
def get_answer(query, vs_path, history, mode):
|
||||||
if mode == "知识库问答":
|
if mode == "知识库问答" and vs_path:
|
||||||
if vs_path:
|
if local_doc_qa.llm.streaming:
|
||||||
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
||||||
query=query, vs_path=vs_path, chat_history=history):
|
query=query, vs_path=vs_path, chat_history=history):
|
||||||
source = "\n\n"
|
source = "\n\n"
|
||||||
|
|
@ -44,14 +44,28 @@ def get_answer(query, vs_path, history, mode):
|
||||||
history[-1][-1] += source
|
history[-1][-1] += source
|
||||||
yield history, ""
|
yield history, ""
|
||||||
else:
|
else:
|
||||||
|
resp, history = local_doc_qa.get_knowledge_based_answer(
|
||||||
|
query=query, vs_path=vs_path, chat_history=history)
|
||||||
|
source = "\n\n"
|
||||||
|
source += "".join(
|
||||||
|
[f"""<details> <summary>出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}</summary>\n"""
|
||||||
|
f"""{doc.page_content}\n"""
|
||||||
|
f"""</details>"""
|
||||||
|
for i, doc in
|
||||||
|
enumerate(resp["source_documents"])])
|
||||||
|
history[-1][-1] += source
|
||||||
|
return history, ""
|
||||||
|
else:
|
||||||
|
if local_doc_qa.llm.streaming:
|
||||||
for resp, history in local_doc_qa.llm._call(query, history):
|
for resp, history in local_doc_qa.llm._call(query, history):
|
||||||
history[-1][-1] = resp + (
|
history[-1][-1] = resp + (
|
||||||
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
|
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
|
||||||
yield history, ""
|
yield history, ""
|
||||||
else:
|
else:
|
||||||
for resp, history in local_doc_qa.llm._call(query, history):
|
resp, history = local_doc_qa.llm._call(query, history)
|
||||||
history[-1][-1] = resp
|
history[-1][-1] = resp + (
|
||||||
yield history, ""
|
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
|
||||||
|
return history, ""
|
||||||
|
|
||||||
|
|
||||||
def update_status(history, status):
|
def update_status(history, status):
|
||||||
|
|
@ -62,7 +76,7 @@ def update_status(history, status):
|
||||||
|
|
||||||
def init_model():
|
def init_model():
|
||||||
try:
|
try:
|
||||||
local_doc_qa.init_cfg()
|
local_doc_qa.init_cfg(streaming=STREAMING)
|
||||||
local_doc_qa.llm._call("你好")
|
local_doc_qa.llm._call("你好")
|
||||||
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
|
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
|
||||||
print(reply)
|
print(reply)
|
||||||
|
|
@ -84,7 +98,8 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, to
|
||||||
embedding_model=embedding_model,
|
embedding_model=embedding_model,
|
||||||
llm_history_len=llm_history_len,
|
llm_history_len=llm_history_len,
|
||||||
use_ptuning_v2=use_ptuning_v2,
|
use_ptuning_v2=use_ptuning_v2,
|
||||||
top_k=top_k)
|
top_k=top_k,
|
||||||
|
streaming=STREAMING)
|
||||||
model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
|
model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
|
||||||
print(model_status)
|
print(model_status)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue