From c4cb4e19e53a9eb6ec571f03d3de5aba6f529b7c Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Wed, 13 Sep 2023 15:35:04 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B8=85=E7=90=86=E4=B8=8D=E5=BF=85=E8=A6=81?= =?UTF-8?q?=E7=9A=84=E4=BE=9D=E8=B5=96=EF=BC=8C=E5=A2=9E=E5=8A=A0=E6=98=9F?= =?UTF-8?q?=E7=81=ABAPI=E9=9C=80=E8=A6=81=E7=9A=84websockets=20(#1463)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | 2 +- requirements_api.txt | 1 + requirements_webui.txt | 3 +- server/model_workers/SparkApi.py | 58 -------------------------------- server/model_workers/xinghuo.py | 4 +-- 5 files changed, 5 insertions(+), 63 deletions(-) diff --git a/requirements.txt b/requirements.txt index 8badee0..abd0f1b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ langchain==0.0.287 fschat[model_worker]==0.2.28 openai -zhipuai sentence_transformers transformers>=4.31.0 torch~=2.0.0 @@ -34,3 +33,4 @@ streamlit-aggrid>=0.3.4.post3 httpx~=0.24.1 watchdog tqdm +websockets diff --git a/requirements_api.txt b/requirements_api.txt index fa4cf1e..9a05708 100644 --- a/requirements_api.txt +++ b/requirements_api.txt @@ -18,6 +18,7 @@ accelerate spacy PyMuPDF==1.22.5 rapidocr_onnxruntime>=1.3.1 +websockets # uncomment libs if you want to use corresponding vector store # pymilvus==2.1.3 # requires milvus==2.1.3 diff --git a/requirements_webui.txt b/requirements_webui.txt index da66c30..8d49ae0 100644 --- a/requirements_webui.txt +++ b/requirements_webui.txt @@ -7,4 +7,5 @@ streamlit-chatbox>=1.1.6 streamlit-aggrid>=0.3.4.post3 httpx~=0.24.1 nltk -watchdog \ No newline at end of file +watchdog +websockets diff --git a/server/model_workers/SparkApi.py b/server/model_workers/SparkApi.py index 3b4f096..e1dce6a 100644 --- a/server/model_workers/SparkApi.py +++ b/server/model_workers/SparkApi.py @@ -1,18 +1,13 @@ -import _thread as thread import base64 import datetime import hashlib import hmac -import json from urllib.parse import urlparse -import ssl from datetime import datetime from time import mktime from urllib.parse import urlencode from wsgiref.handlers import format_date_time -import websocket # 使用websocket_client -answer = "" class Ws_Param(object): # 初始化 @@ -57,46 +52,6 @@ class Ws_Param(object): return url -# 收到websocket错误的处理 -def on_error(ws, error): - print("### error:", error) - - -# 收到websocket关闭的处理 -def on_close(ws,one,two): - print(" ") - - -# 收到websocket连接建立的处理 -def on_open(ws): - thread.start_new_thread(run, (ws,)) - - -def run(ws, *args): - data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question, temperature=ws.temperature)) - ws.send(data) - - -# 收到websocket消息的处理 -def on_message(ws, message): - # print(message) - data = json.loads(message) - code = data['header']['code'] - if code != 0: - print(f'请求错误: {code}, {data}') - ws.close() - else: - choices = data["payload"]["choices"] - status = choices["status"] - content = choices["text"][0]["content"] - print(content,end ="") - global answer - answer += content - # print(1) - if status == 2: - ws.close() - - def gen_params(appid, domain,question, temperature): """ 通过appid和用户的提问来生成请参数 @@ -122,16 +77,3 @@ def gen_params(appid, domain,question, temperature): } } return data - - -def main(appid, api_key, api_secret, Spark_url,domain, question, temperature): - # print("星火:") - wsParam = Ws_Param(appid, api_key, api_secret, Spark_url) - websocket.enableTrace(False) - wsUrl = wsParam.create_url() - ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open) - ws.appid = appid - ws.question = question - ws.domain = domain - ws.temperature = temperature - ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) diff --git a/server/model_workers/xinghuo.py b/server/model_workers/xinghuo.py index 9c90765..499e8bc 100644 --- a/server/model_workers/xinghuo.py +++ b/server/model_workers/xinghuo.py @@ -2,12 +2,10 @@ from server.model_workers.base import ApiModelWorker from fastchat import conversation as conv import sys import json -import httpx -from pprint import pprint from server.model_workers import SparkApi import websockets from server.utils import iter_over_async, asyncio -from typing import List, Dict +from typing import List async def request(appid, api_key, api_secret, Spark_url,domain, question, temperature):