Merge branch 'dev' into chatglm2cpp
This commit is contained in:
commit
1f8924eeac
|
|
@ -174,8 +174,7 @@ embedding/*
|
||||||
|
|
||||||
pyrightconfig.json
|
pyrightconfig.json
|
||||||
loader/tmp_files
|
loader/tmp_files
|
||||||
<<<<<<< HEAD
|
|
||||||
flagged/*
|
flagged/*
|
||||||
=======
|
ptuning-v2/*.json
|
||||||
flagged/*
|
ptuning-v2/*.bin
|
||||||
>>>>>>> 19147c3 (Update .gitignore)
|
|
||||||
|
|
|
||||||
|
|
@ -226,6 +226,10 @@ Web UI 可以实现如下功能:
|
||||||
- [x] [THUDM/chatglm-6b-int4-qe](https://huggingface.co/THUDM/chatglm-6b-int4-qe)
|
- [x] [THUDM/chatglm-6b-int4-qe](https://huggingface.co/THUDM/chatglm-6b-int4-qe)
|
||||||
- [x] [ClueAI/ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2)
|
- [x] [ClueAI/ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2)
|
||||||
- [x] [fnlp/moss-moon-003-sft](https://huggingface.co/fnlp/moss-moon-003-sft)
|
- [x] [fnlp/moss-moon-003-sft](https://huggingface.co/fnlp/moss-moon-003-sft)
|
||||||
|
- [x] [bigscience/bloomz-7b1](https://huggingface.co/bigscience/bloomz-7b1)
|
||||||
|
- [x] [bigscience/bloom-3b](https://huggingface.co/bigscience/bloom-3b)
|
||||||
|
- [x] [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B)
|
||||||
|
- [x] [lmsys/vicuna-13b-delta-v1.1](https://huggingface.co/lmsys/vicuna-13b-delta-v1.1)
|
||||||
- [x] 支持通过调用 [fastchat](https://github.com/lm-sys/FastChat) api 调用 llm
|
- [x] 支持通过调用 [fastchat](https://github.com/lm-sys/FastChat) api 调用 llm
|
||||||
- [x] 增加更多 Embedding 模型支持
|
- [x] 增加更多 Embedding 模型支持
|
||||||
- [x] [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh)
|
- [x] [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh)
|
||||||
|
|
@ -251,7 +255,7 @@ Web UI 可以实现如下功能:
|
||||||
- [x] VUE 前端
|
- [x] VUE 前端
|
||||||
|
|
||||||
## 项目交流群
|
## 项目交流群
|
||||||
<img src="img/qr_code_42.jpg" alt="二维码" width="300" height="300" />
|
<img src="img/qr_code_45.jpg" alt="二维码" width="300" height="300" />
|
||||||
|
|
||||||
|
|
||||||
🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
|
🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
|
||||||
|
|
|
||||||
4
api.py
4
api.py
|
|
@ -1,3 +1,4 @@
|
||||||
|
#encoding:utf-8
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
@ -373,7 +374,7 @@ async def bing_search_chat(
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
||||||
history: List[List[str]] = Body(
|
history: Optional[List[List[str]]] = Body(
|
||||||
[],
|
[],
|
||||||
description="History of previous questions and answers",
|
description="History of previous questions and answers",
|
||||||
example=[
|
example=[
|
||||||
|
|
@ -391,7 +392,6 @@ async def chat(
|
||||||
resp = answer_result.llm_output["answer"]
|
resp = answer_result.llm_output["answer"]
|
||||||
history = answer_result.history
|
history = answer_result.history
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return ChatMessage(
|
return ChatMessage(
|
||||||
question=question,
|
question=question,
|
||||||
response=resp,
|
response=resp,
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ from typing import List
|
||||||
from utils import torch_gc
|
from utils import torch_gc
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from pypinyin import lazy_pinyin
|
from pypinyin import lazy_pinyin
|
||||||
from loader import UnstructuredPaddleImageLoader, UnstructuredPaddlePDFLoader
|
|
||||||
from models.base import (BaseAnswer,
|
from models.base import (BaseAnswer,
|
||||||
AnswerResult)
|
AnswerResult)
|
||||||
from models.loader.args import parser
|
from models.loader.args import parser
|
||||||
|
|
@ -59,6 +58,7 @@ def tree(filepath, ignore_dir_names=None, ignore_file_names=None):
|
||||||
|
|
||||||
|
|
||||||
def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_TITLE_ENHANCE):
|
def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_TITLE_ENHANCE):
|
||||||
|
|
||||||
if filepath.lower().endswith(".md"):
|
if filepath.lower().endswith(".md"):
|
||||||
loader = UnstructuredFileLoader(filepath, mode="elements")
|
loader = UnstructuredFileLoader(filepath, mode="elements")
|
||||||
docs = loader.load()
|
docs = loader.load()
|
||||||
|
|
@ -67,10 +67,14 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_T
|
||||||
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
|
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
|
||||||
docs = loader.load_and_split(textsplitter)
|
docs = loader.load_and_split(textsplitter)
|
||||||
elif filepath.lower().endswith(".pdf"):
|
elif filepath.lower().endswith(".pdf"):
|
||||||
|
# 暂且将paddle相关的loader改为动态加载,可以在不上传pdf/image知识文件的前提下使用protobuf=4.x
|
||||||
|
from loader import UnstructuredPaddlePDFLoader
|
||||||
loader = UnstructuredPaddlePDFLoader(filepath)
|
loader = UnstructuredPaddlePDFLoader(filepath)
|
||||||
textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size)
|
textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size)
|
||||||
docs = loader.load_and_split(textsplitter)
|
docs = loader.load_and_split(textsplitter)
|
||||||
elif filepath.lower().endswith(".jpg") or filepath.lower().endswith(".png"):
|
elif filepath.lower().endswith(".jpg") or filepath.lower().endswith(".png"):
|
||||||
|
# 暂且将paddle相关的loader改为动态加载,可以在不上传pdf/image知识文件的前提下使用protobuf=4.x
|
||||||
|
from loader import UnstructuredPaddleImageLoader
|
||||||
loader = UnstructuredPaddleImageLoader(filepath, mode="elements")
|
loader = UnstructuredPaddleImageLoader(filepath, mode="elements")
|
||||||
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
|
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
|
||||||
docs = loader.load_and_split(text_splitter=textsplitter)
|
docs = loader.load_and_split(text_splitter=textsplitter)
|
||||||
|
|
|
||||||
|
|
@ -1,34 +0,0 @@
|
||||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
|
||||||
|
|
||||||
from typing import Any, List
|
|
||||||
|
|
||||||
|
|
||||||
class MyEmbeddings(HuggingFaceEmbeddings):
|
|
||||||
def __init__(self, **kwargs: Any):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
"""Compute doc embeddings using a HuggingFace transformer model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: The list of texts to embed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of embeddings, one for each text.
|
|
||||||
"""
|
|
||||||
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
|
||||||
embeddings = self.client.encode(texts, normalize_embeddings=True)
|
|
||||||
return embeddings.tolist()
|
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
|
||||||
"""Compute query embeddings using a HuggingFace transformer model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: The text to embed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Embeddings for the text.
|
|
||||||
"""
|
|
||||||
text = text.replace("\n", " ")
|
|
||||||
embedding = self.client.encode(text, normalize_embeddings=True)
|
|
||||||
return embedding.tolist()
|
|
||||||
|
|
@ -1,121 +0,0 @@
|
||||||
from langchain.vectorstores import FAISS
|
|
||||||
from typing import Any, Callable, List, Optional, Tuple, Dict
|
|
||||||
from langchain.docstore.document import Document
|
|
||||||
from langchain.docstore.base import Docstore
|
|
||||||
|
|
||||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
|
||||||
from langchain.embeddings.base import Embeddings
|
|
||||||
import uuid
|
|
||||||
from langchain.docstore.in_memory import InMemoryDocstore
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
def dependable_faiss_import() -> Any:
|
|
||||||
"""Import faiss if available, otherwise raise error."""
|
|
||||||
try:
|
|
||||||
import faiss
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError(
|
|
||||||
"Could not import faiss python package. "
|
|
||||||
"Please install it with `pip install faiss` "
|
|
||||||
"or `pip install faiss-cpu` (depending on Python version)."
|
|
||||||
)
|
|
||||||
return faiss
|
|
||||||
|
|
||||||
class FAISSVS(FAISS):
|
|
||||||
def __init__(self,
|
|
||||||
embedding_function: Callable[..., Any],
|
|
||||||
index: Any,
|
|
||||||
docstore: Docstore,
|
|
||||||
index_to_docstore_id: Dict[int, str]):
|
|
||||||
super().__init__(embedding_function, index, docstore, index_to_docstore_id)
|
|
||||||
|
|
||||||
def max_marginal_relevance_search_by_vector(
|
|
||||||
self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any
|
|
||||||
) -> List[Tuple[Document, float]]:
|
|
||||||
"""Return docs selected using the maximal marginal relevance.
|
|
||||||
|
|
||||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
|
||||||
among selected documents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embedding: Embedding to look up documents similar to.
|
|
||||||
k: Number of Documents to return. Defaults to 4.
|
|
||||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of Documents with scores selected by maximal marginal relevance.
|
|
||||||
"""
|
|
||||||
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), fetch_k)
|
|
||||||
# -1 happens when not enough docs are returned.
|
|
||||||
embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1]
|
|
||||||
mmr_selected = maximal_marginal_relevance(
|
|
||||||
np.array([embedding], dtype=np.float32), embeddings, k=k
|
|
||||||
)
|
|
||||||
selected_indices = [indices[0][i] for i in mmr_selected]
|
|
||||||
selected_scores = [scores[0][i] for i in mmr_selected]
|
|
||||||
docs = []
|
|
||||||
for i, score in zip(selected_indices, selected_scores):
|
|
||||||
if i == -1:
|
|
||||||
# This happens when not enough docs are returned.
|
|
||||||
continue
|
|
||||||
_id = self.index_to_docstore_id[i]
|
|
||||||
doc = self.docstore.search(_id)
|
|
||||||
if not isinstance(doc, Document):
|
|
||||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
|
||||||
docs.append((doc, score))
|
|
||||||
return docs
|
|
||||||
|
|
||||||
def max_marginal_relevance_search(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
k: int = 4,
|
|
||||||
fetch_k: int = 20,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Tuple[Document, float]]:
|
|
||||||
"""Return docs selected using the maximal marginal relevance.
|
|
||||||
|
|
||||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
|
||||||
among selected documents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Text to look up documents similar to.
|
|
||||||
k: Number of Documents to return. Defaults to 4.
|
|
||||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of Documents with scores selected by maximal marginal relevance.
|
|
||||||
"""
|
|
||||||
embedding = self.embedding_function(query)
|
|
||||||
docs = self.max_marginal_relevance_search_by_vector(embedding, k, fetch_k)
|
|
||||||
return docs
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __from(
|
|
||||||
cls,
|
|
||||||
texts: List[str],
|
|
||||||
embeddings: List[List[float]],
|
|
||||||
embedding: Embeddings,
|
|
||||||
metadatas: Optional[List[dict]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> FAISS:
|
|
||||||
faiss = dependable_faiss_import()
|
|
||||||
index = faiss.IndexFlatIP(len(embeddings[0]))
|
|
||||||
index.add(np.array(embeddings, dtype=np.float32))
|
|
||||||
|
|
||||||
# # my code, for speeding up search
|
|
||||||
# quantizer = faiss.IndexFlatL2(len(embeddings[0]))
|
|
||||||
# index = faiss.IndexIVFFlat(quantizer, len(embeddings[0]), 100)
|
|
||||||
# index.train(np.array(embeddings, dtype=np.float32))
|
|
||||||
# index.add(np.array(embeddings, dtype=np.float32))
|
|
||||||
|
|
||||||
documents = []
|
|
||||||
for i, text in enumerate(texts):
|
|
||||||
metadata = metadatas[i] if metadatas else {}
|
|
||||||
documents.append(Document(page_content=text, metadata=metadata))
|
|
||||||
index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))}
|
|
||||||
docstore = InMemoryDocstore(
|
|
||||||
{index_to_id[i]: doc for i, doc in enumerate(documents)}
|
|
||||||
)
|
|
||||||
return cls(embedding.embed_query, index, docstore, index_to_id)
|
|
||||||
|
|
||||||
|
|
@ -94,6 +94,12 @@ llm_model_dict = {
|
||||||
"local_model_path": None,
|
"local_model_path": None,
|
||||||
"provides": "MOSSLLMChain"
|
"provides": "MOSSLLMChain"
|
||||||
},
|
},
|
||||||
|
"moss-int4": {
|
||||||
|
"name": "moss",
|
||||||
|
"pretrained_model_name": "fnlp/moss-moon-003-sft-int4",
|
||||||
|
"local_model_path": None,
|
||||||
|
"provides": "MOSSLLM"
|
||||||
|
},
|
||||||
"vicuna-13b-hf": {
|
"vicuna-13b-hf": {
|
||||||
"name": "vicuna-13b-hf",
|
"name": "vicuna-13b-hf",
|
||||||
"pretrained_model_name": "vicuna-13b-hf",
|
"pretrained_model_name": "vicuna-13b-hf",
|
||||||
|
|
@ -155,6 +161,15 @@ llm_model_dict = {
|
||||||
"provides": "FastChatOpenAILLMChain", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain"
|
"provides": "FastChatOpenAILLMChain", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain"
|
||||||
"api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url"
|
"api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url"
|
||||||
"api_key": "EMPTY"
|
"api_key": "EMPTY"
|
||||||
|
},
|
||||||
|
# 通过 fastchat 调用的模型请参考如下格式
|
||||||
|
"fastchat-chatglm-6b-int4": {
|
||||||
|
"name": "chatglm-6b-int4", # "name"修改为fastchat服务中的"model_name"
|
||||||
|
"pretrained_model_name": "chatglm-6b-int4",
|
||||||
|
"local_model_path": None,
|
||||||
|
"provides": "FastChatOpenAILLMChain", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain"
|
||||||
|
"api_base_url": "http://localhost:8001/v1", # "name"修改为fastchat服务中的"api_base_url"
|
||||||
|
"api_key": "EMPTY"
|
||||||
},
|
},
|
||||||
"fastchat-chatglm2-6b": {
|
"fastchat-chatglm2-6b": {
|
||||||
"name": "chatglm2-6b", # "name"修改为fastchat服务中的"model_name"
|
"name": "chatglm2-6b", # "name"修改为fastchat服务中的"model_name"
|
||||||
|
|
@ -176,11 +191,13 @@ llm_model_dict = {
|
||||||
# 调用chatgpt时如果报出: urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
|
# 调用chatgpt时如果报出: urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
|
||||||
# Max retries exceeded with url: /v1/chat/completions
|
# Max retries exceeded with url: /v1/chat/completions
|
||||||
# 则需要将urllib3版本修改为1.25.11
|
# 则需要将urllib3版本修改为1.25.11
|
||||||
|
# 如果依然报urllib3.exceptions.MaxRetryError: HTTPSConnectionPool,则将https改为http
|
||||||
|
# 参考https://zhuanlan.zhihu.com/p/350015032
|
||||||
|
|
||||||
# 如果报出:raise NewConnectionError(
|
# 如果报出:raise NewConnectionError(
|
||||||
# urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPSConnection object at 0x000001FE4BDB85E0>:
|
# urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPSConnection object at 0x000001FE4BDB85E0>:
|
||||||
# Failed to establish a new connection: [WinError 10060]
|
# Failed to establish a new connection: [WinError 10060]
|
||||||
# 则是因为内地和香港的IP都被OPENAI封了,需要挂切换为日本、新加坡等地
|
# 则是因为内地和香港的IP都被OPENAI封了,需要切换为日本、新加坡等地
|
||||||
"openai-chatgpt-3.5": {
|
"openai-chatgpt-3.5": {
|
||||||
"name": "gpt-3.5-turbo",
|
"name": "gpt-3.5-turbo",
|
||||||
"pretrained_model_name": "gpt-3.5-turbo",
|
"pretrained_model_name": "gpt-3.5-turbo",
|
||||||
|
|
@ -210,7 +227,7 @@ STREAMING = True
|
||||||
|
|
||||||
# Use p-tuning-v2 PrefixEncoder
|
# Use p-tuning-v2 PrefixEncoder
|
||||||
USE_PTUNING_V2 = False
|
USE_PTUNING_V2 = False
|
||||||
|
PTUNING_DIR='./ptuning-v2'
|
||||||
# LLM running device
|
# LLM running device
|
||||||
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
@ -238,8 +255,8 @@ LLM_HISTORY_LEN = 3
|
||||||
# 知识库检索时返回的匹配内容条数
|
# 知识库检索时返回的匹配内容条数
|
||||||
VECTOR_SEARCH_TOP_K = 5
|
VECTOR_SEARCH_TOP_K = 5
|
||||||
|
|
||||||
# 知识检索内容相关度 Score, 数值范围约为0-1100,如果为0,则不生效,经测试设置为小于500时,匹配结果更精准
|
# 知识检索内容相关度 Score, 数值范围约为0-1100,如果为0,则不生效,建议设置为500左右,经测试设置为小于500时,匹配结果更精准
|
||||||
VECTOR_SEARCH_SCORE_THRESHOLD = 390
|
VECTOR_SEARCH_SCORE_THRESHOLD = 500
|
||||||
|
|
||||||
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
|
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
|
||||||
|
|
||||||
|
|
|
||||||
19
docs/FAQ.md
19
docs/FAQ.md
|
|
@ -177,3 +177,22 @@ download_with_progressbar(url, tmp_path)
|
||||||
Q14 调用api中的 `bing_search_chat`接口时,报出 `Failed to establish a new connection: [Errno 110] Connection timed out`
|
Q14 调用api中的 `bing_search_chat`接口时,报出 `Failed to establish a new connection: [Errno 110] Connection timed out`
|
||||||
|
|
||||||
这是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG--!
|
这是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG--!
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Q15 加载chatglm-6b-int8或chatglm-6b-int4抛出 `RuntimeError: Only Tensors of floating point andcomplex dtype can require gradients`
|
||||||
|
|
||||||
|
疑为chatglm的quantization的问题或torch版本差异问题,针对已经变为Parameter的torch.zeros矩阵也执行Parameter操作,从而抛出 `RuntimeError: Only Tensors of floating point andcomplex dtype can require gradients`。解决办法是在chatglm-项目的原始文件中的quantization.py文件374行改为:
|
||||||
|
|
||||||
|
```
|
||||||
|
try:
|
||||||
|
self.weight =Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
如果上述方式不起作用,则在.cache/hugggingface/modules/目录下针对chatglm项目的原始文件中的quantization.py文件执行上述操作,若软链接不止一个,按照错误提示选择正确的路径。
|
||||||
|
|
||||||
|
注:虽然模型可以顺利加载但在cpu上仍存在推理失败的可能:即针对每个问题,模型一直输出gugugugu。
|
||||||
|
|
||||||
|
因此,最好不要试图用cpu加载量化模型,原因可能是目前python主流量化包的量化操作是在gpu上执行的,会天然地存在gap。
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ $ python loader/image_loader.py
|
||||||
|
|
||||||
## llama-cpp模型调用的说明
|
## llama-cpp模型调用的说明
|
||||||
|
|
||||||
1. 首先从huggingface hub中下载对应的模型,如https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/的[ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin),建议使用huggingface_hub库的snapshot_download下载。
|
1. 首先从huggingface hub中下载对应的模型,如 [https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/) 的 [ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin),建议使用huggingface_hub库的snapshot_download下载。
|
||||||
2. 将下载的模型重命名。通过huggingface_hub下载的模型会被重命名为随机序列,因此需要重命名为原始文件名,如[ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin)。
|
2. 将下载的模型重命名。通过huggingface_hub下载的模型会被重命名为随机序列,因此需要重命名为原始文件名,如[ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin)。
|
||||||
3. 基于下载模型的ggml的加载时间,推测对应的llama-cpp版本,下载对应的llama-cpp-python库的wheel文件,实测[ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin)与llama-cpp-python库兼容,然后手动安装wheel文件。
|
3. 基于下载模型的ggml的加载时间,推测对应的llama-cpp版本,下载对应的llama-cpp-python库的wheel文件,实测[ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin)与llama-cpp-python库兼容,然后手动安装wheel文件。
|
||||||
4. 将下载的模型信息写入configs/model_config.py文件里 `llm_model_dict`中,注意保证参数的兼容性,一些参数组合可能会报错.
|
4. 将下载的模型信息写入configs/model_config.py文件里 `llm_model_dict`中,注意保证参数的兼容性,一些参数组合可能会报错.
|
||||||
|
|
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 273 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 192 KiB |
|
|
@ -6,6 +6,7 @@ from queue import Queue
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from models.loader import LoaderCheckPoint
|
from models.loader import LoaderCheckPoint
|
||||||
|
from pydantic import BaseModel
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
|
|
@ -23,13 +24,12 @@ class ListenerToken:
|
||||||
self._scores = _scores
|
self._scores = _scores
|
||||||
|
|
||||||
|
|
||||||
class AnswerResult:
|
class AnswerResult(BaseModel):
|
||||||
"""
|
"""
|
||||||
消息实体
|
消息实体
|
||||||
"""
|
"""
|
||||||
history: List[List[str]] = []
|
history: List[List[str]] = []
|
||||||
llm_output: Optional[dict] = None
|
llm_output: Optional[dict] = None
|
||||||
listenerToken: ListenerToken = None
|
|
||||||
|
|
||||||
|
|
||||||
class AnswerResultStream:
|
class AnswerResultStream:
|
||||||
|
|
@ -167,8 +167,6 @@ class BaseAnswer(ABC):
|
||||||
|
|
||||||
with generate_with_streaming(inputs=inputs, run_manager=run_manager) as generator:
|
with generate_with_streaming(inputs=inputs, run_manager=run_manager) as generator:
|
||||||
for answerResult in generator:
|
for answerResult in generator:
|
||||||
if answerResult.listenerToken:
|
|
||||||
output = answerResult.listenerToken.input_ids
|
|
||||||
yield answerResult
|
yield answerResult
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
||||||
|
|
@ -2,14 +2,14 @@ from abc import ABC
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from typing import Any, Dict, List, Optional, Generator
|
from typing import Any, Dict, List, Optional, Generator
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from transformers.generation.logits_process import LogitsProcessor
|
# from transformers.generation.logits_process import LogitsProcessor
|
||||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
# from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
||||||
from models.loader import LoaderCheckPoint
|
from models.loader import LoaderCheckPoint
|
||||||
from models.base import (BaseAnswer,
|
from models.base import (BaseAnswer,
|
||||||
AnswerResult,
|
AnswerResult,
|
||||||
AnswerResultStream,
|
AnswerResultStream,
|
||||||
AnswerResultQueueSentinelTokenListenerQueue)
|
AnswerResultQueueSentinelTokenListenerQueue)
|
||||||
import torch
|
# import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -94,8 +94,6 @@ class ChatGLMLLMChain(BaseAnswer, Chain, ABC):
|
||||||
answer_result = AnswerResult()
|
answer_result = AnswerResult()
|
||||||
answer_result.history = history
|
answer_result.history = history
|
||||||
answer_result.llm_output = {"answer": stream_resp}
|
answer_result.llm_output = {"answer": stream_resp}
|
||||||
if listenerQueue.listenerQueue.__len__() > 0:
|
|
||||||
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
|
|
||||||
generate_with_callback(answer_result)
|
generate_with_callback(answer_result)
|
||||||
self.checkPoint.clear_torch_cache()
|
self.checkPoint.clear_torch_cache()
|
||||||
else:
|
else:
|
||||||
|
|
@ -114,8 +112,6 @@ class ChatGLMLLMChain(BaseAnswer, Chain, ABC):
|
||||||
answer_result = AnswerResult()
|
answer_result = AnswerResult()
|
||||||
answer_result.history = history
|
answer_result.history = history
|
||||||
answer_result.llm_output = {"answer": response}
|
answer_result.llm_output = {"answer": response}
|
||||||
if listenerQueue.listenerQueue.__len__() > 0:
|
|
||||||
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
|
|
||||||
|
|
||||||
generate_with_callback(answer_result)
|
generate_with_callback(answer_result)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,11 @@
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from typing import Any, Dict, List, Optional, Generator, Collection
|
from typing import (
|
||||||
|
Any, Dict, List, Optional, Generator, Collection, Set,
|
||||||
|
Callable,
|
||||||
|
Tuple,
|
||||||
|
Union)
|
||||||
|
|
||||||
from models.loader import LoaderCheckPoint
|
from models.loader import LoaderCheckPoint
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from models.base import (BaseAnswer,
|
from models.base import (BaseAnswer,
|
||||||
|
|
@ -8,9 +13,26 @@ from models.base import (BaseAnswer,
|
||||||
AnswerResult,
|
AnswerResult,
|
||||||
AnswerResultStream,
|
AnswerResultStream,
|
||||||
AnswerResultQueueSentinelTokenListenerQueue)
|
AnswerResultQueueSentinelTokenListenerQueue)
|
||||||
|
from tenacity import (
|
||||||
|
before_sleep_log,
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
)
|
||||||
|
from pydantic import Extra, Field, root_validator
|
||||||
|
|
||||||
|
from openai import (
|
||||||
|
ChatCompletion
|
||||||
|
)
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import logging
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _build_message_template() -> Dict[str, str]:
|
def _build_message_template() -> Dict[str, str]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -25,15 +47,26 @@ def _build_message_template() -> Dict[str, str]:
|
||||||
# 将历史对话数组转换为文本格式
|
# 将历史对话数组转换为文本格式
|
||||||
def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, str]]:
|
def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, str]]:
|
||||||
build_messages: Collection[Dict[str, str]] = []
|
build_messages: Collection[Dict[str, str]] = []
|
||||||
for i, (old_query, response) in enumerate(history):
|
|
||||||
user_build_message = _build_message_template()
|
system_build_message = _build_message_template()
|
||||||
user_build_message['role'] = 'user'
|
system_build_message['role'] = 'system'
|
||||||
user_build_message['content'] = old_query
|
system_build_message['content'] = "You are a helpful assistant."
|
||||||
system_build_message = _build_message_template()
|
build_messages.append(system_build_message)
|
||||||
system_build_message['role'] = 'system'
|
if history:
|
||||||
system_build_message['content'] = response
|
for i, (user, assistant) in enumerate(history):
|
||||||
build_messages.append(user_build_message)
|
if user:
|
||||||
build_messages.append(system_build_message)
|
|
||||||
|
user_build_message = _build_message_template()
|
||||||
|
user_build_message['role'] = 'user'
|
||||||
|
user_build_message['content'] = user
|
||||||
|
build_messages.append(user_build_message)
|
||||||
|
|
||||||
|
if not assistant:
|
||||||
|
raise RuntimeError("历史数据结构不正确")
|
||||||
|
system_build_message = _build_message_template()
|
||||||
|
system_build_message['role'] = 'assistant'
|
||||||
|
system_build_message['content'] = assistant
|
||||||
|
build_messages.append(system_build_message)
|
||||||
|
|
||||||
user_build_message = _build_message_template()
|
user_build_message = _build_message_template()
|
||||||
user_build_message['role'] = 'user'
|
user_build_message['role'] = 'user'
|
||||||
|
|
@ -43,6 +76,9 @@ def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str,
|
||||||
|
|
||||||
|
|
||||||
class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
|
class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
|
||||||
|
client: Any
|
||||||
|
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
||||||
|
max_retries: int = 6
|
||||||
api_base_url: str = "http://localhost:8000/v1"
|
api_base_url: str = "http://localhost:8000/v1"
|
||||||
model_name: str = "chatglm-6b"
|
model_name: str = "chatglm-6b"
|
||||||
max_token: int = 10000
|
max_token: int = 10000
|
||||||
|
|
@ -108,6 +144,35 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
|
||||||
def call_model_name(self, model_name):
|
def call_model_name(self, model_name):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
|
def _create_retry_decorator(self) -> Callable[[Any], Any]:
|
||||||
|
min_seconds = 1
|
||||||
|
max_seconds = 60
|
||||||
|
# Wait 2^x * 1 second between each retry starting with
|
||||||
|
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||||
|
return retry(
|
||||||
|
reraise=True,
|
||||||
|
stop=stop_after_attempt(self.max_retries),
|
||||||
|
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||||
|
retry=(
|
||||||
|
retry_if_exception_type(openai.error.Timeout)
|
||||||
|
| retry_if_exception_type(openai.error.APIError)
|
||||||
|
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||||
|
| retry_if_exception_type(openai.error.RateLimitError)
|
||||||
|
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||||
|
),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
|
)
|
||||||
|
|
||||||
|
def completion_with_retry(self, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
retry_decorator = self._create_retry_decorator()
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
|
return self.client.create(**kwargs)
|
||||||
|
|
||||||
|
return _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: Dict[str, Any],
|
||||||
|
|
@ -121,32 +186,74 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
generate_with_callback: AnswerResultStream = None) -> None:
|
generate_with_callback: AnswerResultStream = None) -> None:
|
||||||
|
|
||||||
history = inputs[self.history_key]
|
history = inputs.get(self.history_key, [])
|
||||||
streaming = inputs[self.streaming_key]
|
streaming = inputs.get(self.streaming_key, False)
|
||||||
prompt = inputs[self.prompt_key]
|
prompt = inputs[self.prompt_key]
|
||||||
|
stop = inputs.get("stop", "stop")
|
||||||
print(f"__call:{prompt}")
|
print(f"__call:{prompt}")
|
||||||
try:
|
try:
|
||||||
|
|
||||||
import openai
|
|
||||||
# Not support yet
|
# Not support yet
|
||||||
# openai.api_key = "EMPTY"
|
# openai.api_key = "EMPTY"
|
||||||
openai.api_key = self.api_key
|
openai.api_key = self.api_key
|
||||||
openai.api_base = self.api_base_url
|
openai.api_base = self.api_base_url
|
||||||
except ImportError:
|
self.client = openai.ChatCompletion
|
||||||
|
except AttributeError:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Could not import openai python package. "
|
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||||
"Please install it with `pip install openai`."
|
"due to an old version of the openai package. Try upgrading it "
|
||||||
|
"with `pip install --upgrade openai`."
|
||||||
)
|
)
|
||||||
# create a chat completion
|
msg = build_message_list(prompt, history=history)
|
||||||
completion = openai.ChatCompletion.create(
|
|
||||||
model=self.model_name,
|
|
||||||
messages=build_message_list(prompt)
|
|
||||||
)
|
|
||||||
print(f"response:{completion.choices[0].message.content}")
|
|
||||||
print(f"+++++++++++++++++++++++++++++++++++")
|
|
||||||
|
|
||||||
history += [[prompt, completion.choices[0].message.content]]
|
if streaming:
|
||||||
answer_result = AnswerResult()
|
params = {"stream": streaming,
|
||||||
answer_result.history = history
|
"model": self.model_name,
|
||||||
answer_result.llm_output = {"answer": completion.choices[0].message.content}
|
"stop": stop}
|
||||||
generate_with_callback(answer_result)
|
out_str = ""
|
||||||
|
for stream_resp in self.completion_with_retry(
|
||||||
|
messages=msg,
|
||||||
|
**params
|
||||||
|
):
|
||||||
|
role = stream_resp["choices"][0]["delta"].get("role", "")
|
||||||
|
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||||
|
out_str += token
|
||||||
|
history[-1] = [prompt, out_str]
|
||||||
|
answer_result = AnswerResult()
|
||||||
|
answer_result.history = history
|
||||||
|
answer_result.llm_output = {"answer": out_str}
|
||||||
|
generate_with_callback(answer_result)
|
||||||
|
else:
|
||||||
|
|
||||||
|
params = {"stream": streaming,
|
||||||
|
"model": self.model_name,
|
||||||
|
"stop": stop}
|
||||||
|
response = self.completion_with_retry(
|
||||||
|
messages=msg,
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
role = response["choices"][0]["message"].get("role", "")
|
||||||
|
content = response["choices"][0]["message"].get("content", "")
|
||||||
|
history += [[prompt, content]]
|
||||||
|
answer_result = AnswerResult()
|
||||||
|
answer_result.history = history
|
||||||
|
answer_result.llm_output = {"answer": content}
|
||||||
|
generate_with_callback(answer_result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
chain = FastChatOpenAILLMChain()
|
||||||
|
|
||||||
|
chain.set_api_key("EMPTY")
|
||||||
|
# chain.set_api_base_url("https://api.openai.com/v1")
|
||||||
|
# chain.call_model_name("gpt-3.5-turbo")
|
||||||
|
|
||||||
|
answer_result_stream_result = chain({"streaming": True,
|
||||||
|
"prompt": "你好",
|
||||||
|
"history": []
|
||||||
|
})
|
||||||
|
|
||||||
|
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||||
|
resp = answer_result.llm_output["answer"]
|
||||||
|
print(resp)
|
||||||
|
|
|
||||||
|
|
@ -186,7 +186,5 @@ class LLamaLLMChain(BaseAnswer, Chain, ABC):
|
||||||
answer_result = AnswerResult()
|
answer_result = AnswerResult()
|
||||||
history += [[prompt, reply]]
|
history += [[prompt, reply]]
|
||||||
answer_result.history = history
|
answer_result.history = history
|
||||||
if listenerQueue.listenerQueue.__len__() > 0:
|
|
||||||
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
|
|
||||||
answer_result.llm_output = {"answer": reply}
|
answer_result.llm_output = {"answer": reply}
|
||||||
generate_with_callback(answer_result)
|
generate_with_callback(answer_result)
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from configs.model_config import *
|
from configs.model_config import *
|
||||||
|
|
@ -43,7 +44,8 @@ parser.add_argument('--no-remote-model', action='store_true', help='remote in th
|
||||||
parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.')
|
parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.')
|
||||||
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
|
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
|
||||||
parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras")
|
parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras")
|
||||||
|
parser.add_argument('--use-ptuning-v2',action='store_true',help="whether use ptuning-v2 checkpoint")
|
||||||
|
parser.add_argument("--ptuning-dir",type=str,default=PTUNING_DIR,help="the dir of ptuning-v2 checkpoint")
|
||||||
# Accelerate/transformers
|
# Accelerate/transformers
|
||||||
parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT,
|
parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT,
|
||||||
help='Load the model with 8-bit precision.')
|
help='Load the model with 8-bit precision.')
|
||||||
|
|
|
||||||
|
|
@ -149,7 +149,7 @@ class LoaderCheckPoint:
|
||||||
trust_remote_code=True).half()
|
trust_remote_code=True).half()
|
||||||
# 可传入device_map自定义每张卡的部署情况
|
# 可传入device_map自定义每张卡的部署情况
|
||||||
if self.device_map is None:
|
if self.device_map is None:
|
||||||
if 'chatglm' in self.model_name.lower():
|
if 'chatglm' in self.model_name.lower() and not "chatglm2" in self.model_name.lower():
|
||||||
self.device_map = self.chatglm_auto_configure_device_map(num_gpus)
|
self.device_map = self.chatglm_auto_configure_device_map(num_gpus)
|
||||||
elif 'moss' in self.model_name.lower():
|
elif 'moss' in self.model_name.lower():
|
||||||
self.device_map = self.moss_auto_configure_device_map(num_gpus, checkpoint)
|
self.device_map = self.moss_auto_configure_device_map(num_gpus, checkpoint)
|
||||||
|
|
@ -165,13 +165,6 @@ class LoaderCheckPoint:
|
||||||
dtype=torch.float16 if not self.load_in_8bit else torch.int8,
|
dtype=torch.float16 if not self.load_in_8bit else torch.int8,
|
||||||
max_memory=max_memory,
|
max_memory=max_memory,
|
||||||
no_split_module_classes=model._no_split_modules)
|
no_split_module_classes=model._no_split_modules)
|
||||||
# 对于chaglm和moss意外的模型应使用自动指定,而非调用chatglm的配置方式
|
|
||||||
# 其他模型定义的层类几乎不可能与chatglm和moss一致,使用chatglm_auto_configure_device_map
|
|
||||||
# 百分百会报错,使用infer_auto_device_map虽然可能导致负载不均衡,但至少不会报错
|
|
||||||
# 实测在bloom模型上如此
|
|
||||||
# self.device_map = infer_auto_device_map(model,
|
|
||||||
# dtype=torch.int8,
|
|
||||||
# no_split_module_classes=model._no_split_modules)
|
|
||||||
|
|
||||||
model = dispatch_model(model, device_map=self.device_map)
|
model = dispatch_model(model, device_map=self.device_map)
|
||||||
else:
|
else:
|
||||||
|
|
@ -472,12 +465,13 @@ class LoaderCheckPoint:
|
||||||
|
|
||||||
if self.use_ptuning_v2:
|
if self.use_ptuning_v2:
|
||||||
try:
|
try:
|
||||||
prefix_encoder_file = open(Path(f'{self.ptuning_dir}/config.json'), 'r')
|
prefix_encoder_file = open(Path(f'{os.path.abspath(self.ptuning_dir)}/config.json'), 'r')
|
||||||
prefix_encoder_config = json.loads(prefix_encoder_file.read())
|
prefix_encoder_config = json.loads(prefix_encoder_file.read())
|
||||||
prefix_encoder_file.close()
|
prefix_encoder_file.close()
|
||||||
self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
|
self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
|
||||||
self.model_config.prefix_projection = prefix_encoder_config['prefix_projection']
|
self.model_config.prefix_projection = prefix_encoder_config['prefix_projection']
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
print("加载PrefixEncoder config.json失败")
|
print("加载PrefixEncoder config.json失败")
|
||||||
|
|
||||||
self.model, self.tokenizer = self._load_model()
|
self.model, self.tokenizer = self._load_model()
|
||||||
|
|
@ -487,14 +481,16 @@ class LoaderCheckPoint:
|
||||||
|
|
||||||
if self.use_ptuning_v2:
|
if self.use_ptuning_v2:
|
||||||
try:
|
try:
|
||||||
prefix_state_dict = torch.load(Path(f'{self.ptuning_dir}/pytorch_model.bin'))
|
prefix_state_dict = torch.load(Path(f'{os.path.abspath(self.ptuning_dir)}/pytorch_model.bin'))
|
||||||
new_prefix_state_dict = {}
|
new_prefix_state_dict = {}
|
||||||
for k, v in prefix_state_dict.items():
|
for k, v in prefix_state_dict.items():
|
||||||
if k.startswith("transformer.prefix_encoder."):
|
if k.startswith("transformer.prefix_encoder."):
|
||||||
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
||||||
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
||||||
self.model.transformer.prefix_encoder.float()
|
self.model.transformer.prefix_encoder.float()
|
||||||
|
print("加载ptuning检查点成功!")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
print("加载PrefixEncoder模型参数失败")
|
print("加载PrefixEncoder模型参数失败")
|
||||||
# llama-cpp模型(至少vicuna-13b)的eval方法就是自身,其没有eval方法
|
# llama-cpp模型(至少vicuna-13b)的eval方法就是自身,其没有eval方法
|
||||||
if not self.is_llamacpp and not self.is_chatgmlcpp:
|
if not self.is_llamacpp and not self.is_chatgmlcpp:
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ beautifulsoup4
|
||||||
icetk
|
icetk
|
||||||
cpm_kernels
|
cpm_kernels
|
||||||
faiss-cpu
|
faiss-cpu
|
||||||
gradio==3.28.3
|
gradio==3.37.0
|
||||||
fastapi~=0.95.0
|
fastapi~=0.95.0
|
||||||
uvicorn~=0.21.1
|
uvicorn~=0.21.1
|
||||||
pypinyin~=0.48.0
|
pypinyin~=0.48.0
|
||||||
|
|
|
||||||
1
webui.py
1
webui.py
|
|
@ -104,6 +104,7 @@ def init_model():
|
||||||
args_dict = vars(args)
|
args_dict = vars(args)
|
||||||
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||||||
llm_model_ins = shared.loaderLLM()
|
llm_model_ins = shared.loaderLLM()
|
||||||
|
llm_model_ins.history_len = LLM_HISTORY_LEN
|
||||||
try:
|
try:
|
||||||
local_doc_qa.init_cfg(llm_model=llm_model_ins)
|
local_doc_qa.init_cfg(llm_model=llm_model_ins)
|
||||||
answer_result_stream_result = local_doc_qa.llm_model_chain(
|
answer_result_stream_result = local_doc_qa.llm_model_chain(
|
||||||
|
|
|
||||||
499
webui_st.py
499
webui_st.py
|
|
@ -1,5 +1,5 @@
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
# from st_btn_select import st_btn_select
|
from streamlit_chatbox import st_chatbox
|
||||||
import tempfile
|
import tempfile
|
||||||
###### 从webui借用的代码 #####
|
###### 从webui借用的代码 #####
|
||||||
###### 做了少量修改 #####
|
###### 做了少量修改 #####
|
||||||
|
|
@ -23,6 +23,7 @@ def get_vs_list():
|
||||||
if not os.path.exists(KB_ROOT_PATH):
|
if not os.path.exists(KB_ROOT_PATH):
|
||||||
return lst_default
|
return lst_default
|
||||||
lst = os.listdir(KB_ROOT_PATH)
|
lst = os.listdir(KB_ROOT_PATH)
|
||||||
|
lst = [x for x in lst if os.path.isdir(os.path.join(KB_ROOT_PATH, x))]
|
||||||
if not lst:
|
if not lst:
|
||||||
return lst_default
|
return lst_default
|
||||||
lst.sort()
|
lst.sort()
|
||||||
|
|
@ -31,7 +32,6 @@ def get_vs_list():
|
||||||
|
|
||||||
embedding_model_dict_list = list(embedding_model_dict.keys())
|
embedding_model_dict_list = list(embedding_model_dict.keys())
|
||||||
llm_model_dict_list = list(llm_model_dict.keys())
|
llm_model_dict_list = list(llm_model_dict.keys())
|
||||||
# flag_csv_logger = gr.CSVLogger()
|
|
||||||
|
|
||||||
|
|
||||||
def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD,
|
def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD,
|
||||||
|
|
@ -50,6 +50,9 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
|
||||||
history[-1][-1] += source
|
history[-1][-1] += source
|
||||||
yield history, ""
|
yield history, ""
|
||||||
elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path):
|
elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path):
|
||||||
|
local_doc_qa.top_k = vector_search_top_k
|
||||||
|
local_doc_qa.chunk_conent = chunk_conent
|
||||||
|
local_doc_qa.chunk_size = chunk_size
|
||||||
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
||||||
query=query, vs_path=vs_path, chat_history=history, streaming=streaming):
|
query=query, vs_path=vs_path, chat_history=history, streaming=streaming):
|
||||||
source = "\n\n"
|
source = "\n\n"
|
||||||
|
|
@ -95,17 +98,101 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
|
||||||
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
|
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
|
||||||
yield history, ""
|
yield history, ""
|
||||||
logger.info(f"flagging: username={FLAG_USER_NAME},query={query},vs_path={vs_path},mode={mode},history={history}")
|
logger.info(f"flagging: username={FLAG_USER_NAME},query={query},vs_path={vs_path},mode={mode},history={history}")
|
||||||
# flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME)
|
|
||||||
|
|
||||||
|
|
||||||
def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec'):
|
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
|
||||||
|
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
||||||
|
filelist = []
|
||||||
|
if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")):
|
||||||
|
os.makedirs(os.path.join(KB_ROOT_PATH, vs_id, "content"))
|
||||||
|
qa = st.session_state.local_doc_qa
|
||||||
|
if qa.llm_model_chain and qa.embeddings:
|
||||||
|
if isinstance(files, list):
|
||||||
|
for file in files:
|
||||||
|
filename = os.path.split(file.name)[-1]
|
||||||
|
shutil.move(file.name, os.path.join(
|
||||||
|
KB_ROOT_PATH, vs_id, "content", filename))
|
||||||
|
filelist.append(os.path.join(
|
||||||
|
KB_ROOT_PATH, vs_id, "content", filename))
|
||||||
|
vs_path, loaded_files = qa.init_knowledge_vector_store(
|
||||||
|
filelist, vs_path, sentence_size)
|
||||||
|
else:
|
||||||
|
vs_path, loaded_files = qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation,
|
||||||
|
sentence_size)
|
||||||
|
if len(loaded_files):
|
||||||
|
file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问"
|
||||||
|
else:
|
||||||
|
file_status = "文件未成功加载,请重新上传文件"
|
||||||
|
else:
|
||||||
|
file_status = "模型未完成加载,请先在加载模型后再导入文件"
|
||||||
|
vs_path = None
|
||||||
|
logger.info(file_status)
|
||||||
|
return vs_path, None, history + [[None, file_status]]
|
||||||
|
|
||||||
|
|
||||||
|
knowledge_base_test_mode_info = ("【注意】\n\n"
|
||||||
|
"1. 您已进入知识库测试模式,您输入的任何对话内容都将用于进行知识库查询,"
|
||||||
|
"并仅输出知识库匹配出的内容及相似度分值和及输入的文本源路径,查询的内容并不会进入模型查询。\n\n"
|
||||||
|
"2. 知识相关度 Score 经测试,建议设置为 500 或更低,具体设置情况请结合实际使用调整。"
|
||||||
|
"""3. 使用"添加单条数据"添加文本至知识库时,内容如未分段,则内容越多越会稀释各查询内容与之关联的score阈值。\n\n"""
|
||||||
|
"4. 单条内容长度建议设置在100-150左右。")
|
||||||
|
|
||||||
|
|
||||||
|
webui_title = """
|
||||||
|
# 🎉langchain-ChatGLM WebUI🎉
|
||||||
|
👍 [https://github.com/imClumsyPanda/langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM)
|
||||||
|
"""
|
||||||
|
###### #####
|
||||||
|
|
||||||
|
|
||||||
|
###### todo #####
|
||||||
|
# 1. streamlit运行方式与一般web服务器不同,使用模块是无法实现单例模式的,所以shared和local_doc_qa都需要进行全局化处理。
|
||||||
|
# 目前已经实现了local_doc_qa和shared.loaderCheckPoint的全局化。
|
||||||
|
# 2. 当前local_doc_qa是一个全局变量,一方面:任何一个session对其做出修改,都会影响所有session的对话;另一方面,如何处理所有session的请求竞争也是问题。
|
||||||
|
# 这个暂时无法避免,在配置普通的机器上暂时也无需考虑。
|
||||||
|
# 3. 目前只包含了get_answer对应的参数,以后可以添加其他参数,如temperature。
|
||||||
|
###### #####
|
||||||
|
|
||||||
|
|
||||||
|
###### 配置项 #####
|
||||||
|
class ST_CONFIG:
|
||||||
|
default_mode = "知识库问答"
|
||||||
|
default_kb = ""
|
||||||
|
###### #####
|
||||||
|
|
||||||
|
|
||||||
|
class TempFile:
|
||||||
|
'''
|
||||||
|
为保持与get_vector_store的兼容性,需要将streamlit上传文件转化为其可以接受的方式
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, path):
|
||||||
|
self.name = path
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_resource(show_spinner=False, max_entries=1)
|
||||||
|
def load_model(
|
||||||
|
llm_model: str = LLM_MODEL,
|
||||||
|
embedding_model: str = EMBEDDING_MODEL,
|
||||||
|
use_ptuning_v2: bool = USE_PTUNING_V2,
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
对应init_model,利用streamlit cache避免模型重复加载
|
||||||
|
'''
|
||||||
local_doc_qa = LocalDocQA()
|
local_doc_qa = LocalDocQA()
|
||||||
# 初始化消息
|
# 初始化消息
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args_dict = vars(args)
|
args_dict = vars(args)
|
||||||
args_dict.update(model=llm_model)
|
args_dict.update(model=llm_model)
|
||||||
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
if shared.loaderCheckPoint is None: # avoid checkpoint reloading when reinit model
|
||||||
llm_model_ins = shared.loaderLLM()
|
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||||||
|
# shared.loaderCheckPoint.model_name is different by no_remote_model.
|
||||||
|
# if it is not set properly error occurs when reinit llm model(issue#473).
|
||||||
|
# as no_remote_model is removed from model_config, need workaround to set it automaticlly.
|
||||||
|
local_model_path = llm_model_dict.get(llm_model, {}).get('local_model_path') or ''
|
||||||
|
no_remote_model = os.path.isdir(local_model_path)
|
||||||
|
llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2)
|
||||||
|
llm_model_ins.history_len = LLM_HISTORY_LEN
|
||||||
|
|
||||||
try:
|
try:
|
||||||
local_doc_qa.init_cfg(llm_model=llm_model_ins,
|
local_doc_qa.init_cfg(llm_model=llm_model_ins,
|
||||||
|
|
@ -128,235 +215,9 @@ def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec'
|
||||||
return local_doc_qa
|
return local_doc_qa
|
||||||
|
|
||||||
|
|
||||||
# 暂未使用到,先保留
|
|
||||||
# def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k, history):
|
|
||||||
# try:
|
|
||||||
# llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2)
|
|
||||||
# llm_model_ins.history_len = llm_history_len
|
|
||||||
# local_doc_qa.init_cfg(llm_model=llm_model_ins,
|
|
||||||
# embedding_model=embedding_model,
|
|
||||||
# top_k=top_k)
|
|
||||||
# model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
|
|
||||||
# logger.info(model_status)
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(e)
|
|
||||||
# model_status = """模型未成功重新加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
|
|
||||||
# logger.info(model_status)
|
|
||||||
# return history + [[None, model_status]]
|
|
||||||
|
|
||||||
|
|
||||||
def get_vector_store(local_doc_qa, vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
|
|
||||||
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
|
||||||
filelist = []
|
|
||||||
if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")):
|
|
||||||
os.makedirs(os.path.join(KB_ROOT_PATH, vs_id, "content"))
|
|
||||||
if local_doc_qa.llm and local_doc_qa.embeddings:
|
|
||||||
if isinstance(files, list):
|
|
||||||
for file in files:
|
|
||||||
filename = os.path.split(file.name)[-1]
|
|
||||||
shutil.move(file.name, os.path.join(
|
|
||||||
KB_ROOT_PATH, vs_id, "content", filename))
|
|
||||||
filelist.append(os.path.join(
|
|
||||||
KB_ROOT_PATH, vs_id, "content", filename))
|
|
||||||
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(
|
|
||||||
filelist, vs_path, sentence_size)
|
|
||||||
else:
|
|
||||||
vs_path, loaded_files = local_doc_qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation,
|
|
||||||
sentence_size)
|
|
||||||
if len(loaded_files):
|
|
||||||
file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问"
|
|
||||||
else:
|
|
||||||
file_status = "文件未成功加载,请重新上传文件"
|
|
||||||
else:
|
|
||||||
file_status = "模型未完成加载,请先在加载模型后再导入文件"
|
|
||||||
vs_path = None
|
|
||||||
logger.info(file_status)
|
|
||||||
return vs_path, None, history + [[None, file_status]]
|
|
||||||
|
|
||||||
|
|
||||||
knowledge_base_test_mode_info = ("【注意】\n\n"
|
|
||||||
"1. 您已进入知识库测试模式,您输入的任何对话内容都将用于进行知识库查询,"
|
|
||||||
"并仅输出知识库匹配出的内容及相似度分值和及输入的文本源路径,查询的内容并不会进入模型查询。\n\n"
|
|
||||||
"2. 知识相关度 Score 经测试,建议设置为 500 或更低,具体设置情况请结合实际使用调整。"
|
|
||||||
"""3. 使用"添加单条数据"添加文本至知识库时,内容如未分段,则内容越多越会稀释各查询内容与之关联的score阈值。\n\n"""
|
|
||||||
"4. 单条内容长度建议设置在100-150左右。\n\n"
|
|
||||||
"5. 本界面用于知识入库及知识匹配相关参数设定,但当前版本中,"
|
|
||||||
"本界面中修改的参数并不会直接修改对话界面中参数,仍需前往`configs/model_config.py`修改后生效。"
|
|
||||||
"相关参数将在后续版本中支持本界面直接修改。")
|
|
||||||
|
|
||||||
|
|
||||||
webui_title = """
|
|
||||||
# 🎉langchain-ChatGLM WebUI🎉
|
|
||||||
👍 [https://github.com/imClumsyPanda/langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM)
|
|
||||||
"""
|
|
||||||
###### #####
|
|
||||||
|
|
||||||
|
|
||||||
###### todo #####
|
|
||||||
# 1. streamlit运行方式与一般web服务器不同,使用模块是无法实现单例模式的,所以shared和local_doc_qa都需要进行全局化处理。
|
|
||||||
# 目前已经实现了local_doc_qa的全局化,后面要考虑shared。
|
|
||||||
# 2. 当前local_doc_qa是一个全局变量,一方面:任何一个session对其做出修改,都会影响所有session的对话;另一方面,如何处理所有session的请求竞争也是问题。
|
|
||||||
# 这个暂时无法避免,在配置普通的机器上暂时也无需考虑。
|
|
||||||
# 3. 目前只包含了get_answer对应的参数,以后可以添加其他参数,如temperature。
|
|
||||||
###### #####
|
|
||||||
|
|
||||||
|
|
||||||
###### 配置项 #####
|
|
||||||
class ST_CONFIG:
|
|
||||||
user_bg_color = '#77ff77'
|
|
||||||
user_icon = 'https://tse2-mm.cn.bing.net/th/id/OIP-C.LTTKrxNWDr_k74wz6jKqBgHaHa?w=203&h=203&c=7&r=0&o=5&pid=1.7'
|
|
||||||
robot_bg_color = '#ccccee'
|
|
||||||
robot_icon = 'https://ts1.cn.mm.bing.net/th/id/R-C.5302e2cc6f5c7c4933ebb3394e0c41bc?rik=z4u%2b7efba5Mgxw&riu=http%3a%2f%2fcomic-cons.xyz%2fwp-content%2fuploads%2fStar-Wars-avatar-icon-C3PO.png&ehk=kBBvCvpJMHPVpdfpw1GaH%2brbOaIoHjY5Ua9PKcIs%2bAc%3d&risl=&pid=ImgRaw&r=0'
|
|
||||||
default_mode = '知识库问答'
|
|
||||||
defalut_kb = ''
|
|
||||||
###### #####
|
|
||||||
|
|
||||||
|
|
||||||
class MsgType:
|
|
||||||
'''
|
|
||||||
目前仅支持文本类型的输入输出,为以后多模态模型预留图像、视频、音频支持。
|
|
||||||
'''
|
|
||||||
TEXT = 1
|
|
||||||
IMAGE = 2
|
|
||||||
VIDEO = 3
|
|
||||||
AUDIO = 4
|
|
||||||
|
|
||||||
|
|
||||||
class TempFile:
|
|
||||||
'''
|
|
||||||
为保持与get_vector_store的兼容性,需要将streamlit上传文件转化为其可以接受的方式
|
|
||||||
'''
|
|
||||||
|
|
||||||
def __init__(self, path):
|
|
||||||
self.name = path
|
|
||||||
|
|
||||||
|
|
||||||
def init_session():
|
|
||||||
st.session_state.setdefault('history', [])
|
|
||||||
|
|
||||||
|
|
||||||
# def get_query_params():
|
|
||||||
# '''
|
|
||||||
# 可以用url参数传递配置参数:llm_model, embedding_model, kb, mode。
|
|
||||||
# 该参数将覆盖model_config中的配置。处于安全考虑,目前只支持kb和mode
|
|
||||||
# 方便将固定的配置分享给特定的人。
|
|
||||||
# '''
|
|
||||||
# params = st.experimental_get_query_params()
|
|
||||||
# return {k: v[0] for k, v in params.items() if v}
|
|
||||||
|
|
||||||
|
|
||||||
def robot_say(msg, kb=''):
|
|
||||||
st.session_state['history'].append(
|
|
||||||
{'is_user': False, 'type': MsgType.TEXT, 'content': msg, 'kb': kb})
|
|
||||||
|
|
||||||
|
|
||||||
def user_say(msg):
|
|
||||||
st.session_state['history'].append(
|
|
||||||
{'is_user': True, 'type': MsgType.TEXT, 'content': msg})
|
|
||||||
|
|
||||||
|
|
||||||
def format_md(msg, is_user=False, bg_color='', margin='10%'):
|
|
||||||
'''
|
|
||||||
将文本消息格式化为markdown文本
|
|
||||||
'''
|
|
||||||
if is_user:
|
|
||||||
bg_color = bg_color or ST_CONFIG.user_bg_color
|
|
||||||
text = f'''
|
|
||||||
<div style="background:{bg_color};
|
|
||||||
margin-left:{margin};
|
|
||||||
word-break:break-all;
|
|
||||||
float:right;
|
|
||||||
padding:2%;
|
|
||||||
border-radius:2%;">
|
|
||||||
{msg}
|
|
||||||
</div>
|
|
||||||
'''
|
|
||||||
else:
|
|
||||||
bg_color = bg_color or ST_CONFIG.robot_bg_color
|
|
||||||
text = f'''
|
|
||||||
<div style="background:{bg_color};
|
|
||||||
margin-right:{margin};
|
|
||||||
word-break:break-all;
|
|
||||||
padding:2%;
|
|
||||||
border-radius:2%;">
|
|
||||||
{msg}
|
|
||||||
</div>
|
|
||||||
'''
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def message(msg,
|
|
||||||
is_user=False,
|
|
||||||
msg_type=MsgType.TEXT,
|
|
||||||
icon='',
|
|
||||||
bg_color='',
|
|
||||||
margin='10%',
|
|
||||||
kb='',
|
|
||||||
):
|
|
||||||
'''
|
|
||||||
渲染单条消息。目前仅支持文本
|
|
||||||
'''
|
|
||||||
cols = st.columns([1, 10, 1])
|
|
||||||
empty = cols[1].empty()
|
|
||||||
if is_user:
|
|
||||||
icon = icon or ST_CONFIG.user_icon
|
|
||||||
bg_color = bg_color or ST_CONFIG.user_bg_color
|
|
||||||
cols[2].image(icon, width=40)
|
|
||||||
if msg_type == MsgType.TEXT:
|
|
||||||
text = format_md(msg, is_user, bg_color, margin)
|
|
||||||
empty.markdown(text, unsafe_allow_html=True)
|
|
||||||
else:
|
|
||||||
raise RuntimeError('only support text message now.')
|
|
||||||
else:
|
|
||||||
icon = icon or ST_CONFIG.robot_icon
|
|
||||||
bg_color = bg_color or ST_CONFIG.robot_bg_color
|
|
||||||
cols[0].image(icon, width=40)
|
|
||||||
if kb:
|
|
||||||
cols[0].write(f'({kb})')
|
|
||||||
if msg_type == MsgType.TEXT:
|
|
||||||
text = format_md(msg, is_user, bg_color, margin)
|
|
||||||
empty.markdown(text, unsafe_allow_html=True)
|
|
||||||
else:
|
|
||||||
raise RuntimeError('only support text message now.')
|
|
||||||
return empty
|
|
||||||
|
|
||||||
|
|
||||||
def output_messages(
|
|
||||||
user_bg_color='',
|
|
||||||
robot_bg_color='',
|
|
||||||
user_icon='',
|
|
||||||
robot_icon='',
|
|
||||||
):
|
|
||||||
with chat_box.container():
|
|
||||||
last_response = None
|
|
||||||
for msg in st.session_state['history']:
|
|
||||||
bg_color = user_bg_color if msg['is_user'] else robot_bg_color
|
|
||||||
icon = user_icon if msg['is_user'] else robot_icon
|
|
||||||
empty = message(msg['content'],
|
|
||||||
is_user=msg['is_user'],
|
|
||||||
icon=icon,
|
|
||||||
msg_type=msg['type'],
|
|
||||||
bg_color=bg_color,
|
|
||||||
kb=msg.get('kb', '')
|
|
||||||
)
|
|
||||||
if not msg['is_user']:
|
|
||||||
last_response = empty
|
|
||||||
return last_response
|
|
||||||
|
|
||||||
|
|
||||||
@st.cache_resource(show_spinner=False, max_entries=1)
|
|
||||||
def load_model(llm_model: str, embedding_model: str):
|
|
||||||
'''
|
|
||||||
对应init_model,利用streamlit cache避免模型重复加载
|
|
||||||
'''
|
|
||||||
local_doc_qa = init_model(llm_model, embedding_model)
|
|
||||||
robot_say('模型已成功加载,可以开始对话,或从左侧选择模式后开始对话。\n请尽量不要刷新页面,以免模型出错或重复加载。')
|
|
||||||
return local_doc_qa
|
|
||||||
|
|
||||||
|
|
||||||
# @st.cache_data
|
# @st.cache_data
|
||||||
def answer(query, vs_path='', history=[], mode='', score_threshold=0,
|
def answer(query, vs_path='', history=[], mode='', score_threshold=0,
|
||||||
vector_search_top_k=5, chunk_conent=True, chunk_size=100, qa=None
|
vector_search_top_k=5, chunk_conent=True, chunk_size=100
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应get_answer,--利用streamlit cache缓存相同问题的答案--
|
对应get_answer,--利用streamlit cache缓存相同问题的答案--
|
||||||
|
|
@ -365,48 +226,24 @@ def answer(query, vs_path='', history=[], mode='', score_threshold=0,
|
||||||
vector_search_top_k, chunk_conent, chunk_size)
|
vector_search_top_k, chunk_conent, chunk_size)
|
||||||
|
|
||||||
|
|
||||||
def load_vector_store(
|
def use_kb_mode(m):
|
||||||
vs_id,
|
return m in ["知识库问答", "知识库测试"]
|
||||||
files,
|
|
||||||
sentence_size=100,
|
|
||||||
history=[],
|
|
||||||
one_conent=None,
|
|
||||||
one_content_segmentation=None,
|
|
||||||
):
|
|
||||||
return get_vector_store(
|
|
||||||
local_doc_qa,
|
|
||||||
vs_id,
|
|
||||||
files,
|
|
||||||
sentence_size,
|
|
||||||
history,
|
|
||||||
one_conent,
|
|
||||||
one_content_segmentation,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# main ui
|
# main ui
|
||||||
st.set_page_config(webui_title, layout='wide')
|
st.set_page_config(webui_title, layout='wide')
|
||||||
init_session()
|
|
||||||
# params = get_query_params()
|
|
||||||
# llm_model = params.get('llm_model', LLM_MODEL)
|
|
||||||
# embedding_model = params.get('embedding_model', EMBEDDING_MODEL)
|
|
||||||
|
|
||||||
with st.spinner(f'正在加载模型({LLM_MODEL} + {EMBEDDING_MODEL}),请耐心等候...'):
|
|
||||||
local_doc_qa = load_model(LLM_MODEL, EMBEDDING_MODEL)
|
|
||||||
|
|
||||||
|
|
||||||
def use_kb_mode(m):
|
|
||||||
return m in ['知识库问答', '知识库测试']
|
|
||||||
|
|
||||||
|
chat_box = st_chatbox(greetings=["模型已成功加载,可以开始对话,或从左侧选择模式后开始对话。"])
|
||||||
|
# 使用 help(st_chatbox) 查看自定义参数
|
||||||
|
|
||||||
# sidebar
|
# sidebar
|
||||||
modes = ['LLM 对话', '知识库问答', 'Bing搜索问答', '知识库测试']
|
modes = ['LLM 对话', '知识库问答', 'Bing搜索问答', '知识库测试']
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
def on_mode_change():
|
def on_mode_change():
|
||||||
m = st.session_state.mode
|
m = st.session_state.mode
|
||||||
robot_say(f'已切换到"{m}"模式')
|
chat_box.robot_say(f'已切换到"{m}"模式')
|
||||||
if m == '知识库测试':
|
if m == '知识库测试':
|
||||||
robot_say(knowledge_base_test_mode_info)
|
chat_box.robot_say(knowledge_base_test_mode_info)
|
||||||
|
|
||||||
index = 0
|
index = 0
|
||||||
try:
|
try:
|
||||||
|
|
@ -416,7 +253,7 @@ with st.sidebar:
|
||||||
mode = st.selectbox('对话模式', modes, index,
|
mode = st.selectbox('对话模式', modes, index,
|
||||||
on_change=on_mode_change, key='mode')
|
on_change=on_mode_change, key='mode')
|
||||||
|
|
||||||
with st.expander('模型配置', '知识' not in mode):
|
with st.expander('模型配置', not use_kb_mode(mode)):
|
||||||
with st.form('model_config'):
|
with st.form('model_config'):
|
||||||
index = 0
|
index = 0
|
||||||
try:
|
try:
|
||||||
|
|
@ -425,9 +262,8 @@ with st.sidebar:
|
||||||
pass
|
pass
|
||||||
llm_model = st.selectbox('LLM模型', llm_model_dict_list, index)
|
llm_model = st.selectbox('LLM模型', llm_model_dict_list, index)
|
||||||
|
|
||||||
no_remote_model = st.checkbox('加载本地模型', False)
|
|
||||||
use_ptuning_v2 = st.checkbox('使用p-tuning-v2微调过的模型', False)
|
use_ptuning_v2 = st.checkbox('使用p-tuning-v2微调过的模型', False)
|
||||||
use_lora = st.checkbox('使用lora微调的权重', False)
|
|
||||||
try:
|
try:
|
||||||
index = embedding_model_dict_list.index(EMBEDDING_MODEL)
|
index = embedding_model_dict_list.index(EMBEDDING_MODEL)
|
||||||
except:
|
except:
|
||||||
|
|
@ -437,44 +273,52 @@ with st.sidebar:
|
||||||
|
|
||||||
btn_load_model = st.form_submit_button('重新加载模型')
|
btn_load_model = st.form_submit_button('重新加载模型')
|
||||||
if btn_load_model:
|
if btn_load_model:
|
||||||
local_doc_qa = load_model(llm_model, embedding_model)
|
local_doc_qa = load_model(llm_model, embedding_model, use_ptuning_v2)
|
||||||
|
|
||||||
if mode in ['知识库问答', '知识库测试']:
|
history_len = st.slider(
|
||||||
|
"LLM对话轮数", 1, 50, LLM_HISTORY_LEN)
|
||||||
|
|
||||||
|
if use_kb_mode(mode):
|
||||||
vs_list = get_vs_list()
|
vs_list = get_vs_list()
|
||||||
vs_list.remove('新建知识库')
|
vs_list.remove('新建知识库')
|
||||||
|
|
||||||
def on_new_kb():
|
def on_new_kb():
|
||||||
name = st.session_state.kb_name
|
name = st.session_state.kb_name
|
||||||
if name in vs_list:
|
if not name:
|
||||||
st.error(f'名为“{name}”的知识库已存在。')
|
st.sidebar.error(f'新建知识库名称不能为空!')
|
||||||
|
elif name in vs_list:
|
||||||
|
st.sidebar.error(f'名为“{name}”的知识库已存在。')
|
||||||
else:
|
else:
|
||||||
vs_list.append(name)
|
|
||||||
st.session_state.vs_path = name
|
st.session_state.vs_path = name
|
||||||
|
st.session_state.kb_name = ''
|
||||||
|
new_kb_dir = os.path.join(KB_ROOT_PATH, name)
|
||||||
|
if not os.path.exists(new_kb_dir):
|
||||||
|
os.makedirs(new_kb_dir)
|
||||||
|
st.sidebar.success(f'名为“{name}”的知识库创建成功,您可以开始添加文件。')
|
||||||
|
|
||||||
def on_vs_change():
|
def on_vs_change():
|
||||||
robot_say(f'已加载知识库: {st.session_state.vs_path}')
|
chat_box.robot_say(f'已加载知识库: {st.session_state.vs_path}')
|
||||||
with st.expander('知识库配置', True):
|
with st.expander('知识库配置', True):
|
||||||
cols = st.columns([12, 10])
|
cols = st.columns([12, 10])
|
||||||
kb_name = cols[0].text_input(
|
kb_name = cols[0].text_input(
|
||||||
'新知识库名称', placeholder='新知识库名称', label_visibility='collapsed')
|
'新知识库名称', placeholder='新知识库名称', label_visibility='collapsed', key='kb_name')
|
||||||
if 'kb_name' not in st.session_state:
|
|
||||||
st.session_state.kb_name = kb_name
|
|
||||||
cols[1].button('新建知识库', on_click=on_new_kb)
|
cols[1].button('新建知识库', on_click=on_new_kb)
|
||||||
|
index = 0
|
||||||
|
try:
|
||||||
|
index = vs_list.index(ST_CONFIG.default_kb)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
vs_path = st.selectbox(
|
vs_path = st.selectbox(
|
||||||
'选择知识库', vs_list, on_change=on_vs_change, key='vs_path')
|
'选择知识库', vs_list, index, on_change=on_vs_change, key='vs_path')
|
||||||
|
|
||||||
st.text('')
|
st.text('')
|
||||||
|
|
||||||
score_threshold = st.slider(
|
score_threshold = st.slider(
|
||||||
'知识相关度阈值', 0, 1000, VECTOR_SEARCH_SCORE_THRESHOLD)
|
'知识相关度阈值', 0, 1000, VECTOR_SEARCH_SCORE_THRESHOLD)
|
||||||
top_k = st.slider('向量匹配数量', 1, 20, VECTOR_SEARCH_TOP_K)
|
top_k = st.slider('向量匹配数量', 1, 20, VECTOR_SEARCH_TOP_K)
|
||||||
history_len = st.slider(
|
|
||||||
'LLM对话轮数', 1, 50, LLM_HISTORY_LEN) # 也许要跟知识库分开设置
|
|
||||||
# local_doc_qa.llm.set_history_len(history_len)
|
|
||||||
chunk_conent = st.checkbox('启用上下文关联', False)
|
chunk_conent = st.checkbox('启用上下文关联', False)
|
||||||
st.text('')
|
|
||||||
# chunk_conent = st.checkbox('分割文本', True) # 知识库文本分割入库
|
|
||||||
chunk_size = st.slider('上下文关联长度', 1, 1000, CHUNK_SIZE)
|
chunk_size = st.slider('上下文关联长度', 1, 1000, CHUNK_SIZE)
|
||||||
|
st.text('')
|
||||||
sentence_size = st.slider('文本入库分句长度限制', 1, 1000, SENTENCE_SIZE)
|
sentence_size = st.slider('文本入库分句长度限制', 1, 1000, SENTENCE_SIZE)
|
||||||
files = st.file_uploader('上传知识文件',
|
files = st.file_uploader('上传知识文件',
|
||||||
['docx', 'txt', 'md', 'csv', 'xlsx', 'pdf'],
|
['docx', 'txt', 'md', 'csv', 'xlsx', 'pdf'],
|
||||||
|
|
@ -487,56 +331,61 @@ with st.sidebar:
|
||||||
with open(file, 'wb') as fp:
|
with open(file, 'wb') as fp:
|
||||||
fp.write(f.getvalue())
|
fp.write(f.getvalue())
|
||||||
file_list.append(TempFile(file))
|
file_list.append(TempFile(file))
|
||||||
_, _, history = load_vector_store(
|
_, _, history = get_vector_store(
|
||||||
vs_path, file_list, sentence_size, [], None, None)
|
vs_path, file_list, sentence_size, [], None, None)
|
||||||
st.session_state.files = []
|
st.session_state.files = []
|
||||||
|
|
||||||
|
|
||||||
# main body
|
# load model after params rendered
|
||||||
chat_box = st.empty()
|
with st.spinner(f"正在加载模型({llm_model} + {embedding_model}),请耐心等候..."):
|
||||||
|
local_doc_qa = load_model(
|
||||||
|
llm_model,
|
||||||
|
embedding_model,
|
||||||
|
use_ptuning_v2,
|
||||||
|
)
|
||||||
|
local_doc_qa.llm_model_chain.history_len = history_len
|
||||||
|
if use_kb_mode(mode):
|
||||||
|
local_doc_qa.chunk_conent = chunk_conent
|
||||||
|
local_doc_qa.chunk_size = chunk_size
|
||||||
|
# local_doc_qa.llm_model_chain.temperature = temperature # 这样设置temperature似乎不起作用
|
||||||
|
st.session_state.local_doc_qa = local_doc_qa
|
||||||
|
|
||||||
with st.form('my_form', clear_on_submit=True):
|
# input form
|
||||||
|
with st.form("my_form", clear_on_submit=True):
|
||||||
cols = st.columns([8, 1])
|
cols = st.columns([8, 1])
|
||||||
question = cols[0].text_input(
|
question = cols[0].text_area(
|
||||||
'temp', key='input_question', label_visibility='collapsed')
|
'temp', key='input_question', label_visibility='collapsed')
|
||||||
|
|
||||||
def on_send():
|
if cols[1].form_submit_button("发送"):
|
||||||
q = st.session_state.input_question
|
chat_box.user_say(question)
|
||||||
if q:
|
history = []
|
||||||
user_say(q)
|
if mode == "LLM 对话":
|
||||||
|
chat_box.robot_say("正在思考...")
|
||||||
|
chat_box.output_messages()
|
||||||
|
for history, _ in answer(question,
|
||||||
|
history=[],
|
||||||
|
mode=mode):
|
||||||
|
chat_box.update_last_box_text(history[-1][-1])
|
||||||
|
elif use_kb_mode(mode):
|
||||||
|
chat_box.robot_say(f"正在查询 [{vs_path}] ...")
|
||||||
|
chat_box.output_messages()
|
||||||
|
for history, _ in answer(question,
|
||||||
|
vs_path=os.path.join(
|
||||||
|
KB_ROOT_PATH, vs_path, 'vector_store'),
|
||||||
|
history=[],
|
||||||
|
mode=mode,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
vector_search_top_k=top_k,
|
||||||
|
chunk_conent=chunk_conent,
|
||||||
|
chunk_size=chunk_size):
|
||||||
|
chat_box.update_last_box_text(history[-1][-1])
|
||||||
|
else:
|
||||||
|
chat_box.robot_say(f"正在执行Bing搜索...")
|
||||||
|
chat_box.output_messages()
|
||||||
|
for history, _ in answer(question,
|
||||||
|
history=[],
|
||||||
|
mode=mode):
|
||||||
|
chat_box.update_last_box_text(history[-1][-1])
|
||||||
|
|
||||||
if mode == 'LLM 对话':
|
# st.write(chat_box.history)
|
||||||
robot_say('正在思考...')
|
chat_box.output_messages()
|
||||||
last_response = output_messages()
|
|
||||||
for history, _ in answer(q,
|
|
||||||
history=[],
|
|
||||||
mode=mode):
|
|
||||||
last_response.markdown(
|
|
||||||
format_md(history[-1][-1], False),
|
|
||||||
unsafe_allow_html=True
|
|
||||||
)
|
|
||||||
elif use_kb_mode(mode):
|
|
||||||
robot_say('正在思考...', vs_path)
|
|
||||||
last_response = output_messages()
|
|
||||||
for history, _ in answer(q,
|
|
||||||
vs_path=os.path.join(
|
|
||||||
KB_ROOT_PATH, vs_path, "vector_store"),
|
|
||||||
history=[],
|
|
||||||
mode=mode,
|
|
||||||
score_threshold=score_threshold,
|
|
||||||
vector_search_top_k=top_k,
|
|
||||||
chunk_conent=chunk_conent,
|
|
||||||
chunk_size=chunk_size):
|
|
||||||
last_response.markdown(
|
|
||||||
format_md(history[-1][-1], False, 'ligreen'),
|
|
||||||
unsafe_allow_html=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
robot_say('正在思考...')
|
|
||||||
last_response = output_messages()
|
|
||||||
st.session_state['history'][-1]['content'] = history[-1][-1]
|
|
||||||
submit = cols[1].form_submit_button('发送', on_click=on_send)
|
|
||||||
|
|
||||||
output_messages()
|
|
||||||
|
|
||||||
# st.write(st.session_state['history'])
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue