update torch_gc
This commit is contained in:
parent
b03634fb7c
commit
07ff81a119
12
api.py
12
api.py
|
|
@ -16,16 +16,10 @@ from typing_extensions import Annotated
|
|||
|
||||
from chains.local_doc_qa import LocalDocQA
|
||||
from configs.model_config import (API_UPLOAD_ROOT_PATH, EMBEDDING_DEVICE,
|
||||
EMBEDDING_MODEL, LLM_MODEL)
|
||||
|
||||
nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path
|
||||
|
||||
# return top-k text chunk from vector store
|
||||
VECTOR_SEARCH_TOP_K = 6
|
||||
|
||||
# LLM input history length
|
||||
LLM_HISTORY_LEN = 3
|
||||
EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH,
|
||||
VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN)
|
||||
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
code: int = pydantic.Field(200, description="HTTP status code")
|
||||
|
|
|
|||
|
|
@ -10,11 +10,6 @@ from langchain.docstore.document import Document
|
|||
import numpy as np
|
||||
from utils import torch_gc
|
||||
|
||||
# return top-k text chunk from vector store
|
||||
VECTOR_SEARCH_TOP_K = 6
|
||||
|
||||
# LLM input history length
|
||||
LLM_HISTORY_LEN = 3
|
||||
|
||||
DEVICE_ = EMBEDDING_DEVICE
|
||||
DEVICE_ID = "0" if torch.cuda.is_available() else None
|
||||
|
|
@ -109,7 +104,7 @@ def similarity_search_with_score_by_vector(
|
|||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||
docs.append((doc, scores[0][j]))
|
||||
torch_gc(DEVICE)
|
||||
torch_gc()
|
||||
return docs
|
||||
|
||||
|
||||
|
|
@ -181,13 +176,13 @@ class LocalDocQA:
|
|||
if vs_path and os.path.isdir(vs_path):
|
||||
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
||||
vector_store.add_documents(docs)
|
||||
torch_gc(DEVICE)
|
||||
torch_gc()
|
||||
else:
|
||||
if not vs_path:
|
||||
vs_path = os.path.join(VS_ROOT_PATH,
|
||||
f"""{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""")
|
||||
vector_store = FAISS.from_documents(docs, self.embeddings)
|
||||
torch_gc(DEVICE)
|
||||
torch_gc()
|
||||
|
||||
vector_store.save_local(vs_path)
|
||||
return vs_path, loaded_files
|
||||
|
|
@ -206,6 +201,7 @@ class LocalDocQA:
|
|||
related_docs_with_score = vector_store.similarity_search_with_score(query,
|
||||
k=self.top_k)
|
||||
related_docs = get_docs_with_score(related_docs_with_score)
|
||||
torch_gc()
|
||||
prompt = generate_prompt(related_docs, query)
|
||||
|
||||
# if streaming:
|
||||
|
|
@ -220,11 +216,13 @@ class LocalDocQA:
|
|||
for result, history in self.llm._call(prompt=prompt,
|
||||
history=chat_history,
|
||||
streaming=streaming):
|
||||
torch_gc()
|
||||
history[-1][0] = query
|
||||
response = {"query": query,
|
||||
"result": result,
|
||||
"source_documents": related_docs}
|
||||
yield response, history
|
||||
torch_gc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -244,9 +242,4 @@ if __name__ == "__main__":
|
|||
for inum, doc in
|
||||
enumerate(resp["source_documents"])]
|
||||
print("\n\n" + "\n\n".join(source_text))
|
||||
# for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
|
||||
# vs_path=vs_path,
|
||||
# chat_history=[],
|
||||
# streaming=False):
|
||||
# print(resp["result"])
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -3,13 +3,7 @@ from chains.local_doc_qa import LocalDocQA
|
|||
import os
|
||||
import nltk
|
||||
|
||||
nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path
|
||||
|
||||
# return top-k text chunk from vector store
|
||||
VECTOR_SEARCH_TOP_K = 6
|
||||
|
||||
# LLM input history length
|
||||
LLM_HISTORY_LEN = 3
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
|
||||
# Show reply with source text from input document
|
||||
REPLY_WITH_SOURCE = True
|
||||
|
|
|
|||
|
|
@ -49,4 +49,12 @@ PROMPT_TEMPLATE = """已知信息:
|
|||
根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 问题是:{question}"""
|
||||
|
||||
# 匹配后单段上下文长度
|
||||
CHUNK_SIZE = 500
|
||||
CHUNK_SIZE = 250
|
||||
|
||||
# LLM input history length
|
||||
LLM_HISTORY_LEN = 3
|
||||
|
||||
# return top-k text chunk from vector store
|
||||
VECTOR_SEARCH_TOP_K = 5
|
||||
|
||||
NLTK_DATA_PATH = os.path.join(os.path.dirname(__file__), "nltk_data")
|
||||
|
|
@ -69,12 +69,13 @@ class ChatGLM(LLM):
|
|||
max_length=self.max_token,
|
||||
temperature=self.temperature,
|
||||
)):
|
||||
torch_gc(DEVICE)
|
||||
torch_gc()
|
||||
if inum == 0:
|
||||
history += [[prompt, stream_resp]]
|
||||
else:
|
||||
history[-1] = [prompt, stream_resp]
|
||||
yield stream_resp, history
|
||||
torch_gc()
|
||||
else:
|
||||
response, _ = self.model.chat(
|
||||
self.tokenizer,
|
||||
|
|
@ -83,9 +84,10 @@ class ChatGLM(LLM):
|
|||
max_length=self.max_token,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
torch_gc(DEVICE)
|
||||
torch_gc()
|
||||
history += [[prompt, response]]
|
||||
yield response, history
|
||||
torch_gc()
|
||||
|
||||
# def chat(self,
|
||||
# prompt: str) -> str:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import torch
|
||||
|
||||
def torch_gc(DEVICE):
|
||||
def torch_gc():
|
||||
if torch.cuda.is_available():
|
||||
with torch.cuda.device(DEVICE):
|
||||
# with torch.cuda.device(DEVICE):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
elif torch.backends.mps.is_available():
|
||||
|
|
|
|||
8
webui.py
8
webui.py
|
|
@ -5,13 +5,7 @@ from chains.local_doc_qa import LocalDocQA
|
|||
from configs.model_config import *
|
||||
import nltk
|
||||
|
||||
nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path
|
||||
|
||||
# return top-k text chunk from vector store
|
||||
VECTOR_SEARCH_TOP_K = 6
|
||||
|
||||
# LLM input history length
|
||||
LLM_HISTORY_LEN = 3
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
|
||||
|
||||
def get_vs_list():
|
||||
|
|
|
|||
Loading…
Reference in New Issue