update webui.py and local_doc_qa.py
This commit is contained in:
parent
daafe8d5fa
commit
88ab9a1d21
|
|
@ -1,9 +1,8 @@
|
||||||
from langchain.chains import RetrievalQA
|
from langchain.chains import RetrievalQA
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
# from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||||
from chains.lib.embeddings import MyEmbeddings
|
from langchain.vectorstores import FAISS
|
||||||
# from langchain.vectorstores import FAISS
|
from langchain.vectorstores.base import VectorStoreRetriever
|
||||||
from chains.lib.vectorstores import FAISSVS
|
|
||||||
from langchain.document_loaders import UnstructuredFileLoader
|
from langchain.document_loaders import UnstructuredFileLoader
|
||||||
from models.chatglm_llm import ChatGLM
|
from models.chatglm_llm import ChatGLM
|
||||||
import sentence_transformers
|
import sentence_transformers
|
||||||
|
|
@ -12,6 +11,7 @@ from configs.model_config import *
|
||||||
import datetime
|
import datetime
|
||||||
from typing import List
|
from typing import List
|
||||||
from textsplitter import ChineseTextSplitter
|
from textsplitter import ChineseTextSplitter
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
# return top-k text chunk from vector store
|
# return top-k text chunk from vector store
|
||||||
VECTOR_SEARCH_TOP_K = 6
|
VECTOR_SEARCH_TOP_K = 6
|
||||||
|
|
@ -21,7 +21,10 @@ LLM_HISTORY_LEN = 3
|
||||||
|
|
||||||
|
|
||||||
def load_file(filepath):
|
def load_file(filepath):
|
||||||
if filepath.lower().endswith(".pdf"):
|
if filepath.lower().endswith(".md"):
|
||||||
|
loader = UnstructuredFileLoader(filepath, mode="elements")
|
||||||
|
docs = loader.load()
|
||||||
|
elif filepath.lower().endswith(".pdf"):
|
||||||
loader = UnstructuredFileLoader(filepath)
|
loader = UnstructuredFileLoader(filepath)
|
||||||
textsplitter = ChineseTextSplitter(pdf=True)
|
textsplitter = ChineseTextSplitter(pdf=True)
|
||||||
docs = loader.load_and_split(textsplitter)
|
docs = loader.load_and_split(textsplitter)
|
||||||
|
|
@ -32,6 +35,22 @@ def load_file(filepath):
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||||
|
if self.search_type == "similarity":
|
||||||
|
docs = self.vectorstore._similarity_search_with_relevance_scores(query, **self.search_kwargs)
|
||||||
|
for doc in docs:
|
||||||
|
doc[0].metadata["score"] = doc[1]
|
||||||
|
docs = [doc[0] for doc in docs]
|
||||||
|
elif self.search_type == "mmr":
|
||||||
|
docs = self.vectorstore.max_marginal_relevance_search(
|
||||||
|
query, **self.search_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LocalDocQA:
|
class LocalDocQA:
|
||||||
llm: object = None
|
llm: object = None
|
||||||
embeddings: object = None
|
embeddings: object = None
|
||||||
|
|
@ -52,7 +71,7 @@ class LocalDocQA:
|
||||||
use_ptuning_v2=use_ptuning_v2)
|
use_ptuning_v2=use_ptuning_v2)
|
||||||
self.llm.history_len = llm_history_len
|
self.llm.history_len = llm_history_len
|
||||||
|
|
||||||
self.embeddings = MyEmbeddings(model_name=embedding_model_dict[embedding_model],
|
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
|
||||||
model_kwargs={'device': embedding_device})
|
model_kwargs={'device': embedding_device})
|
||||||
# self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
|
# self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
|
||||||
# device=embedding_device)
|
# device=embedding_device)
|
||||||
|
|
@ -99,12 +118,12 @@ class LocalDocQA:
|
||||||
print(f"{file} 未能成功加载")
|
print(f"{file} 未能成功加载")
|
||||||
if len(docs) > 0:
|
if len(docs) > 0:
|
||||||
if vs_path and os.path.isdir(vs_path):
|
if vs_path and os.path.isdir(vs_path):
|
||||||
vector_store = FAISSVS.load_local(vs_path, self.embeddings)
|
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
||||||
vector_store.add_documents(docs)
|
vector_store.add_documents(docs)
|
||||||
else:
|
else:
|
||||||
if not vs_path:
|
if not vs_path:
|
||||||
vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
|
vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
|
||||||
vector_store = FAISSVS.from_documents(docs, self.embeddings)
|
vector_store = FAISS.from_documents(docs, self.embeddings)
|
||||||
|
|
||||||
vector_store.save_local(vs_path)
|
vector_store.save_local(vs_path)
|
||||||
return vs_path, loaded_files
|
return vs_path, loaded_files
|
||||||
|
|
@ -129,10 +148,13 @@ class LocalDocQA:
|
||||||
input_variables=["context", "question"]
|
input_variables=["context", "question"]
|
||||||
)
|
)
|
||||||
self.llm.history = chat_history
|
self.llm.history = chat_history
|
||||||
vector_store = FAISSVS.load_local(vs_path, self.embeddings)
|
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
||||||
|
vs_r = vector_store.as_retriever(search_type="mmr",
|
||||||
|
search_kwargs={"k": self.top_k})
|
||||||
|
# VectorStoreRetriever.get_relevant_documents = get_relevant_documents
|
||||||
knowledge_chain = RetrievalQA.from_llm(
|
knowledge_chain = RetrievalQA.from_llm(
|
||||||
llm=self.llm,
|
llm=self.llm,
|
||||||
retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),
|
retriever=vs_r,
|
||||||
prompt=prompt
|
prompt=prompt
|
||||||
)
|
)
|
||||||
knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
|
knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
|
||||||
|
|
@ -140,7 +162,6 @@ class LocalDocQA:
|
||||||
)
|
)
|
||||||
|
|
||||||
knowledge_chain.return_source_documents = True
|
knowledge_chain.return_source_documents = True
|
||||||
|
|
||||||
result = knowledge_chain({"query": query})
|
result = knowledge_chain({"query": query})
|
||||||
self.llm.history[-1][0] = query
|
self.llm.history[-1][0] = query
|
||||||
return result, self.llm.history
|
return result, self.llm.history
|
||||||
|
|
|
||||||
|
|
@ -72,16 +72,16 @@ class ChatGLM(LLM):
|
||||||
stream=True) -> str:
|
stream=True) -> str:
|
||||||
if stream:
|
if stream:
|
||||||
self.history = self.history + [[None, ""]]
|
self.history = self.history + [[None, ""]]
|
||||||
response, _ = self.model.stream_chat(
|
for response, history in self.model.stream_chat(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
history=self.history[-self.history_len:] if self.history_len > 0 else [],
|
history=self.history[-self.history_len:] if self.history_len > 0 else [],
|
||||||
max_length=self.max_token,
|
max_length=self.max_token,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
)
|
):
|
||||||
torch_gc()
|
torch_gc()
|
||||||
self.history[-1][-1] = response
|
self.history[-1][-1] = response
|
||||||
yield response
|
yield response
|
||||||
else:
|
else:
|
||||||
response, _ = self.model.chat(
|
response, _ = self.model.chat(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
|
|
|
||||||
35
webui.py
35
webui.py
|
|
@ -30,19 +30,28 @@ local_doc_qa = LocalDocQA()
|
||||||
|
|
||||||
|
|
||||||
def get_answer(query, vs_path, history, mode):
|
def get_answer(query, vs_path, history, mode):
|
||||||
if vs_path and mode == "知识库问答":
|
if mode == "知识库问答":
|
||||||
resp, history = local_doc_qa.get_knowledge_based_answer(
|
if vs_path:
|
||||||
query=query, vs_path=vs_path, chat_history=history)
|
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
||||||
source = "".join([f"""<details> <summary>出处 {i + 1}</summary>
|
query=query, vs_path=vs_path, chat_history=history):
|
||||||
{doc.page_content}
|
# source = "".join([f"""<details> <summary>出处 {i + 1}</summary>
|
||||||
|
# {doc.page_content}
|
||||||
<b>所属文件:</b>{doc.metadata["source"]}
|
#
|
||||||
</details>""" for i, doc in enumerate(resp["source_documents"])])
|
# <b>所属文件:</b>{doc.metadata["source"]}
|
||||||
history[-1][-1] += source
|
# </details>""" for i, doc in enumerate(resp["source_documents"])])
|
||||||
|
# history[-1][-1] += source
|
||||||
|
yield history, ""
|
||||||
|
else:
|
||||||
|
history = history + [[query, ""]]
|
||||||
|
for resp in local_doc_qa.llm._call(query):
|
||||||
|
history[-1][-1] = resp + (
|
||||||
|
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
|
||||||
|
yield history, ""
|
||||||
else:
|
else:
|
||||||
resp = local_doc_qa.llm._call(query)
|
history = history + [[query, ""]]
|
||||||
history = history + [[query, resp + ("\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")]]
|
for resp in local_doc_qa.llm._call(query):
|
||||||
return history, ""
|
history[-1][-1] = resp
|
||||||
|
yield history, ""
|
||||||
|
|
||||||
|
|
||||||
def update_status(history, status):
|
def update_status(history, status):
|
||||||
|
|
@ -62,7 +71,7 @@ def init_model():
|
||||||
print(e)
|
print(e)
|
||||||
reply = """模型未成功加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
|
reply = """模型未成功加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
|
||||||
if str(e) == "Unknown platform: darwin":
|
if str(e) == "Unknown platform: darwin":
|
||||||
print("改报错可能因为您使用的是 macOS 操作系统,需先下载模型至本地后执行 Web UI,具体方法请参考项目 README 中本地部署方法及常见问题:"
|
print("该报错可能因为您使用的是 macOS 操作系统,需先下载模型至本地后执行 Web UI,具体方法请参考项目 README 中本地部署方法及常见问题:"
|
||||||
" https://github.com/imClumsyPanda/langchain-ChatGLM")
|
" https://github.com/imClumsyPanda/langchain-ChatGLM")
|
||||||
else:
|
else:
|
||||||
print(reply)
|
print(reply)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue