Azure 的fschat支持,弃用Langchain-MODEL (#1873)
* 支持了agentlm * 支持了agentlm和相关提示词 * 修改了Agent的一些功能,加入了Embed方面的一个优化 * 修改了部分Agent的工具 * 增加一些Langchain的自带工具 * 修复一些兼容性的bug * 恢复知识库 * 恢复知识库 * 修复Azure问题
This commit is contained in:
parent
27418ba9fa
commit
35a7ca74c0
|
|
@ -112,46 +112,6 @@ HISTORY_LEN = 3
|
|||
TEMPERATURE = 0.7
|
||||
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
|
||||
|
||||
|
||||
LANGCHAIN_LLM_MODEL = {
|
||||
# 不需要走Fschat封装的,Langchain直接支持的模型。
|
||||
# 调用chatgpt时如果报出: urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
|
||||
# Max retries exceeded with url: /v1/chat/completions
|
||||
# 则需要将urllib3版本修改为1.25.11
|
||||
# 如果依然报urllib3.exceptions.MaxRetryError: HTTPSConnectionPool,则将https改为http
|
||||
# 参考https://zhuanlan.zhihu.com/p/350015032
|
||||
|
||||
# 如果报出:raise NewConnectionError(
|
||||
# urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPSConnection object at 0x000001FE4BDB85E0>:
|
||||
# Failed to establish a new connection: [WinError 10060]
|
||||
# 则是因为内地和香港的IP都被OPENAI封了,需要切换为日本、新加坡等地
|
||||
|
||||
# 如果出现WARNING: Retrying langchain.chat_models.openai.acompletion_with_retry.<locals>._completion_with_retry in
|
||||
# 4.0 seconds as it raised APIConnectionError: Error communicating with OpenAI.
|
||||
# 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置
|
||||
# 比如: "openai_proxy": 'http://127.0.0.1:4780'
|
||||
|
||||
# 这些配置文件的名字不能改动
|
||||
"Azure-OpenAI": {
|
||||
"deployment_name": "your Azure deployment name",
|
||||
"model_version": "0701",
|
||||
"openai_api_type": "azure",
|
||||
"api_base_url": "https://your Azure point.azure.com",
|
||||
"api_version": "2023-07-01-preview",
|
||||
"api_key": "your Azure api key",
|
||||
"openai_proxy": "",
|
||||
},
|
||||
"OpenAI": {
|
||||
"model_name": "your openai model name(such as gpt-4)",
|
||||
"api_base_url": "https://api.openai.com/v1",
|
||||
"api_key": "your OPENAI_API_KEY",
|
||||
"openai_proxy": "",
|
||||
},
|
||||
"Anthropic": {
|
||||
"model_name": "your claude model name(such as claude2-100k)",
|
||||
"api_key":"your ANTHROPIC_API_KEY",
|
||||
}
|
||||
}
|
||||
ONLINE_LLM_MODEL = {
|
||||
# 线上模型。请在server_config中为每个在线API设置不同的端口
|
||||
# 具体注册及api key获取请前往 http://open.bigmodel.cn
|
||||
|
|
@ -205,6 +165,15 @@ ONLINE_LLM_MODEL = {
|
|||
"secret_key": "",
|
||||
"provider": "BaiChuanWorker",
|
||||
},
|
||||
|
||||
# Azure API
|
||||
"azure-api": {
|
||||
"deployment_name": "", # 部署容器的名字
|
||||
"resource_name": "", # https://{resource_name}.openai.azure.com/openai/ 填写resource_name的部分,其他部分不要填写
|
||||
"api_version": "", # API的版本,不是模型版本
|
||||
"api_key": "",
|
||||
"provider": "AzureWorker",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ streamlit-option-menu>=0.3.6
|
|||
streamlit-antd-components>=0.1.11
|
||||
streamlit-chatbox==1.1.10
|
||||
streamlit-aggrid>=0.3.4.post3
|
||||
httpx~=0.24.1
|
||||
httpx[brotli,http2,socks]>=0.25.0
|
||||
watchdog
|
||||
tqdm
|
||||
websockets
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from fastapi import Body
|
||||
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT,LANGCHAIN_LLM_MODEL
|
||||
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
|
||||
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
|
||||
get_httpx_client, get_model_worker_config)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,3 +5,4 @@ from .qianfan import QianFanWorker
|
|||
from .fangzhou import FangZhouWorker
|
||||
from .qwen import QwenWorker
|
||||
from .baichuan import BaiChuanWorker
|
||||
from .azure import AzureWorker
|
||||
|
|
|
|||
|
|
@ -0,0 +1,117 @@
|
|||
import sys
|
||||
from fastchat.conversation import Conversation
|
||||
from server.model_workers.base import ApiModelWorker
|
||||
from server.utils import get_model_worker_config, get_httpx_client
|
||||
from fastchat import conversation as conv
|
||||
import json
|
||||
from typing import List, Dict
|
||||
from configs import TEMPERATURE
|
||||
|
||||
|
||||
def request_azure_api(
|
||||
messages: List[Dict[str, str]],
|
||||
resource_name: str = None,
|
||||
api_key: str = None,
|
||||
deployment_name: str = None,
|
||||
api_version: str = "2023-07-01-preview",
|
||||
temperature: float = TEMPERATURE,
|
||||
max_tokens: int = None,
|
||||
model_name: str = "azure-api",
|
||||
):
|
||||
config = get_model_worker_config(model_name)
|
||||
data = dict(
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=True,
|
||||
)
|
||||
url = f"https://{resource_name}.openai.azure.com/openai/deployments/{deployment_name}/chat/completions?api-version={api_version}"
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
'api-key': api_key,
|
||||
}
|
||||
with get_httpx_client() as client:
|
||||
with client.stream("POST", url, headers=headers, json=data) as response:
|
||||
for line in response.iter_lines():
|
||||
if not line.strip() or "[DONE]" in line:
|
||||
continue
|
||||
if line.startswith("data: "):
|
||||
line = line[6:]
|
||||
resp = json.loads(line)
|
||||
yield resp
|
||||
|
||||
class AzureWorker(ApiModelWorker):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
controller_addr: str,
|
||||
worker_addr: str,
|
||||
model_names: List[str] = ["azure-api"],
|
||||
version: str = "gpt-35-turbo",
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 8192)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
self.resource_name = config.get("resource_name")
|
||||
self.api_key = config.get("api_key")
|
||||
self.api_version = config.get("api_version")
|
||||
self.version = version
|
||||
self.deployment_name = config.get("deployment_name")
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
super().generate_stream_gate(params)
|
||||
|
||||
messages = self.prompt_to_messages(params["prompt"])
|
||||
text = ""
|
||||
for resp in request_azure_api(messages=messages,
|
||||
resource_name=self.resource_name,
|
||||
api_key=self.api_key,
|
||||
api_version=self.api_version,
|
||||
deployment_name=self.deployment_name,
|
||||
temperature=params.get("temperature"),
|
||||
max_tokens=params.get("max_tokens")):
|
||||
if choices := resp["choices"]:
|
||||
if chunk := choices[0].get("delta", {}).get("content"):
|
||||
text += chunk
|
||||
yield json.dumps(
|
||||
{
|
||||
"error_code": 0,
|
||||
"text": text
|
||||
},
|
||||
ensure_ascii=False
|
||||
).encode() + b"\0"
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
# TODO: 确认模板是否需要修改
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="",
|
||||
messages=[],
|
||||
roles=["user", "assistant"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
from server.utils import MakeFastAPIOffline
|
||||
from fastchat.serve.base_model_worker import app
|
||||
|
||||
worker = AzureWorker(
|
||||
controller_addr="http://127.0.0.1:20001",
|
||||
worker_addr="http://127.0.0.1:21008",
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
MakeFastAPIOffline(app)
|
||||
uvicorn.run(app, port=21008)
|
||||
|
|
@ -138,9 +138,9 @@ if __name__ == "__main__":
|
|||
|
||||
worker = BaiChuanWorker(
|
||||
controller_addr="http://127.0.0.1:20001",
|
||||
worker_addr="http://127.0.0.1:21001",
|
||||
worker_addr="http://127.0.0.1:21007",
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
MakeFastAPIOffline(app)
|
||||
uvicorn.run(app, port=21001)
|
||||
uvicorn.run(app, port=21007)
|
||||
# do_request()
|
||||
|
|
@ -1,17 +1,16 @@
|
|||
import sys
|
||||
from fastchat.conversation import Conversation
|
||||
from server.model_workers.base import ApiModelWorker
|
||||
from configs.model_config import TEMPERATURE
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
import json
|
||||
import httpx
|
||||
from cachetools import cached, TTLCache
|
||||
from server.utils import get_model_worker_config, get_httpx_client
|
||||
from typing import List, Literal, Dict
|
||||
|
||||
|
||||
MODEL_VERSIONS = {
|
||||
"ernie-bot": "completions",
|
||||
"ernie-bot-4": "completions_pro",
|
||||
"ernie-bot-turbo": "eb-instant",
|
||||
"bloomz-7b": "bloomz_7b1",
|
||||
"qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed",
|
||||
|
|
@ -46,7 +45,7 @@ MODEL_VERSIONS = {
|
|||
}
|
||||
|
||||
|
||||
@cached(TTLCache(1, 1800)) # 经过测试,缓存的token可以使用,目前每30分钟刷新一次
|
||||
@cached(TTLCache(1, 1800)) # 经过测试,缓存的token可以使用,目前每30分钟刷新一次
|
||||
def get_baidu_access_token(api_key: str, secret_key: str) -> str:
|
||||
"""
|
||||
使用 AK,SK 生成鉴权签名(Access Token)
|
||||
|
|
@ -62,12 +61,12 @@ def get_baidu_access_token(api_key: str, secret_key: str) -> str:
|
|||
|
||||
|
||||
def request_qianfan_api(
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: float = TEMPERATURE,
|
||||
model_name: str = "qianfan-api",
|
||||
version: str = None,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: float = TEMPERATURE,
|
||||
model_name: str = "qianfan-api",
|
||||
version: str = None,
|
||||
) -> Dict:
|
||||
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat'\
|
||||
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \
|
||||
'/{model_version}?access_token={access_token}'
|
||||
config = get_model_worker_config(model_name)
|
||||
version = version or config.get("version")
|
||||
|
|
@ -83,6 +82,7 @@ def request_qianfan_api(
|
|||
model_version=version_url or MODEL_VERSIONS[version],
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
payload = {
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
|
|
@ -101,6 +101,7 @@ def request_qianfan_api(
|
|||
if line.startswith("data: "):
|
||||
line = line[6:]
|
||||
resp = json.loads(line)
|
||||
breakpoint()
|
||||
yield resp
|
||||
|
||||
|
||||
|
|
@ -108,6 +109,7 @@ class QianFanWorker(ApiModelWorker):
|
|||
"""
|
||||
百度千帆
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -128,7 +130,7 @@ class QianFanWorker(ApiModelWorker):
|
|||
|
||||
def generate_stream_gate(self, params):
|
||||
messages = self.prompt_to_messages(params["prompt"])
|
||||
text=""
|
||||
text = ""
|
||||
for resp in request_qianfan_api(messages,
|
||||
temperature=params.get("temperature"),
|
||||
model_name=self.model_names[0]):
|
||||
|
|
@ -147,7 +149,7 @@ class QianFanWorker(ApiModelWorker):
|
|||
},
|
||||
ensure_ascii=False
|
||||
).encode() + b"\0"
|
||||
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
print("embedding")
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from fastapi import FastAPI
|
|||
from pathlib import Path
|
||||
import asyncio
|
||||
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
|
||||
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, LANGCHAIN_LLM_MODEL, logger, log_verbose,
|
||||
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, logger, log_verbose,
|
||||
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
|
@ -388,7 +388,6 @@ def list_config_llm_models() -> Dict[str, Dict]:
|
|||
workers.insert(0, LLM_MODEL)
|
||||
return {
|
||||
"local": MODEL_PATH["llm_model"],
|
||||
"langchain": LANGCHAIN_LLM_MODEL,
|
||||
"online": ONLINE_LLM_MODEL,
|
||||
"worker": workers,
|
||||
}
|
||||
|
|
@ -436,10 +435,6 @@ def get_model_worker_config(model_name: str = None) -> dict:
|
|||
config.update(ONLINE_LLM_MODEL.get(model_name, {}))
|
||||
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
|
||||
|
||||
# 在线模型API
|
||||
if model_name in LANGCHAIN_LLM_MODEL:
|
||||
config["langchain_model"] = True
|
||||
config["worker_class"] = ""
|
||||
|
||||
if model_name in ONLINE_LLM_MODEL:
|
||||
config["online_api"] = True
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from server.model_workers.azure import request_azure_api
|
||||
import pytest
|
||||
from configs import TEMPERATURE, ONLINE_LLM_MODEL
|
||||
@pytest.mark.parametrize("version", ["azure-api"])
|
||||
def test_azure(version):
|
||||
messages = [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "介绍一下自己"}]
|
||||
for resp in request_azure_api(
|
||||
messages=messages,
|
||||
api_base_url=ONLINE_LLM_MODEL["azure-api"]["api_base_url"],
|
||||
api_key=ONLINE_LLM_MODEL["azure-api"]["api_key"],
|
||||
deployment_name=ONLINE_LLM_MODEL["azure-api"]["deployment_name"],
|
||||
api_version=ONLINE_LLM_MODEL["azure-api"]["api_version"],
|
||||
temperature=TEMPERATURE,
|
||||
max_tokens=1024,
|
||||
model_name="azure-api"
|
||||
):
|
||||
assert resp["code"] == 200
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_azure("azure-api")
|
||||
|
|
@ -7,7 +7,7 @@ from server.model_workers.baichuan import request_baichuan_api
|
|||
from pprint import pprint
|
||||
|
||||
|
||||
def test_qwen():
|
||||
def test_baichuan():
|
||||
messages = [{"role": "user", "content": "hello"}]
|
||||
|
||||
for x in request_baichuan_api(messages):
|
||||
|
|
|
|||
|
|
@ -17,4 +17,4 @@ def test_qianfan(version):
|
|||
pprint(x)
|
||||
assert isinstance(x, dict)
|
||||
assert "error_code" not in x
|
||||
i += 1
|
||||
i += 1
|
||||
|
|
@ -4,7 +4,7 @@ from streamlit_chatbox import *
|
|||
from datetime import datetime
|
||||
import os
|
||||
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
|
||||
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE,LANGCHAIN_LLM_MODEL)
|
||||
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE)
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
|
|
@ -83,7 +83,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||
return x
|
||||
|
||||
running_models = list(api.list_running_models())
|
||||
# running_models += LANGCHAIN_LLM_MODEL.keys()
|
||||
available_models = []
|
||||
config_models = api.list_config_models()
|
||||
worker_models = list(config_models.get("worker", {})) # 仅列出在FSCHAT_MODEL_WORKERS中配置的模型
|
||||
|
|
@ -93,8 +92,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||
for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型
|
||||
if not v.get("provider") and k not in running_models:
|
||||
available_models.append(k)
|
||||
for k, v in config_models.get("langchain", {}).items(): # 列出LANGCHAIN_LLM_MODEL支持的模型
|
||||
available_models.append(k)
|
||||
llm_models = running_models + available_models
|
||||
index = llm_models.index(st.session_state.get("cur_llm_model", api.get_default_llm_model()[0]))
|
||||
llm_model = st.selectbox("选择LLM模型:",
|
||||
|
|
|
|||
Loading…
Reference in New Issue