update torch_gc
This commit is contained in:
parent
b03634fb7c
commit
07ff81a119
12
api.py
12
api.py
|
|
@ -16,16 +16,10 @@ from typing_extensions import Annotated
|
||||||
|
|
||||||
from chains.local_doc_qa import LocalDocQA
|
from chains.local_doc_qa import LocalDocQA
|
||||||
from configs.model_config import (API_UPLOAD_ROOT_PATH, EMBEDDING_DEVICE,
|
from configs.model_config import (API_UPLOAD_ROOT_PATH, EMBEDDING_DEVICE,
|
||||||
EMBEDDING_MODEL, LLM_MODEL)
|
EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH,
|
||||||
|
VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN)
|
||||||
nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path
|
|
||||||
|
|
||||||
# return top-k text chunk from vector store
|
|
||||||
VECTOR_SEARCH_TOP_K = 6
|
|
||||||
|
|
||||||
# LLM input history length
|
|
||||||
LLM_HISTORY_LEN = 3
|
|
||||||
|
|
||||||
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||||
|
|
||||||
class BaseResponse(BaseModel):
|
class BaseResponse(BaseModel):
|
||||||
code: int = pydantic.Field(200, description="HTTP status code")
|
code: int = pydantic.Field(200, description="HTTP status code")
|
||||||
|
|
|
||||||
|
|
@ -10,11 +10,6 @@ from langchain.docstore.document import Document
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from utils import torch_gc
|
from utils import torch_gc
|
||||||
|
|
||||||
# return top-k text chunk from vector store
|
|
||||||
VECTOR_SEARCH_TOP_K = 6
|
|
||||||
|
|
||||||
# LLM input history length
|
|
||||||
LLM_HISTORY_LEN = 3
|
|
||||||
|
|
||||||
DEVICE_ = EMBEDDING_DEVICE
|
DEVICE_ = EMBEDDING_DEVICE
|
||||||
DEVICE_ID = "0" if torch.cuda.is_available() else None
|
DEVICE_ID = "0" if torch.cuda.is_available() else None
|
||||||
|
|
@ -109,7 +104,7 @@ def similarity_search_with_score_by_vector(
|
||||||
if not isinstance(doc, Document):
|
if not isinstance(doc, Document):
|
||||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||||
docs.append((doc, scores[0][j]))
|
docs.append((doc, scores[0][j]))
|
||||||
torch_gc(DEVICE)
|
torch_gc()
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -181,13 +176,13 @@ class LocalDocQA:
|
||||||
if vs_path and os.path.isdir(vs_path):
|
if vs_path and os.path.isdir(vs_path):
|
||||||
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
||||||
vector_store.add_documents(docs)
|
vector_store.add_documents(docs)
|
||||||
torch_gc(DEVICE)
|
torch_gc()
|
||||||
else:
|
else:
|
||||||
if not vs_path:
|
if not vs_path:
|
||||||
vs_path = os.path.join(VS_ROOT_PATH,
|
vs_path = os.path.join(VS_ROOT_PATH,
|
||||||
f"""{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""")
|
f"""{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""")
|
||||||
vector_store = FAISS.from_documents(docs, self.embeddings)
|
vector_store = FAISS.from_documents(docs, self.embeddings)
|
||||||
torch_gc(DEVICE)
|
torch_gc()
|
||||||
|
|
||||||
vector_store.save_local(vs_path)
|
vector_store.save_local(vs_path)
|
||||||
return vs_path, loaded_files
|
return vs_path, loaded_files
|
||||||
|
|
@ -206,6 +201,7 @@ class LocalDocQA:
|
||||||
related_docs_with_score = vector_store.similarity_search_with_score(query,
|
related_docs_with_score = vector_store.similarity_search_with_score(query,
|
||||||
k=self.top_k)
|
k=self.top_k)
|
||||||
related_docs = get_docs_with_score(related_docs_with_score)
|
related_docs = get_docs_with_score(related_docs_with_score)
|
||||||
|
torch_gc()
|
||||||
prompt = generate_prompt(related_docs, query)
|
prompt = generate_prompt(related_docs, query)
|
||||||
|
|
||||||
# if streaming:
|
# if streaming:
|
||||||
|
|
@ -220,11 +216,13 @@ class LocalDocQA:
|
||||||
for result, history in self.llm._call(prompt=prompt,
|
for result, history in self.llm._call(prompt=prompt,
|
||||||
history=chat_history,
|
history=chat_history,
|
||||||
streaming=streaming):
|
streaming=streaming):
|
||||||
|
torch_gc()
|
||||||
history[-1][0] = query
|
history[-1][0] = query
|
||||||
response = {"query": query,
|
response = {"query": query,
|
||||||
"result": result,
|
"result": result,
|
||||||
"source_documents": related_docs}
|
"source_documents": related_docs}
|
||||||
yield response, history
|
yield response, history
|
||||||
|
torch_gc()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
@ -244,9 +242,4 @@ if __name__ == "__main__":
|
||||||
for inum, doc in
|
for inum, doc in
|
||||||
enumerate(resp["source_documents"])]
|
enumerate(resp["source_documents"])]
|
||||||
print("\n\n" + "\n\n".join(source_text))
|
print("\n\n" + "\n\n".join(source_text))
|
||||||
# for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
|
|
||||||
# vs_path=vs_path,
|
|
||||||
# chat_history=[],
|
|
||||||
# streaming=False):
|
|
||||||
# print(resp["result"])
|
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -3,13 +3,7 @@ from chains.local_doc_qa import LocalDocQA
|
||||||
import os
|
import os
|
||||||
import nltk
|
import nltk
|
||||||
|
|
||||||
nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||||
|
|
||||||
# return top-k text chunk from vector store
|
|
||||||
VECTOR_SEARCH_TOP_K = 6
|
|
||||||
|
|
||||||
# LLM input history length
|
|
||||||
LLM_HISTORY_LEN = 3
|
|
||||||
|
|
||||||
# Show reply with source text from input document
|
# Show reply with source text from input document
|
||||||
REPLY_WITH_SOURCE = True
|
REPLY_WITH_SOURCE = True
|
||||||
|
|
|
||||||
|
|
@ -49,4 +49,12 @@ PROMPT_TEMPLATE = """已知信息:
|
||||||
根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 问题是:{question}"""
|
根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 问题是:{question}"""
|
||||||
|
|
||||||
# 匹配后单段上下文长度
|
# 匹配后单段上下文长度
|
||||||
CHUNK_SIZE = 500
|
CHUNK_SIZE = 250
|
||||||
|
|
||||||
|
# LLM input history length
|
||||||
|
LLM_HISTORY_LEN = 3
|
||||||
|
|
||||||
|
# return top-k text chunk from vector store
|
||||||
|
VECTOR_SEARCH_TOP_K = 5
|
||||||
|
|
||||||
|
NLTK_DATA_PATH = os.path.join(os.path.dirname(__file__), "nltk_data")
|
||||||
|
|
@ -69,12 +69,13 @@ class ChatGLM(LLM):
|
||||||
max_length=self.max_token,
|
max_length=self.max_token,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
)):
|
)):
|
||||||
torch_gc(DEVICE)
|
torch_gc()
|
||||||
if inum == 0:
|
if inum == 0:
|
||||||
history += [[prompt, stream_resp]]
|
history += [[prompt, stream_resp]]
|
||||||
else:
|
else:
|
||||||
history[-1] = [prompt, stream_resp]
|
history[-1] = [prompt, stream_resp]
|
||||||
yield stream_resp, history
|
yield stream_resp, history
|
||||||
|
torch_gc()
|
||||||
else:
|
else:
|
||||||
response, _ = self.model.chat(
|
response, _ = self.model.chat(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
|
|
@ -83,9 +84,10 @@ class ChatGLM(LLM):
|
||||||
max_length=self.max_token,
|
max_length=self.max_token,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
)
|
)
|
||||||
torch_gc(DEVICE)
|
torch_gc()
|
||||||
history += [[prompt, response]]
|
history += [[prompt, response]]
|
||||||
yield response, history
|
yield response, history
|
||||||
|
torch_gc()
|
||||||
|
|
||||||
# def chat(self,
|
# def chat(self,
|
||||||
# prompt: str) -> str:
|
# prompt: str) -> str:
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
def torch_gc(DEVICE):
|
def torch_gc():
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
with torch.cuda.device(DEVICE):
|
# with torch.cuda.device(DEVICE):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
try:
|
try:
|
||||||
from torch.mps import empty_cache
|
from torch.mps import empty_cache
|
||||||
|
|
|
||||||
8
webui.py
8
webui.py
|
|
@ -5,13 +5,7 @@ from chains.local_doc_qa import LocalDocQA
|
||||||
from configs.model_config import *
|
from configs.model_config import *
|
||||||
import nltk
|
import nltk
|
||||||
|
|
||||||
nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||||
|
|
||||||
# return top-k text chunk from vector store
|
|
||||||
VECTOR_SEARCH_TOP_K = 6
|
|
||||||
|
|
||||||
# LLM input history length
|
|
||||||
LLM_HISTORY_LEN = 3
|
|
||||||
|
|
||||||
|
|
||||||
def get_vs_list():
|
def get_vs_list():
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue