Azure 的fschat支持,弃用Langchain-MODEL (#1873)

* 支持了agentlm

* 支持了agentlm和相关提示词

* 修改了Agent的一些功能,加入了Embed方面的一个优化

* 修改了部分Agent的工具

* 增加一些Langchain的自带工具

* 修复一些兼容性的bug

* 恢复知识库

* 恢复知识库

* 修复Azure问题
This commit is contained in:
zR 2023-10-25 21:32:40 +08:00 committed by GitHub
parent 27418ba9fa
commit 35a7ca74c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 174 additions and 67 deletions

View File

@ -112,46 +112,6 @@ HISTORY_LEN = 3
TEMPERATURE = 0.7 TEMPERATURE = 0.7
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数 # TOP_P = 0.95 # ChatOpenAI暂不支持该参数
LANGCHAIN_LLM_MODEL = {
# 不需要走Fschat封装的Langchain直接支持的模型。
# 调用chatgpt时如果报出 urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
# Max retries exceeded with url: /v1/chat/completions
# 则需要将urllib3版本修改为1.25.11
# 如果依然报urllib3.exceptions.MaxRetryError: HTTPSConnectionPool则将https改为http
# 参考https://zhuanlan.zhihu.com/p/350015032
# 如果报出raise NewConnectionError(
# urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPSConnection object at 0x000001FE4BDB85E0>:
# Failed to establish a new connection: [WinError 10060]
# 则是因为内地和香港的IP都被OPENAI封了需要切换为日本、新加坡等地
# 如果出现WARNING: Retrying langchain.chat_models.openai.acompletion_with_retry.<locals>._completion_with_retry in
# 4.0 seconds as it raised APIConnectionError: Error communicating with OpenAI.
# 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置
# 比如: "openai_proxy": 'http://127.0.0.1:4780'
# 这些配置文件的名字不能改动
"Azure-OpenAI": {
"deployment_name": "your Azure deployment name",
"model_version": "0701",
"openai_api_type": "azure",
"api_base_url": "https://your Azure point.azure.com",
"api_version": "2023-07-01-preview",
"api_key": "your Azure api key",
"openai_proxy": "",
},
"OpenAI": {
"model_name": "your openai model name(such as gpt-4)",
"api_base_url": "https://api.openai.com/v1",
"api_key": "your OPENAI_API_KEY",
"openai_proxy": "",
},
"Anthropic": {
"model_name": "your claude model name(such as claude2-100k)",
"api_key":"your ANTHROPIC_API_KEY",
}
}
ONLINE_LLM_MODEL = { ONLINE_LLM_MODEL = {
# 线上模型。请在server_config中为每个在线API设置不同的端口 # 线上模型。请在server_config中为每个在线API设置不同的端口
# 具体注册及api key获取请前往 http://open.bigmodel.cn # 具体注册及api key获取请前往 http://open.bigmodel.cn
@ -205,6 +165,15 @@ ONLINE_LLM_MODEL = {
"secret_key": "", "secret_key": "",
"provider": "BaiChuanWorker", "provider": "BaiChuanWorker",
}, },
# Azure API
"azure-api": {
"deployment_name": "", # 部署容器的名字
"resource_name": "", # https://{resource_name}.openai.azure.com/openai/ 填写resource_name的部分其他部分不要填写
"api_version": "", # API的版本不是模型版本
"api_key": "",
"provider": "AzureWorker",
},
} }

View File

@ -50,7 +50,7 @@ streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11 streamlit-antd-components>=0.1.11
streamlit-chatbox==1.1.10 streamlit-chatbox==1.1.10
streamlit-aggrid>=0.3.4.post3 streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1 httpx[brotli,http2,socks]>=0.25.0
watchdog watchdog
tqdm tqdm
websockets websockets

View File

@ -1,5 +1,5 @@
from fastapi import Body from fastapi import Body
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT,LANGCHAIN_LLM_MODEL from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models, from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
get_httpx_client, get_model_worker_config) get_httpx_client, get_model_worker_config)

View File

@ -5,3 +5,4 @@ from .qianfan import QianFanWorker
from .fangzhou import FangZhouWorker from .fangzhou import FangZhouWorker
from .qwen import QwenWorker from .qwen import QwenWorker
from .baichuan import BaiChuanWorker from .baichuan import BaiChuanWorker
from .azure import AzureWorker

View File

@ -0,0 +1,117 @@
import sys
from fastchat.conversation import Conversation
from server.model_workers.base import ApiModelWorker
from server.utils import get_model_worker_config, get_httpx_client
from fastchat import conversation as conv
import json
from typing import List, Dict
from configs import TEMPERATURE
def request_azure_api(
messages: List[Dict[str, str]],
resource_name: str = None,
api_key: str = None,
deployment_name: str = None,
api_version: str = "2023-07-01-preview",
temperature: float = TEMPERATURE,
max_tokens: int = None,
model_name: str = "azure-api",
):
config = get_model_worker_config(model_name)
data = dict(
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=True,
)
url = f"https://{resource_name}.openai.azure.com/openai/deployments/{deployment_name}/chat/completions?api-version={api_version}"
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
'api-key': api_key,
}
with get_httpx_client() as client:
with client.stream("POST", url, headers=headers, json=data) as response:
for line in response.iter_lines():
if not line.strip() or "[DONE]" in line:
continue
if line.startswith("data: "):
line = line[6:]
resp = json.loads(line)
yield resp
class AzureWorker(ApiModelWorker):
def __init__(
self,
*,
controller_addr: str,
worker_addr: str,
model_names: List[str] = ["azure-api"],
version: str = "gpt-35-turbo",
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 8192)
super().__init__(**kwargs)
config = self.get_config()
self.resource_name = config.get("resource_name")
self.api_key = config.get("api_key")
self.api_version = config.get("api_version")
self.version = version
self.deployment_name = config.get("deployment_name")
def generate_stream_gate(self, params):
super().generate_stream_gate(params)
messages = self.prompt_to_messages(params["prompt"])
text = ""
for resp in request_azure_api(messages=messages,
resource_name=self.resource_name,
api_key=self.api_key,
api_version=self.api_version,
deployment_name=self.deployment_name,
temperature=params.get("temperature"),
max_tokens=params.get("max_tokens")):
if choices := resp["choices"]:
if chunk := choices[0].get("delta", {}).get("content"):
text += chunk
yield json.dumps(
{
"error_code": 0,
"text": text
},
ensure_ascii=False
).encode() + b"\0"
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.base_model_worker import app
worker = AzureWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21008",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21008)

View File

@ -138,9 +138,9 @@ if __name__ == "__main__":
worker = BaiChuanWorker( worker = BaiChuanWorker(
controller_addr="http://127.0.0.1:20001", controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21001", worker_addr="http://127.0.0.1:21007",
) )
sys.modules["fastchat.serve.model_worker"].worker = worker sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app) MakeFastAPIOffline(app)
uvicorn.run(app, port=21001) uvicorn.run(app, port=21007)
# do_request() # do_request()

View File

@ -1,17 +1,16 @@
import sys
from fastchat.conversation import Conversation from fastchat.conversation import Conversation
from server.model_workers.base import ApiModelWorker from server.model_workers.base import ApiModelWorker
from configs.model_config import TEMPERATURE from configs.model_config import TEMPERATURE
from fastchat import conversation as conv from fastchat import conversation as conv
import sys
import json import json
import httpx
from cachetools import cached, TTLCache from cachetools import cached, TTLCache
from server.utils import get_model_worker_config, get_httpx_client from server.utils import get_model_worker_config, get_httpx_client
from typing import List, Literal, Dict from typing import List, Literal, Dict
MODEL_VERSIONS = { MODEL_VERSIONS = {
"ernie-bot": "completions", "ernie-bot": "completions",
"ernie-bot-4": "completions_pro",
"ernie-bot-turbo": "eb-instant", "ernie-bot-turbo": "eb-instant",
"bloomz-7b": "bloomz_7b1", "bloomz-7b": "bloomz_7b1",
"qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed", "qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed",
@ -67,7 +66,7 @@ def request_qianfan_api(
model_name: str = "qianfan-api", model_name: str = "qianfan-api",
version: str = None, version: str = None,
) -> Dict: ) -> Dict:
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat'\ BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \
'/{model_version}?access_token={access_token}' '/{model_version}?access_token={access_token}'
config = get_model_worker_config(model_name) config = get_model_worker_config(model_name)
version = version or config.get("version") version = version or config.get("version")
@ -83,6 +82,7 @@ def request_qianfan_api(
model_version=version_url or MODEL_VERSIONS[version], model_version=version_url or MODEL_VERSIONS[version],
access_token=access_token, access_token=access_token,
) )
payload = { payload = {
"messages": messages, "messages": messages,
"temperature": temperature, "temperature": temperature,
@ -101,6 +101,7 @@ def request_qianfan_api(
if line.startswith("data: "): if line.startswith("data: "):
line = line[6:] line = line[6:]
resp = json.loads(line) resp = json.loads(line)
breakpoint()
yield resp yield resp
@ -108,6 +109,7 @@ class QianFanWorker(ApiModelWorker):
""" """
百度千帆 百度千帆
""" """
def __init__( def __init__(
self, self,
*, *,
@ -128,7 +130,7 @@ class QianFanWorker(ApiModelWorker):
def generate_stream_gate(self, params): def generate_stream_gate(self, params):
messages = self.prompt_to_messages(params["prompt"]) messages = self.prompt_to_messages(params["prompt"])
text="" text = ""
for resp in request_qianfan_api(messages, for resp in request_qianfan_api(messages,
temperature=params.get("temperature"), temperature=params.get("temperature"),
model_name=self.model_names[0]): model_name=self.model_names[0]):

View File

@ -5,7 +5,7 @@ from fastapi import FastAPI
from pathlib import Path from pathlib import Path
import asyncio import asyncio
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE, from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, LANGCHAIN_LLM_MODEL, logger, log_verbose, MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, logger, log_verbose,
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT) FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
import os import os
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
@ -388,7 +388,6 @@ def list_config_llm_models() -> Dict[str, Dict]:
workers.insert(0, LLM_MODEL) workers.insert(0, LLM_MODEL)
return { return {
"local": MODEL_PATH["llm_model"], "local": MODEL_PATH["llm_model"],
"langchain": LANGCHAIN_LLM_MODEL,
"online": ONLINE_LLM_MODEL, "online": ONLINE_LLM_MODEL,
"worker": workers, "worker": workers,
} }
@ -436,10 +435,6 @@ def get_model_worker_config(model_name: str = None) -> dict:
config.update(ONLINE_LLM_MODEL.get(model_name, {})) config.update(ONLINE_LLM_MODEL.get(model_name, {}))
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {})) config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
# 在线模型API
if model_name in LANGCHAIN_LLM_MODEL:
config["langchain_model"] = True
config["worker_class"] = ""
if model_name in ONLINE_LLM_MODEL: if model_name in ONLINE_LLM_MODEL:
config["online_api"] = True config["online_api"] = True

View File

@ -0,0 +1,26 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.model_workers.azure import request_azure_api
import pytest
from configs import TEMPERATURE, ONLINE_LLM_MODEL
@pytest.mark.parametrize("version", ["azure-api"])
def test_azure(version):
messages = [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "介绍一下自己"}]
for resp in request_azure_api(
messages=messages,
api_base_url=ONLINE_LLM_MODEL["azure-api"]["api_base_url"],
api_key=ONLINE_LLM_MODEL["azure-api"]["api_key"],
deployment_name=ONLINE_LLM_MODEL["azure-api"]["deployment_name"],
api_version=ONLINE_LLM_MODEL["azure-api"]["api_version"],
temperature=TEMPERATURE,
max_tokens=1024,
model_name="azure-api"
):
assert resp["code"] == 200
if __name__ == "__main__":
test_azure("azure-api")

View File

@ -7,7 +7,7 @@ from server.model_workers.baichuan import request_baichuan_api
from pprint import pprint from pprint import pprint
def test_qwen(): def test_baichuan():
messages = [{"role": "user", "content": "hello"}] messages = [{"role": "user", "content": "hello"}]
for x in request_baichuan_api(messages): for x in request_baichuan_api(messages):

View File

@ -4,7 +4,7 @@ from streamlit_chatbox import *
from datetime import datetime from datetime import datetime
import os import os
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES, from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE,LANGCHAIN_LLM_MODEL) DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE)
from typing import List, Dict from typing import List, Dict
@ -83,7 +83,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
return x return x
running_models = list(api.list_running_models()) running_models = list(api.list_running_models())
# running_models += LANGCHAIN_LLM_MODEL.keys()
available_models = [] available_models = []
config_models = api.list_config_models() config_models = api.list_config_models()
worker_models = list(config_models.get("worker", {})) # 仅列出在FSCHAT_MODEL_WORKERS中配置的模型 worker_models = list(config_models.get("worker", {})) # 仅列出在FSCHAT_MODEL_WORKERS中配置的模型
@ -93,8 +92,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型 for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型
if not v.get("provider") and k not in running_models: if not v.get("provider") and k not in running_models:
available_models.append(k) available_models.append(k)
for k, v in config_models.get("langchain", {}).items(): # 列出LANGCHAIN_LLM_MODEL支持的模型
available_models.append(k)
llm_models = running_models + available_models llm_models = running_models + available_models
index = llm_models.index(st.session_state.get("cur_llm_model", api.get_default_llm_model()[0])) index = llm_models.index(st.session_state.get("cur_llm_model", api.get_default_llm_model()[0]))
llm_model = st.selectbox("选择LLM模型", llm_model = st.selectbox("选择LLM模型",