diff --git a/configs/model_config.py b/configs/model_config.py index 3a2357b..ba925b2 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -69,12 +69,23 @@ llm_model_dict = { "local_model_path": None, "provides": "LLamaLLM" }, - "fastChatOpenAI": { + "fast-chat-chatglm-6b": { "name": "FastChatOpenAI", "pretrained_model_name": "FastChatOpenAI", "local_model_path": None, - "provides": "FastChatOpenAILLM" - } + "provides": "FastChatOpenAILLM", + "api_base_url": "http://localhost:8000/v1", + "model_name": "chatglm-6b" + }, + + "fast-chat-vicuna-13b-hf": { + "name": "FastChatOpenAI", + "pretrained_model_name": "vicuna-13b-hf", + "local_model_path": None, + "provides": "FastChatOpenAILLM", + "api_base_url": "http://localhost:8000/v1", + "model_name": "vicuna-13b-hf" + }, } # LLM 名称 diff --git a/models/fastchat_openai_llm.py b/models/fastchat_openai_llm.py index abf8834..5228c42 100644 --- a/models/fastchat_openai_llm.py +++ b/models/fastchat_openai_llm.py @@ -111,9 +111,9 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC): messages=self.build_message_list(prompt) ) - self.history += [[prompt, completion.choices[0].message.content]] + history += [[prompt, completion.choices[0].message.content]] answer_result = AnswerResult() - answer_result.history = self.history + answer_result.history = history answer_result.llm_output = {"answer": completion.choices[0].message.content} yield answer_result diff --git a/models/shared.py b/models/shared.py index c78cb44..0525750 100644 --- a/models/shared.py +++ b/models/shared.py @@ -34,11 +34,14 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_ loaderCheckPoint.model_path = llm_model_info["local_model_path"] - if 'FastChat' in loaderCheckPoint.model_name: + if 'FastChatOpenAILLM' in llm_model_info["local_model_path"]: loaderCheckPoint.unload_model() else: loaderCheckPoint.reload_model() provides_class = getattr(sys.modules['models'], llm_model_info['provides']) modelInsLLM = provides_class(checkPoint=loaderCheckPoint) + if 'FastChatOpenAILLM' in llm_model_info["provides"]: + modelInsLLM.set_api_base_url(llm_model_info['api_base_url']) + modelInsLLM.call_model_name(llm_model_info['model_name']) return modelInsLLM diff --git a/test/models/test_fastchat_openai_llm.py b/test/models/test_fastchat_openai_llm.py index 9949fcc..1b09bf7 100644 --- a/test/models/test_fastchat_openai_llm.py +++ b/test/models/test_fastchat_openai_llm.py @@ -18,14 +18,13 @@ async def dispatch(args: Namespace): shared.loaderCheckPoint = LoaderCheckPoint(args_dict) llm_model_ins = shared.loaderLLM() - llm_model_ins.set_api_base_url("http://localhost:8000/v1") - llm_model_ins.call_model_name("chatglm-6b") + history = [ ("which city is this?", "tokyo"), ("why?", "she's japanese"), ] - for answer_result in llm_model_ins.generatorAnswer(prompt="她在做什么? ", history=history, + for answer_result in llm_model_ins.generatorAnswer(prompt="你好? ", history=history, streaming=False): resp = answer_result.llm_output["answer"]