修复Azure 不设置Max token的bug (#2254)

This commit is contained in:
zR 2023-12-02 16:50:56 +08:00 committed by GitHub
parent 12113be6ec
commit dcb76984bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 22 additions and 16 deletions

View File

@ -5,4 +5,4 @@ from .server_config import *
from .prompt_config import * from .prompt_config import *
VERSION = "v0.2.8-preview" VERSION = "v0.2.9-preview"

View File

@ -5,7 +5,7 @@ langchain-experimental>=0.0.42
pydantic==1.10.13 pydantic==1.10.13
fschat>=0.2.33 fschat>=0.2.33
xformers>=0.0.22.post7 xformers>=0.0.22.post7
openai>=1.3.6 openai>=1.3.7
sentence_transformers sentence_transformers
transformers>=4.35.2 transformers>=4.35.2
torch==2.1.0 ##on Windows system, install the cuda version manually from https://pytorch.org/ torch==2.1.0 ##on Windows system, install the cuda version manually from https://pytorch.org/
@ -53,7 +53,7 @@ vllm==0.2.2; sys_platform == "linux"
arxiv>=2.0.0 arxiv>=2.0.0
youtube-search>=2.1.2 youtube-search>=2.1.2
duckduckgo-search>=4.9.3 duckduckgo-search>=3.9.9
metaphor-python>=0.1.23 metaphor-python>=0.1.23
# WebUI requirements # WebUI requirements

View File

@ -5,7 +5,7 @@ langchain-experimental>=0.0.42
pydantic==1.10.13 pydantic==1.10.13
fschat>=0.2.33 fschat>=0.2.33
xformers>=0.0.22.post7 xformers>=0.0.22.post7
openai>=1.3.6 openai>=1.3.7
sentence_transformers sentence_transformers
transformers>=4.35.2 transformers>=4.35.2
torch==2.1.0 ##on Windows system, install the cuda version manually from https://pytorch.org/ torch==2.1.0 ##on Windows system, install the cuda version manually from https://pytorch.org/
@ -37,6 +37,7 @@ pandas~=2.0.3
einops>=0.7.0 einops>=0.7.0
transformers_stream_generator==0.0.4 transformers_stream_generator==0.0.4
vllm==0.2.2; sys_platform == "linux" vllm==0.2.2; sys_platform == "linux"
httpx[brotli,http2,socks]>=0.25.2
# Online api libs dependencies # Online api libs dependencies
@ -53,5 +54,5 @@ vllm==0.2.2; sys_platform == "linux"
arxiv>=2.0.0 arxiv>=2.0.0
youtube-search>=2.1.2 youtube-search>=2.1.2
duckduckgo-search>=4.9.3 duckduckgo-search>=3.9.9
metaphor-python>=0.1.23 metaphor-python>=0.1.23

View File

@ -1,7 +1,7 @@
langchain==0.0.344 langchain==0.0.344
pydantic==1.10.13 pydantic==1.10.13
fschat>=0.2.33 fschat>=0.2.33
openai>=1.3.6 openai>=1.3.7
fastapi>=0.104.1 fastapi>=0.104.1
python-multipart python-multipart
nltk~=3.8.1 nltk~=3.8.1
@ -39,14 +39,15 @@ zhipuai>=1.0.7 # zhipu
numpy~=1.24.4 numpy~=1.24.4
pandas~=2.0.3 pandas~=2.0.3
streamlit~=1.28.1 streamlit>=1.29.0
streamlit-option-menu>=0.3.6 streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11 streamlit-antd-components>=0.2.3
streamlit-chatbox==1.1.11 streamlit-chatbox>=1.1.11
streamlit-modal>=0.1.0
streamlit-aggrid>=0.3.4.post3 streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1 httpx[brotli,http2,socks]>=0.25.2
watchdog watchdog>=3.0.0
tqdm tqdm>=4.66.1
websockets websockets
einops>=0.7.0 einops>=0.7.0
@ -58,5 +59,5 @@ einops>=0.7.0
arxiv>=2.0.0 arxiv>=2.0.0
youtube-search>=2.1.2 youtube-search>=2.1.2
duckduckgo-search>=4.9.3 duckduckgo-search>=3.9.9
metaphor-python>=0.1.23 metaphor-python>=0.1.23

View File

@ -1,4 +1,5 @@
import sys import sys
import os
from fastchat.conversation import Conversation from fastchat.conversation import Conversation
from server.model_workers.base import * from server.model_workers.base import *
from server.utils import get_httpx_client from server.utils import get_httpx_client
@ -19,16 +20,16 @@ class AzureWorker(ApiModelWorker):
**kwargs, **kwargs,
): ):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 8000) #TODO 16K模型需要改成16384
super().__init__(**kwargs) super().__init__(**kwargs)
self.version = version self.version = version
def do_chat(self, params: ApiChatParams) -> Dict: def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0]) params.load_config(self.model_names[0])
data = dict( data = dict(
messages=params.messages, messages=params.messages,
temperature=params.temperature, temperature=params.temperature,
max_tokens=params.max_tokens, max_tokens=params.max_tokens if params.max_tokens else None,
stream=True, stream=True,
) )
url = ("https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}" url = ("https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}"
@ -47,6 +48,7 @@ class AzureWorker(ApiModelWorker):
with get_httpx_client() as client: with get_httpx_client() as client:
with client.stream("POST", url, headers=headers, json=data) as response: with client.stream("POST", url, headers=headers, json=data) as response:
print(data)
for line in response.iter_lines(): for line in response.iter_lines():
if not line.strip() or "[DONE]" in line: if not line.strip() or "[DONE]" in line:
continue continue
@ -60,6 +62,7 @@ class AzureWorker(ApiModelWorker):
"error_code": 0, "error_code": 0,
"text": text "text": text
} }
print(text)
else: else:
self.logger.error(f"请求 Azure API 时发生错误:{resp}") self.logger.error(f"请求 Azure API 时发生错误:{resp}")
@ -91,4 +94,4 @@ if __name__ == "__main__":
) )
sys.modules["fastchat.serve.model_worker"].worker = worker sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app) MakeFastAPIOffline(app)
uvicorn.run(app, port=21008) uvicorn.run(app, port=21008)

View File

@ -43,6 +43,7 @@ def get_ChatOpenAI(
config = get_model_worker_config(model_name) config = get_model_worker_config(model_name)
if model_name == "openai-api": if model_name == "openai-api":
model_name = config.get("model_name") model_name = config.get("model_name")
model = ChatOpenAI( model = ChatOpenAI(
streaming=streaming, streaming=streaming,
verbose=verbose, verbose=verbose,