Merge branch 'dev' of github.com:hzg0601/langchain-ChatGLM-annotation into dev
merge upstream dev
This commit is contained in:
commit
760abab1d7
|
|
@ -174,4 +174,4 @@ embedding/*
|
|||
|
||||
pyrightconfig.json
|
||||
loader/tmp_files
|
||||
flagged/*
|
||||
flagged/*
|
||||
|
|
|
|||
7
api.py
7
api.py
|
|
@ -384,8 +384,10 @@ async def chat(
|
|||
],
|
||||
),
|
||||
):
|
||||
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
|
||||
streaming=True):
|
||||
answer_result_stream_result = local_doc_qa.llm_model_chain(
|
||||
{"prompt": question, "history": history, "streaming": True})
|
||||
|
||||
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||
resp = answer_result.llm_output["answer"]
|
||||
history = answer_result.history
|
||||
pass
|
||||
|
|
@ -486,7 +488,6 @@ def api_start(host, port, **kwargs):
|
|||
global local_doc_qa
|
||||
|
||||
llm_model_ins = shared.loaderLLM()
|
||||
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
|
||||
|
||||
app = FastAPI()
|
||||
# Add CORS middleware to allow all origins
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from agent import bing_search
|
|||
from langchain.docstore.document import Document
|
||||
from functools import lru_cache
|
||||
from textsplitter.zh_title_enhance import zh_title_enhance
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
|
||||
# patch HuggingFaceEmbeddings to make it hashable
|
||||
|
|
@ -119,7 +120,7 @@ def search_result2docs(search_results):
|
|||
|
||||
|
||||
class LocalDocQA:
|
||||
llm: BaseAnswer = None
|
||||
llm_model_chain: Chain = None
|
||||
embeddings: object = None
|
||||
top_k: int = VECTOR_SEARCH_TOP_K
|
||||
chunk_size: int = CHUNK_SIZE
|
||||
|
|
@ -129,10 +130,10 @@ class LocalDocQA:
|
|||
def init_cfg(self,
|
||||
embedding_model: str = EMBEDDING_MODEL,
|
||||
embedding_device=EMBEDDING_DEVICE,
|
||||
llm_model: BaseAnswer = None,
|
||||
llm_model: Chain = None,
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
):
|
||||
self.llm = llm_model
|
||||
self.llm_model_chain = llm_model
|
||||
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
|
||||
model_kwargs={'device': embedding_device})
|
||||
self.top_k = top_k
|
||||
|
|
@ -236,8 +237,10 @@ class LocalDocQA:
|
|||
else:
|
||||
prompt = query
|
||||
|
||||
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
|
||||
streaming=streaming):
|
||||
answer_result_stream_result = self.llm_model_chain(
|
||||
{"prompt": prompt, "history": chat_history, "streaming": streaming})
|
||||
|
||||
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||
resp = answer_result.llm_output["answer"]
|
||||
history = answer_result.history
|
||||
history[-1][0] = query
|
||||
|
|
@ -276,8 +279,10 @@ class LocalDocQA:
|
|||
result_docs = search_result2docs(results)
|
||||
prompt = generate_prompt(result_docs, query)
|
||||
|
||||
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
|
||||
streaming=streaming):
|
||||
answer_result_stream_result = self.llm_model_chain(
|
||||
{"prompt": prompt, "history": chat_history, "streaming": streaming})
|
||||
|
||||
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||
resp = answer_result.llm_output["answer"]
|
||||
history = answer_result.history
|
||||
history[-1][0] = query
|
||||
|
|
@ -296,7 +301,7 @@ class LocalDocQA:
|
|||
def update_file_from_vector_store(self,
|
||||
filepath: str or List[str],
|
||||
vs_path,
|
||||
docs: List[Document],):
|
||||
docs: List[Document], ):
|
||||
vector_store = load_vector_store(vs_path, self.embeddings)
|
||||
status = vector_store.update_doc(filepath, docs)
|
||||
return status
|
||||
|
|
@ -320,7 +325,6 @@ if __name__ == "__main__":
|
|||
args_dict = vars(args)
|
||||
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||||
llm_model_ins = shared.loaderLLM()
|
||||
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
|
||||
|
||||
local_doc_qa = LocalDocQA()
|
||||
local_doc_qa.init_cfg(llm_model=llm_model_ins)
|
||||
|
|
|
|||
|
|
@ -37,55 +37,55 @@ llm_model_dict = {
|
|||
"name": "chatglm-6b-int4-qe",
|
||||
"pretrained_model_name": "THUDM/chatglm-6b-int4-qe",
|
||||
"local_model_path": None,
|
||||
"provides": "ChatGLM"
|
||||
"provides": "ChatGLMLLMChain"
|
||||
},
|
||||
"chatglm-6b-int4": {
|
||||
"name": "chatglm-6b-int4",
|
||||
"pretrained_model_name": "THUDM/chatglm-6b-int4",
|
||||
"local_model_path": None,
|
||||
"provides": "ChatGLM"
|
||||
"provides": "ChatGLMLLMChain"
|
||||
},
|
||||
"chatglm-6b-int8": {
|
||||
"name": "chatglm-6b-int8",
|
||||
"pretrained_model_name": "THUDM/chatglm-6b-int8",
|
||||
"local_model_path": None,
|
||||
"provides": "ChatGLM"
|
||||
"provides": "ChatGLMLLMChain"
|
||||
},
|
||||
"chatglm-6b": {
|
||||
"name": "chatglm-6b",
|
||||
"pretrained_model_name": "THUDM/chatglm-6b",
|
||||
"local_model_path": None,
|
||||
"provides": "ChatGLM"
|
||||
"provides": "ChatGLMLLMChain"
|
||||
},
|
||||
"chatglm2-6b": {
|
||||
"name": "chatglm2-6b",
|
||||
"pretrained_model_name": "THUDM/chatglm2-6b",
|
||||
"local_model_path": None,
|
||||
"provides": "ChatGLM"
|
||||
"provides": "ChatGLMLLMChain"
|
||||
},
|
||||
"chatglm2-6b-int4": {
|
||||
"name": "chatglm2-6b-int4",
|
||||
"pretrained_model_name": "THUDM/chatglm2-6b-int4",
|
||||
"local_model_path": None,
|
||||
"provides": "ChatGLM"
|
||||
"provides": "ChatGLMLLMChain"
|
||||
},
|
||||
"chatglm2-6b-int8": {
|
||||
"name": "chatglm2-6b-int8",
|
||||
"pretrained_model_name": "THUDM/chatglm2-6b-int8",
|
||||
"local_model_path": None,
|
||||
"provides": "ChatGLM"
|
||||
"provides": "ChatGLMLLMChain"
|
||||
},
|
||||
"chatyuan": {
|
||||
"name": "chatyuan",
|
||||
"pretrained_model_name": "ClueAI/ChatYuan-large-v2",
|
||||
"local_model_path": None,
|
||||
"provides": "MOSSLLM"
|
||||
"provides": "MOSSLLMChain"
|
||||
},
|
||||
"moss": {
|
||||
"name": "moss",
|
||||
"pretrained_model_name": "fnlp/moss-moon-003-sft",
|
||||
"local_model_path": None,
|
||||
"provides": "MOSSLLM"
|
||||
"provides": "MOSSLLMChain"
|
||||
},
|
||||
"moss-int4": {
|
||||
"name": "moss",
|
||||
|
|
@ -97,7 +97,13 @@ llm_model_dict = {
|
|||
"name": "vicuna-13b-hf",
|
||||
"pretrained_model_name": "vicuna-13b-hf",
|
||||
"local_model_path": None,
|
||||
"provides": "LLamaLLM"
|
||||
"provides": "LLamaLLMChain"
|
||||
},
|
||||
"vicuna-7b-hf": {
|
||||
"name": "vicuna-13b-hf",
|
||||
"pretrained_model_name": "vicuna-13b-hf",
|
||||
"local_model_path": None,
|
||||
"provides": "LLamaLLMChain"
|
||||
},
|
||||
# 直接调用返回requests.exceptions.ConnectionError错误,需要通过huggingface_hub包里的snapshot_download函数
|
||||
# 下载模型,如果snapshot_download还是返回网络错误,多试几次,一般是可以的,
|
||||
|
|
@ -107,7 +113,7 @@ llm_model_dict = {
|
|||
"name": "bloomz-7b1",
|
||||
"pretrained_model_name": "bigscience/bloomz-7b1",
|
||||
"local_model_path": None,
|
||||
"provides": "MOSSLLM"
|
||||
"provides": "MOSSLLMChain"
|
||||
|
||||
},
|
||||
# 实测加载bigscience/bloom-3b需要170秒左右,暂不清楚为什么这么慢
|
||||
|
|
@ -116,14 +122,14 @@ llm_model_dict = {
|
|||
"name": "bloom-3b",
|
||||
"pretrained_model_name": "bigscience/bloom-3b",
|
||||
"local_model_path": None,
|
||||
"provides": "MOSSLLM"
|
||||
"provides": "MOSSLLMChain"
|
||||
|
||||
},
|
||||
"baichuan-7b": {
|
||||
"name": "baichuan-7b",
|
||||
"pretrained_model_name": "baichuan-inc/baichuan-7B",
|
||||
"local_model_path": None,
|
||||
"provides": "MOSSLLM"
|
||||
"provides": "MOSSLLMChain"
|
||||
},
|
||||
# llama-cpp模型的兼容性问题参考https://github.com/abetlen/llama-cpp-python/issues/204
|
||||
"ggml-vicuna-13b-1.1-q5": {
|
||||
|
|
@ -137,7 +143,7 @@ llm_model_dict = {
|
|||
# 需要手动从https://github.com/abetlen/llama-cpp-python/releases/tag/下载对应的wheel安装
|
||||
# 实测v0.1.63与本模型的vicuna/ggml-vicuna-13b-1.1/ggml-vic13b-q5_1.bin可以兼容
|
||||
"local_model_path": f'''{"/".join(os.path.abspath(__file__).split("/")[:3])}/.cache/huggingface/hub/models--vicuna--ggml-vicuna-13b-1.1/blobs/''',
|
||||
"provides": "LLamaLLM"
|
||||
"provides": "LLamaLLMChain"
|
||||
},
|
||||
|
||||
# 通过 fastchat 调用的模型请参考如下格式
|
||||
|
|
@ -145,7 +151,7 @@ llm_model_dict = {
|
|||
"name": "chatglm-6b", # "name"修改为fastchat服务中的"model_name"
|
||||
"pretrained_model_name": "chatglm-6b",
|
||||
"local_model_path": None,
|
||||
"provides": "FastChatOpenAILLM", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM"
|
||||
"provides": "FastChatOpenAILLMChain", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain"
|
||||
"api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url"
|
||||
"api_key": "EMPTY"
|
||||
},
|
||||
|
|
@ -153,7 +159,7 @@ llm_model_dict = {
|
|||
"name": "chatglm2-6b", # "name"修改为fastchat服务中的"model_name"
|
||||
"pretrained_model_name": "chatglm2-6b",
|
||||
"local_model_path": None,
|
||||
"provides": "FastChatOpenAILLM", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM"
|
||||
"provides": "FastChatOpenAILLMChain", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain"
|
||||
"api_base_url": "http://localhost:8000/v1" # "name"修改为fastchat服务中的"api_base_url"
|
||||
},
|
||||
|
||||
|
|
@ -162,7 +168,7 @@ llm_model_dict = {
|
|||
"name": "vicuna-13b-hf", # "name"修改为fastchat服务中的"model_name"
|
||||
"pretrained_model_name": "vicuna-13b-hf",
|
||||
"local_model_path": None,
|
||||
"provides": "FastChatOpenAILLM", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM"
|
||||
"provides": "FastChatOpenAILLMChain", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain"
|
||||
"api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url"
|
||||
"api_key": "EMPTY"
|
||||
},
|
||||
|
|
@ -171,13 +177,13 @@ llm_model_dict = {
|
|||
# 则需要将urllib3版本修改为1.25.11
|
||||
|
||||
# 如果报出:raise NewConnectionError(
|
||||
# urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPSConnection object at 0x000001FE4BDB85E0>:
|
||||
# urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPSConnection object at 0x000001FE4BDB85E0>:
|
||||
# Failed to establish a new connection: [WinError 10060]
|
||||
# 则是因为内地和香港的IP都被OPENAI封了,需要挂切换为日本、新加坡等地
|
||||
"openai-chatgpt-3.5": {
|
||||
"name": "gpt-3.5-turbo",
|
||||
"pretrained_model_name": "gpt-3.5-turbo",
|
||||
"provides": "FastChatOpenAILLM",
|
||||
"provides": "FastChatOpenAILLMChain",
|
||||
"local_model_path": None,
|
||||
"api_base_url": "https://api.openapi.com/v1",
|
||||
"api_key": ""
|
||||
|
|
@ -204,7 +210,10 @@ STREAMING = True
|
|||
# Use p-tuning-v2 PrefixEncoder
|
||||
USE_PTUNING_V2 = False
|
||||
PTUNING_DIR='./ptuing-v2'
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
|
||||
>>>>>>> f68d347c25b4bdd07f293c65a6e44a673a11f614
|
||||
# LLM running device
|
||||
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
|
||||
|
|
@ -233,7 +242,7 @@ LLM_HISTORY_LEN = 3
|
|||
VECTOR_SEARCH_TOP_K = 5
|
||||
|
||||
# 知识检索内容相关度 Score, 数值范围约为0-1100,如果为0,则不生效,经测试设置为小于500时,匹配结果更精准
|
||||
VECTOR_SEARCH_SCORE_THRESHOLD = 0
|
||||
VECTOR_SEARCH_SCORE_THRESHOLD = 390
|
||||
|
||||
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
|
||||
|
||||
|
|
|
|||
Binary file not shown.
|
Before Width: | Height: | Size: 247 KiB |
|
|
@ -1,4 +1,4 @@
|
|||
from .chatglm_llm import ChatGLM
|
||||
from .llama_llm import LLamaLLM
|
||||
from .moss_llm import MOSSLLM
|
||||
from .fastchat_openai_llm import FastChatOpenAILLM
|
||||
from .chatglm_llm import ChatGLMLLMChain
|
||||
from .llama_llm import LLamaLLMChain
|
||||
from .fastchat_openai_llm import FastChatOpenAILLMChain
|
||||
from .moss_llm import MOSSLLMChain
|
||||
|
|
|
|||
|
|
@ -1,13 +1,15 @@
|
|||
from models.base.base import (
|
||||
AnswerResult,
|
||||
BaseAnswer
|
||||
)
|
||||
BaseAnswer,
|
||||
AnswerResultStream,
|
||||
AnswerResultQueueSentinelTokenListenerQueue)
|
||||
from models.base.remote_rpc_model import (
|
||||
RemoteRpcModel
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AnswerResult",
|
||||
"BaseAnswer",
|
||||
"RemoteRpcModel",
|
||||
"AnswerResultStream",
|
||||
"AnswerResultQueueSentinelTokenListenerQueue"
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,13 +1,26 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, List
|
||||
from typing import Any, Dict, List, Optional, Generator
|
||||
import traceback
|
||||
from collections import deque
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from models.loader import LoaderCheckPoint
|
||||
import torch
|
||||
import transformers
|
||||
from models.loader import LoaderCheckPoint
|
||||
|
||||
|
||||
class ListenerToken:
|
||||
"""
|
||||
观测结果
|
||||
"""
|
||||
|
||||
input_ids: torch.LongTensor
|
||||
_scores: torch.FloatTensor
|
||||
|
||||
def __init__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor):
|
||||
self.input_ids = input_ids
|
||||
self._scores = _scores
|
||||
|
||||
|
||||
class AnswerResult:
|
||||
|
|
@ -16,6 +29,123 @@ class AnswerResult:
|
|||
"""
|
||||
history: List[List[str]] = []
|
||||
llm_output: Optional[dict] = None
|
||||
listenerToken: ListenerToken = None
|
||||
|
||||
|
||||
class AnswerResultStream:
|
||||
def __init__(self, callback_func=None):
|
||||
self.callback_func = callback_func
|
||||
|
||||
def __call__(self, answerResult: AnswerResult):
|
||||
if self.callback_func is not None:
|
||||
self.callback_func(answerResult)
|
||||
|
||||
|
||||
class AnswerResultQueueSentinelTokenListenerQueue(transformers.StoppingCriteria):
|
||||
"""
|
||||
定义模型stopping_criteria 监听者,在每次响应时将队列数据同步到AnswerResult
|
||||
实现此监听器的目的是,不同模型的预测输出可能不是矢量信息,hf框架可以自定义transformers.StoppingCriteria入参来接收每次预测的Tensor和损失函数,
|
||||
通过给 StoppingCriteriaList指定模型生成答案时停止的条件。每个 StoppingCriteria 对象表示一个停止条件
|
||||
当每轮预测任务开始时,StoppingCriteria都会收到相同的预测结果,最终由下层实现类确认是否结束
|
||||
输出值可用于 generatorAnswer generate_with_streaming的自定义参数观测,以实现更加精细的控制
|
||||
"""
|
||||
|
||||
listenerQueue: deque = deque(maxlen=1)
|
||||
|
||||
def __init__(self):
|
||||
transformers.StoppingCriteria.__init__(self)
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
"""
|
||||
每次响应时将数据添加到响应队列
|
||||
:param input_ids:
|
||||
:param _scores:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
self.listenerQueue.append(ListenerToken(input_ids=input_ids, _scores=_scores))
|
||||
return False
|
||||
|
||||
|
||||
class Iteratorize:
|
||||
"""
|
||||
Transforms a function that takes a callback
|
||||
into a lazy iterator (generator).
|
||||
"""
|
||||
|
||||
def __init__(self, func, kwargs={}):
|
||||
self.mfunc = func
|
||||
self.q = Queue()
|
||||
self.sentinel = object()
|
||||
self.kwargs = kwargs
|
||||
self.stop_now = False
|
||||
|
||||
def _callback(val):
|
||||
"""
|
||||
模型输出预测结果收集
|
||||
通过定义generate_with_callback收集器AnswerResultStream,收集模型预测的AnswerResult响应结果,最终由下层实现类确认是否结束
|
||||
结束条件包含如下
|
||||
1、模型预测结束、收集器self.q队列收到 self.sentinel标识
|
||||
2、在处理迭代器队列消息时返回了break跳出迭代器,触发了StopIteration事件
|
||||
3、模型预测出错
|
||||
因为当前类是迭代器,所以在for in 中执行了break后 __exit__ 方法会被调用,最终stop_now属性会被更新,然后抛出异常结束预测行为
|
||||
迭代器收集的行为如下
|
||||
创建Iteratorize迭代对象,
|
||||
定义generate_with_callback收集器AnswerResultStream
|
||||
启动一个线程异步预测结果来调用上游checkpoint的实现方法_generate_answer
|
||||
_generate_answer通过generate_with_callback定义的收集器,收集上游checkpoint包装的AnswerResult消息体
|
||||
由于self.q是阻塞模式,每次预测后会被消费后才会执行下次预测
|
||||
这时generate_with_callback会被阻塞
|
||||
主线程Iteratorize对象的__next__方法调用获取阻塞消息并消费
|
||||
1、消息为上游checkpoint包装的AnswerResult消息体,返回下游处理
|
||||
2、消息为self.sentinel标识,抛出StopIteration异常
|
||||
主线程Iteratorize对象__exit__收到消息,最终stop_now属性会被更新
|
||||
异步线程检测stop_now属性被更新,抛出异常结束预测行为
|
||||
迭代行为结束
|
||||
:param val:
|
||||
:return:
|
||||
"""
|
||||
if self.stop_now:
|
||||
raise ValueError
|
||||
self.q.put(val)
|
||||
|
||||
def gen():
|
||||
try:
|
||||
ret = self.mfunc(callback=_callback, **self.kwargs)
|
||||
except ValueError:
|
||||
pass
|
||||
except:
|
||||
traceback.print_exc()
|
||||
pass
|
||||
|
||||
self.q.put(self.sentinel)
|
||||
|
||||
self.thread = Thread(target=gen)
|
||||
self.thread.start()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
obj = self.q.get(True, None)
|
||||
if obj is self.sentinel:
|
||||
raise StopIteration
|
||||
else:
|
||||
return obj
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
暂无实现
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
""" break 后会执行 """
|
||||
self.stop_now = True
|
||||
|
||||
|
||||
class BaseAnswer(ABC):
|
||||
|
|
@ -25,17 +155,25 @@ class BaseAnswer(ABC):
|
|||
@abstractmethod
|
||||
def _check_point(self) -> LoaderCheckPoint:
|
||||
"""Return _check_point of llm."""
|
||||
def generatorAnswer(self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,) -> Generator[Any, str, bool]:
|
||||
def generate_with_callback(callback=None, **kwargs):
|
||||
kwargs['generate_with_callback'] = AnswerResultStream(callback_func=callback)
|
||||
self._generate_answer(**kwargs)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _history_len(self) -> int:
|
||||
"""Return _history_len of llm."""
|
||||
def generate_with_streaming(**kwargs):
|
||||
return Iteratorize(generate_with_callback, kwargs)
|
||||
|
||||
with generate_with_streaming(inputs=inputs, run_manager=run_manager) as generator:
|
||||
for answerResult in generator:
|
||||
if answerResult.listenerToken:
|
||||
output = answerResult.listenerToken.input_ids
|
||||
yield answerResult
|
||||
|
||||
@abstractmethod
|
||||
def set_history_len(self, history_len: int) -> None:
|
||||
"""Return _history_len of llm."""
|
||||
|
||||
def generatorAnswer(self, prompt: str,
|
||||
history: List[List[str]] = [],
|
||||
streaming: bool = False):
|
||||
def _generate_answer(self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
generate_with_callback: AnswerResultStream = None) -> None:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,70 +1,102 @@
|
|||
from abc import ABC
|
||||
from langchain.llms.base import LLM
|
||||
from typing import Optional, List
|
||||
from langchain.chains.base import Chain
|
||||
from typing import Any, Dict, List, Optional, Generator
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
||||
from models.loader import LoaderCheckPoint
|
||||
from models.base import (BaseAnswer,
|
||||
AnswerResult)
|
||||
AnswerResult,
|
||||
AnswerResultStream,
|
||||
AnswerResultQueueSentinelTokenListenerQueue)
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
|
||||
class ChatGLM(BaseAnswer, LLM, ABC):
|
||||
class ChatGLMLLMChain(BaseAnswer, Chain, ABC):
|
||||
max_token: int = 10000
|
||||
temperature: float = 0.01
|
||||
top_p = 0.9
|
||||
# 相关度
|
||||
top_p = 0.4
|
||||
# 候选词数量
|
||||
top_k = 10
|
||||
checkPoint: LoaderCheckPoint = None
|
||||
# history = []
|
||||
history_len: int = 10
|
||||
streaming_key: str = "streaming" #: :meta private:
|
||||
history_key: str = "history" #: :meta private:
|
||||
prompt_key: str = "prompt" #: :meta private:
|
||||
output_key: str = "answer_result_stream" #: :meta private:
|
||||
|
||||
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
||||
super().__init__()
|
||||
self.checkPoint = checkPoint
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "ChatGLM"
|
||||
def _chain_type(self) -> str:
|
||||
return "ChatGLMLLMChain"
|
||||
|
||||
@property
|
||||
def _check_point(self) -> LoaderCheckPoint:
|
||||
return self.checkPoint
|
||||
|
||||
@property
|
||||
def _history_len(self) -> int:
|
||||
return self.history_len
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the prompt expects.
|
||||
|
||||
def set_history_len(self, history_len: int = 10) -> None:
|
||||
self.history_len = history_len
|
||||
:meta private:
|
||||
"""
|
||||
return [self.prompt_key]
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Will always return text key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Generator]:
|
||||
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
|
||||
return {self.output_key: generator}
|
||||
|
||||
def _generate_answer(self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
generate_with_callback: AnswerResultStream = None) -> None:
|
||||
history = inputs[self.history_key]
|
||||
streaming = inputs[self.streaming_key]
|
||||
prompt = inputs[self.prompt_key]
|
||||
print(f"__call:{prompt}")
|
||||
response, _ = self.checkPoint.model.chat(
|
||||
self.checkPoint.tokenizer,
|
||||
prompt,
|
||||
history=[],
|
||||
max_length=self.max_token,
|
||||
temperature=self.temperature
|
||||
)
|
||||
print(f"response:{response}")
|
||||
print(f"+++++++++++++++++++++++++++++++++++")
|
||||
return response
|
||||
|
||||
def generatorAnswer(self, prompt: str,
|
||||
history: List[List[str]] = [],
|
||||
streaming: bool = False):
|
||||
|
||||
# Create the StoppingCriteriaList with the stopping strings
|
||||
stopping_criteria_list = transformers.StoppingCriteriaList()
|
||||
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
|
||||
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
|
||||
stopping_criteria_list.append(listenerQueue)
|
||||
if streaming:
|
||||
history += [[]]
|
||||
for inum, (stream_resp, _) in enumerate(self.checkPoint.model.stream_chat(
|
||||
self.checkPoint.tokenizer,
|
||||
prompt,
|
||||
history=history[-self.history_len:-1] if self.history_len > 1 else [],
|
||||
history=history[-self.history_len:-1] if self.history_len > 0 else [],
|
||||
max_length=self.max_token,
|
||||
temperature=self.temperature
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
stopping_criteria=stopping_criteria_list
|
||||
)):
|
||||
# self.checkPoint.clear_torch_cache()
|
||||
history[-1] = [prompt, stream_resp]
|
||||
answer_result = AnswerResult()
|
||||
answer_result.history = history
|
||||
answer_result.llm_output = {"answer": stream_resp}
|
||||
yield answer_result
|
||||
if listenerQueue.listenerQueue.__len__() > 0:
|
||||
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
|
||||
generate_with_callback(answer_result)
|
||||
self.checkPoint.clear_torch_cache()
|
||||
else:
|
||||
response, _ = self.checkPoint.model.chat(
|
||||
|
|
@ -72,13 +104,18 @@ class ChatGLM(BaseAnswer, LLM, ABC):
|
|||
prompt,
|
||||
history=history[-self.history_len:] if self.history_len > 0 else [],
|
||||
max_length=self.max_token,
|
||||
temperature=self.temperature
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
stopping_criteria=stopping_criteria_list
|
||||
)
|
||||
self.checkPoint.clear_torch_cache()
|
||||
history += [[prompt, response]]
|
||||
answer_result = AnswerResult()
|
||||
answer_result.history = history
|
||||
answer_result.llm_output = {"answer": response}
|
||||
yield answer_result
|
||||
if listenerQueue.listenerQueue.__len__() > 0:
|
||||
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
|
||||
|
||||
generate_with_callback(answer_result)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,15 @@
|
|||
from abc import ABC
|
||||
import requests
|
||||
from typing import Optional, List
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from typing import Any, Dict, List, Optional, Generator, Collection
|
||||
from models.loader import LoaderCheckPoint
|
||||
from models.base import (RemoteRpcModel,
|
||||
AnswerResult)
|
||||
from typing import (
|
||||
Collection,
|
||||
Dict
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from models.base import (BaseAnswer,
|
||||
RemoteRpcModel,
|
||||
AnswerResult,
|
||||
AnswerResultStream,
|
||||
AnswerResultQueueSentinelTokenListenerQueue)
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
|
||||
def _build_message_template() -> Dict[str, str]:
|
||||
|
|
@ -22,41 +22,74 @@ def _build_message_template() -> Dict[str, str]:
|
|||
}
|
||||
|
||||
|
||||
class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
|
||||
# 将历史对话数组转换为文本格式
|
||||
def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, str]]:
|
||||
build_messages: Collection[Dict[str, str]] = []
|
||||
for i, (old_query, response) in enumerate(history):
|
||||
user_build_message = _build_message_template()
|
||||
user_build_message['role'] = 'user'
|
||||
user_build_message['content'] = old_query
|
||||
system_build_message = _build_message_template()
|
||||
system_build_message['role'] = 'system'
|
||||
system_build_message['content'] = response
|
||||
build_messages.append(user_build_message)
|
||||
build_messages.append(system_build_message)
|
||||
|
||||
user_build_message = _build_message_template()
|
||||
user_build_message['role'] = 'user'
|
||||
user_build_message['content'] = query
|
||||
build_messages.append(user_build_message)
|
||||
return build_messages
|
||||
|
||||
|
||||
class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
|
||||
api_base_url: str = "http://localhost:8000/v1"
|
||||
model_name: str = "chatglm-6b"
|
||||
max_token: int = 10000
|
||||
temperature: float = 0.01
|
||||
top_p = 0.9
|
||||
checkPoint: LoaderCheckPoint = None
|
||||
history = []
|
||||
# history = []
|
||||
history_len: int = 10
|
||||
api_key: str = ""
|
||||
|
||||
def __init__(self,
|
||||
streaming_key: str = "streaming" #: :meta private:
|
||||
history_key: str = "history" #: :meta private:
|
||||
prompt_key: str = "prompt" #: :meta private:
|
||||
output_key: str = "answer_result_stream" #: :meta private:
|
||||
|
||||
def __init__(self,
|
||||
checkPoint: LoaderCheckPoint = None,
|
||||
# api_base_url:str="http://localhost:8000/v1",
|
||||
# model_name:str="chatglm-6b",
|
||||
# api_key:str=""
|
||||
# api_base_url:str="http://localhost:8000/v1",
|
||||
# model_name:str="chatglm-6b",
|
||||
# api_key:str=""
|
||||
):
|
||||
super().__init__()
|
||||
self.checkPoint = checkPoint
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "FastChat"
|
||||
def _chain_type(self) -> str:
|
||||
return "LLamaLLMChain"
|
||||
|
||||
@property
|
||||
def _check_point(self) -> LoaderCheckPoint:
|
||||
return self.checkPoint
|
||||
|
||||
@property
|
||||
def _history_len(self) -> int:
|
||||
return self.history_len
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the prompt expects.
|
||||
|
||||
def set_history_len(self, history_len: int = 10) -> None:
|
||||
self.history_len = history_len
|
||||
:meta private:
|
||||
"""
|
||||
return [self.prompt_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Will always return text key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
@property
|
||||
def _api_key(self) -> str:
|
||||
|
|
@ -75,53 +108,25 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
|
|||
def call_model_name(self, model_name):
|
||||
self.model_name = model_name
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Generator]:
|
||||
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
|
||||
return {self.output_key: generator}
|
||||
|
||||
def _generate_answer(self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
generate_with_callback: AnswerResultStream = None) -> None:
|
||||
|
||||
history = inputs[self.history_key]
|
||||
streaming = inputs[self.streaming_key]
|
||||
prompt = inputs[self.prompt_key]
|
||||
print(f"__call:{prompt}")
|
||||
try:
|
||||
import openai
|
||||
# Not support yet
|
||||
# openai.api_key = "EMPTY"
|
||||
openai.key = self.api_key
|
||||
openai.api_base = self.api_base_url
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
# create a chat completion
|
||||
completion = openai.ChatCompletion.create(
|
||||
model=self.model_name,
|
||||
messages=self.build_message_list(prompt)
|
||||
)
|
||||
print(f"response:{completion.choices[0].message.content}")
|
||||
print(f"+++++++++++++++++++++++++++++++++++")
|
||||
return completion.choices[0].message.content
|
||||
|
||||
# 将历史对话数组转换为文本格式
|
||||
def build_message_list(self, query) -> Collection[Dict[str, str]]:
|
||||
build_message_list: Collection[Dict[str, str]] = []
|
||||
history = self.history[-self.history_len:] if self.history_len > 0 else []
|
||||
for i, (old_query, response) in enumerate(history):
|
||||
user_build_message = _build_message_template()
|
||||
user_build_message['role'] = 'user'
|
||||
user_build_message['content'] = old_query
|
||||
system_build_message = _build_message_template()
|
||||
system_build_message['role'] = 'system'
|
||||
system_build_message['content'] = response
|
||||
build_message_list.append(user_build_message)
|
||||
build_message_list.append(system_build_message)
|
||||
|
||||
user_build_message = _build_message_template()
|
||||
user_build_message['role'] = 'user'
|
||||
user_build_message['content'] = query
|
||||
build_message_list.append(user_build_message)
|
||||
return build_message_list
|
||||
|
||||
def generatorAnswer(self, prompt: str,
|
||||
history: List[List[str]] = [],
|
||||
streaming: bool = False):
|
||||
|
||||
try:
|
||||
import openai
|
||||
# Not support yet
|
||||
# openai.api_key = "EMPTY"
|
||||
|
|
@ -135,12 +140,13 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
|
|||
# create a chat completion
|
||||
completion = openai.ChatCompletion.create(
|
||||
model=self.model_name,
|
||||
messages=self.build_message_list(prompt)
|
||||
messages=build_message_list(prompt)
|
||||
)
|
||||
print(f"response:{completion.choices[0].message.content}")
|
||||
print(f"+++++++++++++++++++++++++++++++++++")
|
||||
|
||||
history += [[prompt, completion.choices[0].message.content]]
|
||||
answer_result = AnswerResult()
|
||||
answer_result.history = history
|
||||
answer_result.llm_output = {"answer": completion.choices[0].message.content}
|
||||
|
||||
yield answer_result
|
||||
generate_with_callback(answer_result)
|
||||
|
|
|
|||
|
|
@ -1,29 +1,32 @@
|
|||
from abc import ABC
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
import random
|
||||
import torch
|
||||
import transformers
|
||||
from abc import ABC
|
||||
from langchain.chains.base import Chain
|
||||
from typing import Any, Dict, List, Optional, Generator, Union
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
||||
from typing import Optional, List, Dict, Any,Union
|
||||
from models.loader import LoaderCheckPoint
|
||||
from models.base import (BaseAnswer,
|
||||
AnswerResult)
|
||||
AnswerResult,
|
||||
AnswerResultStream,
|
||||
AnswerResultQueueSentinelTokenListenerQueue)
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
|
||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||
def __call__(self, input_ids: Union[torch.LongTensor,list], scores: Union[torch.FloatTensor,list]) -> torch.FloatTensor:
|
||||
def __call__(self, input_ids: Union[torch.LongTensor, list],
|
||||
scores: Union[torch.FloatTensor, list]) -> torch.FloatTensor:
|
||||
# llama-cpp模型返回的是list,为兼容性考虑,需要判断input_ids和scores的类型,将list转换为torch.Tensor
|
||||
input_ids = torch.tensor(input_ids) if isinstance(input_ids,list) else input_ids
|
||||
scores = torch.tensor(scores) if isinstance(scores,list) else scores
|
||||
input_ids = torch.tensor(input_ids) if isinstance(input_ids, list) else input_ids
|
||||
scores = torch.tensor(scores) if isinstance(scores, list) else scores
|
||||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||||
scores.zero_()
|
||||
scores[..., 5] = 5e4
|
||||
return scores
|
||||
|
||||
|
||||
class LLamaLLM(BaseAnswer, LLM, ABC):
|
||||
class LLamaLLMChain(BaseAnswer, Chain, ABC):
|
||||
checkPoint: LoaderCheckPoint = None
|
||||
# history = []
|
||||
history_len: int = 3
|
||||
|
|
@ -37,32 +40,34 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
|
|||
min_length: int = 0
|
||||
logits_processor: LogitsProcessorList = None
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None
|
||||
eos_token_id: Optional[int] = [2]
|
||||
|
||||
state: object = {'max_new_tokens': 50,
|
||||
'seed': 1,
|
||||
'temperature': 0, 'top_p': 0.1,
|
||||
'top_k': 40, 'typical_p': 1,
|
||||
'repetition_penalty': 1.2,
|
||||
'encoder_repetition_penalty': 1,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'min_length': 0,
|
||||
'penalty_alpha': 0,
|
||||
'num_beams': 1,
|
||||
'length_penalty': 1,
|
||||
'early_stopping': False, 'add_bos_token': True, 'ban_eos_token': False,
|
||||
'truncation_length': 2048, 'custom_stopping_strings': '',
|
||||
'cpu_memory': 0, 'auto_devices': False, 'disk': False, 'cpu': False, 'bf16': False,
|
||||
'load_in_8bit': False, 'wbits': 'None', 'groupsize': 'None', 'model_type': 'None',
|
||||
'pre_layer': 0, 'gpu_memory_0': 0}
|
||||
streaming_key: str = "streaming" #: :meta private:
|
||||
history_key: str = "history" #: :meta private:
|
||||
prompt_key: str = "prompt" #: :meta private:
|
||||
output_key: str = "answer_result_stream" #: :meta private:
|
||||
|
||||
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
||||
super().__init__()
|
||||
self.checkPoint = checkPoint
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "LLamaLLM"
|
||||
def _chain_type(self) -> str:
|
||||
return "LLamaLLMChain"
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the prompt expects.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.prompt_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Will always return text key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
@property
|
||||
def _check_point(self) -> LoaderCheckPoint:
|
||||
|
|
@ -107,35 +112,31 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
|
|||
formatted_history += "### Human:{}\n### Assistant:".format(query)
|
||||
return formatted_history
|
||||
|
||||
def prepare_inputs_for_generation(self,
|
||||
input_ids: torch.LongTensor):
|
||||
"""
|
||||
预生成注意力掩码和 输入序列中每个位置的索引的张量
|
||||
# TODO 没有思路
|
||||
:return:
|
||||
"""
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Generator]:
|
||||
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
|
||||
return {self.output_key: generator}
|
||||
|
||||
mask_positions = torch.zeros((1, input_ids.shape[1]), dtype=input_ids.dtype).to(self.checkPoint.model.device)
|
||||
def _generate_answer(self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
generate_with_callback: AnswerResultStream = None) -> None:
|
||||
|
||||
attention_mask = self.get_masks(input_ids, input_ids.device)
|
||||
|
||||
position_ids = self.get_position_ids(
|
||||
input_ids,
|
||||
device=input_ids.device,
|
||||
mask_positions=mask_positions
|
||||
)
|
||||
|
||||
return input_ids, position_ids, attention_mask
|
||||
|
||||
@property
|
||||
def _history_len(self) -> int:
|
||||
return self.history_len
|
||||
|
||||
def set_history_len(self, history_len: int = 10) -> None:
|
||||
self.history_len = history_len
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
history = inputs[self.history_key]
|
||||
streaming = inputs[self.streaming_key]
|
||||
prompt = inputs[self.prompt_key]
|
||||
print(f"__call:{prompt}")
|
||||
|
||||
# Create the StoppingCriteriaList with the stopping strings
|
||||
self.stopping_criteria = transformers.StoppingCriteriaList()
|
||||
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
|
||||
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
|
||||
self.stopping_criteria.append(listenerQueue)
|
||||
# TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现
|
||||
soft_prompt = self.history_to_text(query=prompt, history=history)
|
||||
if self.logits_processor is None:
|
||||
self.logits_processor = LogitsProcessorList()
|
||||
self.logits_processor.append(InvalidScoreLogitsProcessor())
|
||||
|
|
@ -154,16 +155,10 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
|
|||
"logits_processor": self.logits_processor}
|
||||
|
||||
# 向量转换
|
||||
input_ids = self.encode(prompt, add_bos_token=self.state['add_bos_token'], truncation_length=self.max_new_tokens)
|
||||
# input_ids, position_ids, attention_mask = self.prepare_inputs_for_generation(input_ids=filler_input_ids)
|
||||
|
||||
input_ids = self.encode(soft_prompt, add_bos_token=self.checkPoint.tokenizer.add_bos_token,
|
||||
truncation_length=self.max_new_tokens)
|
||||
|
||||
gen_kwargs.update({'inputs': input_ids})
|
||||
# 注意力掩码
|
||||
# gen_kwargs.update({'attention_mask': attention_mask})
|
||||
# gen_kwargs.update({'position_ids': position_ids})
|
||||
if self.stopping_criteria is None:
|
||||
self.stopping_criteria = transformers.StoppingCriteriaList()
|
||||
# 观测输出
|
||||
gen_kwargs.update({'stopping_criteria': self.stopping_criteria})
|
||||
# llama-cpp模型的参数与transformers的参数字段有较大差异,直接调用会返回不支持的字段错误
|
||||
|
|
@ -173,11 +168,13 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
|
|||
if "llama_cpp" in self.checkPoint.model.__str__():
|
||||
import inspect
|
||||
|
||||
common_kwargs_keys = set(inspect.getfullargspec(self.checkPoint.model.generate).args)&set(gen_kwargs.keys())
|
||||
common_kwargs = {key:gen_kwargs[key] for key in common_kwargs_keys}
|
||||
#? llama-cpp模型的generate方法似乎只接受.cpu类型的输入,响应很慢,慢到哭泣
|
||||
#?为什么会不支持GPU呢,不应该啊?
|
||||
output_ids = torch.tensor([list(self.checkPoint.model.generate(input_id_i.cpu(),**common_kwargs)) for input_id_i in input_ids])
|
||||
common_kwargs_keys = set(inspect.getfullargspec(self.checkPoint.model.generate).args) & set(
|
||||
gen_kwargs.keys())
|
||||
common_kwargs = {key: gen_kwargs[key] for key in common_kwargs_keys}
|
||||
# ? llama-cpp模型的generate方法似乎只接受.cpu类型的输入,响应很慢,慢到哭泣
|
||||
# ?为什么会不支持GPU呢,不应该啊?
|
||||
output_ids = torch.tensor(
|
||||
[list(self.checkPoint.model.generate(input_id_i.cpu(), **common_kwargs)) for input_id_i in input_ids])
|
||||
|
||||
else:
|
||||
output_ids = self.checkPoint.model.generate(**gen_kwargs)
|
||||
|
|
@ -185,17 +182,11 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
|
|||
reply = self.decode(output_ids[0][-new_tokens:])
|
||||
print(f"response:{reply}")
|
||||
print(f"+++++++++++++++++++++++++++++++++++")
|
||||
return reply
|
||||
|
||||
def generatorAnswer(self, prompt: str,
|
||||
history: List[List[str]] = [],
|
||||
streaming: bool = False):
|
||||
|
||||
# TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现
|
||||
softprompt = self.history_to_text(prompt,history=history)
|
||||
response = self._call(prompt=softprompt, stop=['\n###'])
|
||||
|
||||
answer_result = AnswerResult()
|
||||
answer_result.history = history + [[prompt, response]]
|
||||
answer_result.llm_output = {"answer": response}
|
||||
yield answer_result
|
||||
history += [[prompt, reply]]
|
||||
answer_result.history = history
|
||||
if listenerQueue.listenerQueue.__len__() > 0:
|
||||
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
|
||||
answer_result.llm_output = {"answer": reply}
|
||||
generate_with_callback(answer_result)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
|
||||
import argparse
|
||||
import os
|
||||
from configs.model_config import *
|
||||
|
|
@ -45,7 +46,6 @@ parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the m
|
|||
parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras")
|
||||
parser.add_argument('--use-ptuning-v2',type=str,default=False,help="whether use ptuning-v2 checkpoint")
|
||||
parser.add_argument("--ptuning-dir",type=str,default=PTUNING_DIR,help="the dir of ptuning-v2 checkpoint")
|
||||
|
||||
# Accelerate/transformers
|
||||
parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT,
|
||||
help='Load the model with 8-bit precision.')
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ class LoaderCheckPoint:
|
|||
no_remote_model: bool = False
|
||||
# 模型名称
|
||||
model_name: str = None
|
||||
pretrained_model_name: str = None
|
||||
tokenizer: object = None
|
||||
# 模型全路径
|
||||
model_path: str = None
|
||||
|
|
@ -35,11 +36,11 @@ class LoaderCheckPoint:
|
|||
# 因此主要的解决思路是清理环境变量里PATH下的不匹配的cuda版本,一劳永逸的方法是:
|
||||
# 0. 在终端执行`pip uninstall bitsandbytes`
|
||||
# 1. 删除.bashrc文件下关于PATH的条目
|
||||
# 2. 在终端执行 `echo $PATH >> .bashrc`
|
||||
# 2. 在终端执行 `echo $PATH >> .bashrc`
|
||||
# 3. 删除.bashrc文件下PATH中关于不匹配的cuda版本路径
|
||||
# 4. 在终端执行`source .bashrc`
|
||||
# 5. 再执行`pip install bitsandbytes`
|
||||
|
||||
|
||||
load_in_8bit: bool = False
|
||||
is_llamacpp: bool = False
|
||||
bf16: bool = False
|
||||
|
|
@ -67,48 +68,49 @@ class LoaderCheckPoint:
|
|||
self.load_in_8bit = params.get('load_in_8bit', False)
|
||||
self.bf16 = params.get('bf16', False)
|
||||
|
||||
|
||||
def _load_model_config(self, model_name):
|
||||
def _load_model_config(self):
|
||||
|
||||
if self.model_path:
|
||||
self.model_path = re.sub("\s","",self.model_path)
|
||||
self.model_path = re.sub("\s", "", self.model_path)
|
||||
checkpoint = Path(f'{self.model_path}')
|
||||
else:
|
||||
if not self.no_remote_model:
|
||||
checkpoint = model_name
|
||||
else:
|
||||
if self.no_remote_model:
|
||||
raise ValueError(
|
||||
"本地模型local_model_path未配置路径"
|
||||
)
|
||||
else:
|
||||
checkpoint = self.pretrained_model_name
|
||||
|
||||
print(f"load_model_config {checkpoint}...")
|
||||
try:
|
||||
|
||||
model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
|
||||
return model_config
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return checkpoint
|
||||
|
||||
def _load_model(self, model_name):
|
||||
def _load_model(self):
|
||||
"""
|
||||
加载自定义位置的model
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
print(f"Loading {model_name}...")
|
||||
t0 = time.time()
|
||||
|
||||
if self.model_path:
|
||||
self.model_path = re.sub("\s","",self.model_path)
|
||||
self.model_path = re.sub("\s", "", self.model_path)
|
||||
checkpoint = Path(f'{self.model_path}')
|
||||
else:
|
||||
if not self.no_remote_model:
|
||||
checkpoint = model_name
|
||||
else:
|
||||
if self.no_remote_model:
|
||||
raise ValueError(
|
||||
"本地模型local_model_path未配置路径"
|
||||
)
|
||||
else:
|
||||
checkpoint = self.pretrained_model_name
|
||||
|
||||
print(f"Loading {checkpoint}...")
|
||||
self.is_llamacpp = len(list(Path(f'{checkpoint}').glob('ggml*.bin'))) > 0
|
||||
if 'chatglm' in model_name.lower() or "chatyuan" in model_name.lower():
|
||||
if 'chatglm' in self.model_name.lower() or "chatyuan" in self.model_name.lower():
|
||||
LoaderClass = AutoModel
|
||||
else:
|
||||
LoaderClass = AutoModelForCausalLM
|
||||
|
|
@ -134,11 +136,11 @@ class LoaderCheckPoint:
|
|||
# 支持自定义cuda设备
|
||||
elif ":" in self.llm_device:
|
||||
model = LoaderClass.from_pretrained(checkpoint,
|
||||
config=self.model_config,
|
||||
torch_dtype=torch.bfloat16 if self.bf16 else torch.float16,
|
||||
trust_remote_code=True).half().to(self.llm_device)
|
||||
config=self.model_config,
|
||||
torch_dtype=torch.bfloat16 if self.bf16 else torch.float16,
|
||||
trust_remote_code=True).half().to(self.llm_device)
|
||||
else:
|
||||
from accelerate import dispatch_model,infer_auto_device_map
|
||||
from accelerate import dispatch_model, infer_auto_device_map
|
||||
|
||||
model = LoaderClass.from_pretrained(checkpoint,
|
||||
config=self.model_config,
|
||||
|
|
@ -146,29 +148,29 @@ class LoaderCheckPoint:
|
|||
trust_remote_code=True).half()
|
||||
# 可传入device_map自定义每张卡的部署情况
|
||||
if self.device_map is None:
|
||||
if 'chatglm' in model_name.lower():
|
||||
if 'chatglm' in self.model_name.lower():
|
||||
self.device_map = self.chatglm_auto_configure_device_map(num_gpus)
|
||||
elif 'moss' in model_name.lower():
|
||||
self.device_map = self.moss_auto_configure_device_map(num_gpus, model_name)
|
||||
elif 'moss' in self.model_name.lower():
|
||||
self.device_map = self.moss_auto_configure_device_map(num_gpus, checkpoint)
|
||||
else:
|
||||
# 基于如下方式作为默认的多卡加载方案针对新模型基本不会失败
|
||||
# 在chatglm2-6b,bloom-3b,blooz-7b1上进行了测试,GPU负载也相对均衡
|
||||
from accelerate.utils import get_balanced_memory
|
||||
max_memory = get_balanced_memory(model,
|
||||
dtype=torch.int8 if self.load_in_8bit else None,
|
||||
low_zero=False,
|
||||
no_split_module_classes=model._no_split_modules)
|
||||
self.device_map = infer_auto_device_map(model,
|
||||
dtype=torch.float16 if not self.load_in_8bit else torch.int8,
|
||||
max_memory=max_memory,
|
||||
no_split_module_classes=model._no_split_modules)
|
||||
max_memory = get_balanced_memory(model,
|
||||
dtype=torch.int8 if self.load_in_8bit else None,
|
||||
low_zero=False,
|
||||
no_split_module_classes=model._no_split_modules)
|
||||
self.device_map = infer_auto_device_map(model,
|
||||
dtype=torch.float16 if not self.load_in_8bit else torch.int8,
|
||||
max_memory=max_memory,
|
||||
no_split_module_classes=model._no_split_modules)
|
||||
# 对于chaglm和moss意外的模型应使用自动指定,而非调用chatglm的配置方式
|
||||
# 其他模型定义的层类几乎不可能与chatglm和moss一致,使用chatglm_auto_configure_device_map
|
||||
# 百分百会报错,使用infer_auto_device_map虽然可能导致负载不均衡,但至少不会报错
|
||||
# 实测在bloom模型上如此
|
||||
# self.device_map = infer_auto_device_map(model,
|
||||
# dtype=torch.int8,
|
||||
# no_split_module_classes=model._no_split_modules)
|
||||
# self.device_map = infer_auto_device_map(model,
|
||||
# dtype=torch.int8,
|
||||
# no_split_module_classes=model._no_split_modules)
|
||||
|
||||
model = dispatch_model(model, device_map=self.device_map)
|
||||
else:
|
||||
|
|
@ -202,7 +204,7 @@ class LoaderCheckPoint:
|
|||
|
||||
# tokenizer = model.tokenizer
|
||||
# todo 此处调用AutoTokenizer的tokenizer,但后续可以测试自带tokenizer是不是兼容
|
||||
#* -> 自带的tokenizer不与transoformers的tokenizer兼容,无法使用
|
||||
# * -> 自带的tokenizer不与transoformers的tokenizer兼容,无法使用
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
return model, tokenizer
|
||||
|
|
@ -231,7 +233,7 @@ class LoaderCheckPoint:
|
|||
llm_int8_enable_fp32_cpu_offload=False)
|
||||
|
||||
with init_empty_weights():
|
||||
model = LoaderClass.from_config(self.model_config,trust_remote_code = True)
|
||||
model = LoaderClass.from_config(self.model_config, trust_remote_code=True)
|
||||
model.tie_weights()
|
||||
if self.device_map is not None:
|
||||
params['device_map'] = self.device_map
|
||||
|
|
@ -294,7 +296,7 @@ class LoaderCheckPoint:
|
|||
# 在调用chat或者stream_chat时,input_ids会被放到model.device上
|
||||
# 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
|
||||
# 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
|
||||
|
||||
|
||||
encode = ""
|
||||
if 'chatglm2' in self.model_name:
|
||||
device_map = {
|
||||
|
|
@ -302,13 +304,13 @@ class LoaderCheckPoint:
|
|||
f"{layer_prefix}.rotary_pos_emb": 0,
|
||||
f"{layer_prefix}.output_layer": 0,
|
||||
f"{layer_prefix}.encoder.final_layernorm": 0,
|
||||
f"base_model.model.output_layer": 0
|
||||
f"base_model.model.output_layer": 0
|
||||
}
|
||||
encode = ".encoder"
|
||||
else:
|
||||
device_map = {f'{layer_prefix}.word_embeddings': 0,
|
||||
f'{layer_prefix}.final_layernorm': 0, 'lm_head': 0,
|
||||
f'base_model.model.lm_head': 0, }
|
||||
f'{layer_prefix}.final_layernorm': 0, 'lm_head': 0,
|
||||
f'base_model.model.lm_head': 0, }
|
||||
used = 2
|
||||
gpu_target = 0
|
||||
for i in range(num_trans_layers):
|
||||
|
|
@ -321,7 +323,7 @@ class LoaderCheckPoint:
|
|||
|
||||
return device_map
|
||||
|
||||
def moss_auto_configure_device_map(self, num_gpus: int, model_name) -> Dict[str, int]:
|
||||
def moss_auto_configure_device_map(self, num_gpus: int, checkpoint) -> Dict[str, int]:
|
||||
try:
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
|
|
@ -336,16 +338,6 @@ class LoaderCheckPoint:
|
|||
"`pip install bitsandbytes``pip install accelerate`."
|
||||
) from exc
|
||||
|
||||
if self.model_path:
|
||||
checkpoint = Path(f'{self.model_path}')
|
||||
else:
|
||||
if not self.no_remote_model:
|
||||
checkpoint = model_name
|
||||
else:
|
||||
raise ValueError(
|
||||
"本地模型local_model_path未配置路径"
|
||||
)
|
||||
|
||||
cls = get_class_from_dynamic_module(class_reference="fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM",
|
||||
pretrained_model_name_or_path=checkpoint)
|
||||
|
||||
|
|
@ -452,7 +444,7 @@ class LoaderCheckPoint:
|
|||
|
||||
def reload_model(self):
|
||||
self.unload_model()
|
||||
self.model_config = self._load_model_config(self.model_name)
|
||||
self.model_config = self._load_model_config()
|
||||
|
||||
if self.use_ptuning_v2:
|
||||
try:
|
||||
|
|
@ -464,7 +456,7 @@ class LoaderCheckPoint:
|
|||
except Exception as e:
|
||||
print("加载PrefixEncoder config.json失败")
|
||||
|
||||
self.model, self.tokenizer = self._load_model(self.model_name)
|
||||
self.model, self.tokenizer = self._load_model()
|
||||
|
||||
if self.lora:
|
||||
self._add_lora_to_model([self.lora])
|
||||
|
|
|
|||
|
|
@ -1,11 +1,19 @@
|
|||
from abc import ABC
|
||||
from langchain.llms.base import LLM
|
||||
from typing import Optional, List
|
||||
from langchain.chains.base import Chain
|
||||
from typing import Any, Dict, List, Optional, Generator, Union
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
||||
from models.loader import LoaderCheckPoint
|
||||
from models.base import (BaseAnswer,
|
||||
AnswerResult)
|
||||
AnswerResult,
|
||||
AnswerResultStream,
|
||||
AnswerResultQueueSentinelTokenListenerQueue)
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
import torch
|
||||
|
||||
# todo 建议重写instruction,在该instruction下,各模型的表现比较差
|
||||
META_INSTRUCTION = \
|
||||
"""You are an AI assistant whose name is MOSS.
|
||||
|
|
@ -20,41 +28,65 @@ META_INSTRUCTION = \
|
|||
Capabilities and tools that MOSS can possess.
|
||||
"""
|
||||
|
||||
|
||||
# todo 在MOSSLLM类下,各模型的响应速度很慢,后续要检查一下原因
|
||||
class MOSSLLM(BaseAnswer, LLM, ABC):
|
||||
class MOSSLLMChain(BaseAnswer, Chain, ABC):
|
||||
max_token: int = 2048
|
||||
temperature: float = 0.7
|
||||
top_p = 0.8
|
||||
# history = []
|
||||
checkPoint: LoaderCheckPoint = None
|
||||
history_len: int = 10
|
||||
streaming_key: str = "streaming" #: :meta private:
|
||||
history_key: str = "history" #: :meta private:
|
||||
prompt_key: str = "prompt" #: :meta private:
|
||||
output_key: str = "answer_result_stream" #: :meta private:
|
||||
|
||||
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
||||
super().__init__()
|
||||
self.checkPoint = checkPoint
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "MOSS"
|
||||
def _chain_type(self) -> str:
|
||||
return "MOSSLLMChain"
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the prompt expects.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.prompt_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Will always return text key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
@property
|
||||
def _check_point(self) -> LoaderCheckPoint:
|
||||
return self.checkPoint
|
||||
|
||||
@property
|
||||
def _history_len(self) -> int:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Generator]:
|
||||
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
|
||||
return {self.output_key: generator}
|
||||
|
||||
return self.history_len
|
||||
def _generate_answer(self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
generate_with_callback: AnswerResultStream = None) -> None:
|
||||
|
||||
def set_history_len(self, history_len: int) -> None:
|
||||
self.history_len = history_len
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
pass
|
||||
|
||||
def generatorAnswer(self, prompt: str,
|
||||
history: List[List[str]] = [],
|
||||
streaming: bool = False):
|
||||
history = inputs[self.history_key]
|
||||
streaming = inputs[self.streaming_key]
|
||||
prompt = inputs[self.prompt_key]
|
||||
print(f"__call:{prompt}")
|
||||
if len(history) > 0:
|
||||
history = history[-self.history_len:] if self.history_len > 0 else []
|
||||
prompt_w_history = str(history)
|
||||
|
|
@ -66,7 +98,7 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
|
|||
inputs = self.checkPoint.tokenizer(prompt_w_history, return_tensors="pt")
|
||||
with torch.no_grad():
|
||||
# max_length似乎可以设的小一些,而repetion_penalty应大一些,否则chatyuan,bloom等模型为满足max会重复输出
|
||||
#
|
||||
#
|
||||
outputs = self.checkPoint.model.generate(
|
||||
inputs.input_ids.cuda(),
|
||||
attention_mask=inputs.attention_mask.cuda(),
|
||||
|
|
@ -79,13 +111,12 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
|
|||
num_return_sequences=1,
|
||||
eos_token_id=106068,
|
||||
pad_token_id=self.checkPoint.tokenizer.pad_token_id)
|
||||
response = self.checkPoint.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||
response = self.checkPoint.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:],
|
||||
skip_special_tokens=True)
|
||||
self.checkPoint.clear_torch_cache()
|
||||
history += [[prompt, response]]
|
||||
answer_result = AnswerResult()
|
||||
answer_result.history = history
|
||||
answer_result.llm_output = {"answer": response}
|
||||
|
||||
yield answer_result
|
||||
|
||||
|
||||
generate_with_callback(answer_result)
|
||||
|
|
|
|||
|
|
@ -24,13 +24,12 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_
|
|||
if use_ptuning_v2:
|
||||
loaderCheckPoint.use_ptuning_v2 = use_ptuning_v2
|
||||
|
||||
# 如果指定了参数,则使用参数的配置
|
||||
if llm_model:
|
||||
llm_model_info = llm_model_dict[llm_model]
|
||||
|
||||
if loaderCheckPoint.no_remote_model:
|
||||
loaderCheckPoint.model_name = llm_model_info['name']
|
||||
else:
|
||||
loaderCheckPoint.model_name = llm_model_info['pretrained_model_name']
|
||||
loaderCheckPoint.model_name = llm_model_info['name']
|
||||
loaderCheckPoint.pretrained_model_name = llm_model_info['pretrained_model_name']
|
||||
|
||||
loaderCheckPoint.model_path = llm_model_info["local_model_path"]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,39 +0,0 @@
|
|||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../../')
|
||||
import asyncio
|
||||
from argparse import Namespace
|
||||
from models.loader.args import parser
|
||||
from models.loader import LoaderCheckPoint
|
||||
|
||||
|
||||
import models.shared as shared
|
||||
|
||||
|
||||
|
||||
async def dispatch(args: Namespace):
|
||||
args_dict = vars(args)
|
||||
|
||||
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||||
|
||||
llm_model_ins = shared.loaderLLM()
|
||||
|
||||
history = [
|
||||
("which city is this?", "tokyo"),
|
||||
("why?", "she's japanese"),
|
||||
|
||||
]
|
||||
for answer_result in llm_model_ins.generatorAnswer(prompt="你好? ", history=history,
|
||||
streaming=False):
|
||||
resp = answer_result.llm_output["answer"]
|
||||
|
||||
print(resp)
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = None
|
||||
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'fastchat-chatglm-6b', '--no-remote-model'])
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(dispatch(args))
|
||||
47
webui.py
47
webui.py
|
|
@ -85,8 +85,11 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
|
|||
yield history + [[query,
|
||||
"请选择知识库后进行测试,当前未选择知识库。"]], ""
|
||||
else:
|
||||
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history,
|
||||
streaming=streaming):
|
||||
|
||||
answer_result_stream_result = local_doc_qa.llm_model_chain(
|
||||
{"prompt": query, "history": history, "streaming": streaming})
|
||||
|
||||
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||
resp = answer_result.llm_output["answer"]
|
||||
history = answer_result.history
|
||||
history[-1][-1] = resp
|
||||
|
|
@ -101,11 +104,12 @@ def init_model():
|
|||
args_dict = vars(args)
|
||||
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||||
llm_model_ins = shared.loaderLLM()
|
||||
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
|
||||
try:
|
||||
local_doc_qa.init_cfg(llm_model=llm_model_ins)
|
||||
generator = local_doc_qa.llm.generatorAnswer("你好")
|
||||
for answer_result in generator:
|
||||
answer_result_stream_result = local_doc_qa.llm_model_chain(
|
||||
{"prompt": "你好", "history": [], "streaming": False})
|
||||
|
||||
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||
print(answer_result.llm_output)
|
||||
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
|
||||
logger.info(reply)
|
||||
|
|
@ -141,7 +145,7 @@ def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, u
|
|||
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
|
||||
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
||||
filelist = []
|
||||
if local_doc_qa.llm and local_doc_qa.embeddings:
|
||||
if local_doc_qa.llm_model_chain and local_doc_qa.embeddings:
|
||||
if isinstance(files, list):
|
||||
for file in files:
|
||||
filename = os.path.split(file.name)[-1]
|
||||
|
|
@ -165,8 +169,8 @@ def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_conte
|
|||
|
||||
def change_vs_name_input(vs_id, history):
|
||||
if vs_id == "新建知识库":
|
||||
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history,\
|
||||
gr.update(choices=[]), gr.update(visible=False)
|
||||
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history, \
|
||||
gr.update(choices=[]), gr.update(visible=False)
|
||||
else:
|
||||
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
||||
if "index.faiss" in os.listdir(vs_path):
|
||||
|
|
@ -218,7 +222,7 @@ def change_chunk_conent(mode, label_conent, history):
|
|||
|
||||
|
||||
def add_vs_name(vs_name, chatbot):
|
||||
if vs_name is None or vs_name.strip() == "" :
|
||||
if vs_name is None or vs_name.strip() == "":
|
||||
vs_status = "知识库名称不能为空,请重新填写知识库名称"
|
||||
chatbot = chatbot + [[None, vs_status]]
|
||||
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
|
||||
|
|
@ -262,6 +266,7 @@ def reinit_vector_store(vs_id, history):
|
|||
def refresh_vs_list():
|
||||
return gr.update(choices=get_vs_list()), gr.update(choices=get_vs_list())
|
||||
|
||||
|
||||
def delete_file(vs_id, files_to_delete, chatbot):
|
||||
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
||||
content_path = os.path.join(KB_ROOT_PATH, vs_id, "content")
|
||||
|
|
@ -275,11 +280,11 @@ def delete_file(vs_id, files_to_delete, chatbot):
|
|||
rested_files = local_doc_qa.list_file_from_vector_store(vs_path)
|
||||
if "fail" in status:
|
||||
vs_status = "文件删除失败。"
|
||||
elif len(rested_files)>0:
|
||||
elif len(rested_files) > 0:
|
||||
vs_status = "文件删除成功。"
|
||||
else:
|
||||
vs_status = f"文件删除成功,知识库{vs_id}中无已上传文件,请先上传文件后,再开始提问。"
|
||||
logger.info(",".join(files_to_delete)+vs_status)
|
||||
logger.info(",".join(files_to_delete) + vs_status)
|
||||
chatbot = chatbot + [[None, vs_status]]
|
||||
return gr.update(choices=local_doc_qa.list_file_from_vector_store(vs_path), value=[]), chatbot
|
||||
|
||||
|
|
@ -290,7 +295,8 @@ def delete_vs(vs_id, chatbot):
|
|||
status = f"成功删除知识库{vs_id}"
|
||||
logger.info(status)
|
||||
chatbot = chatbot + [[None, status]]
|
||||
return gr.update(choices=get_vs_list(), value=get_vs_list()[0]), gr.update(visible=True), gr.update(visible=True), \
|
||||
return gr.update(choices=get_vs_list(), value=get_vs_list()[0]), gr.update(visible=True), gr.update(
|
||||
visible=True), \
|
||||
gr.update(visible=False), chatbot, gr.update(visible=False)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
|
@ -333,7 +339,8 @@ default_theme_args = dict(
|
|||
|
||||
with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as demo:
|
||||
vs_path, file_status, model_status = gr.State(
|
||||
os.path.join(KB_ROOT_PATH, get_vs_list()[0], "vector_store") if len(get_vs_list()) > 1 else ""), gr.State(""), gr.State(
|
||||
os.path.join(KB_ROOT_PATH, get_vs_list()[0], "vector_store") if len(get_vs_list()) > 1 else ""), gr.State(
|
||||
""), gr.State(
|
||||
model_status)
|
||||
gr.Markdown(webui_title)
|
||||
with gr.Tab("对话"):
|
||||
|
|
@ -386,8 +393,8 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
|
|||
load_folder_button = gr.Button("上传文件夹并加载知识库")
|
||||
with gr.Tab("删除文件"):
|
||||
files_to_delete = gr.CheckboxGroup(choices=[],
|
||||
label="请从知识库已有文件中选择要删除的文件",
|
||||
interactive=True)
|
||||
label="请从知识库已有文件中选择要删除的文件",
|
||||
interactive=True)
|
||||
delete_file_button = gr.Button("从知识库中删除选中文件")
|
||||
vs_refresh.click(fn=refresh_vs_list,
|
||||
inputs=[],
|
||||
|
|
@ -455,9 +462,9 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
|
|||
with vs_setting:
|
||||
vs_refresh = gr.Button("更新已有知识库选项")
|
||||
select_vs_test = gr.Dropdown(get_vs_list(),
|
||||
label="请选择要加载的知识库",
|
||||
interactive=True,
|
||||
value=get_vs_list()[0] if len(get_vs_list()) > 0 else None)
|
||||
label="请选择要加载的知识库",
|
||||
interactive=True,
|
||||
value=get_vs_list()[0] if len(get_vs_list()) > 0 else None)
|
||||
vs_name = gr.Textbox(label="请输入新建知识库名称,当前知识库命名暂不支持中文",
|
||||
lines=1,
|
||||
interactive=True,
|
||||
|
|
@ -497,8 +504,8 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
|
|||
inputs=[vs_name, chatbot],
|
||||
outputs=[select_vs_test, vs_name, vs_add, file2vs, chatbot])
|
||||
select_vs_test.change(fn=change_vs_name_input,
|
||||
inputs=[select_vs_test, chatbot],
|
||||
outputs=[vs_name, vs_add, file2vs, vs_path, chatbot])
|
||||
inputs=[select_vs_test, chatbot],
|
||||
outputs=[vs_name, vs_add, file2vs, vs_path, chatbot])
|
||||
load_file_button.click(get_vector_store,
|
||||
show_progress=True,
|
||||
inputs=[select_vs_test, files, sentence_size, chatbot, vs_add, vs_add],
|
||||
|
|
|
|||
14
webui_st.py
14
webui_st.py
|
|
@ -85,9 +85,10 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
|
|||
yield history + [[query,
|
||||
"请选择知识库后进行测试,当前未选择知识库。"]], ""
|
||||
else:
|
||||
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history,
|
||||
streaming=streaming):
|
||||
answer_result_stream_result = local_doc_qa.llm_model_chain(
|
||||
{"prompt": query, "history": history, "streaming": streaming})
|
||||
|
||||
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||
resp = answer_result.llm_output["answer"]
|
||||
history = answer_result.history
|
||||
history[-1][-1] = resp + (
|
||||
|
|
@ -105,13 +106,14 @@ def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec'
|
|||
args_dict.update(model=llm_model)
|
||||
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||||
llm_model_ins = shared.loaderLLM()
|
||||
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
|
||||
|
||||
try:
|
||||
local_doc_qa.init_cfg(llm_model=llm_model_ins,
|
||||
embedding_model=embedding_model)
|
||||
generator = local_doc_qa.llm.generatorAnswer("你好")
|
||||
for answer_result in generator:
|
||||
answer_result_stream_result = local_doc_qa.llm_model_chain(
|
||||
{"prompt": "你好", "history": [], "streaming": False})
|
||||
|
||||
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||
print(answer_result.llm_output)
|
||||
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
|
||||
logger.info(reply)
|
||||
|
|
@ -468,7 +470,7 @@ with st.sidebar:
|
|||
top_k = st.slider('向量匹配数量', 1, 20, VECTOR_SEARCH_TOP_K)
|
||||
history_len = st.slider(
|
||||
'LLM对话轮数', 1, 50, LLM_HISTORY_LEN) # 也许要跟知识库分开设置
|
||||
local_doc_qa.llm.set_history_len(history_len)
|
||||
# local_doc_qa.llm.set_history_len(history_len)
|
||||
chunk_conent = st.checkbox('启用上下文关联', False)
|
||||
st.text('')
|
||||
# chunk_conent = st.checkbox('分割文本', True) # 知识库文本分割入库
|
||||
|
|
|
|||
Loading…
Reference in New Issue