适配远程LLM调用

This commit is contained in:
glide-the 2023-05-31 22:11:28 +08:00
parent 99e9d1d730
commit 24324563d6
4 changed files with 22 additions and 9 deletions

View File

@ -69,12 +69,23 @@ llm_model_dict = {
"local_model_path": None,
"provides": "LLamaLLM"
},
"fastChatOpenAI": {
"fast-chat-chatglm-6b": {
"name": "FastChatOpenAI",
"pretrained_model_name": "FastChatOpenAI",
"local_model_path": None,
"provides": "FastChatOpenAILLM"
}
"provides": "FastChatOpenAILLM",
"api_base_url": "http://localhost:8000/v1",
"model_name": "chatglm-6b"
},
"fast-chat-vicuna-13b-hf": {
"name": "FastChatOpenAI",
"pretrained_model_name": "vicuna-13b-hf",
"local_model_path": None,
"provides": "FastChatOpenAILLM",
"api_base_url": "http://localhost:8000/v1",
"model_name": "vicuna-13b-hf"
},
}
# LLM 名称

View File

@ -111,9 +111,9 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
messages=self.build_message_list(prompt)
)
self.history += [[prompt, completion.choices[0].message.content]]
history += [[prompt, completion.choices[0].message.content]]
answer_result = AnswerResult()
answer_result.history = self.history
answer_result.history = history
answer_result.llm_output = {"answer": completion.choices[0].message.content}
yield answer_result

View File

@ -34,11 +34,14 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_
loaderCheckPoint.model_path = llm_model_info["local_model_path"]
if 'FastChat' in loaderCheckPoint.model_name:
if 'FastChatOpenAILLM' in llm_model_info["local_model_path"]:
loaderCheckPoint.unload_model()
else:
loaderCheckPoint.reload_model()
provides_class = getattr(sys.modules['models'], llm_model_info['provides'])
modelInsLLM = provides_class(checkPoint=loaderCheckPoint)
if 'FastChatOpenAILLM' in llm_model_info["provides"]:
modelInsLLM.set_api_base_url(llm_model_info['api_base_url'])
modelInsLLM.call_model_name(llm_model_info['model_name'])
return modelInsLLM

View File

@ -18,14 +18,13 @@ async def dispatch(args: Namespace):
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM()
llm_model_ins.set_api_base_url("http://localhost:8000/v1")
llm_model_ins.call_model_name("chatglm-6b")
history = [
("which city is this?", "tokyo"),
("why?", "she's japanese"),
]
for answer_result in llm_model_ins.generatorAnswer(prompt="她在做什么? ", history=history,
for answer_result in llm_model_ins.generatorAnswer(prompt="你好? ", history=history,
streaming=False):
resp = answer_result.llm_output["answer"]