diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index 97c4e65..b3eca15 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -8,6 +8,7 @@ from textsplitter import ChineseTextSplitter from typing import List, Tuple from langchain.docstore.document import Document import numpy as np +from utils import torch_gc # return top-k text chunk from vector store VECTOR_SEARCH_TOP_K = 6 @@ -15,6 +16,10 @@ VECTOR_SEARCH_TOP_K = 6 # LLM input history length LLM_HISTORY_LEN = 3 +DEVICE_ = EMBEDDING_DEVICE +DEVICE_ID = "0" if torch.cuda.is_available() else None +DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_ + def load_file(filepath): if filepath.lower().endswith(".md"): @@ -30,6 +35,7 @@ def load_file(filepath): docs = loader.load_and_split(text_splitter=textsplitter) return docs + def generate_prompt(related_docs: List[str], query: str, prompt_template=PROMPT_TEMPLATE) -> str: @@ -39,7 +45,7 @@ def generate_prompt(related_docs: List[str], def get_docs_with_score(docs_with_score): - docs=[] + docs = [] for doc, score in docs_with_score: doc.metadata["score"] = score docs.append(doc) @@ -50,7 +56,7 @@ def seperate_list(ls: List[int]) -> List[List[int]]: lists = [] ls1 = [ls[0]] for i in range(1, len(ls)): - if ls[i-1] + 1 == ls[i]: + if ls[i - 1] + 1 == ls[i]: ls1.append(ls[i]) else: lists.append(ls1) @@ -59,49 +65,48 @@ def seperate_list(ls: List[int]) -> List[List[int]]: return lists - def similarity_search_with_score_by_vector( self, embedding: List[float], k: int = 4, - ) -> List[Tuple[Document, float]]: - scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k) - docs = [] - id_set = set() - for j, i in enumerate(indices[0]): - if i == -1: - # This happens when not enough docs are returned. - continue - _id = self.index_to_docstore_id[i] - doc = self.docstore.search(_id) - id_set.add(i) - docs_len = len(doc.page_content) - for k in range(1, max(i, len(docs)-i)): - for l in [i+k, i-k]: - if 0 <= l < len(self.index_to_docstore_id): - _id0 = self.index_to_docstore_id[l] - doc0 = self.docstore.search(_id0) - if docs_len + len(doc0.page_content) > self.chunk_size: - break - elif doc0.metadata["source"] == doc.metadata["source"]: - docs_len += len(doc0.page_content) - id_set.add(l) - id_list = sorted(list(id_set)) - id_lists = seperate_list(id_list) - for id_seq in id_lists: - for id in id_seq: - if id == id_seq[0]: - _id = self.index_to_docstore_id[id] - doc = self.docstore.search(_id) - else: - _id0 = self.index_to_docstore_id[id] +) -> List[Tuple[Document, float]]: + scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k) + docs = [] + id_set = set() + for j, i in enumerate(indices[0]): + if i == -1: + # This happens when not enough docs are returned. + continue + _id = self.index_to_docstore_id[i] + doc = self.docstore.search(_id) + id_set.add(i) + docs_len = len(doc.page_content) + for k in range(1, max(i, len(docs) - i)): + for l in [i + k, i - k]: + if 0 <= l < len(self.index_to_docstore_id): + _id0 = self.index_to_docstore_id[l] doc0 = self.docstore.search(_id0) - doc.page_content += doc0.page_content - if not isinstance(doc, Document): - raise ValueError(f"Could not find document for id {_id}, got {doc}") - docs.append((doc, scores[0][j])) - return docs - + if docs_len + len(doc0.page_content) > self.chunk_size: + break + elif doc0.metadata["source"] == doc.metadata["source"]: + docs_len += len(doc0.page_content) + id_set.add(l) + id_list = sorted(list(id_set)) + id_lists = seperate_list(id_list) + for id_seq in id_lists: + for id in id_seq: + if id == id_seq[0]: + _id = self.index_to_docstore_id[id] + doc = self.docstore.search(_id) + else: + _id0 = self.index_to_docstore_id[id] + doc0 = self.docstore.search(_id0) + doc.page_content += doc0.page_content + if not isinstance(doc, Document): + raise ValueError(f"Could not find document for id {_id}, got {doc}") + docs.append((doc, scores[0][j])) + torch_gc(DEVICE) + return docs class LocalDocQA: @@ -116,12 +121,10 @@ class LocalDocQA: llm_history_len: int = LLM_HISTORY_LEN, llm_model: str = LLM_MODEL, llm_device=LLM_DEVICE, - streaming=STREAMING, top_k=VECTOR_SEARCH_TOP_K, use_ptuning_v2: bool = USE_PTUNING_V2 ): self.llm = ChatGLM() - self.llm.streaming = streaming self.llm.load_model(model_name_or_path=llm_model_dict[llm_model], llm_device=llm_device, use_ptuning_v2=use_ptuning_v2) @@ -174,10 +177,12 @@ class LocalDocQA: if vs_path and os.path.isdir(vs_path): vector_store = FAISS.load_local(vs_path, self.embeddings) vector_store.add_documents(docs) + torch_gc(DEVICE) else: if not vs_path: vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""" vector_store = FAISS.from_documents(docs, self.embeddings) + torch_gc(DEVICE) vector_store.save_local(vs_path) return vs_path, loaded_files @@ -188,28 +193,50 @@ class LocalDocQA: def get_knowledge_based_answer(self, query, vs_path, - chat_history=[]): + chat_history=[], + streaming: bool = STREAMING): vector_store = FAISS.load_local(vs_path, self.embeddings) FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector - vector_store.chunk_size=self.chunk_size + vector_store.chunk_size = self.chunk_size related_docs_with_score = vector_store.similarity_search_with_score(query, k=self.top_k) related_docs = get_docs_with_score(related_docs_with_score) prompt = generate_prompt(related_docs, query) - if self.llm.streaming: - for result, history in self.llm._call(prompt=prompt, - history=chat_history): - history[-1][0] = query - response = {"query": query, - "result": result, - "source_documents": related_docs} - yield response, history - else: - result, history = self.llm._call(prompt=prompt, - history=chat_history) + # if streaming: + # for result, history in self.llm._stream_call(prompt=prompt, + # history=chat_history): + # history[-1][0] = query + # response = {"query": query, + # "result": result, + # "source_documents": related_docs} + # yield response, history + # else: + for result, history in self.llm._call(prompt=prompt, + history=chat_history, + streaming=streaming): history[-1][0] = query response = {"query": query, "result": result, "source_documents": related_docs} - return response, history + yield response, history + + +if __name__ == "__main__": + local_doc_qa = LocalDocQA() + local_doc_qa.init_cfg() + query = "你好" + vs_path = "/Users/liuqian/Downloads/glm-dev/vector_store/123" + last_print_len = 0 + for resp, history in local_doc_qa.get_knowledge_based_answer(query=query, + vs_path=vs_path, + chat_history=[], + streaming=True): + print(resp["result"][last_print_len:], end="", flush=True) + last_print_len = len(resp["result"]) + for resp, history in local_doc_qa.get_knowledge_based_answer(query=query, + vs_path=vs_path, + chat_history=[], + streaming=False): + print(resp["result"]) + pass diff --git a/cli_demo.py b/cli_demo.py index 232f75c..33d616d 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -32,9 +32,12 @@ if __name__ == "__main__": for resp, history in local_doc_qa.get_knowledge_based_answer(query=query, vs_path=vs_path, chat_history=history, - streaming=True): - print(resp["result"][last_print_len:], end="", flush=True) - last_print_len = len(resp["result"]) + streaming=STREAMING): + if STREAMING: + print(resp["result"][last_print_len:], end="", flush=True) + last_print_len = len(resp["result"]) + else: + print(resp["result"]) if REPLY_WITH_SOURCE: source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" # f"""相关度:{doc.metadata['score']}\n\n""" diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py index a0e95d9..e69ce87 100644 --- a/models/chatglm_llm.py +++ b/models/chatglm_llm.py @@ -4,21 +4,15 @@ from typing import Optional, List from langchain.llms.utils import enforce_stop_tokens from transformers import AutoTokenizer, AutoModel, AutoConfig import torch -from configs.model_config import LLM_DEVICE +from configs.model_config import * from langchain.callbacks.base import CallbackManager from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from typing import Dict, Tuple, Union, Optional +from utils import torch_gc -DEVICE = LLM_DEVICE +DEVICE_ = LLM_DEVICE DEVICE_ID = "0" if torch.cuda.is_available() else None -CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE - - -def torch_gc(): - if torch.cuda.is_available(): - with torch.cuda.device(CUDA_DEVICE): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() +DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_ def auto_configure_device_map(num_gpus: int) -> Dict[str, int]: @@ -59,7 +53,6 @@ class ChatGLM(LLM): tokenizer: object = None model: object = None history_len: int = 10 - streaming: bool = True callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) def __init__(self): @@ -72,8 +65,8 @@ class ChatGLM(LLM): def _call(self, prompt: str, history: List[List[str]] = [], - stop: Optional[List[str]] = None) -> str: - if self.streaming: + streaming: bool = STREAMING): # -> Tuple[str, List[List[str]]]: + if streaming: for inum, (stream_resp, _) in enumerate(self.model.stream_chat( self.tokenizer, prompt, @@ -81,25 +74,23 @@ class ChatGLM(LLM): max_length=self.max_token, temperature=self.temperature, )): + torch_gc(DEVICE) if inum == 0: history += [[prompt, stream_resp]] else: history[-1] = [prompt, stream_resp] yield stream_resp, history - else: response, _ = self.model.chat( - self.tokenizer, - prompt, - history=history[-self.history_len:] if self.history_len > 0 else [], - max_length=self.max_token, - temperature=self.temperature, + self.tokenizer, + prompt, + history=history[-self.history_len:] if self.history_len > 0 else [], + max_length=self.max_token, + temperature=self.temperature, ) - torch_gc() - if stop is not None: - response = enforce_stop_tokens(response, stop) - history = history + [[None, response]] - return response, history + torch_gc(DEVICE) + history += [[prompt, response]] + yield response, history # def chat(self, # prompt: str) -> str: @@ -191,3 +182,16 @@ class ChatGLM(LLM): print("加载PrefixEncoder模型参数失败") self.model = self.model.eval() + + +if __name__ == "__main__": + llm = ChatGLM() + llm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL], + llm_device=LLM_DEVICE, ) + last_print_len=0 + for resp, history in llm._call("你好", streaming=True): + print(resp[last_print_len:], end="", flush=True) + last_print_len = len(resp) + for resp, history in llm._call("你好", streaming=False): + print(resp) + pass diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..8508c7d --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,11 @@ +import torch.cuda +import torch.mps +import torch.backends + +def torch_gc(DEVICE): + if torch.cuda.is_available(): + with torch.cuda.device(DEVICE): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + elif torch.backends.mps.is_available(): + torch.mps.empty_cache() \ No newline at end of file diff --git a/webui.py b/webui.py index 4409d07..6c2a29c 100644 --- a/webui.py +++ b/webui.py @@ -29,23 +29,14 @@ llm_model_dict_list = list(llm_model_dict.keys()) local_doc_qa = LocalDocQA() -def get_answer(query, vs_path, history, mode): +def get_answer(query, vs_path, history, mode, + streaming: bool = STREAMING): if mode == "知识库问答" and vs_path: - if local_doc_qa.llm.streaming: - for resp, history in local_doc_qa.get_knowledge_based_answer( - query=query, vs_path=vs_path, chat_history=history): - source = "\n\n" - source += "".join( - [f"""
出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}\n""" - f"""{doc.page_content}\n""" - f"""
""" - for i, doc in - enumerate(resp["source_documents"])]) - history[-1][-1] += source - yield history, "" - else: - resp, history = local_doc_qa.get_knowledge_based_answer( - query=query, vs_path=vs_path, chat_history=history) + for resp, history in local_doc_qa.get_knowledge_based_answer( + query=query, + vs_path=vs_path, + chat_history=history, + streaming=streaming): source = "\n\n" source += "".join( [f"""
出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}\n""" @@ -54,18 +45,13 @@ def get_answer(query, vs_path, history, mode): for i, doc in enumerate(resp["source_documents"])]) history[-1][-1] += source - return history, "" + yield history, "" else: - if local_doc_qa.llm.streaming: - for resp, history in local_doc_qa.llm._call(query, history): - history[-1][-1] = resp + ( - "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "") - yield history, "" - else: - resp, history = local_doc_qa.llm._call(query, history) + for resp, history in local_doc_qa.llm._call(query, history, + streaming=streaming): history[-1][-1] = resp + ( "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "") - return history, "" + yield history, "" def update_status(history, status): @@ -76,7 +62,7 @@ def update_status(history, status): def init_model(): try: - local_doc_qa.init_cfg(streaming=STREAMING) + local_doc_qa.init_cfg() local_doc_qa.llm._call("你好") reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话""" print(reply) @@ -98,8 +84,7 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, to embedding_model=embedding_model, llm_history_len=llm_history_len, use_ptuning_v2=use_ptuning_v2, - top_k=top_k, - streaming=STREAMING) + top_k=top_k,) model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话""" print(model_status) except Exception as e: