diff --git a/README.md b/README.md
index c42da6a..e4b34a5 100644
--- a/README.md
+++ b/README.md
@@ -14,8 +14,14 @@

+从文档处理角度来看,实现流程如下:
+
+
+
🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。
+🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/imClumsyPanda/langchain-ChatGLM/langchain-ChatGLM)
+
📓 [ModelWhale 在线运行项目](https://www.heywhale.com/mw/project/643977aa446c45f4592a1e59)
## 变更日志
@@ -166,6 +172,6 @@ Web UI 可以实现如下功能:
- [ ] 实现调用 API 的 Web UI Demo
## 项目交流群
-
+
🎉 langchain-ChatGLM 项目交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py
index 84f820c..6b3d1e2 100644
--- a/chains/local_doc_qa.py
+++ b/chains/local_doc_qa.py
@@ -8,6 +8,7 @@ from textsplitter import ChineseTextSplitter
from typing import List, Tuple
from langchain.docstore.document import Document
import numpy as np
+from utils import torch_gc
# return top-k text chunk from vector store
VECTOR_SEARCH_TOP_K = 6
@@ -15,6 +16,10 @@ VECTOR_SEARCH_TOP_K = 6
# LLM input history length
LLM_HISTORY_LEN = 3
+DEVICE_ = EMBEDDING_DEVICE
+DEVICE_ID = "0" if torch.cuda.is_available() else None
+DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_
+
def load_file(filepath):
if filepath.lower().endswith(".md"):
@@ -30,6 +35,7 @@ def load_file(filepath):
docs = loader.load_and_split(text_splitter=textsplitter)
return docs
+
def generate_prompt(related_docs: List[str],
query: str,
prompt_template=PROMPT_TEMPLATE) -> str:
@@ -39,7 +45,7 @@ def generate_prompt(related_docs: List[str],
def get_docs_with_score(docs_with_score):
- docs=[]
+ docs = []
for doc, score in docs_with_score:
doc.metadata["score"] = score
docs.append(doc)
@@ -50,7 +56,7 @@ def seperate_list(ls: List[int]) -> List[List[int]]:
lists = []
ls1 = [ls[0]]
for i in range(1, len(ls)):
- if ls[i-1] + 1 == ls[i]:
+ if ls[i - 1] + 1 == ls[i]:
ls1.append(ls[i])
else:
lists.append(ls1)
@@ -59,49 +65,52 @@ def seperate_list(ls: List[int]) -> List[List[int]]:
return lists
-
def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
- ) -> List[Tuple[Document, float]]:
- scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
- docs = []
- id_set = set()
- for j, i in enumerate(indices[0]):
- if i == -1:
- # This happens when not enough docs are returned.
- continue
- _id = self.index_to_docstore_id[i]
- doc = self.docstore.search(_id)
- id_set.add(i)
- docs_len = len(doc.page_content)
- for k in range(1, max(i, len(docs)-i)):
- for l in [i+k, i-k]:
- if 0 <= l < len(self.index_to_docstore_id):
- _id0 = self.index_to_docstore_id[l]
- doc0 = self.docstore.search(_id0)
- if docs_len + len(doc0.page_content) > self.chunk_size:
- break
- elif doc0.metadata["source"] == doc.metadata["source"]:
- docs_len += len(doc0.page_content)
- id_set.add(l)
- id_list = sorted(list(id_set))
- id_lists = seperate_list(id_list)
- for id_seq in id_lists:
- for id in id_seq:
- if id == id_seq[0]:
- _id = self.index_to_docstore_id[id]
- doc = self.docstore.search(_id)
- else:
- _id0 = self.index_to_docstore_id[id]
+) -> List[Tuple[Document, float]]:
+ scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
+ docs = []
+ id_set = set()
+ for j, i in enumerate(indices[0]):
+ if i == -1:
+ # This happens when not enough docs are returned.
+ continue
+ _id = self.index_to_docstore_id[i]
+ doc = self.docstore.search(_id)
+ id_set.add(i)
+ docs_len = len(doc.page_content)
+ for k in range(1, max(i, len(docs) - i)):
+ break_flag = False
+ for l in [i + k, i - k]:
+ if 0 <= l < len(self.index_to_docstore_id):
+ _id0 = self.index_to_docstore_id[l]
doc0 = self.docstore.search(_id0)
- doc.page_content += doc0.page_content
- if not isinstance(doc, Document):
- raise ValueError(f"Could not find document for id {_id}, got {doc}")
- docs.append((doc, scores[0][j]))
- return docs
-
+ if docs_len + len(doc0.page_content) > self.chunk_size:
+ break_flag=True
+ break
+ elif doc0.metadata["source"] == doc.metadata["source"]:
+ docs_len += len(doc0.page_content)
+ id_set.add(l)
+ if break_flag:
+ break
+ id_list = sorted(list(id_set))
+ id_lists = seperate_list(id_list)
+ for id_seq in id_lists:
+ for id in id_seq:
+ if id == id_seq[0]:
+ _id = self.index_to_docstore_id[id]
+ doc = self.docstore.search(_id)
+ else:
+ _id0 = self.index_to_docstore_id[id]
+ doc0 = self.docstore.search(_id0)
+ doc.page_content += doc0.page_content
+ if not isinstance(doc, Document):
+ raise ValueError(f"Could not find document for id {_id}, got {doc}")
+ docs.append((doc, scores[0][j]))
+ torch_gc(DEVICE)
+ return docs
class LocalDocQA:
@@ -172,10 +181,12 @@ class LocalDocQA:
if vs_path and os.path.isdir(vs_path):
vector_store = FAISS.load_local(vs_path, self.embeddings)
vector_store.add_documents(docs)
+ torch_gc(DEVICE)
else:
if not vs_path:
vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
vector_store = FAISS.from_documents(docs, self.embeddings)
+ torch_gc(DEVICE)
vector_store.save_local(vs_path)
return vs_path, loaded_files
@@ -187,29 +198,54 @@ class LocalDocQA:
query,
vs_path,
chat_history=[],
- streaming=True):
- self.llm.streaming = streaming
+ streaming: bool = STREAMING):
vector_store = FAISS.load_local(vs_path, self.embeddings)
FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector
- vector_store.chunk_size=self.chunk_size
+ vector_store.chunk_size = self.chunk_size
related_docs_with_score = vector_store.similarity_search_with_score(query,
k=self.top_k)
related_docs = get_docs_with_score(related_docs_with_score)
prompt = generate_prompt(related_docs, query)
- if streaming:
- for result, history in self.llm._call(prompt=prompt,
- history=chat_history):
- history[-1][0] = query
- response = {"query": query,
- "result": result,
- "source_documents": related_docs}
- yield response, history
- else:
- result, history = self.llm._call(prompt=prompt,
- history=chat_history)
+ # if streaming:
+ # for result, history in self.llm._stream_call(prompt=prompt,
+ # history=chat_history):
+ # history[-1][0] = query
+ # response = {"query": query,
+ # "result": result,
+ # "source_documents": related_docs}
+ # yield response, history
+ # else:
+ for result, history in self.llm._call(prompt=prompt,
+ history=chat_history,
+ streaming=streaming):
history[-1][0] = query
response = {"query": query,
"result": result,
"source_documents": related_docs}
- return response, history
+ yield response, history
+
+
+if __name__ == "__main__":
+ local_doc_qa = LocalDocQA()
+ local_doc_qa.init_cfg()
+ query = "本项目使用的embedding模型是什么,消耗多少显存"
+ vs_path = "/Users/liuqian/Downloads/glm-dev/vector_store/aaa"
+ last_print_len = 0
+ for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
+ vs_path=vs_path,
+ chat_history=[],
+ streaming=True):
+ print(resp["result"][last_print_len:], end="", flush=True)
+ last_print_len = len(resp["result"])
+ source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
+ # f"""相关度:{doc.metadata['score']}\n\n"""
+ for inum, doc in
+ enumerate(resp["source_documents"])]
+ print("\n\n" + "\n\n".join(source_text))
+ # for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
+ # vs_path=vs_path,
+ # chat_history=[],
+ # streaming=False):
+ # print(resp["result"])
+ pass
diff --git a/cli_demo.py b/cli_demo.py
index 232f75c..33d616d 100644
--- a/cli_demo.py
+++ b/cli_demo.py
@@ -32,9 +32,12 @@ if __name__ == "__main__":
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
vs_path=vs_path,
chat_history=history,
- streaming=True):
- print(resp["result"][last_print_len:], end="", flush=True)
- last_print_len = len(resp["result"])
+ streaming=STREAMING):
+ if STREAMING:
+ print(resp["result"][last_print_len:], end="", flush=True)
+ last_print_len = len(resp["result"])
+ else:
+ print(resp["result"])
if REPLY_WITH_SOURCE:
source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
# f"""相关度:{doc.metadata['score']}\n\n"""
diff --git a/configs/model_config.py b/configs/model_config.py
index 3147b6a..ab3e848 100644
--- a/configs/model_config.py
+++ b/configs/model_config.py
@@ -27,6 +27,9 @@ llm_model_dict = {
# LLM model name
LLM_MODEL = "chatglm-6b"
+# LLM streaming reponse
+STREAMING = True
+
# Use p-tuning-v2 PrefixEncoder
USE_PTUNING_V2 = False
@@ -38,14 +41,10 @@ VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vector_
UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content", "")
# 基于上下文的prompt模版,请务必保留"{question}"和"{context}"
-PROMPT_TEMPLATE = """已知信息在下方"="包裹的段落,基于以下已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
-
-====================================已知信息=====================================================
+PROMPT_TEMPLATE = """已知信息:
{context}
-================================================================================================
-问题:"{question}"
-答案:"""
+根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 问题是:{question}"""
# 匹配后单段上下文长度
CHUNK_SIZE = 500
\ No newline at end of file
diff --git a/img/langchain+chatglm2.png b/img/langchain+chatglm2.png
new file mode 100644
index 0000000..d98b810
Binary files /dev/null and b/img/langchain+chatglm2.png differ
diff --git a/img/qr_code_10.jpg b/img/qr_code_10.jpg
new file mode 100644
index 0000000..348c0fe
Binary files /dev/null and b/img/qr_code_10.jpg differ
diff --git a/img/qr_code_9.jpg b/img/qr_code_9.jpg
deleted file mode 100644
index 0f3be2d..0000000
Binary files a/img/qr_code_9.jpg and /dev/null differ
diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py
index a0e95d9..e69ce87 100644
--- a/models/chatglm_llm.py
+++ b/models/chatglm_llm.py
@@ -4,21 +4,15 @@ from typing import Optional, List
from langchain.llms.utils import enforce_stop_tokens
from transformers import AutoTokenizer, AutoModel, AutoConfig
import torch
-from configs.model_config import LLM_DEVICE
+from configs.model_config import *
from langchain.callbacks.base import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from typing import Dict, Tuple, Union, Optional
+from utils import torch_gc
-DEVICE = LLM_DEVICE
+DEVICE_ = LLM_DEVICE
DEVICE_ID = "0" if torch.cuda.is_available() else None
-CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
-
-
-def torch_gc():
- if torch.cuda.is_available():
- with torch.cuda.device(CUDA_DEVICE):
- torch.cuda.empty_cache()
- torch.cuda.ipc_collect()
+DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_
def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
@@ -59,7 +53,6 @@ class ChatGLM(LLM):
tokenizer: object = None
model: object = None
history_len: int = 10
- streaming: bool = True
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
def __init__(self):
@@ -72,8 +65,8 @@ class ChatGLM(LLM):
def _call(self,
prompt: str,
history: List[List[str]] = [],
- stop: Optional[List[str]] = None) -> str:
- if self.streaming:
+ streaming: bool = STREAMING): # -> Tuple[str, List[List[str]]]:
+ if streaming:
for inum, (stream_resp, _) in enumerate(self.model.stream_chat(
self.tokenizer,
prompt,
@@ -81,25 +74,23 @@ class ChatGLM(LLM):
max_length=self.max_token,
temperature=self.temperature,
)):
+ torch_gc(DEVICE)
if inum == 0:
history += [[prompt, stream_resp]]
else:
history[-1] = [prompt, stream_resp]
yield stream_resp, history
-
else:
response, _ = self.model.chat(
- self.tokenizer,
- prompt,
- history=history[-self.history_len:] if self.history_len > 0 else [],
- max_length=self.max_token,
- temperature=self.temperature,
+ self.tokenizer,
+ prompt,
+ history=history[-self.history_len:] if self.history_len > 0 else [],
+ max_length=self.max_token,
+ temperature=self.temperature,
)
- torch_gc()
- if stop is not None:
- response = enforce_stop_tokens(response, stop)
- history = history + [[None, response]]
- return response, history
+ torch_gc(DEVICE)
+ history += [[prompt, response]]
+ yield response, history
# def chat(self,
# prompt: str) -> str:
@@ -191,3 +182,16 @@ class ChatGLM(LLM):
print("加载PrefixEncoder模型参数失败")
self.model = self.model.eval()
+
+
+if __name__ == "__main__":
+ llm = ChatGLM()
+ llm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL],
+ llm_device=LLM_DEVICE, )
+ last_print_len=0
+ for resp, history in llm._call("你好", streaming=True):
+ print(resp[last_print_len:], end="", flush=True)
+ last_print_len = len(resp)
+ for resp, history in llm._call("你好", streaming=False):
+ print(resp)
+ pass
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000..8508c7d
--- /dev/null
+++ b/utils/__init__.py
@@ -0,0 +1,11 @@
+import torch.cuda
+import torch.mps
+import torch.backends
+
+def torch_gc(DEVICE):
+ if torch.cuda.is_available():
+ with torch.cuda.device(DEVICE):
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+ elif torch.backends.mps.is_available():
+ torch.mps.empty_cache()
\ No newline at end of file
diff --git a/webui.py b/webui.py
index aaeb734..6c2a29c 100644
--- a/webui.py
+++ b/webui.py
@@ -29,28 +29,28 @@ llm_model_dict_list = list(llm_model_dict.keys())
local_doc_qa = LocalDocQA()
-def get_answer(query, vs_path, history, mode):
- if mode == "知识库问答":
- if vs_path:
- for resp, history in local_doc_qa.get_knowledge_based_answer(
- query=query, vs_path=vs_path, chat_history=history):
- source = "\n\n"
- source += "".join(
- [f""" 出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}
\n"""
- f"""{doc.page_content}\n"""
- f""" """
- for i, doc in
- enumerate(resp["source_documents"])])
- history[-1][-1] += source
- yield history, ""
- else:
- for resp, history in local_doc_qa.llm._call(query, history):
- history[-1][-1] = resp + (
- "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
- yield history, ""
+def get_answer(query, vs_path, history, mode,
+ streaming: bool = STREAMING):
+ if mode == "知识库问答" and vs_path:
+ for resp, history in local_doc_qa.get_knowledge_based_answer(
+ query=query,
+ vs_path=vs_path,
+ chat_history=history,
+ streaming=streaming):
+ source = "\n\n"
+ source += "".join(
+ [f""" 出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}
\n"""
+ f"""{doc.page_content}\n"""
+ f""" """
+ for i, doc in
+ enumerate(resp["source_documents"])])
+ history[-1][-1] += source
+ yield history, ""
else:
- for resp, history in local_doc_qa.llm._call(query, history):
- history[-1][-1] = resp
+ for resp, history in local_doc_qa.llm._call(query, history,
+ streaming=streaming):
+ history[-1][-1] = resp + (
+ "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
yield history, ""
@@ -84,7 +84,7 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, to
embedding_model=embedding_model,
llm_history_len=llm_history_len,
use_ptuning_v2=use_ptuning_v2,
- top_k=top_k)
+ top_k=top_k,)
model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
print(model_status)
except Exception as e: