Merge branch 'dev' of github.com:hzg0601/langchain-ChatGLM-annotation into dev

merge upstream dev
This commit is contained in:
hzg0601 2023-07-14 13:45:48 +08:00
commit 760abab1d7
18 changed files with 562 additions and 382 deletions

7
api.py
View File

@ -384,8 +384,10 @@ async def chat(
],
),
):
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
streaming=True):
answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": question, "history": history, "streaming": True})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"]
history = answer_result.history
pass
@ -486,7 +488,6 @@ def api_start(host, port, **kwargs):
global local_doc_qa
llm_model_ins = shared.loaderLLM()
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
app = FastAPI()
# Add CORS middleware to allow all origins

View File

@ -18,6 +18,7 @@ from agent import bing_search
from langchain.docstore.document import Document
from functools import lru_cache
from textsplitter.zh_title_enhance import zh_title_enhance
from langchain.chains.base import Chain
# patch HuggingFaceEmbeddings to make it hashable
@ -119,7 +120,7 @@ def search_result2docs(search_results):
class LocalDocQA:
llm: BaseAnswer = None
llm_model_chain: Chain = None
embeddings: object = None
top_k: int = VECTOR_SEARCH_TOP_K
chunk_size: int = CHUNK_SIZE
@ -129,10 +130,10 @@ class LocalDocQA:
def init_cfg(self,
embedding_model: str = EMBEDDING_MODEL,
embedding_device=EMBEDDING_DEVICE,
llm_model: BaseAnswer = None,
llm_model: Chain = None,
top_k=VECTOR_SEARCH_TOP_K,
):
self.llm = llm_model
self.llm_model_chain = llm_model
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
model_kwargs={'device': embedding_device})
self.top_k = top_k
@ -236,8 +237,10 @@ class LocalDocQA:
else:
prompt = query
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
streaming=streaming):
answer_result_stream_result = self.llm_model_chain(
{"prompt": prompt, "history": chat_history, "streaming": streaming})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"]
history = answer_result.history
history[-1][0] = query
@ -276,8 +279,10 @@ class LocalDocQA:
result_docs = search_result2docs(results)
prompt = generate_prompt(result_docs, query)
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
streaming=streaming):
answer_result_stream_result = self.llm_model_chain(
{"prompt": prompt, "history": chat_history, "streaming": streaming})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"]
history = answer_result.history
history[-1][0] = query
@ -320,7 +325,6 @@ if __name__ == "__main__":
args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM()
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg(llm_model=llm_model_ins)

View File

@ -37,55 +37,55 @@ llm_model_dict = {
"name": "chatglm-6b-int4-qe",
"pretrained_model_name": "THUDM/chatglm-6b-int4-qe",
"local_model_path": None,
"provides": "ChatGLM"
"provides": "ChatGLMLLMChain"
},
"chatglm-6b-int4": {
"name": "chatglm-6b-int4",
"pretrained_model_name": "THUDM/chatglm-6b-int4",
"local_model_path": None,
"provides": "ChatGLM"
"provides": "ChatGLMLLMChain"
},
"chatglm-6b-int8": {
"name": "chatglm-6b-int8",
"pretrained_model_name": "THUDM/chatglm-6b-int8",
"local_model_path": None,
"provides": "ChatGLM"
"provides": "ChatGLMLLMChain"
},
"chatglm-6b": {
"name": "chatglm-6b",
"pretrained_model_name": "THUDM/chatglm-6b",
"local_model_path": None,
"provides": "ChatGLM"
"provides": "ChatGLMLLMChain"
},
"chatglm2-6b": {
"name": "chatglm2-6b",
"pretrained_model_name": "THUDM/chatglm2-6b",
"local_model_path": None,
"provides": "ChatGLM"
"provides": "ChatGLMLLMChain"
},
"chatglm2-6b-int4": {
"name": "chatglm2-6b-int4",
"pretrained_model_name": "THUDM/chatglm2-6b-int4",
"local_model_path": None,
"provides": "ChatGLM"
"provides": "ChatGLMLLMChain"
},
"chatglm2-6b-int8": {
"name": "chatglm2-6b-int8",
"pretrained_model_name": "THUDM/chatglm2-6b-int8",
"local_model_path": None,
"provides": "ChatGLM"
"provides": "ChatGLMLLMChain"
},
"chatyuan": {
"name": "chatyuan",
"pretrained_model_name": "ClueAI/ChatYuan-large-v2",
"local_model_path": None,
"provides": "MOSSLLM"
"provides": "MOSSLLMChain"
},
"moss": {
"name": "moss",
"pretrained_model_name": "fnlp/moss-moon-003-sft",
"local_model_path": None,
"provides": "MOSSLLM"
"provides": "MOSSLLMChain"
},
"moss-int4": {
"name": "moss",
@ -97,7 +97,13 @@ llm_model_dict = {
"name": "vicuna-13b-hf",
"pretrained_model_name": "vicuna-13b-hf",
"local_model_path": None,
"provides": "LLamaLLM"
"provides": "LLamaLLMChain"
},
"vicuna-7b-hf": {
"name": "vicuna-13b-hf",
"pretrained_model_name": "vicuna-13b-hf",
"local_model_path": None,
"provides": "LLamaLLMChain"
},
# 直接调用返回requests.exceptions.ConnectionError错误需要通过huggingface_hub包里的snapshot_download函数
# 下载模型如果snapshot_download还是返回网络错误多试几次一般是可以的
@ -107,7 +113,7 @@ llm_model_dict = {
"name": "bloomz-7b1",
"pretrained_model_name": "bigscience/bloomz-7b1",
"local_model_path": None,
"provides": "MOSSLLM"
"provides": "MOSSLLMChain"
},
# 实测加载bigscience/bloom-3b需要170秒左右暂不清楚为什么这么慢
@ -116,14 +122,14 @@ llm_model_dict = {
"name": "bloom-3b",
"pretrained_model_name": "bigscience/bloom-3b",
"local_model_path": None,
"provides": "MOSSLLM"
"provides": "MOSSLLMChain"
},
"baichuan-7b": {
"name": "baichuan-7b",
"pretrained_model_name": "baichuan-inc/baichuan-7B",
"local_model_path": None,
"provides": "MOSSLLM"
"provides": "MOSSLLMChain"
},
# llama-cpp模型的兼容性问题参考https://github.com/abetlen/llama-cpp-python/issues/204
"ggml-vicuna-13b-1.1-q5": {
@ -137,7 +143,7 @@ llm_model_dict = {
# 需要手动从https://github.com/abetlen/llama-cpp-python/releases/tag/下载对应的wheel安装
# 实测v0.1.63与本模型的vicuna/ggml-vicuna-13b-1.1/ggml-vic13b-q5_1.bin可以兼容
"local_model_path": f'''{"/".join(os.path.abspath(__file__).split("/")[:3])}/.cache/huggingface/hub/models--vicuna--ggml-vicuna-13b-1.1/blobs/''',
"provides": "LLamaLLM"
"provides": "LLamaLLMChain"
},
# 通过 fastchat 调用的模型请参考如下格式
@ -145,7 +151,7 @@ llm_model_dict = {
"name": "chatglm-6b", # "name"修改为fastchat服务中的"model_name"
"pretrained_model_name": "chatglm-6b",
"local_model_path": None,
"provides": "FastChatOpenAILLM", # 使用fastchat api时需保证"provides"为"FastChatOpenAILLM"
"provides": "FastChatOpenAILLMChain", # 使用fastchat api时需保证"provides"为"FastChatOpenAILLMChain"
"api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url"
"api_key": "EMPTY"
},
@ -153,7 +159,7 @@ llm_model_dict = {
"name": "chatglm2-6b", # "name"修改为fastchat服务中的"model_name"
"pretrained_model_name": "chatglm2-6b",
"local_model_path": None,
"provides": "FastChatOpenAILLM", # 使用fastchat api时需保证"provides"为"FastChatOpenAILLM"
"provides": "FastChatOpenAILLMChain", # 使用fastchat api时需保证"provides"为"FastChatOpenAILLMChain"
"api_base_url": "http://localhost:8000/v1" # "name"修改为fastchat服务中的"api_base_url"
},
@ -162,7 +168,7 @@ llm_model_dict = {
"name": "vicuna-13b-hf", # "name"修改为fastchat服务中的"model_name"
"pretrained_model_name": "vicuna-13b-hf",
"local_model_path": None,
"provides": "FastChatOpenAILLM", # 使用fastchat api时需保证"provides"为"FastChatOpenAILLM"
"provides": "FastChatOpenAILLMChain", # 使用fastchat api时需保证"provides"为"FastChatOpenAILLMChain"
"api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url"
"api_key": "EMPTY"
},
@ -177,7 +183,7 @@ llm_model_dict = {
"openai-chatgpt-3.5": {
"name": "gpt-3.5-turbo",
"pretrained_model_name": "gpt-3.5-turbo",
"provides": "FastChatOpenAILLM",
"provides": "FastChatOpenAILLMChain",
"local_model_path": None,
"api_base_url": "https://api.openapi.com/v1",
"api_key": ""
@ -204,7 +210,10 @@ STREAMING = True
# Use p-tuning-v2 PrefixEncoder
USE_PTUNING_V2 = False
PTUNING_DIR='./ptuing-v2'
<<<<<<< HEAD
=======
>>>>>>> f68d347c25b4bdd07f293c65a6e44a673a11f614
# LLM running device
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
@ -233,7 +242,7 @@ LLM_HISTORY_LEN = 3
VECTOR_SEARCH_TOP_K = 5
# 知识检索内容相关度 Score, 数值范围约为0-1100如果为0则不生效经测试设置为小于500时匹配结果更精准
VECTOR_SEARCH_SCORE_THRESHOLD = 0
VECTOR_SEARCH_SCORE_THRESHOLD = 390
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")

Binary file not shown.

Before

Width:  |  Height:  |  Size: 247 KiB

View File

@ -1,4 +1,4 @@
from .chatglm_llm import ChatGLM
from .llama_llm import LLamaLLM
from .moss_llm import MOSSLLM
from .fastchat_openai_llm import FastChatOpenAILLM
from .chatglm_llm import ChatGLMLLMChain
from .llama_llm import LLamaLLMChain
from .fastchat_openai_llm import FastChatOpenAILLMChain
from .moss_llm import MOSSLLMChain

View File

@ -1,13 +1,15 @@
from models.base.base import (
AnswerResult,
BaseAnswer
)
BaseAnswer,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
from models.base.remote_rpc_model import (
RemoteRpcModel
)
__all__ = [
"AnswerResult",
"BaseAnswer",
"RemoteRpcModel",
"AnswerResultStream",
"AnswerResultQueueSentinelTokenListenerQueue"
]

View File

@ -1,13 +1,26 @@
from abc import ABC, abstractmethod
from typing import Optional, List
from typing import Any, Dict, List, Optional, Generator
import traceback
from collections import deque
from queue import Queue
from threading import Thread
from langchain.callbacks.manager import CallbackManagerForChainRun
from models.loader import LoaderCheckPoint
import torch
import transformers
from models.loader import LoaderCheckPoint
class ListenerToken:
"""
观测结果
"""
input_ids: torch.LongTensor
_scores: torch.FloatTensor
def __init__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor):
self.input_ids = input_ids
self._scores = _scores
class AnswerResult:
@ -16,6 +29,123 @@ class AnswerResult:
"""
history: List[List[str]] = []
llm_output: Optional[dict] = None
listenerToken: ListenerToken = None
class AnswerResultStream:
def __init__(self, callback_func=None):
self.callback_func = callback_func
def __call__(self, answerResult: AnswerResult):
if self.callback_func is not None:
self.callback_func(answerResult)
class AnswerResultQueueSentinelTokenListenerQueue(transformers.StoppingCriteria):
"""
定义模型stopping_criteria 监听者在每次响应时将队列数据同步到AnswerResult
实现此监听器的目的是不同模型的预测输出可能不是矢量信息hf框架可以自定义transformers.StoppingCriteria入参来接收每次预测的Tensor和损失函数
通过给 StoppingCriteriaList指定模型生成答案时停止的条件每个 StoppingCriteria 对象表示一个停止条件
当每轮预测任务开始时StoppingCriteria都会收到相同的预测结果最终由下层实现类确认是否结束
输出值可用于 generatorAnswer generate_with_streaming的自定义参数观测以实现更加精细的控制
"""
listenerQueue: deque = deque(maxlen=1)
def __init__(self):
transformers.StoppingCriteria.__init__(self)
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor, **kwargs) -> bool:
"""
每次响应时将数据添加到响应队列
:param input_ids:
:param _scores:
:param kwargs:
:return:
"""
self.listenerQueue.append(ListenerToken(input_ids=input_ids, _scores=_scores))
return False
class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def __init__(self, func, kwargs={}):
self.mfunc = func
self.q = Queue()
self.sentinel = object()
self.kwargs = kwargs
self.stop_now = False
def _callback(val):
"""
模型输出预测结果收集
通过定义generate_with_callback收集器AnswerResultStream收集模型预测的AnswerResult响应结果最终由下层实现类确认是否结束
结束条件包含如下
1模型预测结束收集器self.q队列收到 self.sentinel标识
2在处理迭代器队列消息时返回了break跳出迭代器触发了StopIteration事件
3模型预测出错
因为当前类是迭代器所以在for in 中执行了break后 __exit__ 方法会被调用最终stop_now属性会被更新然后抛出异常结束预测行为
迭代器收集的行为如下
创建Iteratorize迭代对象
定义generate_with_callback收集器AnswerResultStream
启动一个线程异步预测结果来调用上游checkpoint的实现方法_generate_answer
_generate_answer通过generate_with_callback定义的收集器收集上游checkpoint包装的AnswerResult消息体
由于self.q是阻塞模式每次预测后会被消费后才会执行下次预测
这时generate_with_callback会被阻塞
主线程Iteratorize对象的__next__方法调用获取阻塞消息并消费
1消息为上游checkpoint包装的AnswerResult消息体返回下游处理
2消息为self.sentinel标识抛出StopIteration异常
主线程Iteratorize对象__exit__收到消息最终stop_now属性会被更新
异步线程检测stop_now属性被更新抛出异常结束预测行为
迭代行为结束
:param val:
:return:
"""
if self.stop_now:
raise ValueError
self.q.put(val)
def gen():
try:
ret = self.mfunc(callback=_callback, **self.kwargs)
except ValueError:
pass
except:
traceback.print_exc()
pass
self.q.put(self.sentinel)
self.thread = Thread(target=gen)
self.thread.start()
def __iter__(self):
return self
def __next__(self):
obj = self.q.get(True, None)
if obj is self.sentinel:
raise StopIteration
else:
return obj
def __del__(self):
"""
暂无实现
:return:
"""
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
""" break 后会执行 """
self.stop_now = True
class BaseAnswer(ABC):
@ -25,17 +155,25 @@ class BaseAnswer(ABC):
@abstractmethod
def _check_point(self) -> LoaderCheckPoint:
"""Return _check_point of llm."""
def generatorAnswer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,) -> Generator[Any, str, bool]:
def generate_with_callback(callback=None, **kwargs):
kwargs['generate_with_callback'] = AnswerResultStream(callback_func=callback)
self._generate_answer(**kwargs)
@property
@abstractmethod
def _history_len(self) -> int:
"""Return _history_len of llm."""
def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs)
with generate_with_streaming(inputs=inputs, run_manager=run_manager) as generator:
for answerResult in generator:
if answerResult.listenerToken:
output = answerResult.listenerToken.input_ids
yield answerResult
@abstractmethod
def set_history_len(self, history_len: int) -> None:
"""Return _history_len of llm."""
def generatorAnswer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False):
def _generate_answer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
pass

View File

@ -1,70 +1,102 @@
from abc import ABC
from langchain.llms.base import LLM
from typing import Optional, List
from langchain.chains.base import Chain
from typing import Any, Dict, List, Optional, Generator
from langchain.callbacks.manager import CallbackManagerForChainRun
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer,
AnswerResult)
AnswerResult,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
import torch
import transformers
class ChatGLM(BaseAnswer, LLM, ABC):
class ChatGLMLLMChain(BaseAnswer, Chain, ABC):
max_token: int = 10000
temperature: float = 0.01
top_p = 0.9
# 相关度
top_p = 0.4
# 候选词数量
top_k = 10
checkPoint: LoaderCheckPoint = None
# history = []
history_len: int = 10
streaming_key: str = "streaming" #: :meta private:
history_key: str = "history" #: :meta private:
prompt_key: str = "prompt" #: :meta private:
output_key: str = "answer_result_stream" #: :meta private:
def __init__(self, checkPoint: LoaderCheckPoint = None):
super().__init__()
self.checkPoint = checkPoint
@property
def _llm_type(self) -> str:
return "ChatGLM"
def _chain_type(self) -> str:
return "ChatGLMLLMChain"
@property
def _check_point(self) -> LoaderCheckPoint:
return self.checkPoint
@property
def _history_len(self) -> int:
return self.history_len
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
def set_history_len(self, history_len: int = 10) -> None:
self.history_len = history_len
:meta private:
"""
return [self.prompt_key]
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Generator]:
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
return {self.output_key: generator}
def _generate_answer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
history = inputs[self.history_key]
streaming = inputs[self.streaming_key]
prompt = inputs[self.prompt_key]
print(f"__call:{prompt}")
response, _ = self.checkPoint.model.chat(
self.checkPoint.tokenizer,
prompt,
history=[],
max_length=self.max_token,
temperature=self.temperature
)
print(f"response:{response}")
print(f"+++++++++++++++++++++++++++++++++++")
return response
def generatorAnswer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False):
# Create the StoppingCriteriaList with the stopping strings
stopping_criteria_list = transformers.StoppingCriteriaList()
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
stopping_criteria_list.append(listenerQueue)
if streaming:
history += [[]]
for inum, (stream_resp, _) in enumerate(self.checkPoint.model.stream_chat(
self.checkPoint.tokenizer,
prompt,
history=history[-self.history_len:-1] if self.history_len > 1 else [],
history=history[-self.history_len:-1] if self.history_len > 0 else [],
max_length=self.max_token,
temperature=self.temperature
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
stopping_criteria=stopping_criteria_list
)):
# self.checkPoint.clear_torch_cache()
history[-1] = [prompt, stream_resp]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": stream_resp}
yield answer_result
if listenerQueue.listenerQueue.__len__() > 0:
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
generate_with_callback(answer_result)
self.checkPoint.clear_torch_cache()
else:
response, _ = self.checkPoint.model.chat(
@ -72,13 +104,18 @@ class ChatGLM(BaseAnswer, LLM, ABC):
prompt,
history=history[-self.history_len:] if self.history_len > 0 else [],
max_length=self.max_token,
temperature=self.temperature
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
stopping_criteria=stopping_criteria_list
)
self.checkPoint.clear_torch_cache()
history += [[prompt, response]]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": response}
yield answer_result
if listenerQueue.listenerQueue.__len__() > 0:
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
generate_with_callback(answer_result)

View File

@ -1,15 +1,15 @@
from abc import ABC
import requests
from typing import Optional, List
from langchain.llms.base import LLM
from langchain.chains.base import Chain
from typing import Any, Dict, List, Optional, Generator, Collection
from models.loader import LoaderCheckPoint
from models.base import (RemoteRpcModel,
AnswerResult)
from typing import (
Collection,
Dict
)
from langchain.callbacks.manager import CallbackManagerForChainRun
from models.base import (BaseAnswer,
RemoteRpcModel,
AnswerResult,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
import torch
import transformers
def _build_message_template() -> Dict[str, str]:
@ -22,18 +22,42 @@ def _build_message_template() -> Dict[str, str]:
}
class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
# 将历史对话数组转换为文本格式
def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, str]]:
build_messages: Collection[Dict[str, str]] = []
for i, (old_query, response) in enumerate(history):
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
user_build_message['content'] = old_query
system_build_message = _build_message_template()
system_build_message['role'] = 'system'
system_build_message['content'] = response
build_messages.append(user_build_message)
build_messages.append(system_build_message)
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
user_build_message['content'] = query
build_messages.append(user_build_message)
return build_messages
class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
api_base_url: str = "http://localhost:8000/v1"
model_name: str = "chatglm-6b"
max_token: int = 10000
temperature: float = 0.01
top_p = 0.9
checkPoint: LoaderCheckPoint = None
history = []
# history = []
history_len: int = 10
api_key: str = ""
streaming_key: str = "streaming" #: :meta private:
history_key: str = "history" #: :meta private:
prompt_key: str = "prompt" #: :meta private:
output_key: str = "answer_result_stream" #: :meta private:
def __init__(self,
checkPoint: LoaderCheckPoint = None,
# api_base_url:str="http://localhost:8000/v1",
@ -44,19 +68,28 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
self.checkPoint = checkPoint
@property
def _llm_type(self) -> str:
return "FastChat"
def _chain_type(self) -> str:
return "LLamaLLMChain"
@property
def _check_point(self) -> LoaderCheckPoint:
return self.checkPoint
@property
def _history_len(self) -> int:
return self.history_len
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
def set_history_len(self, history_len: int = 10) -> None:
self.history_len = history_len
:meta private:
"""
return [self.prompt_key]
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
@property
def _api_key(self) -> str:
@ -75,53 +108,25 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
def call_model_name(self, model_name):
self.model_name = model_name
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Generator]:
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
return {self.output_key: generator}
def _generate_answer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
history = inputs[self.history_key]
streaming = inputs[self.streaming_key]
prompt = inputs[self.prompt_key]
print(f"__call:{prompt}")
try:
import openai
# Not support yet
# openai.api_key = "EMPTY"
openai.key = self.api_key
openai.api_base = self.api_base_url
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
# create a chat completion
completion = openai.ChatCompletion.create(
model=self.model_name,
messages=self.build_message_list(prompt)
)
print(f"response:{completion.choices[0].message.content}")
print(f"+++++++++++++++++++++++++++++++++++")
return completion.choices[0].message.content
# 将历史对话数组转换为文本格式
def build_message_list(self, query) -> Collection[Dict[str, str]]:
build_message_list: Collection[Dict[str, str]] = []
history = self.history[-self.history_len:] if self.history_len > 0 else []
for i, (old_query, response) in enumerate(history):
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
user_build_message['content'] = old_query
system_build_message = _build_message_template()
system_build_message['role'] = 'system'
system_build_message['content'] = response
build_message_list.append(user_build_message)
build_message_list.append(system_build_message)
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
user_build_message['content'] = query
build_message_list.append(user_build_message)
return build_message_list
def generatorAnswer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False):
try:
import openai
# Not support yet
# openai.api_key = "EMPTY"
@ -135,12 +140,13 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
# create a chat completion
completion = openai.ChatCompletion.create(
model=self.model_name,
messages=self.build_message_list(prompt)
messages=build_message_list(prompt)
)
print(f"response:{completion.choices[0].message.content}")
print(f"+++++++++++++++++++++++++++++++++++")
history += [[prompt, completion.choices[0].message.content]]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": completion.choices[0].message.content}
yield answer_result
generate_with_callback(answer_result)

View File

@ -1,19 +1,22 @@
from abc import ABC
from langchain.llms.base import LLM
import random
import torch
import transformers
from abc import ABC
from langchain.chains.base import Chain
from typing import Any, Dict, List, Optional, Generator, Union
from langchain.callbacks.manager import CallbackManagerForChainRun
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from typing import Optional, List, Dict, Any,Union
from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer,
AnswerResult)
AnswerResult,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
import torch
import transformers
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: Union[torch.LongTensor,list], scores: Union[torch.FloatTensor,list]) -> torch.FloatTensor:
def __call__(self, input_ids: Union[torch.LongTensor, list],
scores: Union[torch.FloatTensor, list]) -> torch.FloatTensor:
# llama-cpp模型返回的是list,为兼容性考虑需要判断input_ids和scores的类型将list转换为torch.Tensor
input_ids = torch.tensor(input_ids) if isinstance(input_ids, list) else input_ids
scores = torch.tensor(scores) if isinstance(scores, list) else scores
@ -23,7 +26,7 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
return scores
class LLamaLLM(BaseAnswer, LLM, ABC):
class LLamaLLMChain(BaseAnswer, Chain, ABC):
checkPoint: LoaderCheckPoint = None
# history = []
history_len: int = 3
@ -37,32 +40,34 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
min_length: int = 0
logits_processor: LogitsProcessorList = None
stopping_criteria: Optional[StoppingCriteriaList] = None
eos_token_id: Optional[int] = [2]
state: object = {'max_new_tokens': 50,
'seed': 1,
'temperature': 0, 'top_p': 0.1,
'top_k': 40, 'typical_p': 1,
'repetition_penalty': 1.2,
'encoder_repetition_penalty': 1,
'no_repeat_ngram_size': 0,
'min_length': 0,
'penalty_alpha': 0,
'num_beams': 1,
'length_penalty': 1,
'early_stopping': False, 'add_bos_token': True, 'ban_eos_token': False,
'truncation_length': 2048, 'custom_stopping_strings': '',
'cpu_memory': 0, 'auto_devices': False, 'disk': False, 'cpu': False, 'bf16': False,
'load_in_8bit': False, 'wbits': 'None', 'groupsize': 'None', 'model_type': 'None',
'pre_layer': 0, 'gpu_memory_0': 0}
streaming_key: str = "streaming" #: :meta private:
history_key: str = "history" #: :meta private:
prompt_key: str = "prompt" #: :meta private:
output_key: str = "answer_result_stream" #: :meta private:
def __init__(self, checkPoint: LoaderCheckPoint = None):
super().__init__()
self.checkPoint = checkPoint
@property
def _llm_type(self) -> str:
return "LLamaLLM"
def _chain_type(self) -> str:
return "LLamaLLMChain"
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return [self.prompt_key]
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
@property
def _check_point(self) -> LoaderCheckPoint:
@ -107,35 +112,31 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
formatted_history += "### Human{}\n### Assistant".format(query)
return formatted_history
def prepare_inputs_for_generation(self,
input_ids: torch.LongTensor):
"""
预生成注意力掩码和 输入序列中每个位置的索引的张量
# TODO 没有思路
:return:
"""
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Generator]:
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
return {self.output_key: generator}
mask_positions = torch.zeros((1, input_ids.shape[1]), dtype=input_ids.dtype).to(self.checkPoint.model.device)
def _generate_answer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
attention_mask = self.get_masks(input_ids, input_ids.device)
position_ids = self.get_position_ids(
input_ids,
device=input_ids.device,
mask_positions=mask_positions
)
return input_ids, position_ids, attention_mask
@property
def _history_len(self) -> int:
return self.history_len
def set_history_len(self, history_len: int = 10) -> None:
self.history_len = history_len
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
history = inputs[self.history_key]
streaming = inputs[self.streaming_key]
prompt = inputs[self.prompt_key]
print(f"__call:{prompt}")
# Create the StoppingCriteriaList with the stopping strings
self.stopping_criteria = transformers.StoppingCriteriaList()
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
self.stopping_criteria.append(listenerQueue)
# TODO 需要实现chat对话模块和注意力模型目前_call为langchain的LLM拓展的api默认为无提示词模式如果需要操作注意力模型可以参考chat_glm的实现
soft_prompt = self.history_to_text(query=prompt, history=history)
if self.logits_processor is None:
self.logits_processor = LogitsProcessorList()
self.logits_processor.append(InvalidScoreLogitsProcessor())
@ -154,16 +155,10 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
"logits_processor": self.logits_processor}
# 向量转换
input_ids = self.encode(prompt, add_bos_token=self.state['add_bos_token'], truncation_length=self.max_new_tokens)
# input_ids, position_ids, attention_mask = self.prepare_inputs_for_generation(input_ids=filler_input_ids)
input_ids = self.encode(soft_prompt, add_bos_token=self.checkPoint.tokenizer.add_bos_token,
truncation_length=self.max_new_tokens)
gen_kwargs.update({'inputs': input_ids})
# 注意力掩码
# gen_kwargs.update({'attention_mask': attention_mask})
# gen_kwargs.update({'position_ids': position_ids})
if self.stopping_criteria is None:
self.stopping_criteria = transformers.StoppingCriteriaList()
# 观测输出
gen_kwargs.update({'stopping_criteria': self.stopping_criteria})
# llama-cpp模型的参数与transformers的参数字段有较大差异直接调用会返回不支持的字段错误
@ -173,11 +168,13 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
if "llama_cpp" in self.checkPoint.model.__str__():
import inspect
common_kwargs_keys = set(inspect.getfullargspec(self.checkPoint.model.generate).args)&set(gen_kwargs.keys())
common_kwargs_keys = set(inspect.getfullargspec(self.checkPoint.model.generate).args) & set(
gen_kwargs.keys())
common_kwargs = {key: gen_kwargs[key] for key in common_kwargs_keys}
# ? llama-cpp模型的generate方法似乎只接受.cpu类型的输入响应很慢慢到哭泣
# ?为什么会不支持GPU呢不应该啊
output_ids = torch.tensor([list(self.checkPoint.model.generate(input_id_i.cpu(),**common_kwargs)) for input_id_i in input_ids])
output_ids = torch.tensor(
[list(self.checkPoint.model.generate(input_id_i.cpu(), **common_kwargs)) for input_id_i in input_ids])
else:
output_ids = self.checkPoint.model.generate(**gen_kwargs)
@ -185,17 +182,11 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
reply = self.decode(output_ids[0][-new_tokens:])
print(f"response:{reply}")
print(f"+++++++++++++++++++++++++++++++++++")
return reply
def generatorAnswer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False):
# TODO 需要实现chat对话模块和注意力模型目前_call为langchain的LLM拓展的api默认为无提示词模式如果需要操作注意力模型可以参考chat_glm的实现
softprompt = self.history_to_text(prompt,history=history)
response = self._call(prompt=softprompt, stop=['\n###'])
answer_result = AnswerResult()
answer_result.history = history + [[prompt, response]]
answer_result.llm_output = {"answer": response}
yield answer_result
history += [[prompt, reply]]
answer_result.history = history
if listenerQueue.listenerQueue.__len__() > 0:
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
answer_result.llm_output = {"answer": reply}
generate_with_callback(answer_result)

View File

@ -1,3 +1,4 @@
import argparse
import os
from configs.model_config import *
@ -45,7 +46,6 @@ parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the m
parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras")
parser.add_argument('--use-ptuning-v2',type=str,default=False,help="whether use ptuning-v2 checkpoint")
parser.add_argument("--ptuning-dir",type=str,default=PTUNING_DIR,help="the dir of ptuning-v2 checkpoint")
# Accelerate/transformers
parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT,
help='Load the model with 8-bit precision.')

View File

@ -20,6 +20,7 @@ class LoaderCheckPoint:
no_remote_model: bool = False
# 模型名称
model_name: str = None
pretrained_model_name: str = None
tokenizer: object = None
# 模型全路径
model_path: str = None
@ -67,48 +68,49 @@ class LoaderCheckPoint:
self.load_in_8bit = params.get('load_in_8bit', False)
self.bf16 = params.get('bf16', False)
def _load_model_config(self, model_name):
def _load_model_config(self):
if self.model_path:
self.model_path = re.sub("\s", "", self.model_path)
checkpoint = Path(f'{self.model_path}')
else:
if not self.no_remote_model:
checkpoint = model_name
else:
if self.no_remote_model:
raise ValueError(
"本地模型local_model_path未配置路径"
)
else:
checkpoint = self.pretrained_model_name
print(f"load_model_config {checkpoint}...")
try:
model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
return model_config
except Exception as e:
print(e)
return checkpoint
def _load_model(self, model_name):
def _load_model(self):
"""
加载自定义位置的model
:param model_name:
:return:
"""
print(f"Loading {model_name}...")
t0 = time.time()
if self.model_path:
self.model_path = re.sub("\s", "", self.model_path)
checkpoint = Path(f'{self.model_path}')
else:
if not self.no_remote_model:
checkpoint = model_name
else:
if self.no_remote_model:
raise ValueError(
"本地模型local_model_path未配置路径"
)
else:
checkpoint = self.pretrained_model_name
print(f"Loading {checkpoint}...")
self.is_llamacpp = len(list(Path(f'{checkpoint}').glob('ggml*.bin'))) > 0
if 'chatglm' in model_name.lower() or "chatyuan" in model_name.lower():
if 'chatglm' in self.model_name.lower() or "chatyuan" in self.model_name.lower():
LoaderClass = AutoModel
else:
LoaderClass = AutoModelForCausalLM
@ -146,10 +148,10 @@ class LoaderCheckPoint:
trust_remote_code=True).half()
# 可传入device_map自定义每张卡的部署情况
if self.device_map is None:
if 'chatglm' in model_name.lower():
if 'chatglm' in self.model_name.lower():
self.device_map = self.chatglm_auto_configure_device_map(num_gpus)
elif 'moss' in model_name.lower():
self.device_map = self.moss_auto_configure_device_map(num_gpus, model_name)
elif 'moss' in self.model_name.lower():
self.device_map = self.moss_auto_configure_device_map(num_gpus, checkpoint)
else:
# 基于如下方式作为默认的多卡加载方案针对新模型基本不会失败
# 在chatglm2-6b,bloom-3b,blooz-7b1上进行了测试GPU负载也相对均衡
@ -321,7 +323,7 @@ class LoaderCheckPoint:
return device_map
def moss_auto_configure_device_map(self, num_gpus: int, model_name) -> Dict[str, int]:
def moss_auto_configure_device_map(self, num_gpus: int, checkpoint) -> Dict[str, int]:
try:
from accelerate import init_empty_weights
@ -336,16 +338,6 @@ class LoaderCheckPoint:
"`pip install bitsandbytes``pip install accelerate`."
) from exc
if self.model_path:
checkpoint = Path(f'{self.model_path}')
else:
if not self.no_remote_model:
checkpoint = model_name
else:
raise ValueError(
"本地模型local_model_path未配置路径"
)
cls = get_class_from_dynamic_module(class_reference="fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM",
pretrained_model_name_or_path=checkpoint)
@ -452,7 +444,7 @@ class LoaderCheckPoint:
def reload_model(self):
self.unload_model()
self.model_config = self._load_model_config(self.model_name)
self.model_config = self._load_model_config()
if self.use_ptuning_v2:
try:
@ -464,7 +456,7 @@ class LoaderCheckPoint:
except Exception as e:
print("加载PrefixEncoder config.json失败")
self.model, self.tokenizer = self._load_model(self.model_name)
self.model, self.tokenizer = self._load_model()
if self.lora:
self._add_lora_to_model([self.lora])

View File

@ -1,11 +1,19 @@
from abc import ABC
from langchain.llms.base import LLM
from typing import Optional, List
from langchain.chains.base import Chain
from typing import Any, Dict, List, Optional, Generator, Union
from langchain.callbacks.manager import CallbackManagerForChainRun
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer,
AnswerResult)
AnswerResult,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
import torch
import transformers
import torch
# todo 建议重写instruction,在该instruction下各模型的表现比较差
META_INSTRUCTION = \
"""You are an AI assistant whose name is MOSS.
@ -20,41 +28,65 @@ META_INSTRUCTION = \
Capabilities and tools that MOSS can possess.
"""
# todo 在MOSSLLM类下各模型的响应速度很慢后续要检查一下原因
class MOSSLLM(BaseAnswer, LLM, ABC):
class MOSSLLMChain(BaseAnswer, Chain, ABC):
max_token: int = 2048
temperature: float = 0.7
top_p = 0.8
# history = []
checkPoint: LoaderCheckPoint = None
history_len: int = 10
streaming_key: str = "streaming" #: :meta private:
history_key: str = "history" #: :meta private:
prompt_key: str = "prompt" #: :meta private:
output_key: str = "answer_result_stream" #: :meta private:
def __init__(self, checkPoint: LoaderCheckPoint = None):
super().__init__()
self.checkPoint = checkPoint
@property
def _llm_type(self) -> str:
return "MOSS"
def _chain_type(self) -> str:
return "MOSSLLMChain"
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return [self.prompt_key]
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
@property
def _check_point(self) -> LoaderCheckPoint:
return self.checkPoint
@property
def _history_len(self) -> int:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Generator]:
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
return {self.output_key: generator}
return self.history_len
def _generate_answer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
def set_history_len(self, history_len: int) -> None:
self.history_len = history_len
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
pass
def generatorAnswer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False):
history = inputs[self.history_key]
streaming = inputs[self.streaming_key]
prompt = inputs[self.prompt_key]
print(f"__call:{prompt}")
if len(history) > 0:
history = history[-self.history_len:] if self.history_len > 0 else []
prompt_w_history = str(history)
@ -79,13 +111,12 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
num_return_sequences=1,
eos_token_id=106068,
pad_token_id=self.checkPoint.tokenizer.pad_token_id)
response = self.checkPoint.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
response = self.checkPoint.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True)
self.checkPoint.clear_torch_cache()
history += [[prompt, response]]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": response}
yield answer_result
generate_with_callback(answer_result)

View File

@ -24,13 +24,12 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_
if use_ptuning_v2:
loaderCheckPoint.use_ptuning_v2 = use_ptuning_v2
# 如果指定了参数,则使用参数的配置
if llm_model:
llm_model_info = llm_model_dict[llm_model]
if loaderCheckPoint.no_remote_model:
loaderCheckPoint.model_name = llm_model_info['name']
else:
loaderCheckPoint.model_name = llm_model_info['pretrained_model_name']
loaderCheckPoint.pretrained_model_name = llm_model_info['pretrained_model_name']
loaderCheckPoint.model_path = llm_model_info["local_model_path"]

View File

@ -1,39 +0,0 @@
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../../')
import asyncio
from argparse import Namespace
from models.loader.args import parser
from models.loader import LoaderCheckPoint
import models.shared as shared
async def dispatch(args: Namespace):
args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM()
history = [
("which city is this?", "tokyo"),
("why?", "she's japanese"),
]
for answer_result in llm_model_ins.generatorAnswer(prompt="你好? ", history=history,
streaming=False):
resp = answer_result.llm_output["answer"]
print(resp)
if __name__ == '__main__':
args = None
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'fastchat-chatglm-6b', '--no-remote-model'])
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(dispatch(args))

View File

@ -85,8 +85,11 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
yield history + [[query,
"请选择知识库后进行测试,当前未选择知识库。"]], ""
else:
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history,
streaming=streaming):
answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": query, "history": history, "streaming": streaming})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"]
history = answer_result.history
history[-1][-1] = resp
@ -101,11 +104,12 @@ def init_model():
args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM()
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
try:
local_doc_qa.init_cfg(llm_model=llm_model_ins)
generator = local_doc_qa.llm.generatorAnswer("你好")
for answer_result in generator:
answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": "你好", "history": [], "streaming": False})
for answer_result in answer_result_stream_result['answer_result_stream']:
print(answer_result.llm_output)
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
logger.info(reply)
@ -141,7 +145,7 @@ def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, u
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
filelist = []
if local_doc_qa.llm and local_doc_qa.embeddings:
if local_doc_qa.llm_model_chain and local_doc_qa.embeddings:
if isinstance(files, list):
for file in files:
filename = os.path.split(file.name)[-1]
@ -262,6 +266,7 @@ def reinit_vector_store(vs_id, history):
def refresh_vs_list():
return gr.update(choices=get_vs_list()), gr.update(choices=get_vs_list())
def delete_file(vs_id, files_to_delete, chatbot):
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
content_path = os.path.join(KB_ROOT_PATH, vs_id, "content")
@ -290,7 +295,8 @@ def delete_vs(vs_id, chatbot):
status = f"成功删除知识库{vs_id}"
logger.info(status)
chatbot = chatbot + [[None, status]]
return gr.update(choices=get_vs_list(), value=get_vs_list()[0]), gr.update(visible=True), gr.update(visible=True), \
return gr.update(choices=get_vs_list(), value=get_vs_list()[0]), gr.update(visible=True), gr.update(
visible=True), \
gr.update(visible=False), chatbot, gr.update(visible=False)
except Exception as e:
logger.error(e)
@ -333,7 +339,8 @@ default_theme_args = dict(
with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as demo:
vs_path, file_status, model_status = gr.State(
os.path.join(KB_ROOT_PATH, get_vs_list()[0], "vector_store") if len(get_vs_list()) > 1 else ""), gr.State(""), gr.State(
os.path.join(KB_ROOT_PATH, get_vs_list()[0], "vector_store") if len(get_vs_list()) > 1 else ""), gr.State(
""), gr.State(
model_status)
gr.Markdown(webui_title)
with gr.Tab("对话"):

View File

@ -85,9 +85,10 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
yield history + [[query,
"请选择知识库后进行测试,当前未选择知识库。"]], ""
else:
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history,
streaming=streaming):
answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": query, "history": history, "streaming": streaming})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"]
history = answer_result.history
history[-1][-1] = resp + (
@ -105,13 +106,14 @@ def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec'
args_dict.update(model=llm_model)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM()
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
try:
local_doc_qa.init_cfg(llm_model=llm_model_ins,
embedding_model=embedding_model)
generator = local_doc_qa.llm.generatorAnswer("你好")
for answer_result in generator:
answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": "你好", "history": [], "streaming": False})
for answer_result in answer_result_stream_result['answer_result_stream']:
print(answer_result.llm_output)
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
logger.info(reply)
@ -468,7 +470,7 @@ with st.sidebar:
top_k = st.slider('向量匹配数量', 1, 20, VECTOR_SEARCH_TOP_K)
history_len = st.slider(
'LLM对话轮数', 1, 50, LLM_HISTORY_LEN) # 也许要跟知识库分开设置
local_doc_qa.llm.set_history_len(history_len)
# local_doc_qa.llm.set_history_len(history_len)
chunk_conent = st.checkbox('启用上下文关联', False)
st.text('')
# chunk_conent = st.checkbox('分割文本', True) # 知识库文本分割入库