清理不必要的依赖,增加星火API需要的websockets (#1463)
This commit is contained in:
parent
99b862dfc8
commit
c4cb4e19e5
|
|
@ -1,7 +1,6 @@
|
||||||
langchain==0.0.287
|
langchain==0.0.287
|
||||||
fschat[model_worker]==0.2.28
|
fschat[model_worker]==0.2.28
|
||||||
openai
|
openai
|
||||||
zhipuai
|
|
||||||
sentence_transformers
|
sentence_transformers
|
||||||
transformers>=4.31.0
|
transformers>=4.31.0
|
||||||
torch~=2.0.0
|
torch~=2.0.0
|
||||||
|
|
@ -34,3 +33,4 @@ streamlit-aggrid>=0.3.4.post3
|
||||||
httpx~=0.24.1
|
httpx~=0.24.1
|
||||||
watchdog
|
watchdog
|
||||||
tqdm
|
tqdm
|
||||||
|
websockets
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ accelerate
|
||||||
spacy
|
spacy
|
||||||
PyMuPDF==1.22.5
|
PyMuPDF==1.22.5
|
||||||
rapidocr_onnxruntime>=1.3.1
|
rapidocr_onnxruntime>=1.3.1
|
||||||
|
websockets
|
||||||
|
|
||||||
# uncomment libs if you want to use corresponding vector store
|
# uncomment libs if you want to use corresponding vector store
|
||||||
# pymilvus==2.1.3 # requires milvus==2.1.3
|
# pymilvus==2.1.3 # requires milvus==2.1.3
|
||||||
|
|
|
||||||
|
|
@ -8,3 +8,4 @@ streamlit-aggrid>=0.3.4.post3
|
||||||
httpx~=0.24.1
|
httpx~=0.24.1
|
||||||
nltk
|
nltk
|
||||||
watchdog
|
watchdog
|
||||||
|
websockets
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,13 @@
|
||||||
import _thread as thread
|
|
||||||
import base64
|
import base64
|
||||||
import datetime
|
import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
import json
|
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
import ssl
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from time import mktime
|
from time import mktime
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
from wsgiref.handlers import format_date_time
|
from wsgiref.handlers import format_date_time
|
||||||
|
|
||||||
import websocket # 使用websocket_client
|
|
||||||
answer = ""
|
|
||||||
|
|
||||||
class Ws_Param(object):
|
class Ws_Param(object):
|
||||||
# 初始化
|
# 初始化
|
||||||
|
|
@ -57,46 +52,6 @@ class Ws_Param(object):
|
||||||
return url
|
return url
|
||||||
|
|
||||||
|
|
||||||
# 收到websocket错误的处理
|
|
||||||
def on_error(ws, error):
|
|
||||||
print("### error:", error)
|
|
||||||
|
|
||||||
|
|
||||||
# 收到websocket关闭的处理
|
|
||||||
def on_close(ws,one,two):
|
|
||||||
print(" ")
|
|
||||||
|
|
||||||
|
|
||||||
# 收到websocket连接建立的处理
|
|
||||||
def on_open(ws):
|
|
||||||
thread.start_new_thread(run, (ws,))
|
|
||||||
|
|
||||||
|
|
||||||
def run(ws, *args):
|
|
||||||
data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question, temperature=ws.temperature))
|
|
||||||
ws.send(data)
|
|
||||||
|
|
||||||
|
|
||||||
# 收到websocket消息的处理
|
|
||||||
def on_message(ws, message):
|
|
||||||
# print(message)
|
|
||||||
data = json.loads(message)
|
|
||||||
code = data['header']['code']
|
|
||||||
if code != 0:
|
|
||||||
print(f'请求错误: {code}, {data}')
|
|
||||||
ws.close()
|
|
||||||
else:
|
|
||||||
choices = data["payload"]["choices"]
|
|
||||||
status = choices["status"]
|
|
||||||
content = choices["text"][0]["content"]
|
|
||||||
print(content,end ="")
|
|
||||||
global answer
|
|
||||||
answer += content
|
|
||||||
# print(1)
|
|
||||||
if status == 2:
|
|
||||||
ws.close()
|
|
||||||
|
|
||||||
|
|
||||||
def gen_params(appid, domain,question, temperature):
|
def gen_params(appid, domain,question, temperature):
|
||||||
"""
|
"""
|
||||||
通过appid和用户的提问来生成请参数
|
通过appid和用户的提问来生成请参数
|
||||||
|
|
@ -122,16 +77,3 @@ def gen_params(appid, domain,question, temperature):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def main(appid, api_key, api_secret, Spark_url,domain, question, temperature):
|
|
||||||
# print("星火:")
|
|
||||||
wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
|
|
||||||
websocket.enableTrace(False)
|
|
||||||
wsUrl = wsParam.create_url()
|
|
||||||
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
|
|
||||||
ws.appid = appid
|
|
||||||
ws.question = question
|
|
||||||
ws.domain = domain
|
|
||||||
ws.temperature = temperature
|
|
||||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,10 @@ from server.model_workers.base import ApiModelWorker
|
||||||
from fastchat import conversation as conv
|
from fastchat import conversation as conv
|
||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
import httpx
|
|
||||||
from pprint import pprint
|
|
||||||
from server.model_workers import SparkApi
|
from server.model_workers import SparkApi
|
||||||
import websockets
|
import websockets
|
||||||
from server.utils import iter_over_async, asyncio
|
from server.utils import iter_over_async, asyncio
|
||||||
from typing import List, Dict
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
async def request(appid, api_key, api_secret, Spark_url,domain, question, temperature):
|
async def request(appid, api_key, api_secret, Spark_url,domain, question, temperature):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue