Azure 的fschat支持,弃用Langchain-MODEL (#1873)

* 支持了agentlm

* 支持了agentlm和相关提示词

* 修改了Agent的一些功能,加入了Embed方面的一个优化

* 修改了部分Agent的工具

* 增加一些Langchain的自带工具

* 修复一些兼容性的bug

* 恢复知识库

* 恢复知识库

* 修复Azure问题
This commit is contained in:
zR 2023-10-25 21:32:40 +08:00 committed by GitHub
parent 27418ba9fa
commit 35a7ca74c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 174 additions and 67 deletions

View File

@ -112,46 +112,6 @@ HISTORY_LEN = 3
TEMPERATURE = 0.7
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
LANGCHAIN_LLM_MODEL = {
# 不需要走Fschat封装的Langchain直接支持的模型。
# 调用chatgpt时如果报出 urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
# Max retries exceeded with url: /v1/chat/completions
# 则需要将urllib3版本修改为1.25.11
# 如果依然报urllib3.exceptions.MaxRetryError: HTTPSConnectionPool则将https改为http
# 参考https://zhuanlan.zhihu.com/p/350015032
# 如果报出raise NewConnectionError(
# urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPSConnection object at 0x000001FE4BDB85E0>:
# Failed to establish a new connection: [WinError 10060]
# 则是因为内地和香港的IP都被OPENAI封了需要切换为日本、新加坡等地
# 如果出现WARNING: Retrying langchain.chat_models.openai.acompletion_with_retry.<locals>._completion_with_retry in
# 4.0 seconds as it raised APIConnectionError: Error communicating with OpenAI.
# 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置
# 比如: "openai_proxy": 'http://127.0.0.1:4780'
# 这些配置文件的名字不能改动
"Azure-OpenAI": {
"deployment_name": "your Azure deployment name",
"model_version": "0701",
"openai_api_type": "azure",
"api_base_url": "https://your Azure point.azure.com",
"api_version": "2023-07-01-preview",
"api_key": "your Azure api key",
"openai_proxy": "",
},
"OpenAI": {
"model_name": "your openai model name(such as gpt-4)",
"api_base_url": "https://api.openai.com/v1",
"api_key": "your OPENAI_API_KEY",
"openai_proxy": "",
},
"Anthropic": {
"model_name": "your claude model name(such as claude2-100k)",
"api_key":"your ANTHROPIC_API_KEY",
}
}
ONLINE_LLM_MODEL = {
# 线上模型。请在server_config中为每个在线API设置不同的端口
# 具体注册及api key获取请前往 http://open.bigmodel.cn
@ -205,6 +165,15 @@ ONLINE_LLM_MODEL = {
"secret_key": "",
"provider": "BaiChuanWorker",
},
# Azure API
"azure-api": {
"deployment_name": "", # 部署容器的名字
"resource_name": "", # https://{resource_name}.openai.azure.com/openai/ 填写resource_name的部分其他部分不要填写
"api_version": "", # API的版本不是模型版本
"api_key": "",
"provider": "AzureWorker",
},
}

View File

@ -50,7 +50,7 @@ streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11
streamlit-chatbox==1.1.10
streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1
httpx[brotli,http2,socks]>=0.25.0
watchdog
tqdm
websockets

View File

@ -1,5 +1,5 @@
from fastapi import Body
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT,LANGCHAIN_LLM_MODEL
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
get_httpx_client, get_model_worker_config)

View File

@ -5,3 +5,4 @@ from .qianfan import QianFanWorker
from .fangzhou import FangZhouWorker
from .qwen import QwenWorker
from .baichuan import BaiChuanWorker
from .azure import AzureWorker

View File

@ -0,0 +1,117 @@
import sys
from fastchat.conversation import Conversation
from server.model_workers.base import ApiModelWorker
from server.utils import get_model_worker_config, get_httpx_client
from fastchat import conversation as conv
import json
from typing import List, Dict
from configs import TEMPERATURE
def request_azure_api(
messages: List[Dict[str, str]],
resource_name: str = None,
api_key: str = None,
deployment_name: str = None,
api_version: str = "2023-07-01-preview",
temperature: float = TEMPERATURE,
max_tokens: int = None,
model_name: str = "azure-api",
):
config = get_model_worker_config(model_name)
data = dict(
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=True,
)
url = f"https://{resource_name}.openai.azure.com/openai/deployments/{deployment_name}/chat/completions?api-version={api_version}"
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
'api-key': api_key,
}
with get_httpx_client() as client:
with client.stream("POST", url, headers=headers, json=data) as response:
for line in response.iter_lines():
if not line.strip() or "[DONE]" in line:
continue
if line.startswith("data: "):
line = line[6:]
resp = json.loads(line)
yield resp
class AzureWorker(ApiModelWorker):
def __init__(
self,
*,
controller_addr: str,
worker_addr: str,
model_names: List[str] = ["azure-api"],
version: str = "gpt-35-turbo",
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 8192)
super().__init__(**kwargs)
config = self.get_config()
self.resource_name = config.get("resource_name")
self.api_key = config.get("api_key")
self.api_version = config.get("api_version")
self.version = version
self.deployment_name = config.get("deployment_name")
def generate_stream_gate(self, params):
super().generate_stream_gate(params)
messages = self.prompt_to_messages(params["prompt"])
text = ""
for resp in request_azure_api(messages=messages,
resource_name=self.resource_name,
api_key=self.api_key,
api_version=self.api_version,
deployment_name=self.deployment_name,
temperature=params.get("temperature"),
max_tokens=params.get("max_tokens")):
if choices := resp["choices"]:
if chunk := choices[0].get("delta", {}).get("content"):
text += chunk
yield json.dumps(
{
"error_code": 0,
"text": text
},
ensure_ascii=False
).encode() + b"\0"
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.base_model_worker import app
worker = AzureWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21008",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21008)

View File

@ -138,9 +138,9 @@ if __name__ == "__main__":
worker = BaiChuanWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21001",
worker_addr="http://127.0.0.1:21007",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21001)
uvicorn.run(app, port=21007)
# do_request()

View File

@ -1,17 +1,16 @@
import sys
from fastchat.conversation import Conversation
from server.model_workers.base import ApiModelWorker
from configs.model_config import TEMPERATURE
from fastchat import conversation as conv
import sys
import json
import httpx
from cachetools import cached, TTLCache
from server.utils import get_model_worker_config, get_httpx_client
from typing import List, Literal, Dict
MODEL_VERSIONS = {
"ernie-bot": "completions",
"ernie-bot-4": "completions_pro",
"ernie-bot-turbo": "eb-instant",
"bloomz-7b": "bloomz_7b1",
"qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed",
@ -46,7 +45,7 @@ MODEL_VERSIONS = {
}
@cached(TTLCache(1, 1800)) # 经过测试缓存的token可以使用目前每30分钟刷新一次
@cached(TTLCache(1, 1800)) # 经过测试缓存的token可以使用目前每30分钟刷新一次
def get_baidu_access_token(api_key: str, secret_key: str) -> str:
"""
使用 AKSK 生成鉴权签名Access Token
@ -62,12 +61,12 @@ def get_baidu_access_token(api_key: str, secret_key: str) -> str:
def request_qianfan_api(
messages: List[Dict[str, str]],
temperature: float = TEMPERATURE,
model_name: str = "qianfan-api",
version: str = None,
messages: List[Dict[str, str]],
temperature: float = TEMPERATURE,
model_name: str = "qianfan-api",
version: str = None,
) -> Dict:
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat'\
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \
'/{model_version}?access_token={access_token}'
config = get_model_worker_config(model_name)
version = version or config.get("version")
@ -83,6 +82,7 @@ def request_qianfan_api(
model_version=version_url or MODEL_VERSIONS[version],
access_token=access_token,
)
payload = {
"messages": messages,
"temperature": temperature,
@ -101,6 +101,7 @@ def request_qianfan_api(
if line.startswith("data: "):
line = line[6:]
resp = json.loads(line)
breakpoint()
yield resp
@ -108,6 +109,7 @@ class QianFanWorker(ApiModelWorker):
"""
百度千帆
"""
def __init__(
self,
*,
@ -128,7 +130,7 @@ class QianFanWorker(ApiModelWorker):
def generate_stream_gate(self, params):
messages = self.prompt_to_messages(params["prompt"])
text=""
text = ""
for resp in request_qianfan_api(messages,
temperature=params.get("temperature"),
model_name=self.model_names[0]):
@ -147,7 +149,7 @@ class QianFanWorker(ApiModelWorker):
},
ensure_ascii=False
).encode() + b"\0"
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")

View File

@ -5,7 +5,7 @@ from fastapi import FastAPI
from pathlib import Path
import asyncio
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, LANGCHAIN_LLM_MODEL, logger, log_verbose,
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, logger, log_verbose,
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
@ -388,7 +388,6 @@ def list_config_llm_models() -> Dict[str, Dict]:
workers.insert(0, LLM_MODEL)
return {
"local": MODEL_PATH["llm_model"],
"langchain": LANGCHAIN_LLM_MODEL,
"online": ONLINE_LLM_MODEL,
"worker": workers,
}
@ -436,10 +435,6 @@ def get_model_worker_config(model_name: str = None) -> dict:
config.update(ONLINE_LLM_MODEL.get(model_name, {}))
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
# 在线模型API
if model_name in LANGCHAIN_LLM_MODEL:
config["langchain_model"] = True
config["worker_class"] = ""
if model_name in ONLINE_LLM_MODEL:
config["online_api"] = True

View File

@ -0,0 +1,26 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.model_workers.azure import request_azure_api
import pytest
from configs import TEMPERATURE, ONLINE_LLM_MODEL
@pytest.mark.parametrize("version", ["azure-api"])
def test_azure(version):
messages = [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "介绍一下自己"}]
for resp in request_azure_api(
messages=messages,
api_base_url=ONLINE_LLM_MODEL["azure-api"]["api_base_url"],
api_key=ONLINE_LLM_MODEL["azure-api"]["api_key"],
deployment_name=ONLINE_LLM_MODEL["azure-api"]["deployment_name"],
api_version=ONLINE_LLM_MODEL["azure-api"]["api_version"],
temperature=TEMPERATURE,
max_tokens=1024,
model_name="azure-api"
):
assert resp["code"] == 200
if __name__ == "__main__":
test_azure("azure-api")

View File

@ -7,7 +7,7 @@ from server.model_workers.baichuan import request_baichuan_api
from pprint import pprint
def test_qwen():
def test_baichuan():
messages = [{"role": "user", "content": "hello"}]
for x in request_baichuan_api(messages):

View File

@ -17,4 +17,4 @@ def test_qianfan(version):
pprint(x)
assert isinstance(x, dict)
assert "error_code" not in x
i += 1
i += 1

View File

@ -4,7 +4,7 @@ from streamlit_chatbox import *
from datetime import datetime
import os
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE,LANGCHAIN_LLM_MODEL)
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE)
from typing import List, Dict
@ -83,7 +83,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
return x
running_models = list(api.list_running_models())
# running_models += LANGCHAIN_LLM_MODEL.keys()
available_models = []
config_models = api.list_config_models()
worker_models = list(config_models.get("worker", {})) # 仅列出在FSCHAT_MODEL_WORKERS中配置的模型
@ -93,8 +92,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型
if not v.get("provider") and k not in running_models:
available_models.append(k)
for k, v in config_models.get("langchain", {}).items(): # 列出LANGCHAIN_LLM_MODEL支持的模型
available_models.append(k)
llm_models = running_models + available_models
index = llm_models.index(st.session_state.get("cur_llm_model", api.get_default_llm_model()[0]))
llm_model = st.selectbox("选择LLM模型",