diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index 84f820c..97c4e65 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -116,10 +116,12 @@ class LocalDocQA: llm_history_len: int = LLM_HISTORY_LEN, llm_model: str = LLM_MODEL, llm_device=LLM_DEVICE, + streaming=STREAMING, top_k=VECTOR_SEARCH_TOP_K, use_ptuning_v2: bool = USE_PTUNING_V2 ): self.llm = ChatGLM() + self.llm.streaming = streaming self.llm.load_model(model_name_or_path=llm_model_dict[llm_model], llm_device=llm_device, use_ptuning_v2=use_ptuning_v2) @@ -186,9 +188,7 @@ class LocalDocQA: def get_knowledge_based_answer(self, query, vs_path, - chat_history=[], - streaming=True): - self.llm.streaming = streaming + chat_history=[]): vector_store = FAISS.load_local(vs_path, self.embeddings) FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector vector_store.chunk_size=self.chunk_size @@ -197,7 +197,7 @@ class LocalDocQA: related_docs = get_docs_with_score(related_docs_with_score) prompt = generate_prompt(related_docs, query) - if streaming: + if self.llm.streaming: for result, history in self.llm._call(prompt=prompt, history=chat_history): history[-1][0] = query diff --git a/configs/model_config.py b/configs/model_config.py index 1afc537..20bea5f 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -27,6 +27,9 @@ llm_model_dict = { # LLM model name LLM_MODEL = "chatglm-6b" +# LLM streaming reponse +STREAMING = True + # Use p-tuning-v2 PrefixEncoder USE_PTUNING_V2 = False diff --git a/webui.py b/webui.py index aaeb734..4409d07 100644 --- a/webui.py +++ b/webui.py @@ -30,8 +30,8 @@ local_doc_qa = LocalDocQA() def get_answer(query, vs_path, history, mode): - if mode == "知识库问答": - if vs_path: + if mode == "知识库问答" and vs_path: + if local_doc_qa.llm.streaming: for resp, history in local_doc_qa.get_knowledge_based_answer( query=query, vs_path=vs_path, chat_history=history): source = "\n\n" @@ -44,14 +44,28 @@ def get_answer(query, vs_path, history, mode): history[-1][-1] += source yield history, "" else: + resp, history = local_doc_qa.get_knowledge_based_answer( + query=query, vs_path=vs_path, chat_history=history) + source = "\n\n" + source += "".join( + [f"""
出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}\n""" + f"""{doc.page_content}\n""" + f"""
""" + for i, doc in + enumerate(resp["source_documents"])]) + history[-1][-1] += source + return history, "" + else: + if local_doc_qa.llm.streaming: for resp, history in local_doc_qa.llm._call(query, history): history[-1][-1] = resp + ( "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "") yield history, "" - else: - for resp, history in local_doc_qa.llm._call(query, history): - history[-1][-1] = resp - yield history, "" + else: + resp, history = local_doc_qa.llm._call(query, history) + history[-1][-1] = resp + ( + "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "") + return history, "" def update_status(history, status): @@ -62,7 +76,7 @@ def update_status(history, status): def init_model(): try: - local_doc_qa.init_cfg() + local_doc_qa.init_cfg(streaming=STREAMING) local_doc_qa.llm._call("你好") reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话""" print(reply) @@ -84,7 +98,8 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, to embedding_model=embedding_model, llm_history_len=llm_history_len, use_ptuning_v2=use_ptuning_v2, - top_k=top_k) + top_k=top_k, + streaming=STREAMING) model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话""" print(model_status) except Exception as e: