diff --git a/README.md b/README.md index c42da6a..e4b34a5 100644 --- a/README.md +++ b/README.md @@ -14,8 +14,14 @@ ![实现原理图](img/langchain+chatglm.png) +从文档处理角度来看,实现流程如下: + +![实现原理图2](img/langchain+chatglm2.png) + 🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。 +🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/imClumsyPanda/langchain-ChatGLM/langchain-ChatGLM) + 📓 [ModelWhale 在线运行项目](https://www.heywhale.com/mw/project/643977aa446c45f4592a1e59) ## 变更日志 @@ -166,6 +172,6 @@ Web UI 可以实现如下功能: - [ ] 实现调用 API 的 Web UI Demo ## 项目交流群 -![二维码](img/qr_code_9.jpg) +![二维码](img/qr_code_10.jpg) 🎉 langchain-ChatGLM 项目交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。 diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index 84f820c..6b3d1e2 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -8,6 +8,7 @@ from textsplitter import ChineseTextSplitter from typing import List, Tuple from langchain.docstore.document import Document import numpy as np +from utils import torch_gc # return top-k text chunk from vector store VECTOR_SEARCH_TOP_K = 6 @@ -15,6 +16,10 @@ VECTOR_SEARCH_TOP_K = 6 # LLM input history length LLM_HISTORY_LEN = 3 +DEVICE_ = EMBEDDING_DEVICE +DEVICE_ID = "0" if torch.cuda.is_available() else None +DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_ + def load_file(filepath): if filepath.lower().endswith(".md"): @@ -30,6 +35,7 @@ def load_file(filepath): docs = loader.load_and_split(text_splitter=textsplitter) return docs + def generate_prompt(related_docs: List[str], query: str, prompt_template=PROMPT_TEMPLATE) -> str: @@ -39,7 +45,7 @@ def generate_prompt(related_docs: List[str], def get_docs_with_score(docs_with_score): - docs=[] + docs = [] for doc, score in docs_with_score: doc.metadata["score"] = score docs.append(doc) @@ -50,7 +56,7 @@ def seperate_list(ls: List[int]) -> List[List[int]]: lists = [] ls1 = [ls[0]] for i in range(1, len(ls)): - if ls[i-1] + 1 == ls[i]: + if ls[i - 1] + 1 == ls[i]: ls1.append(ls[i]) else: lists.append(ls1) @@ -59,49 +65,52 @@ def seperate_list(ls: List[int]) -> List[List[int]]: return lists - def similarity_search_with_score_by_vector( self, embedding: List[float], k: int = 4, - ) -> List[Tuple[Document, float]]: - scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k) - docs = [] - id_set = set() - for j, i in enumerate(indices[0]): - if i == -1: - # This happens when not enough docs are returned. - continue - _id = self.index_to_docstore_id[i] - doc = self.docstore.search(_id) - id_set.add(i) - docs_len = len(doc.page_content) - for k in range(1, max(i, len(docs)-i)): - for l in [i+k, i-k]: - if 0 <= l < len(self.index_to_docstore_id): - _id0 = self.index_to_docstore_id[l] - doc0 = self.docstore.search(_id0) - if docs_len + len(doc0.page_content) > self.chunk_size: - break - elif doc0.metadata["source"] == doc.metadata["source"]: - docs_len += len(doc0.page_content) - id_set.add(l) - id_list = sorted(list(id_set)) - id_lists = seperate_list(id_list) - for id_seq in id_lists: - for id in id_seq: - if id == id_seq[0]: - _id = self.index_to_docstore_id[id] - doc = self.docstore.search(_id) - else: - _id0 = self.index_to_docstore_id[id] +) -> List[Tuple[Document, float]]: + scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k) + docs = [] + id_set = set() + for j, i in enumerate(indices[0]): + if i == -1: + # This happens when not enough docs are returned. + continue + _id = self.index_to_docstore_id[i] + doc = self.docstore.search(_id) + id_set.add(i) + docs_len = len(doc.page_content) + for k in range(1, max(i, len(docs) - i)): + break_flag = False + for l in [i + k, i - k]: + if 0 <= l < len(self.index_to_docstore_id): + _id0 = self.index_to_docstore_id[l] doc0 = self.docstore.search(_id0) - doc.page_content += doc0.page_content - if not isinstance(doc, Document): - raise ValueError(f"Could not find document for id {_id}, got {doc}") - docs.append((doc, scores[0][j])) - return docs - + if docs_len + len(doc0.page_content) > self.chunk_size: + break_flag=True + break + elif doc0.metadata["source"] == doc.metadata["source"]: + docs_len += len(doc0.page_content) + id_set.add(l) + if break_flag: + break + id_list = sorted(list(id_set)) + id_lists = seperate_list(id_list) + for id_seq in id_lists: + for id in id_seq: + if id == id_seq[0]: + _id = self.index_to_docstore_id[id] + doc = self.docstore.search(_id) + else: + _id0 = self.index_to_docstore_id[id] + doc0 = self.docstore.search(_id0) + doc.page_content += doc0.page_content + if not isinstance(doc, Document): + raise ValueError(f"Could not find document for id {_id}, got {doc}") + docs.append((doc, scores[0][j])) + torch_gc(DEVICE) + return docs class LocalDocQA: @@ -172,10 +181,12 @@ class LocalDocQA: if vs_path and os.path.isdir(vs_path): vector_store = FAISS.load_local(vs_path, self.embeddings) vector_store.add_documents(docs) + torch_gc(DEVICE) else: if not vs_path: vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""" vector_store = FAISS.from_documents(docs, self.embeddings) + torch_gc(DEVICE) vector_store.save_local(vs_path) return vs_path, loaded_files @@ -187,29 +198,54 @@ class LocalDocQA: query, vs_path, chat_history=[], - streaming=True): - self.llm.streaming = streaming + streaming: bool = STREAMING): vector_store = FAISS.load_local(vs_path, self.embeddings) FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector - vector_store.chunk_size=self.chunk_size + vector_store.chunk_size = self.chunk_size related_docs_with_score = vector_store.similarity_search_with_score(query, k=self.top_k) related_docs = get_docs_with_score(related_docs_with_score) prompt = generate_prompt(related_docs, query) - if streaming: - for result, history in self.llm._call(prompt=prompt, - history=chat_history): - history[-1][0] = query - response = {"query": query, - "result": result, - "source_documents": related_docs} - yield response, history - else: - result, history = self.llm._call(prompt=prompt, - history=chat_history) + # if streaming: + # for result, history in self.llm._stream_call(prompt=prompt, + # history=chat_history): + # history[-1][0] = query + # response = {"query": query, + # "result": result, + # "source_documents": related_docs} + # yield response, history + # else: + for result, history in self.llm._call(prompt=prompt, + history=chat_history, + streaming=streaming): history[-1][0] = query response = {"query": query, "result": result, "source_documents": related_docs} - return response, history + yield response, history + + +if __name__ == "__main__": + local_doc_qa = LocalDocQA() + local_doc_qa.init_cfg() + query = "本项目使用的embedding模型是什么,消耗多少显存" + vs_path = "/Users/liuqian/Downloads/glm-dev/vector_store/aaa" + last_print_len = 0 + for resp, history in local_doc_qa.get_knowledge_based_answer(query=query, + vs_path=vs_path, + chat_history=[], + streaming=True): + print(resp["result"][last_print_len:], end="", flush=True) + last_print_len = len(resp["result"]) + source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" + # f"""相关度:{doc.metadata['score']}\n\n""" + for inum, doc in + enumerate(resp["source_documents"])] + print("\n\n" + "\n\n".join(source_text)) + # for resp, history in local_doc_qa.get_knowledge_based_answer(query=query, + # vs_path=vs_path, + # chat_history=[], + # streaming=False): + # print(resp["result"]) + pass diff --git a/cli_demo.py b/cli_demo.py index 232f75c..33d616d 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -32,9 +32,12 @@ if __name__ == "__main__": for resp, history in local_doc_qa.get_knowledge_based_answer(query=query, vs_path=vs_path, chat_history=history, - streaming=True): - print(resp["result"][last_print_len:], end="", flush=True) - last_print_len = len(resp["result"]) + streaming=STREAMING): + if STREAMING: + print(resp["result"][last_print_len:], end="", flush=True) + last_print_len = len(resp["result"]) + else: + print(resp["result"]) if REPLY_WITH_SOURCE: source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" # f"""相关度:{doc.metadata['score']}\n\n""" diff --git a/configs/model_config.py b/configs/model_config.py index 3147b6a..ab3e848 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -27,6 +27,9 @@ llm_model_dict = { # LLM model name LLM_MODEL = "chatglm-6b" +# LLM streaming reponse +STREAMING = True + # Use p-tuning-v2 PrefixEncoder USE_PTUNING_V2 = False @@ -38,14 +41,10 @@ VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vector_ UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content", "") # 基于上下文的prompt模版,请务必保留"{question}"和"{context}" -PROMPT_TEMPLATE = """已知信息在下方"="包裹的段落,基于以下已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。 - -====================================已知信息===================================================== +PROMPT_TEMPLATE = """已知信息: {context} -================================================================================================ -问题:"{question}" -答案:""" +根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 问题是:{question}""" # 匹配后单段上下文长度 CHUNK_SIZE = 500 \ No newline at end of file diff --git a/img/langchain+chatglm2.png b/img/langchain+chatglm2.png new file mode 100644 index 0000000..d98b810 Binary files /dev/null and b/img/langchain+chatglm2.png differ diff --git a/img/qr_code_10.jpg b/img/qr_code_10.jpg new file mode 100644 index 0000000..348c0fe Binary files /dev/null and b/img/qr_code_10.jpg differ diff --git a/img/qr_code_9.jpg b/img/qr_code_9.jpg deleted file mode 100644 index 0f3be2d..0000000 Binary files a/img/qr_code_9.jpg and /dev/null differ diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py index a0e95d9..e69ce87 100644 --- a/models/chatglm_llm.py +++ b/models/chatglm_llm.py @@ -4,21 +4,15 @@ from typing import Optional, List from langchain.llms.utils import enforce_stop_tokens from transformers import AutoTokenizer, AutoModel, AutoConfig import torch -from configs.model_config import LLM_DEVICE +from configs.model_config import * from langchain.callbacks.base import CallbackManager from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from typing import Dict, Tuple, Union, Optional +from utils import torch_gc -DEVICE = LLM_DEVICE +DEVICE_ = LLM_DEVICE DEVICE_ID = "0" if torch.cuda.is_available() else None -CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE - - -def torch_gc(): - if torch.cuda.is_available(): - with torch.cuda.device(CUDA_DEVICE): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() +DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_ def auto_configure_device_map(num_gpus: int) -> Dict[str, int]: @@ -59,7 +53,6 @@ class ChatGLM(LLM): tokenizer: object = None model: object = None history_len: int = 10 - streaming: bool = True callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) def __init__(self): @@ -72,8 +65,8 @@ class ChatGLM(LLM): def _call(self, prompt: str, history: List[List[str]] = [], - stop: Optional[List[str]] = None) -> str: - if self.streaming: + streaming: bool = STREAMING): # -> Tuple[str, List[List[str]]]: + if streaming: for inum, (stream_resp, _) in enumerate(self.model.stream_chat( self.tokenizer, prompt, @@ -81,25 +74,23 @@ class ChatGLM(LLM): max_length=self.max_token, temperature=self.temperature, )): + torch_gc(DEVICE) if inum == 0: history += [[prompt, stream_resp]] else: history[-1] = [prompt, stream_resp] yield stream_resp, history - else: response, _ = self.model.chat( - self.tokenizer, - prompt, - history=history[-self.history_len:] if self.history_len > 0 else [], - max_length=self.max_token, - temperature=self.temperature, + self.tokenizer, + prompt, + history=history[-self.history_len:] if self.history_len > 0 else [], + max_length=self.max_token, + temperature=self.temperature, ) - torch_gc() - if stop is not None: - response = enforce_stop_tokens(response, stop) - history = history + [[None, response]] - return response, history + torch_gc(DEVICE) + history += [[prompt, response]] + yield response, history # def chat(self, # prompt: str) -> str: @@ -191,3 +182,16 @@ class ChatGLM(LLM): print("加载PrefixEncoder模型参数失败") self.model = self.model.eval() + + +if __name__ == "__main__": + llm = ChatGLM() + llm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL], + llm_device=LLM_DEVICE, ) + last_print_len=0 + for resp, history in llm._call("你好", streaming=True): + print(resp[last_print_len:], end="", flush=True) + last_print_len = len(resp) + for resp, history in llm._call("你好", streaming=False): + print(resp) + pass diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..8508c7d --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,11 @@ +import torch.cuda +import torch.mps +import torch.backends + +def torch_gc(DEVICE): + if torch.cuda.is_available(): + with torch.cuda.device(DEVICE): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + elif torch.backends.mps.is_available(): + torch.mps.empty_cache() \ No newline at end of file diff --git a/webui.py b/webui.py index aaeb734..6c2a29c 100644 --- a/webui.py +++ b/webui.py @@ -29,28 +29,28 @@ llm_model_dict_list = list(llm_model_dict.keys()) local_doc_qa = LocalDocQA() -def get_answer(query, vs_path, history, mode): - if mode == "知识库问答": - if vs_path: - for resp, history in local_doc_qa.get_knowledge_based_answer( - query=query, vs_path=vs_path, chat_history=history): - source = "\n\n" - source += "".join( - [f"""
出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}\n""" - f"""{doc.page_content}\n""" - f"""
""" - for i, doc in - enumerate(resp["source_documents"])]) - history[-1][-1] += source - yield history, "" - else: - for resp, history in local_doc_qa.llm._call(query, history): - history[-1][-1] = resp + ( - "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "") - yield history, "" +def get_answer(query, vs_path, history, mode, + streaming: bool = STREAMING): + if mode == "知识库问答" and vs_path: + for resp, history in local_doc_qa.get_knowledge_based_answer( + query=query, + vs_path=vs_path, + chat_history=history, + streaming=streaming): + source = "\n\n" + source += "".join( + [f"""
出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}\n""" + f"""{doc.page_content}\n""" + f"""
""" + for i, doc in + enumerate(resp["source_documents"])]) + history[-1][-1] += source + yield history, "" else: - for resp, history in local_doc_qa.llm._call(query, history): - history[-1][-1] = resp + for resp, history in local_doc_qa.llm._call(query, history, + streaming=streaming): + history[-1][-1] = resp + ( + "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "") yield history, "" @@ -84,7 +84,7 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, to embedding_model=embedding_model, llm_history_len=llm_history_len, use_ptuning_v2=use_ptuning_v2, - top_k=top_k) + top_k=top_k,) model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话""" print(model_status) except Exception as e: