add streaming option in configs/model_config.py
This commit is contained in:
parent
0e8cc0d16c
commit
4df9d76f8a
|
|
@ -8,6 +8,7 @@ from textsplitter import ChineseTextSplitter
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from utils import torch_gc
|
||||||
|
|
||||||
# return top-k text chunk from vector store
|
# return top-k text chunk from vector store
|
||||||
VECTOR_SEARCH_TOP_K = 6
|
VECTOR_SEARCH_TOP_K = 6
|
||||||
|
|
@ -15,6 +16,10 @@ VECTOR_SEARCH_TOP_K = 6
|
||||||
# LLM input history length
|
# LLM input history length
|
||||||
LLM_HISTORY_LEN = 3
|
LLM_HISTORY_LEN = 3
|
||||||
|
|
||||||
|
DEVICE_ = EMBEDDING_DEVICE
|
||||||
|
DEVICE_ID = "0" if torch.cuda.is_available() else None
|
||||||
|
DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_
|
||||||
|
|
||||||
|
|
||||||
def load_file(filepath):
|
def load_file(filepath):
|
||||||
if filepath.lower().endswith(".md"):
|
if filepath.lower().endswith(".md"):
|
||||||
|
|
@ -30,6 +35,7 @@ def load_file(filepath):
|
||||||
docs = loader.load_and_split(text_splitter=textsplitter)
|
docs = loader.load_and_split(text_splitter=textsplitter)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
||||||
def generate_prompt(related_docs: List[str],
|
def generate_prompt(related_docs: List[str],
|
||||||
query: str,
|
query: str,
|
||||||
prompt_template=PROMPT_TEMPLATE) -> str:
|
prompt_template=PROMPT_TEMPLATE) -> str:
|
||||||
|
|
@ -39,7 +45,7 @@ def generate_prompt(related_docs: List[str],
|
||||||
|
|
||||||
|
|
||||||
def get_docs_with_score(docs_with_score):
|
def get_docs_with_score(docs_with_score):
|
||||||
docs=[]
|
docs = []
|
||||||
for doc, score in docs_with_score:
|
for doc, score in docs_with_score:
|
||||||
doc.metadata["score"] = score
|
doc.metadata["score"] = score
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
|
|
@ -50,7 +56,7 @@ def seperate_list(ls: List[int]) -> List[List[int]]:
|
||||||
lists = []
|
lists = []
|
||||||
ls1 = [ls[0]]
|
ls1 = [ls[0]]
|
||||||
for i in range(1, len(ls)):
|
for i in range(1, len(ls)):
|
||||||
if ls[i-1] + 1 == ls[i]:
|
if ls[i - 1] + 1 == ls[i]:
|
||||||
ls1.append(ls[i])
|
ls1.append(ls[i])
|
||||||
else:
|
else:
|
||||||
lists.append(ls1)
|
lists.append(ls1)
|
||||||
|
|
@ -59,49 +65,48 @@ def seperate_list(ls: List[int]) -> List[List[int]]:
|
||||||
return lists
|
return lists
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def similarity_search_with_score_by_vector(
|
def similarity_search_with_score_by_vector(
|
||||||
self,
|
self,
|
||||||
embedding: List[float],
|
embedding: List[float],
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
) -> List[Tuple[Document, float]]:
|
) -> List[Tuple[Document, float]]:
|
||||||
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
|
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
|
||||||
docs = []
|
docs = []
|
||||||
id_set = set()
|
id_set = set()
|
||||||
for j, i in enumerate(indices[0]):
|
for j, i in enumerate(indices[0]):
|
||||||
if i == -1:
|
if i == -1:
|
||||||
# This happens when not enough docs are returned.
|
# This happens when not enough docs are returned.
|
||||||
continue
|
continue
|
||||||
_id = self.index_to_docstore_id[i]
|
_id = self.index_to_docstore_id[i]
|
||||||
doc = self.docstore.search(_id)
|
doc = self.docstore.search(_id)
|
||||||
id_set.add(i)
|
id_set.add(i)
|
||||||
docs_len = len(doc.page_content)
|
docs_len = len(doc.page_content)
|
||||||
for k in range(1, max(i, len(docs)-i)):
|
for k in range(1, max(i, len(docs) - i)):
|
||||||
for l in [i+k, i-k]:
|
for l in [i + k, i - k]:
|
||||||
if 0 <= l < len(self.index_to_docstore_id):
|
if 0 <= l < len(self.index_to_docstore_id):
|
||||||
_id0 = self.index_to_docstore_id[l]
|
_id0 = self.index_to_docstore_id[l]
|
||||||
doc0 = self.docstore.search(_id0)
|
|
||||||
if docs_len + len(doc0.page_content) > self.chunk_size:
|
|
||||||
break
|
|
||||||
elif doc0.metadata["source"] == doc.metadata["source"]:
|
|
||||||
docs_len += len(doc0.page_content)
|
|
||||||
id_set.add(l)
|
|
||||||
id_list = sorted(list(id_set))
|
|
||||||
id_lists = seperate_list(id_list)
|
|
||||||
for id_seq in id_lists:
|
|
||||||
for id in id_seq:
|
|
||||||
if id == id_seq[0]:
|
|
||||||
_id = self.index_to_docstore_id[id]
|
|
||||||
doc = self.docstore.search(_id)
|
|
||||||
else:
|
|
||||||
_id0 = self.index_to_docstore_id[id]
|
|
||||||
doc0 = self.docstore.search(_id0)
|
doc0 = self.docstore.search(_id0)
|
||||||
doc.page_content += doc0.page_content
|
if docs_len + len(doc0.page_content) > self.chunk_size:
|
||||||
if not isinstance(doc, Document):
|
break
|
||||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
elif doc0.metadata["source"] == doc.metadata["source"]:
|
||||||
docs.append((doc, scores[0][j]))
|
docs_len += len(doc0.page_content)
|
||||||
return docs
|
id_set.add(l)
|
||||||
|
id_list = sorted(list(id_set))
|
||||||
|
id_lists = seperate_list(id_list)
|
||||||
|
for id_seq in id_lists:
|
||||||
|
for id in id_seq:
|
||||||
|
if id == id_seq[0]:
|
||||||
|
_id = self.index_to_docstore_id[id]
|
||||||
|
doc = self.docstore.search(_id)
|
||||||
|
else:
|
||||||
|
_id0 = self.index_to_docstore_id[id]
|
||||||
|
doc0 = self.docstore.search(_id0)
|
||||||
|
doc.page_content += doc0.page_content
|
||||||
|
if not isinstance(doc, Document):
|
||||||
|
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||||
|
docs.append((doc, scores[0][j]))
|
||||||
|
torch_gc(DEVICE)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
class LocalDocQA:
|
class LocalDocQA:
|
||||||
|
|
@ -116,12 +121,10 @@ class LocalDocQA:
|
||||||
llm_history_len: int = LLM_HISTORY_LEN,
|
llm_history_len: int = LLM_HISTORY_LEN,
|
||||||
llm_model: str = LLM_MODEL,
|
llm_model: str = LLM_MODEL,
|
||||||
llm_device=LLM_DEVICE,
|
llm_device=LLM_DEVICE,
|
||||||
streaming=STREAMING,
|
|
||||||
top_k=VECTOR_SEARCH_TOP_K,
|
top_k=VECTOR_SEARCH_TOP_K,
|
||||||
use_ptuning_v2: bool = USE_PTUNING_V2
|
use_ptuning_v2: bool = USE_PTUNING_V2
|
||||||
):
|
):
|
||||||
self.llm = ChatGLM()
|
self.llm = ChatGLM()
|
||||||
self.llm.streaming = streaming
|
|
||||||
self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
|
self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
|
||||||
llm_device=llm_device,
|
llm_device=llm_device,
|
||||||
use_ptuning_v2=use_ptuning_v2)
|
use_ptuning_v2=use_ptuning_v2)
|
||||||
|
|
@ -174,10 +177,12 @@ class LocalDocQA:
|
||||||
if vs_path and os.path.isdir(vs_path):
|
if vs_path and os.path.isdir(vs_path):
|
||||||
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
||||||
vector_store.add_documents(docs)
|
vector_store.add_documents(docs)
|
||||||
|
torch_gc(DEVICE)
|
||||||
else:
|
else:
|
||||||
if not vs_path:
|
if not vs_path:
|
||||||
vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
|
vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
|
||||||
vector_store = FAISS.from_documents(docs, self.embeddings)
|
vector_store = FAISS.from_documents(docs, self.embeddings)
|
||||||
|
torch_gc(DEVICE)
|
||||||
|
|
||||||
vector_store.save_local(vs_path)
|
vector_store.save_local(vs_path)
|
||||||
return vs_path, loaded_files
|
return vs_path, loaded_files
|
||||||
|
|
@ -188,28 +193,50 @@ class LocalDocQA:
|
||||||
def get_knowledge_based_answer(self,
|
def get_knowledge_based_answer(self,
|
||||||
query,
|
query,
|
||||||
vs_path,
|
vs_path,
|
||||||
chat_history=[]):
|
chat_history=[],
|
||||||
|
streaming: bool = STREAMING):
|
||||||
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
||||||
FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector
|
FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector
|
||||||
vector_store.chunk_size=self.chunk_size
|
vector_store.chunk_size = self.chunk_size
|
||||||
related_docs_with_score = vector_store.similarity_search_with_score(query,
|
related_docs_with_score = vector_store.similarity_search_with_score(query,
|
||||||
k=self.top_k)
|
k=self.top_k)
|
||||||
related_docs = get_docs_with_score(related_docs_with_score)
|
related_docs = get_docs_with_score(related_docs_with_score)
|
||||||
prompt = generate_prompt(related_docs, query)
|
prompt = generate_prompt(related_docs, query)
|
||||||
|
|
||||||
if self.llm.streaming:
|
# if streaming:
|
||||||
for result, history in self.llm._call(prompt=prompt,
|
# for result, history in self.llm._stream_call(prompt=prompt,
|
||||||
history=chat_history):
|
# history=chat_history):
|
||||||
history[-1][0] = query
|
# history[-1][0] = query
|
||||||
response = {"query": query,
|
# response = {"query": query,
|
||||||
"result": result,
|
# "result": result,
|
||||||
"source_documents": related_docs}
|
# "source_documents": related_docs}
|
||||||
yield response, history
|
# yield response, history
|
||||||
else:
|
# else:
|
||||||
result, history = self.llm._call(prompt=prompt,
|
for result, history in self.llm._call(prompt=prompt,
|
||||||
history=chat_history)
|
history=chat_history,
|
||||||
|
streaming=streaming):
|
||||||
history[-1][0] = query
|
history[-1][0] = query
|
||||||
response = {"query": query,
|
response = {"query": query,
|
||||||
"result": result,
|
"result": result,
|
||||||
"source_documents": related_docs}
|
"source_documents": related_docs}
|
||||||
return response, history
|
yield response, history
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
local_doc_qa = LocalDocQA()
|
||||||
|
local_doc_qa.init_cfg()
|
||||||
|
query = "你好"
|
||||||
|
vs_path = "/Users/liuqian/Downloads/glm-dev/vector_store/123"
|
||||||
|
last_print_len = 0
|
||||||
|
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
|
||||||
|
vs_path=vs_path,
|
||||||
|
chat_history=[],
|
||||||
|
streaming=True):
|
||||||
|
print(resp["result"][last_print_len:], end="", flush=True)
|
||||||
|
last_print_len = len(resp["result"])
|
||||||
|
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
|
||||||
|
vs_path=vs_path,
|
||||||
|
chat_history=[],
|
||||||
|
streaming=False):
|
||||||
|
print(resp["result"])
|
||||||
|
pass
|
||||||
|
|
|
||||||
|
|
@ -32,9 +32,12 @@ if __name__ == "__main__":
|
||||||
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
|
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
|
||||||
vs_path=vs_path,
|
vs_path=vs_path,
|
||||||
chat_history=history,
|
chat_history=history,
|
||||||
streaming=True):
|
streaming=STREAMING):
|
||||||
print(resp["result"][last_print_len:], end="", flush=True)
|
if STREAMING:
|
||||||
last_print_len = len(resp["result"])
|
print(resp["result"][last_print_len:], end="", flush=True)
|
||||||
|
last_print_len = len(resp["result"])
|
||||||
|
else:
|
||||||
|
print(resp["result"])
|
||||||
if REPLY_WITH_SOURCE:
|
if REPLY_WITH_SOURCE:
|
||||||
source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
|
source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
|
||||||
# f"""相关度:{doc.metadata['score']}\n\n"""
|
# f"""相关度:{doc.metadata['score']}\n\n"""
|
||||||
|
|
|
||||||
|
|
@ -4,21 +4,15 @@ from typing import Optional, List
|
||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
from transformers import AutoTokenizer, AutoModel, AutoConfig
|
from transformers import AutoTokenizer, AutoModel, AutoConfig
|
||||||
import torch
|
import torch
|
||||||
from configs.model_config import LLM_DEVICE
|
from configs.model_config import *
|
||||||
from langchain.callbacks.base import CallbackManager
|
from langchain.callbacks.base import CallbackManager
|
||||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
from typing import Dict, Tuple, Union, Optional
|
from typing import Dict, Tuple, Union, Optional
|
||||||
|
from utils import torch_gc
|
||||||
|
|
||||||
DEVICE = LLM_DEVICE
|
DEVICE_ = LLM_DEVICE
|
||||||
DEVICE_ID = "0" if torch.cuda.is_available() else None
|
DEVICE_ID = "0" if torch.cuda.is_available() else None
|
||||||
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
|
DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_
|
||||||
|
|
||||||
|
|
||||||
def torch_gc():
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
with torch.cuda.device(CUDA_DEVICE):
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.ipc_collect()
|
|
||||||
|
|
||||||
|
|
||||||
def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
|
def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
|
||||||
|
|
@ -59,7 +53,6 @@ class ChatGLM(LLM):
|
||||||
tokenizer: object = None
|
tokenizer: object = None
|
||||||
model: object = None
|
model: object = None
|
||||||
history_len: int = 10
|
history_len: int = 10
|
||||||
streaming: bool = True
|
|
||||||
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
@ -72,8 +65,8 @@ class ChatGLM(LLM):
|
||||||
def _call(self,
|
def _call(self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
history: List[List[str]] = [],
|
history: List[List[str]] = [],
|
||||||
stop: Optional[List[str]] = None) -> str:
|
streaming: bool = STREAMING): # -> Tuple[str, List[List[str]]]:
|
||||||
if self.streaming:
|
if streaming:
|
||||||
for inum, (stream_resp, _) in enumerate(self.model.stream_chat(
|
for inum, (stream_resp, _) in enumerate(self.model.stream_chat(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
|
|
@ -81,25 +74,23 @@ class ChatGLM(LLM):
|
||||||
max_length=self.max_token,
|
max_length=self.max_token,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
)):
|
)):
|
||||||
|
torch_gc(DEVICE)
|
||||||
if inum == 0:
|
if inum == 0:
|
||||||
history += [[prompt, stream_resp]]
|
history += [[prompt, stream_resp]]
|
||||||
else:
|
else:
|
||||||
history[-1] = [prompt, stream_resp]
|
history[-1] = [prompt, stream_resp]
|
||||||
yield stream_resp, history
|
yield stream_resp, history
|
||||||
|
|
||||||
else:
|
else:
|
||||||
response, _ = self.model.chat(
|
response, _ = self.model.chat(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
history=history[-self.history_len:] if self.history_len > 0 else [],
|
history=history[-self.history_len:] if self.history_len > 0 else [],
|
||||||
max_length=self.max_token,
|
max_length=self.max_token,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
)
|
)
|
||||||
torch_gc()
|
torch_gc(DEVICE)
|
||||||
if stop is not None:
|
history += [[prompt, response]]
|
||||||
response = enforce_stop_tokens(response, stop)
|
yield response, history
|
||||||
history = history + [[None, response]]
|
|
||||||
return response, history
|
|
||||||
|
|
||||||
# def chat(self,
|
# def chat(self,
|
||||||
# prompt: str) -> str:
|
# prompt: str) -> str:
|
||||||
|
|
@ -191,3 +182,16 @@ class ChatGLM(LLM):
|
||||||
print("加载PrefixEncoder模型参数失败")
|
print("加载PrefixEncoder模型参数失败")
|
||||||
|
|
||||||
self.model = self.model.eval()
|
self.model = self.model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
llm = ChatGLM()
|
||||||
|
llm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL],
|
||||||
|
llm_device=LLM_DEVICE, )
|
||||||
|
last_print_len=0
|
||||||
|
for resp, history in llm._call("你好", streaming=True):
|
||||||
|
print(resp[last_print_len:], end="", flush=True)
|
||||||
|
last_print_len = len(resp)
|
||||||
|
for resp, history in llm._call("你好", streaming=False):
|
||||||
|
print(resp)
|
||||||
|
pass
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,11 @@
|
||||||
|
import torch.cuda
|
||||||
|
import torch.mps
|
||||||
|
import torch.backends
|
||||||
|
|
||||||
|
def torch_gc(DEVICE):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
with torch.cuda.device(DEVICE):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.ipc_collect()
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
torch.mps.empty_cache()
|
||||||
41
webui.py
41
webui.py
|
|
@ -29,23 +29,14 @@ llm_model_dict_list = list(llm_model_dict.keys())
|
||||||
local_doc_qa = LocalDocQA()
|
local_doc_qa = LocalDocQA()
|
||||||
|
|
||||||
|
|
||||||
def get_answer(query, vs_path, history, mode):
|
def get_answer(query, vs_path, history, mode,
|
||||||
|
streaming: bool = STREAMING):
|
||||||
if mode == "知识库问答" and vs_path:
|
if mode == "知识库问答" and vs_path:
|
||||||
if local_doc_qa.llm.streaming:
|
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
||||||
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
query=query,
|
||||||
query=query, vs_path=vs_path, chat_history=history):
|
vs_path=vs_path,
|
||||||
source = "\n\n"
|
chat_history=history,
|
||||||
source += "".join(
|
streaming=streaming):
|
||||||
[f"""<details> <summary>出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}</summary>\n"""
|
|
||||||
f"""{doc.page_content}\n"""
|
|
||||||
f"""</details>"""
|
|
||||||
for i, doc in
|
|
||||||
enumerate(resp["source_documents"])])
|
|
||||||
history[-1][-1] += source
|
|
||||||
yield history, ""
|
|
||||||
else:
|
|
||||||
resp, history = local_doc_qa.get_knowledge_based_answer(
|
|
||||||
query=query, vs_path=vs_path, chat_history=history)
|
|
||||||
source = "\n\n"
|
source = "\n\n"
|
||||||
source += "".join(
|
source += "".join(
|
||||||
[f"""<details> <summary>出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}</summary>\n"""
|
[f"""<details> <summary>出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}</summary>\n"""
|
||||||
|
|
@ -54,18 +45,13 @@ def get_answer(query, vs_path, history, mode):
|
||||||
for i, doc in
|
for i, doc in
|
||||||
enumerate(resp["source_documents"])])
|
enumerate(resp["source_documents"])])
|
||||||
history[-1][-1] += source
|
history[-1][-1] += source
|
||||||
return history, ""
|
yield history, ""
|
||||||
else:
|
else:
|
||||||
if local_doc_qa.llm.streaming:
|
for resp, history in local_doc_qa.llm._call(query, history,
|
||||||
for resp, history in local_doc_qa.llm._call(query, history):
|
streaming=streaming):
|
||||||
history[-1][-1] = resp + (
|
|
||||||
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
|
|
||||||
yield history, ""
|
|
||||||
else:
|
|
||||||
resp, history = local_doc_qa.llm._call(query, history)
|
|
||||||
history[-1][-1] = resp + (
|
history[-1][-1] = resp + (
|
||||||
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
|
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
|
||||||
return history, ""
|
yield history, ""
|
||||||
|
|
||||||
|
|
||||||
def update_status(history, status):
|
def update_status(history, status):
|
||||||
|
|
@ -76,7 +62,7 @@ def update_status(history, status):
|
||||||
|
|
||||||
def init_model():
|
def init_model():
|
||||||
try:
|
try:
|
||||||
local_doc_qa.init_cfg(streaming=STREAMING)
|
local_doc_qa.init_cfg()
|
||||||
local_doc_qa.llm._call("你好")
|
local_doc_qa.llm._call("你好")
|
||||||
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
|
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
|
||||||
print(reply)
|
print(reply)
|
||||||
|
|
@ -98,8 +84,7 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, to
|
||||||
embedding_model=embedding_model,
|
embedding_model=embedding_model,
|
||||||
llm_history_len=llm_history_len,
|
llm_history_len=llm_history_len,
|
||||||
use_ptuning_v2=use_ptuning_v2,
|
use_ptuning_v2=use_ptuning_v2,
|
||||||
top_k=top_k,
|
top_k=top_k,)
|
||||||
streaming=STREAMING)
|
|
||||||
model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
|
model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
|
||||||
print(model_status)
|
print(model_status)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue