清理不必要的依赖,增加星火API需要的websockets (#1463)

This commit is contained in:
liunux4odoo 2023-09-13 15:35:04 +08:00 committed by GitHub
parent 99b862dfc8
commit c4cb4e19e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 5 additions and 63 deletions

View File

@ -1,7 +1,6 @@
langchain==0.0.287
fschat[model_worker]==0.2.28
openai
zhipuai
sentence_transformers
transformers>=4.31.0
torch~=2.0.0
@ -34,3 +33,4 @@ streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1
watchdog
tqdm
websockets

View File

@ -18,6 +18,7 @@ accelerate
spacy
PyMuPDF==1.22.5
rapidocr_onnxruntime>=1.3.1
websockets
# uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3

View File

@ -7,4 +7,5 @@ streamlit-chatbox>=1.1.6
streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1
nltk
watchdog
watchdog
websockets

View File

@ -1,18 +1,13 @@
import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
from urllib.parse import urlparse
import ssl
from datetime import datetime
from time import mktime
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
import websocket # 使用websocket_client
answer = ""
class Ws_Param(object):
# 初始化
@ -57,46 +52,6 @@ class Ws_Param(object):
return url
# 收到websocket错误的处理
def on_error(ws, error):
print("### error:", error)
# 收到websocket关闭的处理
def on_close(ws,one,two):
print(" ")
# 收到websocket连接建立的处理
def on_open(ws):
thread.start_new_thread(run, (ws,))
def run(ws, *args):
data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question, temperature=ws.temperature))
ws.send(data)
# 收到websocket消息的处理
def on_message(ws, message):
# print(message)
data = json.loads(message)
code = data['header']['code']
if code != 0:
print(f'请求错误: {code}, {data}')
ws.close()
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
print(content,end ="")
global answer
answer += content
# print(1)
if status == 2:
ws.close()
def gen_params(appid, domain,question, temperature):
"""
通过appid和用户的提问来生成请参数
@ -122,16 +77,3 @@ def gen_params(appid, domain,question, temperature):
}
}
return data
def main(appid, api_key, api_secret, Spark_url,domain, question, temperature):
# print("星火:")
wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
websocket.enableTrace(False)
wsUrl = wsParam.create_url()
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
ws.appid = appid
ws.question = question
ws.domain = domain
ws.temperature = temperature
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})

View File

@ -2,12 +2,10 @@ from server.model_workers.base import ApiModelWorker
from fastchat import conversation as conv
import sys
import json
import httpx
from pprint import pprint
from server.model_workers import SparkApi
import websockets
from server.utils import iter_over_async, asyncio
from typing import List, Dict
from typing import List
async def request(appid, api_key, api_secret, Spark_url,domain, question, temperature):