清理不必要的依赖,增加星火API需要的websockets (#1463)
This commit is contained in:
parent
99b862dfc8
commit
c4cb4e19e5
|
|
@ -1,7 +1,6 @@
|
|||
langchain==0.0.287
|
||||
fschat[model_worker]==0.2.28
|
||||
openai
|
||||
zhipuai
|
||||
sentence_transformers
|
||||
transformers>=4.31.0
|
||||
torch~=2.0.0
|
||||
|
|
@ -34,3 +33,4 @@ streamlit-aggrid>=0.3.4.post3
|
|||
httpx~=0.24.1
|
||||
watchdog
|
||||
tqdm
|
||||
websockets
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ accelerate
|
|||
spacy
|
||||
PyMuPDF==1.22.5
|
||||
rapidocr_onnxruntime>=1.3.1
|
||||
websockets
|
||||
|
||||
# uncomment libs if you want to use corresponding vector store
|
||||
# pymilvus==2.1.3 # requires milvus==2.1.3
|
||||
|
|
|
|||
|
|
@ -7,4 +7,5 @@ streamlit-chatbox>=1.1.6
|
|||
streamlit-aggrid>=0.3.4.post3
|
||||
httpx~=0.24.1
|
||||
nltk
|
||||
watchdog
|
||||
watchdog
|
||||
websockets
|
||||
|
|
|
|||
|
|
@ -1,18 +1,13 @@
|
|||
import _thread as thread
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
from urllib.parse import urlparse
|
||||
import ssl
|
||||
from datetime import datetime
|
||||
from time import mktime
|
||||
from urllib.parse import urlencode
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
import websocket # 使用websocket_client
|
||||
answer = ""
|
||||
|
||||
class Ws_Param(object):
|
||||
# 初始化
|
||||
|
|
@ -57,46 +52,6 @@ class Ws_Param(object):
|
|||
return url
|
||||
|
||||
|
||||
# 收到websocket错误的处理
|
||||
def on_error(ws, error):
|
||||
print("### error:", error)
|
||||
|
||||
|
||||
# 收到websocket关闭的处理
|
||||
def on_close(ws,one,two):
|
||||
print(" ")
|
||||
|
||||
|
||||
# 收到websocket连接建立的处理
|
||||
def on_open(ws):
|
||||
thread.start_new_thread(run, (ws,))
|
||||
|
||||
|
||||
def run(ws, *args):
|
||||
data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question, temperature=ws.temperature))
|
||||
ws.send(data)
|
||||
|
||||
|
||||
# 收到websocket消息的处理
|
||||
def on_message(ws, message):
|
||||
# print(message)
|
||||
data = json.loads(message)
|
||||
code = data['header']['code']
|
||||
if code != 0:
|
||||
print(f'请求错误: {code}, {data}')
|
||||
ws.close()
|
||||
else:
|
||||
choices = data["payload"]["choices"]
|
||||
status = choices["status"]
|
||||
content = choices["text"][0]["content"]
|
||||
print(content,end ="")
|
||||
global answer
|
||||
answer += content
|
||||
# print(1)
|
||||
if status == 2:
|
||||
ws.close()
|
||||
|
||||
|
||||
def gen_params(appid, domain,question, temperature):
|
||||
"""
|
||||
通过appid和用户的提问来生成请参数
|
||||
|
|
@ -122,16 +77,3 @@ def gen_params(appid, domain,question, temperature):
|
|||
}
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
def main(appid, api_key, api_secret, Spark_url,domain, question, temperature):
|
||||
# print("星火:")
|
||||
wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
|
||||
websocket.enableTrace(False)
|
||||
wsUrl = wsParam.create_url()
|
||||
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
|
||||
ws.appid = appid
|
||||
ws.question = question
|
||||
ws.domain = domain
|
||||
ws.temperature = temperature
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
|
|
|
|||
|
|
@ -2,12 +2,10 @@ from server.model_workers.base import ApiModelWorker
|
|||
from fastchat import conversation as conv
|
||||
import sys
|
||||
import json
|
||||
import httpx
|
||||
from pprint import pprint
|
||||
from server.model_workers import SparkApi
|
||||
import websockets
|
||||
from server.utils import iter_over_async, asyncio
|
||||
from typing import List, Dict
|
||||
from typing import List
|
||||
|
||||
|
||||
async def request(appid, api_key, api_secret, Spark_url,domain, question, temperature):
|
||||
|
|
|
|||
Loading…
Reference in New Issue