From 07ff81a119b6bda98c516f88ae8600acdcd26d63 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Thu, 4 May 2023 20:48:36 +0800 Subject: [PATCH] update torch_gc --- api.py | 12 +++--------- chains/local_doc_qa.py | 19 ++++++------------- cli_demo.py | 8 +------- configs/model_config.py | 10 +++++++++- models/chatglm_llm.py | 6 ++++-- utils/__init__.py | 8 ++++---- webui.py | 8 +------- 7 files changed, 28 insertions(+), 43 deletions(-) diff --git a/api.py b/api.py index 94a49d5..edb6754 100644 --- a/api.py +++ b/api.py @@ -16,16 +16,10 @@ from typing_extensions import Annotated from chains.local_doc_qa import LocalDocQA from configs.model_config import (API_UPLOAD_ROOT_PATH, EMBEDDING_DEVICE, - EMBEDDING_MODEL, LLM_MODEL) - -nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path - -# return top-k text chunk from vector store -VECTOR_SEARCH_TOP_K = 6 - -# LLM input history length -LLM_HISTORY_LEN = 3 + EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH, + VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN) +nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path class BaseResponse(BaseModel): code: int = pydantic.Field(200, description="HTTP status code") diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index b896607..6176bc2 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -10,11 +10,6 @@ from langchain.docstore.document import Document import numpy as np from utils import torch_gc -# return top-k text chunk from vector store -VECTOR_SEARCH_TOP_K = 6 - -# LLM input history length -LLM_HISTORY_LEN = 3 DEVICE_ = EMBEDDING_DEVICE DEVICE_ID = "0" if torch.cuda.is_available() else None @@ -109,7 +104,7 @@ def similarity_search_with_score_by_vector( if not isinstance(doc, Document): raise ValueError(f"Could not find document for id {_id}, got {doc}") docs.append((doc, scores[0][j])) - torch_gc(DEVICE) + torch_gc() return docs @@ -181,13 +176,13 @@ class LocalDocQA: if vs_path and os.path.isdir(vs_path): vector_store = FAISS.load_local(vs_path, self.embeddings) vector_store.add_documents(docs) - torch_gc(DEVICE) + torch_gc() else: if not vs_path: vs_path = os.path.join(VS_ROOT_PATH, f"""{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""") vector_store = FAISS.from_documents(docs, self.embeddings) - torch_gc(DEVICE) + torch_gc() vector_store.save_local(vs_path) return vs_path, loaded_files @@ -206,6 +201,7 @@ class LocalDocQA: related_docs_with_score = vector_store.similarity_search_with_score(query, k=self.top_k) related_docs = get_docs_with_score(related_docs_with_score) + torch_gc() prompt = generate_prompt(related_docs, query) # if streaming: @@ -220,11 +216,13 @@ class LocalDocQA: for result, history in self.llm._call(prompt=prompt, history=chat_history, streaming=streaming): + torch_gc() history[-1][0] = query response = {"query": query, "result": result, "source_documents": related_docs} yield response, history + torch_gc() if __name__ == "__main__": @@ -244,9 +242,4 @@ if __name__ == "__main__": for inum, doc in enumerate(resp["source_documents"])] print("\n\n" + "\n\n".join(source_text)) - # for resp, history in local_doc_qa.get_knowledge_based_answer(query=query, - # vs_path=vs_path, - # chat_history=[], - # streaming=False): - # print(resp["result"]) pass diff --git a/cli_demo.py b/cli_demo.py index 33d616d..64aaeb0 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -3,13 +3,7 @@ from chains.local_doc_qa import LocalDocQA import os import nltk -nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path - -# return top-k text chunk from vector store -VECTOR_SEARCH_TOP_K = 6 - -# LLM input history length -LLM_HISTORY_LEN = 3 +nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path # Show reply with source text from input document REPLY_WITH_SOURCE = True diff --git a/configs/model_config.py b/configs/model_config.py index bda83a2..38a8970 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -49,4 +49,12 @@ PROMPT_TEMPLATE = """已知信息: 根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 问题是:{question}""" # 匹配后单段上下文长度 -CHUNK_SIZE = 500 +CHUNK_SIZE = 250 + +# LLM input history length +LLM_HISTORY_LEN = 3 + +# return top-k text chunk from vector store +VECTOR_SEARCH_TOP_K = 5 + +NLTK_DATA_PATH = os.path.join(os.path.dirname(__file__), "nltk_data") \ No newline at end of file diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py index 2ca7790..d33b62e 100644 --- a/models/chatglm_llm.py +++ b/models/chatglm_llm.py @@ -69,12 +69,13 @@ class ChatGLM(LLM): max_length=self.max_token, temperature=self.temperature, )): - torch_gc(DEVICE) + torch_gc() if inum == 0: history += [[prompt, stream_resp]] else: history[-1] = [prompt, stream_resp] yield stream_resp, history + torch_gc() else: response, _ = self.model.chat( self.tokenizer, @@ -83,9 +84,10 @@ class ChatGLM(LLM): max_length=self.max_token, temperature=self.temperature, ) - torch_gc(DEVICE) + torch_gc() history += [[prompt, response]] yield response, history + torch_gc() # def chat(self, # prompt: str) -> str: diff --git a/utils/__init__.py b/utils/__init__.py index 4499cf3..ea1acf8 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,10 +1,10 @@ import torch -def torch_gc(DEVICE): +def torch_gc(): if torch.cuda.is_available(): - with torch.cuda.device(DEVICE): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + # with torch.cuda.device(DEVICE): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() elif torch.backends.mps.is_available(): try: from torch.mps import empty_cache diff --git a/webui.py b/webui.py index 2531059..24cc2ff 100644 --- a/webui.py +++ b/webui.py @@ -5,13 +5,7 @@ from chains.local_doc_qa import LocalDocQA from configs.model_config import * import nltk -nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path - -# return top-k text chunk from vector store -VECTOR_SEARCH_TOP_K = 6 - -# LLM input history length -LLM_HISTORY_LEN = 3 +nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path def get_vs_list():