merge dev
This commit is contained in:
commit
641fcaee40
53
README.md
53
README.md
|
|
@ -15,6 +15,7 @@
|
|||
* [3. 设置配置项](README.md#3.-设置配置项)
|
||||
* [4. 知识库初始化与迁移](README.md#4.-知识库初始化与迁移)
|
||||
* [5. 启动 API 服务或 Web UI](README.md#5.-启动-API-服务或-Web-UI)
|
||||
* [6. 一键启动](README.md#6.-一键启动)
|
||||
* [常见问题](README.md#常见问题)
|
||||
* [路线图](README.md#路线图)
|
||||
* [项目交流群](README.md#项目交流群)
|
||||
|
|
@ -248,14 +249,23 @@ gpus=None,
|
|||
num_gpus=1,
|
||||
max_gpu_memory="20GiB"
|
||||
```
|
||||
其中,`gpus` 控制使用的显卡的ID,如果 "0,1";
|
||||
|
||||
`num_gpus` 控制使用的卡数;
|
||||
其中,`gpus` 控制使用的显卡的ID,如果 "0,1";
|
||||
|
||||
`num_gpus` 控制使用的卡数;
|
||||
|
||||
`max_gpu_memory` 控制每个卡使用的显存容量。
|
||||
|
||||
##### 5.1.2 基于命令行脚本 llm_api_launch.py 启动 LLM 服务
|
||||
|
||||
**!!!注意:**
|
||||
|
||||
**1.llm_api_launch.py脚本仅适用于linux和mac设备,win平台请使用wls;**
|
||||
|
||||
**2.加载非默认模型需要用命令行参数--model-path-address指定指定模型,不会读取model_config.py配置;**
|
||||
|
||||
**!!!**
|
||||
|
||||
在项目根目录下,执行 [server/llm_api_launch.py](server/llm_api.py) 脚本启动 **LLM 模型**服务:
|
||||
|
||||
```shell
|
||||
|
|
@ -274,7 +284,7 @@ $ python server/llm_api_launch.py --model-path-addresss model1@host1@port1 model
|
|||
$ python server/llm_api_launch.py --gpus 0,1 --num-gpus 2 --max-gpu-memory 10GiB
|
||||
```
|
||||
|
||||
注:以如上方式启动LLM服务会以nohup命令在后台运行 fastchat 服务,如需停止服务,可以运行如下命令,但该脚本**仅适用于linux和mac平台**:
|
||||
注:以如上方式启动LLM服务会以nohup命令在后台运行 FastChat 服务,如需停止服务,可以运行如下命令:
|
||||
|
||||
```shell
|
||||
$ python server/llm_api_shutdown.py --serve all
|
||||
|
|
@ -349,11 +359,17 @@ $ streamlit run webui.py --server.port 666
|
|||
|
||||
---
|
||||
|
||||
### 6 一键启动:
|
||||
### 6. 一键启动
|
||||
|
||||
#### 6.1 api服务一键启动脚本
|
||||
⚠️ **注意:**
|
||||
|
||||
新增api一键启动脚本,可一键开启fastchat后台服务及本项目提供的langchain api服务,调用示例:
|
||||
**1. 一键启动脚本仅适用于 Linux 和 Mac 设备, Winodws 平台请使用 WLS;**
|
||||
|
||||
**2. 加载非默认模型需要用命令行参数 `--model-path-address` 指定指定模型,不会读取 `model_config.py` 配置。**
|
||||
|
||||
#### 6.1 API 服务一键启动脚本
|
||||
|
||||
新增 API 一键启动脚本,可一键开启 FastChat 后台服务及本项目提供的 API 服务,调用示例:
|
||||
|
||||
调用默认模型:
|
||||
|
||||
|
|
@ -372,32 +388,49 @@ $ python server/api_allinone.py --model-path-address model1@host1@port1 model2@h
|
|||
```shell
|
||||
python server/api_allinone.py --model-path-address model@host@port --num-gpus 2 --gpus 0,1 --max-gpu-memory 10GiB
|
||||
```
|
||||
其他参数详见各脚本及fastchat服务说明。
|
||||
|
||||
其他参数详见各脚本及 FastChat 服务说明。
|
||||
|
||||
#### 6.2 webui一键启动脚本
|
||||
|
||||
加载本地模型:
|
||||
|
||||
```shell
|
||||
$ python webui_allinone.py
|
||||
```
|
||||
|
||||
调用远程api服务:
|
||||
调用远程 API 服务:
|
||||
|
||||
```shell
|
||||
$ python webui_allinone.py --use-remote-api
|
||||
```
|
||||
|
||||
后台运行webui服务:
|
||||
|
||||
```shell
|
||||
$ python webui_allinone.py --nohup
|
||||
```
|
||||
|
||||
加载多个非默认模型:
|
||||
|
||||
```shell
|
||||
$ python webui_allinone.py --model-path-address model1@host1@port1 model2@host2@port2
|
||||
```
|
||||
|
||||
多卡启动:
|
||||
|
||||
```shell
|
||||
$ python webui_alline.py --model-path-address model@host@port --num-gpus 2 --gpus 0,1 --max-gpu-memory 10GiB
|
||||
```
|
||||
python webui_alline.py --model-path-address model@host@port --num-gpus 2 --gpus 0,1 --max-gpu-memory 10GiB
|
||||
```
|
||||
|
||||
其他参数详见各脚本及fastchat服务说明。
|
||||
|
||||
上述两个一键启动脚本会后台运行多个服务,如要停止所有服务,可使用 `shutdown_all.sh` 脚本:
|
||||
|
||||
```shell
|
||||
bash shutdown_all.sh
|
||||
```
|
||||
|
||||
## 常见问题
|
||||
|
||||
参见 [常见问题](docs/FAQ.md)。
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ llm_model_dict = {
|
|||
|
||||
"chatglm-6b-int4": {
|
||||
"local_model_path": "THUDM/chatglm-6b-int4",
|
||||
"api_base_url": "http://localhost:8001/v1", # "name"修改为fastchat服务中的"api_base_url"
|
||||
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
|
||||
"api_key": "EMPTY"
|
||||
},
|
||||
|
||||
|
|
@ -64,7 +64,7 @@ llm_model_dict = {
|
|||
|
||||
"vicuna-13b-hf": {
|
||||
"local_model_path": "",
|
||||
"api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url"
|
||||
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
|
||||
"api_key": "EMPTY"
|
||||
},
|
||||
|
||||
|
|
@ -85,6 +85,7 @@ llm_model_dict = {
|
|||
},
|
||||
}
|
||||
|
||||
|
||||
# LLM 名称
|
||||
LLM_MODEL = "chatglm2-6b"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
langchain==0.0.257
|
||||
openai
|
||||
sentence_transformers
|
||||
fschat==0.2.20
|
||||
transformers
|
||||
fschat==0.2.24
|
||||
transformers>=4.31.0
|
||||
torch~=2.0.0
|
||||
fastapi~=0.99.1
|
||||
fastapi-offline
|
||||
nltk~=3.8.1
|
||||
uvicorn~=0.23.1
|
||||
starlette~=0.27.0
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
langchain==0.0.257
|
||||
openai
|
||||
sentence_transformers
|
||||
fschat==0.2.20
|
||||
transformers
|
||||
fschat==0.2.24
|
||||
transformers>=4.31.0
|
||||
torch~=2.0.0
|
||||
fastapi~=0.99.1
|
||||
fastapi-offline
|
||||
nltk~=3.8.1
|
||||
uvicorn~=0.23.1
|
||||
starlette~=0.27.0
|
||||
|
|
|
|||
|
|
@ -7,15 +7,17 @@ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
|||
from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN
|
||||
import argparse
|
||||
import uvicorn
|
||||
from server.utils import FastAPIOffline as FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.responses import RedirectResponse
|
||||
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
||||
search_engine_chat)
|
||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
|
||||
update_doc, download_doc, recreate_vector_store)
|
||||
from server.utils import BaseResponse, ListResponse
|
||||
update_doc, download_doc, recreate_vector_store,
|
||||
search_docs, DocumentWithScore)
|
||||
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
|
||||
from typing import List
|
||||
|
||||
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
|
||||
|
|
@ -25,7 +27,8 @@ async def document():
|
|||
|
||||
|
||||
def create_app():
|
||||
app = FastAPI()
|
||||
app = FastAPI(title="Langchain-Chatchat API Server")
|
||||
MakeFastAPIOffline(app)
|
||||
# Add CORS middleware to allow all origins
|
||||
# 在config.py中设置OPEN_DOMAIN=True,允许跨域
|
||||
# set OPEN_DOMAIN=True in config.py to allow cross-domain
|
||||
|
|
@ -83,6 +86,12 @@ def create_app():
|
|||
summary="获取知识库内的文件列表"
|
||||
)(list_docs)
|
||||
|
||||
app.post("/knowledge_base/search_docs",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=List[DocumentWithScore],
|
||||
summary="搜索知识库"
|
||||
)(search_docs)
|
||||
|
||||
app.post("/knowledge_base/upload_doc",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
|
|
|
|||
|
|
@ -22,9 +22,7 @@ parser.add_argument("--api-host", type=str, default="0.0.0.0")
|
|||
parser.add_argument("--api-port", type=int, default=7861)
|
||||
parser.add_argument("--ssl_keyfile", type=str)
|
||||
parser.add_argument("--ssl_certfile", type=str)
|
||||
# 初始化消息
|
||||
args = parser.parse_args()
|
||||
args_dict = vars(args)
|
||||
|
||||
|
||||
api_args = ["api-host","api-port","ssl_keyfile","ssl_certfile"]
|
||||
|
||||
|
|
@ -41,6 +39,11 @@ def run_api(host, port, **kwargs):
|
|||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Luanching api_allinone,it would take a while, please be patient...")
|
||||
print("正在启动api_allinone,LLM服务启动约3-10分钟,请耐心等待...")
|
||||
# 初始化消息
|
||||
args = parser.parse_args()
|
||||
args_dict = vars(args)
|
||||
launch_all(args=args,controller_args=controller_args,worker_args=worker_args,server_args=server_args)
|
||||
run_api(
|
||||
host=args.api_host,
|
||||
|
|
@ -48,3 +51,5 @@ if __name__ == "__main__":
|
|||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,
|
||||
)
|
||||
print("Luanching api_allinone done.")
|
||||
print("api_allinone启动完毕.")
|
||||
|
|
|
|||
|
|
@ -1,26 +1,27 @@
|
|||
from fastapi import Body, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
||||
VECTOR_SEARCH_TOP_K)
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
from server.chat.utils import wrap_done
|
||||
from server.utils import BaseResponse
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable
|
||||
from typing import AsyncIterable, List, Optional
|
||||
import asyncio
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from typing import List, Optional
|
||||
from server.chat.utils import History
|
||||
from server.knowledge_base.kb_service.base import KBService, KBServiceFactory
|
||||
import json
|
||||
import os
|
||||
from urllib.parse import urlencode
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
|
||||
|
||||
def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||||
history: List[History] = Body([],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
|
|
@ -53,7 +54,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
|||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL
|
||||
)
|
||||
docs = kb.search_docs(query, top_k)
|
||||
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from fastapi.responses import StreamingResponse
|
||||
from typing import List
|
||||
import openai
|
||||
from configs.model_config import llm_model_dict, LLM_MODEL
|
||||
from configs.model_config import llm_model_dict, LLM_MODEL, logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
|
@ -33,19 +33,23 @@ async def openai_chat(msg: OpenAiChatMsgIn):
|
|||
data = msg.dict()
|
||||
data["streaming"] = True
|
||||
data.pop("stream")
|
||||
response = openai.ChatCompletion.create(**data)
|
||||
|
||||
if msg.stream:
|
||||
for chunk in response.choices[0].message.content:
|
||||
print(chunk)
|
||||
yield chunk
|
||||
else:
|
||||
answer = ""
|
||||
for chunk in response.choices[0].message.content:
|
||||
answer += chunk
|
||||
print(answer)
|
||||
yield(answer)
|
||||
|
||||
try:
|
||||
response = openai.ChatCompletion.create(**data)
|
||||
if msg.stream:
|
||||
for chunk in response.choices[0].message.content:
|
||||
print(chunk)
|
||||
yield chunk
|
||||
else:
|
||||
answer = ""
|
||||
for chunk in response.choices[0].message.content:
|
||||
answer += chunk
|
||||
print(answer)
|
||||
yield(answer)
|
||||
except Exception as e:
|
||||
print(type(e))
|
||||
logger.error(e)
|
||||
|
||||
return StreamingResponse(
|
||||
get_response(msg),
|
||||
media_type='text/event-stream',
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ def add_doc_to_db(session, kb_file: KnowledgeFile):
|
|||
file_ext=kb_file.ext,
|
||||
kb_name=kb_file.kb_name,
|
||||
document_loader_name=kb_file.document_loader_name,
|
||||
text_splitter_name=kb_file.text_splitter_name,
|
||||
text_splitter_name=kb_file.text_splitter_name or "SpacyTextSplitter",
|
||||
)
|
||||
kb.file_count += 1
|
||||
session.add(new_file)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,32 @@
|
|||
import os
|
||||
import urllib
|
||||
from fastapi import File, Form, Body, Query, UploadFile
|
||||
from configs.model_config import DEFAULT_VS_TYPE, EMBEDDING_MODEL
|
||||
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
from server.utils import BaseResponse, ListResponse
|
||||
from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
import json
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from typing import List
|
||||
from typing import List, Dict
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
||||
class DocumentWithScore(Document):
|
||||
score: float = None
|
||||
|
||||
|
||||
def search_docs(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||||
) -> List[DocumentWithScore]:
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}", "docs": []}
|
||||
docs = kb.search_docs(query, top_k, score_threshold)
|
||||
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def list_docs(
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from server.db.repository.knowledge_file_repository import (
|
|||
list_docs_from_db, get_file_detail, delete_file_from_db
|
||||
)
|
||||
|
||||
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K,
|
||||
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
EMBEDDING_DEVICE, EMBEDDING_MODEL)
|
||||
from server.knowledge_base.utils import (
|
||||
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
|
||||
|
|
@ -112,9 +112,10 @@ class KBService(ABC):
|
|||
def search_docs(self,
|
||||
query: str,
|
||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||
score_threshold: float = SCORE_THRESHOLD,
|
||||
):
|
||||
embeddings = self._load_embeddings()
|
||||
docs = self.do_search(query, top_k, embeddings)
|
||||
docs = self.do_search(query, top_k, score_threshold, embeddings)
|
||||
return docs
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -81,12 +81,13 @@ class FaissKBService(KBService):
|
|||
def do_search(self,
|
||||
query: str,
|
||||
top_k: int,
|
||||
embeddings: Embeddings,
|
||||
score_threshold: float = SCORE_THRESHOLD,
|
||||
embeddings: Embeddings = None,
|
||||
) -> List[Document]:
|
||||
search_index = load_vector_store(self.kb_name,
|
||||
embeddings=embeddings,
|
||||
tick=_VECTOR_STORE_TICKS.get(self.kb_name))
|
||||
docs = search_index.similarity_search(query, k=top_k, score_threshold=SCORE_THRESHOLD)
|
||||
docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
|
||||
return docs
|
||||
|
||||
def do_add_doc(self,
|
||||
|
|
|
|||
|
|
@ -45,7 +45,8 @@ class MilvusKBService(KBService):
|
|||
def do_drop_kb(self):
|
||||
self.milvus.col.drop()
|
||||
|
||||
def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]:
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
|
||||
# todo: support score threshold
|
||||
self._load_milvus(embeddings=embeddings)
|
||||
return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD)
|
||||
|
||||
|
|
|
|||
|
|
@ -43,7 +43,8 @@ class PGKBService(KBService):
|
|||
'''))
|
||||
connect.commit()
|
||||
|
||||
def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]:
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
|
||||
# todo: support score threshold
|
||||
self._load_pg_vector(embeddings=embeddings)
|
||||
return self.pg_vector.similarity_search(query, top_k)
|
||||
|
||||
|
|
|
|||
|
|
@ -102,6 +102,7 @@ class KnowledgeFile:
|
|||
chunk_size=CHUNK_SIZE,
|
||||
chunk_overlap=OVERLAP_SIZE,
|
||||
)
|
||||
self.text_splitter_name = "SpacyTextSplitter"
|
||||
else:
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, self.text_splitter_name)
|
||||
|
|
|
|||
|
|
@ -1,16 +1,18 @@
|
|||
from multiprocessing import Process, Queue
|
||||
import multiprocessing as mp
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger
|
||||
from server.utils import MakeFastAPIOffline
|
||||
|
||||
|
||||
host_ip = "0.0.0.0"
|
||||
controller_port = 20001
|
||||
model_worker_port = 20002
|
||||
openai_api_port = 8888
|
||||
base_url = "http://127.0.0.1:{}"
|
||||
queue = Queue()
|
||||
|
||||
|
||||
def set_httpx_timeout(timeout=60.0):
|
||||
|
|
@ -30,33 +32,50 @@ def create_controller_app(
|
|||
controller = Controller(dispatch_method)
|
||||
sys.modules["fastchat.serve.controller"].controller = controller
|
||||
|
||||
MakeFastAPIOffline(app)
|
||||
app.title = "FastChat Controller"
|
||||
return app
|
||||
|
||||
|
||||
def create_model_worker_app(
|
||||
worker_address=base_url.format(model_worker_port),
|
||||
controller_address=base_url.format(controller_port),
|
||||
model_path=llm_model_dict[LLM_MODEL].get("local_model_path"),
|
||||
model_names=[LLM_MODEL],
|
||||
device=LLM_DEVICE,
|
||||
gpus=None,
|
||||
max_gpu_memory="20GiB",
|
||||
load_8bit=False,
|
||||
cpu_offloading=None,
|
||||
gptq_ckpt=None,
|
||||
gptq_wbits=16,
|
||||
gptq_groupsize=-1,
|
||||
gptq_act_order=None,
|
||||
gpus=None,
|
||||
num_gpus=1,
|
||||
max_gpu_memory="20GiB",
|
||||
cpu_offloading=None,
|
||||
worker_address=base_url.format(model_worker_port),
|
||||
controller_address=base_url.format(controller_port),
|
||||
gptq_act_order=False,
|
||||
awq_ckpt=None,
|
||||
awq_wbits=16,
|
||||
awq_groupsize=-1,
|
||||
model_names=[LLM_MODEL],
|
||||
num_gpus=1, # not in fastchat
|
||||
conv_template=None,
|
||||
limit_worker_concurrency=5,
|
||||
stream_interval=2,
|
||||
no_register=False,
|
||||
):
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.model_worker import app, GptqConfig, ModelWorker, worker_id
|
||||
from fastchat.serve import model_worker
|
||||
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
|
||||
import argparse
|
||||
import threading
|
||||
import fastchat.serve.model_worker
|
||||
|
||||
# workaround to make program exit with Ctrl+c
|
||||
# it should be deleted after pr is merged by fastchat
|
||||
def _new_init_heart_beat(self):
|
||||
self.register_to_controller()
|
||||
self.heart_beat_thread = threading.Thread(
|
||||
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
|
||||
)
|
||||
self.heart_beat_thread.start()
|
||||
ModelWorker.init_heart_beat = _new_init_heart_beat
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
args = parser.parse_args()
|
||||
|
|
@ -68,12 +87,16 @@ def create_model_worker_app(
|
|||
args.gptq_wbits = gptq_wbits
|
||||
args.gptq_groupsize = gptq_groupsize
|
||||
args.gptq_act_order = gptq_act_order
|
||||
args.awq_ckpt = awq_ckpt
|
||||
args.awq_wbits = awq_wbits
|
||||
args.awq_groupsize = awq_groupsize
|
||||
args.gpus = gpus
|
||||
args.num_gpus = num_gpus
|
||||
args.max_gpu_memory = max_gpu_memory
|
||||
args.cpu_offloading = cpu_offloading
|
||||
args.worker_address = worker_address
|
||||
args.controller_address = controller_address
|
||||
args.conv_template = conv_template
|
||||
args.limit_worker_concurrency = limit_worker_concurrency
|
||||
args.stream_interval = stream_interval
|
||||
args.no_register = no_register
|
||||
|
|
@ -95,6 +118,12 @@ def create_model_worker_app(
|
|||
groupsize=args.gptq_groupsize,
|
||||
act_order=args.gptq_act_order,
|
||||
)
|
||||
awq_config = AWQConfig(
|
||||
ckpt=args.awq_ckpt or args.model_path,
|
||||
wbits=args.awq_wbits,
|
||||
groupsize=args.awq_groupsize,
|
||||
)
|
||||
|
||||
# torch.multiprocessing.set_start_method('spawn')
|
||||
worker = ModelWorker(
|
||||
controller_addr=args.controller_address,
|
||||
|
|
@ -110,19 +139,21 @@ def create_model_worker_app(
|
|||
load_8bit=args.load_8bit,
|
||||
cpu_offloading=args.cpu_offloading,
|
||||
gptq_config=gptq_config,
|
||||
awq_config=awq_config,
|
||||
stream_interval=args.stream_interval,
|
||||
conv_template=args.conv_template,
|
||||
)
|
||||
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
sys.modules["fastchat.serve.model_worker"].args = args
|
||||
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
||||
|
||||
MakeFastAPIOffline(app)
|
||||
app.title = f"FastChat LLM Server ({LLM_MODEL})"
|
||||
return app
|
||||
|
||||
|
||||
def create_openai_api_app(
|
||||
host=host_ip,
|
||||
port=openai_api_port,
|
||||
controller_address=base_url.format(controller_port),
|
||||
api_keys=[],
|
||||
):
|
||||
|
|
@ -141,6 +172,8 @@ def create_openai_api_app(
|
|||
app_settings.controller_address = controller_address
|
||||
app_settings.api_keys = api_keys
|
||||
|
||||
MakeFastAPIOffline(app)
|
||||
app.title = "FastChat OpeanAI API Server"
|
||||
return app
|
||||
|
||||
|
||||
|
|
@ -193,6 +226,8 @@ def run_openai_api(q):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mp.set_start_method("spawn")
|
||||
queue = Queue()
|
||||
logger.info(llm_model_dict[LLM_MODEL])
|
||||
model_path = llm_model_dict[LLM_MODEL]["local_model_path"]
|
||||
|
||||
|
|
@ -209,15 +244,14 @@ if __name__ == "__main__":
|
|||
)
|
||||
controller_process.start()
|
||||
|
||||
# cuda 没办法用在fork的多进程中
|
||||
# model_worker_process = Process(
|
||||
# target=run_model_worker,
|
||||
# name=f"model_worker({os.getpid()})",
|
||||
# args=(queue,),
|
||||
# # kwargs={"load_8bit": True},
|
||||
# daemon=True,
|
||||
# )
|
||||
# model_worker_process.start()
|
||||
model_worker_process = Process(
|
||||
target=run_model_worker,
|
||||
name=f"model_worker({os.getpid()})",
|
||||
args=(queue,),
|
||||
# kwargs={"load_8bit": True},
|
||||
daemon=True,
|
||||
)
|
||||
model_worker_process.start()
|
||||
|
||||
openai_api_process = Process(
|
||||
target=run_openai_api,
|
||||
|
|
@ -227,11 +261,14 @@ if __name__ == "__main__":
|
|||
)
|
||||
openai_api_process.start()
|
||||
|
||||
run_model_worker(queue)
|
||||
|
||||
controller_process.join()
|
||||
# model_worker_process.join()
|
||||
openai_api_process.join()
|
||||
try:
|
||||
model_worker_process.join()
|
||||
controller_process.join()
|
||||
openai_api_process.join()
|
||||
except KeyboardInterrupt:
|
||||
model_worker_process.terminate()
|
||||
controller_process.terminate()
|
||||
openai_api_process.terminate()
|
||||
|
||||
# 服务启动后接口调用示例:
|
||||
# import openai
|
||||
|
|
|
|||
|
|
@ -159,17 +159,7 @@ server_args = ["server-host", "server-port", "allow-credentials", "api-keys",
|
|||
"controller-address"
|
||||
]
|
||||
|
||||
args = parser.parse_args()
|
||||
# 必须要加http//:,否则InvalidSchema: No connection adapters were found
|
||||
args = argparse.Namespace(**vars(args),
|
||||
**{"controller-address": f"http://{args.controller_host}:{str(args.controller_port)}"})
|
||||
|
||||
if args.gpus:
|
||||
if len(args.gpus.split(",")) < args.num_gpus:
|
||||
raise ValueError(
|
||||
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
|
||||
)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
||||
|
||||
# 0,controller, model_worker, openai_api_server
|
||||
# 1, 命令行选项
|
||||
|
|
@ -211,12 +201,13 @@ def string_args(args, args_list):
|
|||
return args_str
|
||||
|
||||
|
||||
def launch_worker(item,args=args,worker_args=worker_args):
|
||||
def launch_worker(item,args,worker_args=worker_args):
|
||||
log_name = item.split("/")[-1].split("\\")[-1].replace("-", "_").replace("@", "_").replace(".", "_")
|
||||
# 先分割model-path-address,在传到string_args中分析参数
|
||||
args.model_path, args.worker_host, args.worker_port = item.split("@")
|
||||
args.worker_address = f"http://{args.worker_host}:{args.worker_port}"
|
||||
print("*" * 80)
|
||||
print(f"如长时间未启动,请到{LOG_PATH}{log_name}.log下查看日志")
|
||||
worker_str_args = string_args(args, worker_args)
|
||||
print(worker_str_args)
|
||||
worker_sh = base_launch_sh.format("model_worker", worker_str_args, LOG_PATH, f"worker_{log_name}")
|
||||
|
|
@ -225,30 +216,44 @@ def launch_worker(item,args=args,worker_args=worker_args):
|
|||
subprocess.run(worker_check_sh, shell=True, check=True)
|
||||
|
||||
|
||||
def launch_all(args=args,
|
||||
def launch_all(args,
|
||||
controller_args=controller_args,
|
||||
worker_args=worker_args,
|
||||
server_args=server_args
|
||||
):
|
||||
print(f"Launching llm service,logs are located in {LOG_PATH}...")
|
||||
print(f"开始启动LLM服务,请到{LOG_PATH}下监控各模块日志...")
|
||||
controller_str_args = string_args(args, controller_args)
|
||||
controller_sh = base_launch_sh.format("controller", controller_str_args, LOG_PATH, "controller")
|
||||
controller_check_sh = base_check_sh.format(LOG_PATH, "controller", "controller")
|
||||
subprocess.run(controller_sh, shell=True, check=True)
|
||||
subprocess.run(controller_check_sh, shell=True, check=True)
|
||||
|
||||
print(f"worker启动时间视设备不同而不同,约需3-10分钟,请耐心等待...")
|
||||
if isinstance(args.model_path_address, str):
|
||||
launch_worker(args.model_path_address,args=args,worker_args=worker_args)
|
||||
else:
|
||||
for idx, item in enumerate(args.model_path_address):
|
||||
print(f"开始加载第{idx}个模型:{item}")
|
||||
launch_worker(item)
|
||||
launch_worker(item,args=args,worker_args=worker_args)
|
||||
|
||||
server_str_args = string_args(args, server_args)
|
||||
server_sh = base_launch_sh.format("openai_api_server", server_str_args, LOG_PATH, "openai_api_server")
|
||||
server_check_sh = base_check_sh.format(LOG_PATH, "openai_api_server", "openai_api_server")
|
||||
subprocess.run(server_sh, shell=True, check=True)
|
||||
subprocess.run(server_check_sh, shell=True, check=True)
|
||||
|
||||
print("Launching LLM service done!")
|
||||
print("LLM服务启动完毕。")
|
||||
|
||||
if __name__ == "__main__":
|
||||
launch_all()
|
||||
args = parser.parse_args()
|
||||
# 必须要加http//:,否则InvalidSchema: No connection adapters were found
|
||||
args = argparse.Namespace(**vars(args),
|
||||
**{"controller-address": f"http://{args.controller_host}:{str(args.controller_port)}"})
|
||||
|
||||
if args.gpus:
|
||||
if len(args.gpus.split(",")) < args.num_gpus:
|
||||
raise ValueError(
|
||||
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
|
||||
)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
||||
launch_all(args=args)
|
||||
|
|
|
|||
Binary file not shown.
|
After Width: | Height: | Size: 7.1 KiB |
|
|
@ -2,14 +2,10 @@ import pydantic
|
|||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
import torch
|
||||
from fastapi_offline import FastAPIOffline
|
||||
import fastapi_offline
|
||||
from fastapi import FastAPI
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
|
||||
|
||||
# patch fastapi_offline to use local static assests
|
||||
fastapi_offline.core._STATIC_PATH = Path(__file__).parent / "static"
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
|
|
@ -112,3 +108,81 @@ def iter_over_async(ait, loop):
|
|||
if done:
|
||||
break
|
||||
yield obj
|
||||
|
||||
|
||||
def MakeFastAPIOffline(
|
||||
app: FastAPI,
|
||||
static_dir = Path(__file__).parent / "static",
|
||||
static_url = "/static-offline-docs",
|
||||
docs_url: Optional[str] = "/docs",
|
||||
redoc_url: Optional[str] = "/redoc",
|
||||
) -> None:
|
||||
"""patch the FastAPI obj that doesn't rely on CDN for the documentation page"""
|
||||
from fastapi import Request
|
||||
from fastapi.openapi.docs import (
|
||||
get_redoc_html,
|
||||
get_swagger_ui_html,
|
||||
get_swagger_ui_oauth2_redirect_html,
|
||||
)
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette.responses import HTMLResponse
|
||||
|
||||
openapi_url = app.openapi_url
|
||||
swagger_ui_oauth2_redirect_url = app.swagger_ui_oauth2_redirect_url
|
||||
|
||||
def remove_route(url: str) -> None:
|
||||
'''
|
||||
remove original route from app
|
||||
'''
|
||||
index = None
|
||||
for i, r in enumerate(app.routes):
|
||||
if r.path.lower() == url.lower():
|
||||
index = i
|
||||
break
|
||||
if isinstance(index, int):
|
||||
app.routes.pop(i)
|
||||
|
||||
# Set up static file mount
|
||||
app.mount(
|
||||
static_url,
|
||||
StaticFiles(directory=Path(static_dir).as_posix()),
|
||||
name="static-offline-docs",
|
||||
)
|
||||
|
||||
if docs_url is not None:
|
||||
remove_route(docs_url)
|
||||
remove_route(swagger_ui_oauth2_redirect_url)
|
||||
|
||||
# Define the doc and redoc pages, pointing at the right files
|
||||
@app.get(docs_url, include_in_schema=False)
|
||||
async def custom_swagger_ui_html(request: Request) -> HTMLResponse:
|
||||
root = request.scope.get("root_path")
|
||||
favicon = f"{root}{static_url}/favicon.png"
|
||||
return get_swagger_ui_html(
|
||||
openapi_url=f"{root}{openapi_url}",
|
||||
title=app.title + " - Swagger UI",
|
||||
oauth2_redirect_url=swagger_ui_oauth2_redirect_url,
|
||||
swagger_js_url=f"{root}{static_url}/swagger-ui-bundle.js",
|
||||
swagger_css_url=f"{root}{static_url}/swagger-ui.css",
|
||||
swagger_favicon_url=favicon,
|
||||
)
|
||||
|
||||
@app.get(swagger_ui_oauth2_redirect_url, include_in_schema=False)
|
||||
async def swagger_ui_redirect() -> HTMLResponse:
|
||||
return get_swagger_ui_oauth2_redirect_html()
|
||||
|
||||
if redoc_url is not None:
|
||||
remove_route(redoc_url)
|
||||
|
||||
@app.get(redoc_url, include_in_schema=False)
|
||||
async def redoc_html(request: Request) -> HTMLResponse:
|
||||
root = request.scope.get("root_path")
|
||||
favicon = f"{root}{static_url}/favicon.png"
|
||||
|
||||
return get_redoc_html(
|
||||
openapi_url=f"{root}{openapi_url}",
|
||||
title=app.title + " - ReDoc",
|
||||
redoc_js_url=f"{root}{static_url}/redoc.standalone.js",
|
||||
with_google_fonts=False,
|
||||
redoc_favicon_url=favicon,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
ps -eo pid,user,cmd|grep -P 'server/api.py|webui.py|fastchat.serve'|grep -v grep|awk '{print $1}'|xargs kill -9
|
||||
6
webui.py
6
webui.py
|
|
@ -13,7 +13,11 @@ import os
|
|||
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
st.set_page_config("Langchain-Chatchat WebUI", initial_sidebar_state="expanded")
|
||||
st.set_page_config(
|
||||
"Langchain-Chatchat WebUI",
|
||||
os.path.join("img", "chatchat_icon_blue_square_v2.png"),
|
||||
initial_sidebar_state="expanded",
|
||||
)
|
||||
|
||||
if not chat_box.chat_inited:
|
||||
st.toast(
|
||||
|
|
|
|||
|
|
@ -34,27 +34,31 @@ parser.add_argument("--theme.secondaryBackgroundColor",type=str,default='"#f5f5f
|
|||
parser.add_argument("--theme.textColor",type=str,default='"#000000"')
|
||||
web_args = ["server.port","theme.base","theme.primaryColor","theme.secondaryBackgroundColor","theme.textColor"]
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def launch_api(args=args,args_list=api_args,log_name=None):
|
||||
print("launch api ...")
|
||||
def launch_api(args,args_list=api_args,log_name=None):
|
||||
print("Launching api ...")
|
||||
print("启动API服务...")
|
||||
if not log_name:
|
||||
log_name = f"{LOG_PATH}api_{args.api_host}_{args.api_port}"
|
||||
print(f"logs on api are written in {log_name}")
|
||||
print(f"API日志位于{log_name}下,如启动异常请查看日志")
|
||||
args_str = string_args(args,args_list)
|
||||
api_sh = "python server/{script} {args_str} >{log_name}.log 2>&1 &".format(
|
||||
script="api.py",args_str=args_str,log_name=log_name)
|
||||
subprocess.run(api_sh, shell=True, check=True)
|
||||
print("launch api done!")
|
||||
print("启动API服务完毕.")
|
||||
|
||||
|
||||
def launch_webui(args=args,args_list=web_args,log_name=None):
|
||||
def launch_webui(args,args_list=web_args,log_name=None):
|
||||
print("Launching webui...")
|
||||
print("启动webui服务...")
|
||||
if not log_name:
|
||||
log_name = f"{LOG_PATH}webui"
|
||||
print(f"logs on api are written in {log_name}")
|
||||
|
||||
args_str = string_args(args,args_list)
|
||||
if args.nohup:
|
||||
print(f"logs on api are written in {log_name}")
|
||||
print(f"webui服务日志位于{log_name}下,如启动异常请查看日志")
|
||||
webui_sh = "streamlit run webui.py {args_str} >{log_name}.log 2>&1 &".format(
|
||||
args_str=args_str,log_name=log_name)
|
||||
else:
|
||||
|
|
@ -62,12 +66,18 @@ def launch_webui(args=args,args_list=web_args,log_name=None):
|
|||
args_str=args_str)
|
||||
subprocess.run(webui_sh, shell=True, check=True)
|
||||
print("launch webui done!")
|
||||
print("启动webui服务完毕.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Starting webui_allineone.py, it would take a while, please be patient....")
|
||||
print(f"开始启动webui_allinone,启动LLM服务需要约3-10分钟,请耐心等待,如长时间未启动,请到{LOG_PATH}下查看日志...")
|
||||
args = parser.parse_args()
|
||||
|
||||
print("*"*80)
|
||||
if not args.use_remote_api:
|
||||
launch_all(args=args,controller_args=controller_args,worker_args=worker_args,server_args=server_args)
|
||||
launch_api(args=args,args_list=api_args)
|
||||
launch_webui(args=args,args_list=web_args)
|
||||
print("Start webui_allinone.py done!")
|
||||
print("感谢耐心等待,启动webui_allinone完毕。")
|
||||
|
|
@ -76,7 +76,7 @@ def dialogue_page(api: ApiRequest):
|
|||
key="selected_kb",
|
||||
)
|
||||
kb_top_k = st.number_input("匹配知识条数:", 1, 20, 3)
|
||||
# score_threshold = st.slider("知识匹配分数阈值:", 0, 1, 0, disabled=True)
|
||||
score_threshold = st.number_input("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01)
|
||||
# chunk_content = st.checkbox("关联上下文", False, disabled=True)
|
||||
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
|
||||
elif dialogue_mode == "搜索引擎问答":
|
||||
|
|
@ -98,6 +98,9 @@ def dialogue_page(api: ApiRequest):
|
|||
text = ""
|
||||
r = api.chat_chat(prompt, history)
|
||||
for t in r:
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
st.error(error_msg)
|
||||
break
|
||||
text += t
|
||||
chat_box.update_msg(text)
|
||||
chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标
|
||||
|
|
@ -108,7 +111,9 @@ def dialogue_page(api: ApiRequest):
|
|||
Markdown("...", in_expander=True, title="知识库匹配结果"),
|
||||
])
|
||||
text = ""
|
||||
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, history):
|
||||
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
text += d["answer"]
|
||||
chat_box.update_msg(text, 0)
|
||||
chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False)
|
||||
|
|
@ -120,6 +125,8 @@ def dialogue_page(api: ApiRequest):
|
|||
])
|
||||
text = ""
|
||||
for d in api.search_engine_chat(prompt, search_engine, se_top_k):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
text += d["answer"]
|
||||
chat_box.update_msg(text, 0)
|
||||
chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False)
|
||||
|
|
|
|||
|
|
@ -49,21 +49,27 @@ def file_exists(kb: str, selected_rows: List) -> Tuple[str, str]:
|
|||
|
||||
def knowledge_base_page(api: ApiRequest):
|
||||
try:
|
||||
kb_list = get_kb_details()
|
||||
kb_list = {x["kb_name"]: x for x in get_kb_details()}
|
||||
except Exception as e:
|
||||
st.error("获取知识库信息错误,请检查是否已按照 `README.md` 中 `4 知识库初始化与迁移` 步骤完成初始化或迁移,或是否为数据库连接错误。")
|
||||
st.stop()
|
||||
kb_names = [x["kb_name"] for x in kb_list]
|
||||
kb_names = list(kb_list.keys())
|
||||
|
||||
if "selected_kb_name" in st.session_state and st.session_state["selected_kb_name"] in kb_names:
|
||||
selected_kb_index = kb_names.index(st.session_state["selected_kb_name"])
|
||||
else:
|
||||
selected_kb_index = 0
|
||||
|
||||
def format_selected_kb(kb_name: str) -> str:
|
||||
if kb := kb_list.get(kb_name):
|
||||
return f"{kb_name} ({kb['vs_type']} @ {kb['embed_model']})"
|
||||
else:
|
||||
return kb_name
|
||||
|
||||
selected_kb = st.selectbox(
|
||||
"请选择或新建知识库:",
|
||||
kb_list + ["新建知识库"],
|
||||
format_func=lambda s: f"{s['kb_name']} ({s['vs_type']} @ {s['embed_model']})" if type(s) != str else s,
|
||||
kb_names + ["新建知识库"],
|
||||
format_func=format_selected_kb,
|
||||
index=selected_kb_index
|
||||
)
|
||||
|
||||
|
|
@ -117,9 +123,7 @@ def knowledge_base_page(api: ApiRequest):
|
|||
st.experimental_rerun()
|
||||
|
||||
elif selected_kb:
|
||||
kb = selected_kb["kb_name"]
|
||||
|
||||
|
||||
kb = selected_kb
|
||||
|
||||
# 上传文件
|
||||
# sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True)
|
||||
|
|
|
|||
|
|
@ -6,8 +6,10 @@ from configs.model_config import (
|
|||
DEFAULT_VS_TYPE,
|
||||
KB_ROOT_PATH,
|
||||
LLM_MODEL,
|
||||
SCORE_THRESHOLD,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
SEARCH_ENGINE_TOP_K,
|
||||
logger,
|
||||
)
|
||||
import httpx
|
||||
import asyncio
|
||||
|
|
@ -24,6 +26,7 @@ from configs.model_config import NLTK_DATA_PATH
|
|||
import nltk
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
|
||||
|
||||
def set_httpx_timeout(timeout=60.0):
|
||||
'''
|
||||
设置httpx默认timeout到60秒。
|
||||
|
|
@ -80,7 +83,8 @@ class ApiRequest:
|
|||
return httpx.stream("GET", url, params=params, **kwargs)
|
||||
else:
|
||||
return httpx.get(url, params=params, **kwargs)
|
||||
except:
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
retry -= 1
|
||||
|
||||
async def aget(
|
||||
|
|
@ -100,7 +104,8 @@ class ApiRequest:
|
|||
return await client.stream("GET", url, params=params, **kwargs)
|
||||
else:
|
||||
return await client.get(url, params=params, **kwargs)
|
||||
except:
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
retry -= 1
|
||||
|
||||
def post(
|
||||
|
|
@ -121,7 +126,8 @@ class ApiRequest:
|
|||
return httpx.stream("POST", url, data=data, json=json, **kwargs)
|
||||
else:
|
||||
return httpx.post(url, data=data, json=json, **kwargs)
|
||||
except:
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
retry -= 1
|
||||
|
||||
async def apost(
|
||||
|
|
@ -142,7 +148,8 @@ class ApiRequest:
|
|||
return await client.stream("POST", url, data=data, json=json, **kwargs)
|
||||
else:
|
||||
return await client.post(url, data=data, json=json, **kwargs)
|
||||
except:
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
retry -= 1
|
||||
|
||||
def delete(
|
||||
|
|
@ -162,7 +169,8 @@ class ApiRequest:
|
|||
return httpx.stream("DELETE", url, data=data, json=json, **kwargs)
|
||||
else:
|
||||
return httpx.delete(url, data=data, json=json, **kwargs)
|
||||
except:
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
retry -= 1
|
||||
|
||||
async def adelete(
|
||||
|
|
@ -183,7 +191,8 @@ class ApiRequest:
|
|||
return await client.stream("DELETE", url, data=data, json=json, **kwargs)
|
||||
else:
|
||||
return await client.delete(url, data=data, json=json, **kwargs)
|
||||
except:
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
retry -= 1
|
||||
|
||||
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
|
||||
|
|
@ -195,11 +204,14 @@ class ApiRequest:
|
|||
except:
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
for chunk in iter_over_async(response.body_iterator, loop):
|
||||
if as_json and chunk:
|
||||
yield json.loads(chunk)
|
||||
elif chunk.strip():
|
||||
yield chunk
|
||||
try:
|
||||
for chunk in iter_over_async(response.body_iterator, loop):
|
||||
if as_json and chunk:
|
||||
yield json.loads(chunk)
|
||||
elif chunk.strip():
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def _httpx_stream2generator(
|
||||
self,
|
||||
|
|
@ -209,12 +221,26 @@ class ApiRequest:
|
|||
'''
|
||||
将httpx.stream返回的GeneratorContextManager转化为普通生成器
|
||||
'''
|
||||
with response as r:
|
||||
for chunk in r.iter_text(None):
|
||||
if as_json and chunk:
|
||||
yield json.loads(chunk)
|
||||
elif chunk.strip():
|
||||
yield chunk
|
||||
try:
|
||||
with response as r:
|
||||
for chunk in r.iter_text(None):
|
||||
if as_json and chunk:
|
||||
yield json.loads(chunk)
|
||||
elif chunk.strip():
|
||||
yield chunk
|
||||
except httpx.ConnectError as e:
|
||||
msg = f"无法连接API服务器,请确认已执行python server\\api.py"
|
||||
logger.error(msg)
|
||||
logger.error(e)
|
||||
yield {"code": 500, "errorMsg": msg}
|
||||
except httpx.ReadTimeout as e:
|
||||
msg = f"API通信超时,请确认已启动FastChat与API服务(详见RADME '5. 启动 API 服务或 Web UI')"
|
||||
logger.error(msg)
|
||||
logger.error(e)
|
||||
yield {"code": 500, "errorMsg": msg}
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
yield {"code": 500, "errorMsg": str(e)}
|
||||
|
||||
# 对话相关操作
|
||||
|
||||
|
|
@ -287,6 +313,7 @@ class ApiRequest:
|
|||
query: str,
|
||||
knowledge_base_name: str,
|
||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||
score_threshold: float = SCORE_THRESHOLD,
|
||||
history: List[Dict] = [],
|
||||
stream: bool = True,
|
||||
no_remote_api: bool = None,
|
||||
|
|
@ -301,6 +328,7 @@ class ApiRequest:
|
|||
"query": query,
|
||||
"knowledge_base_name": knowledge_base_name,
|
||||
"top_k": top_k,
|
||||
"score_threshold": score_threshold,
|
||||
"history": history,
|
||||
"stream": stream,
|
||||
"local_doc_url": no_remote_api,
|
||||
|
|
@ -353,6 +381,21 @@ class ApiRequest:
|
|||
|
||||
# 知识库相关操作
|
||||
|
||||
def _check_httpx_json_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
errorMsg: str = f"无法连接API服务器,请确认已执行python server\\api.py",
|
||||
) -> Dict:
|
||||
'''
|
||||
check whether httpx returns correct data with normal Response.
|
||||
error in api with streaming support was checked in _httpx_stream2enerator
|
||||
'''
|
||||
try:
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return {"code": 500, "errorMsg": errorMsg or str(e)}
|
||||
|
||||
def list_knowledge_bases(
|
||||
self,
|
||||
no_remote_api: bool = None,
|
||||
|
|
@ -369,7 +412,8 @@ class ApiRequest:
|
|||
return response.data
|
||||
else:
|
||||
response = self.get("/knowledge_base/list_knowledge_bases")
|
||||
return response.json().get("data")
|
||||
data = self._check_httpx_json_response(response)
|
||||
return data.get("data", [])
|
||||
|
||||
def create_knowledge_base(
|
||||
self,
|
||||
|
|
@ -399,7 +443,7 @@ class ApiRequest:
|
|||
"/knowledge_base/create_knowledge_base",
|
||||
json=data,
|
||||
)
|
||||
return response.json()
|
||||
return self._check_httpx_json_response(response)
|
||||
|
||||
def delete_knowledge_base(
|
||||
self,
|
||||
|
|
@ -421,7 +465,7 @@ class ApiRequest:
|
|||
"/knowledge_base/delete_knowledge_base",
|
||||
json=f"{knowledge_base_name}",
|
||||
)
|
||||
return response.json()
|
||||
return self._check_httpx_json_response(response)
|
||||
|
||||
def list_kb_docs(
|
||||
self,
|
||||
|
|
@ -443,7 +487,8 @@ class ApiRequest:
|
|||
"/knowledge_base/list_docs",
|
||||
params={"knowledge_base_name": knowledge_base_name}
|
||||
)
|
||||
return response.json().get("data")
|
||||
data = self._check_httpx_json_response(response)
|
||||
return data.get("data", [])
|
||||
|
||||
def upload_kb_doc(
|
||||
self,
|
||||
|
|
@ -487,7 +532,7 @@ class ApiRequest:
|
|||
data={"knowledge_base_name": knowledge_base_name, "override": override},
|
||||
files={"file": (filename, file)},
|
||||
)
|
||||
return response.json()
|
||||
return self._check_httpx_json_response(response)
|
||||
|
||||
def delete_kb_doc(
|
||||
self,
|
||||
|
|
@ -517,7 +562,7 @@ class ApiRequest:
|
|||
"/knowledge_base/delete_doc",
|
||||
json=data,
|
||||
)
|
||||
return response.json()
|
||||
return self._check_httpx_json_response(response)
|
||||
|
||||
def update_kb_doc(
|
||||
self,
|
||||
|
|
@ -540,7 +585,7 @@ class ApiRequest:
|
|||
"/knowledge_base/update_doc",
|
||||
json={"knowledge_base_name": knowledge_base_name, "file_name": file_name},
|
||||
)
|
||||
return response.json()
|
||||
return self._check_httpx_json_response(response)
|
||||
|
||||
def recreate_vector_store(
|
||||
self,
|
||||
|
|
@ -572,10 +617,20 @@ class ApiRequest:
|
|||
"/knowledge_base/recreate_vector_store",
|
||||
json=data,
|
||||
stream=True,
|
||||
timeout=False,
|
||||
)
|
||||
return self._httpx_stream2generator(response, as_json=True)
|
||||
|
||||
|
||||
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
|
||||
'''
|
||||
return error message if error occured when requests API
|
||||
'''
|
||||
if isinstance(data, dict) and key in data:
|
||||
return data[key]
|
||||
return ""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
api = ApiRequest(no_remote_api=True)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue