清理不必要的依赖,增加星火API需要的websockets (#1463)

This commit is contained in:
liunux4odoo 2023-09-13 15:35:04 +08:00 committed by GitHub
parent 99b862dfc8
commit c4cb4e19e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 5 additions and 63 deletions

View File

@ -1,7 +1,6 @@
langchain==0.0.287 langchain==0.0.287
fschat[model_worker]==0.2.28 fschat[model_worker]==0.2.28
openai openai
zhipuai
sentence_transformers sentence_transformers
transformers>=4.31.0 transformers>=4.31.0
torch~=2.0.0 torch~=2.0.0
@ -34,3 +33,4 @@ streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1 httpx~=0.24.1
watchdog watchdog
tqdm tqdm
websockets

View File

@ -18,6 +18,7 @@ accelerate
spacy spacy
PyMuPDF==1.22.5 PyMuPDF==1.22.5
rapidocr_onnxruntime>=1.3.1 rapidocr_onnxruntime>=1.3.1
websockets
# uncomment libs if you want to use corresponding vector store # uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3 # pymilvus==2.1.3 # requires milvus==2.1.3

View File

@ -8,3 +8,4 @@ streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1 httpx~=0.24.1
nltk nltk
watchdog watchdog
websockets

View File

@ -1,18 +1,13 @@
import _thread as thread
import base64 import base64
import datetime import datetime
import hashlib import hashlib
import hmac import hmac
import json
from urllib.parse import urlparse from urllib.parse import urlparse
import ssl
from datetime import datetime from datetime import datetime
from time import mktime from time import mktime
from urllib.parse import urlencode from urllib.parse import urlencode
from wsgiref.handlers import format_date_time from wsgiref.handlers import format_date_time
import websocket # 使用websocket_client
answer = ""
class Ws_Param(object): class Ws_Param(object):
# 初始化 # 初始化
@ -57,46 +52,6 @@ class Ws_Param(object):
return url return url
# 收到websocket错误的处理
def on_error(ws, error):
print("### error:", error)
# 收到websocket关闭的处理
def on_close(ws,one,two):
print(" ")
# 收到websocket连接建立的处理
def on_open(ws):
thread.start_new_thread(run, (ws,))
def run(ws, *args):
data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question, temperature=ws.temperature))
ws.send(data)
# 收到websocket消息的处理
def on_message(ws, message):
# print(message)
data = json.loads(message)
code = data['header']['code']
if code != 0:
print(f'请求错误: {code}, {data}')
ws.close()
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
print(content,end ="")
global answer
answer += content
# print(1)
if status == 2:
ws.close()
def gen_params(appid, domain,question, temperature): def gen_params(appid, domain,question, temperature):
""" """
通过appid和用户的提问来生成请参数 通过appid和用户的提问来生成请参数
@ -122,16 +77,3 @@ def gen_params(appid, domain,question, temperature):
} }
} }
return data return data
def main(appid, api_key, api_secret, Spark_url,domain, question, temperature):
# print("星火:")
wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
websocket.enableTrace(False)
wsUrl = wsParam.create_url()
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
ws.appid = appid
ws.question = question
ws.domain = domain
ws.temperature = temperature
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})

View File

@ -2,12 +2,10 @@ from server.model_workers.base import ApiModelWorker
from fastchat import conversation as conv from fastchat import conversation as conv
import sys import sys
import json import json
import httpx
from pprint import pprint
from server.model_workers import SparkApi from server.model_workers import SparkApi
import websockets import websockets
from server.utils import iter_over_async, asyncio from server.utils import iter_over_async, asyncio
from typing import List, Dict from typing import List
async def request(appid, api_key, api_secret, Spark_url,domain, question, temperature): async def request(appid, api_key, api_secret, Spark_url,domain, question, temperature):