From 88ab9a1d214be01503d52be48b0ee255b917c112 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Tue, 25 Apr 2023 20:36:16 +0800 Subject: [PATCH] update webui.py and local_doc_qa.py --- chains/local_doc_qa.py | 43 +++++++++++++++++++++++++++++++----------- models/chatglm_llm.py | 10 +++++----- webui.py | 35 +++++++++++++++++++++------------- 3 files changed, 59 insertions(+), 29 deletions(-) diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index 8a45bd2..01bf520 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -1,9 +1,8 @@ from langchain.chains import RetrievalQA from langchain.prompts import PromptTemplate -# from langchain.embeddings.huggingface import HuggingFaceEmbeddings -from chains.lib.embeddings import MyEmbeddings -# from langchain.vectorstores import FAISS -from chains.lib.vectorstores import FAISSVS +from langchain.embeddings.huggingface import HuggingFaceEmbeddings +from langchain.vectorstores import FAISS +from langchain.vectorstores.base import VectorStoreRetriever from langchain.document_loaders import UnstructuredFileLoader from models.chatglm_llm import ChatGLM import sentence_transformers @@ -12,6 +11,7 @@ from configs.model_config import * import datetime from typing import List from textsplitter import ChineseTextSplitter +from langchain.docstore.document import Document # return top-k text chunk from vector store VECTOR_SEARCH_TOP_K = 6 @@ -21,7 +21,10 @@ LLM_HISTORY_LEN = 3 def load_file(filepath): - if filepath.lower().endswith(".pdf"): + if filepath.lower().endswith(".md"): + loader = UnstructuredFileLoader(filepath, mode="elements") + docs = loader.load() + elif filepath.lower().endswith(".pdf"): loader = UnstructuredFileLoader(filepath) textsplitter = ChineseTextSplitter(pdf=True) docs = loader.load_and_split(textsplitter) @@ -32,6 +35,22 @@ def load_file(filepath): return docs +def get_relevant_documents(self, query: str) -> List[Document]: + if self.search_type == "similarity": + docs = self.vectorstore._similarity_search_with_relevance_scores(query, **self.search_kwargs) + for doc in docs: + doc[0].metadata["score"] = doc[1] + docs = [doc[0] for doc in docs] + elif self.search_type == "mmr": + docs = self.vectorstore.max_marginal_relevance_search( + query, **self.search_kwargs + ) + else: + raise ValueError(f"search_type of {self.search_type} not allowed.") + return docs + + + class LocalDocQA: llm: object = None embeddings: object = None @@ -52,7 +71,7 @@ class LocalDocQA: use_ptuning_v2=use_ptuning_v2) self.llm.history_len = llm_history_len - self.embeddings = MyEmbeddings(model_name=embedding_model_dict[embedding_model], + self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], model_kwargs={'device': embedding_device}) # self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name, # device=embedding_device) @@ -99,12 +118,12 @@ class LocalDocQA: print(f"{file} 未能成功加载") if len(docs) > 0: if vs_path and os.path.isdir(vs_path): - vector_store = FAISSVS.load_local(vs_path, self.embeddings) + vector_store = FAISS.load_local(vs_path, self.embeddings) vector_store.add_documents(docs) else: if not vs_path: vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""" - vector_store = FAISSVS.from_documents(docs, self.embeddings) + vector_store = FAISS.from_documents(docs, self.embeddings) vector_store.save_local(vs_path) return vs_path, loaded_files @@ -129,10 +148,13 @@ class LocalDocQA: input_variables=["context", "question"] ) self.llm.history = chat_history - vector_store = FAISSVS.load_local(vs_path, self.embeddings) + vector_store = FAISS.load_local(vs_path, self.embeddings) + vs_r = vector_store.as_retriever(search_type="mmr", + search_kwargs={"k": self.top_k}) + # VectorStoreRetriever.get_relevant_documents = get_relevant_documents knowledge_chain = RetrievalQA.from_llm( llm=self.llm, - retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}), + retriever=vs_r, prompt=prompt ) knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate( @@ -140,7 +162,6 @@ class LocalDocQA: ) knowledge_chain.return_source_documents = True - result = knowledge_chain({"query": query}) self.llm.history[-1][0] = query return result, self.llm.history diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py index 7243967..1c020de 100644 --- a/models/chatglm_llm.py +++ b/models/chatglm_llm.py @@ -72,16 +72,16 @@ class ChatGLM(LLM): stream=True) -> str: if stream: self.history = self.history + [[None, ""]] - response, _ = self.model.stream_chat( + for response, history in self.model.stream_chat( self.tokenizer, prompt, history=self.history[-self.history_len:] if self.history_len > 0 else [], max_length=self.max_token, temperature=self.temperature, - ) - torch_gc() - self.history[-1][-1] = response - yield response + ): + torch_gc() + self.history[-1][-1] = response + yield response else: response, _ = self.model.chat( self.tokenizer, diff --git a/webui.py b/webui.py index a5f9973..e75a831 100644 --- a/webui.py +++ b/webui.py @@ -30,19 +30,28 @@ local_doc_qa = LocalDocQA() def get_answer(query, vs_path, history, mode): - if vs_path and mode == "知识库问答": - resp, history = local_doc_qa.get_knowledge_based_answer( - query=query, vs_path=vs_path, chat_history=history) - source = "".join([f"""
出处 {i + 1} -{doc.page_content} - -所属文件:{doc.metadata["source"]} -
""" for i, doc in enumerate(resp["source_documents"])]) - history[-1][-1] += source + if mode == "知识库问答": + if vs_path: + for resp, history in local_doc_qa.get_knowledge_based_answer( + query=query, vs_path=vs_path, chat_history=history): + # source = "".join([f"""
出处 {i + 1} + # {doc.page_content} + # + # 所属文件:{doc.metadata["source"]} + #
""" for i, doc in enumerate(resp["source_documents"])]) + # history[-1][-1] += source + yield history, "" + else: + history = history + [[query, ""]] + for resp in local_doc_qa.llm._call(query): + history[-1][-1] = resp + ( + "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "") + yield history, "" else: - resp = local_doc_qa.llm._call(query) - history = history + [[query, resp + ("\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")]] - return history, "" + history = history + [[query, ""]] + for resp in local_doc_qa.llm._call(query): + history[-1][-1] = resp + yield history, "" def update_status(history, status): @@ -62,7 +71,7 @@ def init_model(): print(e) reply = """模型未成功加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮""" if str(e) == "Unknown platform: darwin": - print("改报错可能因为您使用的是 macOS 操作系统,需先下载模型至本地后执行 Web UI,具体方法请参考项目 README 中本地部署方法及常见问题:" + print("该报错可能因为您使用的是 macOS 操作系统,需先下载模型至本地后执行 Web UI,具体方法请参考项目 README 中本地部署方法及常见问题:" " https://github.com/imClumsyPanda/langchain-ChatGLM") else: print(reply)