add stream support to cli_demo.py
This commit is contained in:
parent
88ab9a1d21
commit
b4aefca555
|
|
@ -2,7 +2,6 @@ from langchain.chains import RetrievalQA
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||||
from langchain.vectorstores import FAISS
|
from langchain.vectorstores import FAISS
|
||||||
from langchain.vectorstores.base import VectorStoreRetriever
|
|
||||||
from langchain.document_loaders import UnstructuredFileLoader
|
from langchain.document_loaders import UnstructuredFileLoader
|
||||||
from models.chatglm_llm import ChatGLM
|
from models.chatglm_llm import ChatGLM
|
||||||
import sentence_transformers
|
import sentence_transformers
|
||||||
|
|
@ -34,23 +33,21 @@ def load_file(filepath):
|
||||||
docs = loader.load_and_split(text_splitter=textsplitter)
|
docs = loader.load_and_split(text_splitter=textsplitter)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
def generate_prompt(related_docs: List[str],
|
||||||
|
query: str,
|
||||||
|
prompt_template=PROMPT_TEMPLATE) -> str:
|
||||||
|
context = "\n".join([doc.page_content for doc in related_docs])
|
||||||
|
prompt = prompt_template.replace("{question}", query).replace("{context}", context)
|
||||||
|
return prompt
|
||||||
|
|
||||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
|
||||||
if self.search_type == "similarity":
|
def get_docs_with_score(docs_with_score):
|
||||||
docs = self.vectorstore._similarity_search_with_relevance_scores(query, **self.search_kwargs)
|
docs=[]
|
||||||
for doc in docs:
|
for doc, score in docs_with_score:
|
||||||
doc[0].metadata["score"] = doc[1]
|
doc.metadata["score"] = score
|
||||||
docs = [doc[0] for doc in docs]
|
docs.append(doc)
|
||||||
elif self.search_type == "mmr":
|
|
||||||
docs = self.vectorstore.max_marginal_relevance_search(
|
|
||||||
query, **self.search_kwargs
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LocalDocQA:
|
class LocalDocQA:
|
||||||
llm: object = None
|
llm: object = None
|
||||||
embeddings: object = None
|
embeddings: object = None
|
||||||
|
|
@ -73,8 +70,6 @@ class LocalDocQA:
|
||||||
|
|
||||||
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
|
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
|
||||||
model_kwargs={'device': embedding_device})
|
model_kwargs={'device': embedding_device})
|
||||||
# self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
|
|
||||||
# device=embedding_device)
|
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|
||||||
def init_knowledge_vector_store(self,
|
def init_knowledge_vector_store(self,
|
||||||
|
|
@ -134,34 +129,30 @@ class LocalDocQA:
|
||||||
def get_knowledge_based_answer(self,
|
def get_knowledge_based_answer(self,
|
||||||
query,
|
query,
|
||||||
vs_path,
|
vs_path,
|
||||||
chat_history=[], ):
|
chat_history=[],
|
||||||
prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。
|
streaming=True):
|
||||||
如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
|
self.llm.streaming = streaming
|
||||||
|
|
||||||
已知内容:
|
|
||||||
{context}
|
|
||||||
|
|
||||||
问题:
|
|
||||||
{question}"""
|
|
||||||
prompt = PromptTemplate(
|
|
||||||
template=prompt_template,
|
|
||||||
input_variables=["context", "question"]
|
|
||||||
)
|
|
||||||
self.llm.history = chat_history
|
|
||||||
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
||||||
vs_r = vector_store.as_retriever(search_type="mmr",
|
related_docs_with_score = vector_store.similarity_search_with_score(query,
|
||||||
search_kwargs={"k": self.top_k})
|
k=self.top_k)
|
||||||
# VectorStoreRetriever.get_relevant_documents = get_relevant_documents
|
related_docs = get_docs_with_score(related_docs_with_score)
|
||||||
knowledge_chain = RetrievalQA.from_llm(
|
prompt = generate_prompt(related_docs, query)
|
||||||
llm=self.llm,
|
|
||||||
retriever=vs_r,
|
|
||||||
prompt=prompt
|
|
||||||
)
|
|
||||||
knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
|
|
||||||
input_variables=["page_content"], template="{page_content}"
|
|
||||||
)
|
|
||||||
|
|
||||||
knowledge_chain.return_source_documents = True
|
if streaming:
|
||||||
result = knowledge_chain({"query": query})
|
for result, history in self.llm._call(prompt=prompt,
|
||||||
self.llm.history[-1][0] = query
|
history=chat_history):
|
||||||
return result, self.llm.history
|
history[-1] = list(history[-1])
|
||||||
|
history[-1][0] = query
|
||||||
|
response = {"query": query,
|
||||||
|
"result": result,
|
||||||
|
"source_documents": related_docs}
|
||||||
|
yield response, history
|
||||||
|
else:
|
||||||
|
result, history = self.llm._call(prompt=prompt,
|
||||||
|
history=chat_history)
|
||||||
|
history[-1] = list(history[-1])
|
||||||
|
history[-1][0] = query
|
||||||
|
response = {"query": query,
|
||||||
|
"result": result,
|
||||||
|
"source_documents": related_docs}
|
||||||
|
return response, history
|
||||||
|
|
|
||||||
16
cli_demo.py
16
cli_demo.py
|
|
@ -28,10 +28,16 @@ if __name__ == "__main__":
|
||||||
history = []
|
history = []
|
||||||
while True:
|
while True:
|
||||||
query = input("Input your question 请输入问题:")
|
query = input("Input your question 请输入问题:")
|
||||||
resp, history = local_doc_qa.get_knowledge_based_answer(query=query,
|
last_print_len = 0
|
||||||
|
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
|
||||||
vs_path=vs_path,
|
vs_path=vs_path,
|
||||||
chat_history=history)
|
chat_history=history,
|
||||||
|
streaming=True):
|
||||||
|
print(resp["result"][last_print_len:], end="", flush=True)
|
||||||
|
last_print_len = len(resp["result"])
|
||||||
if REPLY_WITH_SOURCE:
|
if REPLY_WITH_SOURCE:
|
||||||
print(resp)
|
source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
|
||||||
else:
|
# f"""相关度:{doc.metadata['score']}\n\n"""
|
||||||
print(resp["result"])
|
for inum, doc in
|
||||||
|
enumerate(resp["source_documents"])]
|
||||||
|
print("\n\n" + "\n\n".join(source_text))
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,8 @@ from langchain.llms.utils import enforce_stop_tokens
|
||||||
from transformers import AutoTokenizer, AutoModel, AutoConfig
|
from transformers import AutoTokenizer, AutoModel, AutoConfig
|
||||||
import torch
|
import torch
|
||||||
from configs.model_config import LLM_DEVICE
|
from configs.model_config import LLM_DEVICE
|
||||||
|
from langchain.callbacks.base import CallbackManager
|
||||||
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
from typing import Dict, Tuple, Union, Optional
|
from typing import Dict, Tuple, Union, Optional
|
||||||
|
|
||||||
DEVICE = LLM_DEVICE
|
DEVICE = LLM_DEVICE
|
||||||
|
|
@ -54,10 +55,12 @@ class ChatGLM(LLM):
|
||||||
max_token: int = 10000
|
max_token: int = 10000
|
||||||
temperature: float = 0.01
|
temperature: float = 0.01
|
||||||
top_p = 0.9
|
top_p = 0.9
|
||||||
history = []
|
# history = []
|
||||||
tokenizer: object = None
|
tokenizer: object = None
|
||||||
model: object = None
|
model: object = None
|
||||||
history_len: int = 10
|
history_len: int = 10
|
||||||
|
streaming: bool = True
|
||||||
|
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -68,46 +71,45 @@ class ChatGLM(LLM):
|
||||||
|
|
||||||
def _call(self,
|
def _call(self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
history: List[List[str]] = [],
|
||||||
stream=True) -> str:
|
stop: Optional[List[str]] = None) -> str:
|
||||||
if stream:
|
if self.streaming:
|
||||||
self.history = self.history + [[None, ""]]
|
history = history + [[None, ""]]
|
||||||
for response, history in self.model.stream_chat(
|
for stream_resp, history in self.model.stream_chat(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
history=self.history[-self.history_len:] if self.history_len > 0 else [],
|
history=history[-self.history_len:] if self.history_len > 0 else [],
|
||||||
max_length=self.max_token,
|
max_length=self.max_token,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
):
|
):
|
||||||
torch_gc()
|
yield stream_resp, history
|
||||||
self.history[-1][-1] = response
|
|
||||||
yield response
|
|
||||||
else:
|
else:
|
||||||
response, _ = self.model.chat(
|
response, _ = self.model.chat(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
history=self.history[-self.history_len:] if self.history_len > 0 else [],
|
history=history[-self.history_len:] if self.history_len > 0 else [],
|
||||||
max_length=self.max_token,
|
max_length=self.max_token,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
)
|
)
|
||||||
torch_gc()
|
torch_gc()
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
response = enforce_stop_tokens(response, stop)
|
response = enforce_stop_tokens(response, stop)
|
||||||
self.history = self.history + [[None, response]]
|
history = history + [[None, response]]
|
||||||
return response
|
return response, history
|
||||||
|
|
||||||
def chat(self,
|
# def chat(self,
|
||||||
prompt: str) -> str:
|
# prompt: str) -> str:
|
||||||
response, _ = self.model.chat(
|
# response, _ = self.model.chat(
|
||||||
self.tokenizer,
|
# self.tokenizer,
|
||||||
prompt,
|
# prompt,
|
||||||
history=self.history[-self.history_len:] if self.history_len > 0 else [],
|
# history=self.history[-self.history_len:] if self.history_len > 0 else [],
|
||||||
max_length=self.max_token,
|
# max_length=self.max_token,
|
||||||
temperature=self.temperature,
|
# temperature=self.temperature,
|
||||||
)
|
# )
|
||||||
torch_gc()
|
# torch_gc()
|
||||||
self.history = self.history + [[None, response]]
|
# self.history = self.history + [[None, response]]
|
||||||
return response
|
# return response
|
||||||
|
|
||||||
def load_model(self,
|
def load_model(self,
|
||||||
model_name_or_path: str = "THUDM/chatglm-6b",
|
model_name_or_path: str = "THUDM/chatglm-6b",
|
||||||
|
|
@ -149,7 +151,13 @@ class ChatGLM(LLM):
|
||||||
else:
|
else:
|
||||||
from accelerate import dispatch_model
|
from accelerate import dispatch_model
|
||||||
|
|
||||||
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True, **kwargs).half()
|
model = (
|
||||||
|
AutoModel.from_pretrained(
|
||||||
|
model_name_or_path,
|
||||||
|
trust_remote_code=True,
|
||||||
|
config=model_config,
|
||||||
|
**kwargs)
|
||||||
|
.half())
|
||||||
# 可传入device_map自定义每张卡的部署情况
|
# 可传入device_map自定义每张卡的部署情况
|
||||||
if device_map is None:
|
if device_map is None:
|
||||||
device_map = auto_configure_device_map(num_gpus)
|
device_map = auto_configure_device_map(num_gpus)
|
||||||
|
|
@ -160,7 +168,8 @@ class ChatGLM(LLM):
|
||||||
AutoModel.from_pretrained(
|
AutoModel.from_pretrained(
|
||||||
model_name_or_path,
|
model_name_or_path,
|
||||||
config=model_config,
|
config=model_config,
|
||||||
trust_remote_code=True)
|
trust_remote_code=True,
|
||||||
|
**kwargs)
|
||||||
.float()
|
.float()
|
||||||
.to(llm_device)
|
.to(llm_device)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue