merge dev

This commit is contained in:
imClumsyPanda 2023-08-16 22:19:18 +08:00
commit 641fcaee40
25 changed files with 412 additions and 140 deletions

View File

@ -15,6 +15,7 @@
* [3. 设置配置项](README.md#3.-设置配置项)
* [4. 知识库初始化与迁移](README.md#4.-知识库初始化与迁移)
* [5. 启动 API 服务或 Web UI](README.md#5.-启动-API-服务或-Web-UI)
* [6. 一键启动](README.md#6.-一键启动)
* [常见问题](README.md#常见问题)
* [路线图](README.md#路线图)
* [项目交流群](README.md#项目交流群)
@ -248,14 +249,23 @@ gpus=None,
num_gpus=1,
max_gpu_memory="20GiB"
```
其中,`gpus` 控制使用的显卡的ID如果 "0,1";
`num_gpus` 控制使用的卡数;
其中,`gpus` 控制使用的显卡的ID如果 "0,1";
`num_gpus` 控制使用的卡数;
`max_gpu_memory` 控制每个卡使用的显存容量。
##### 5.1.2 基于命令行脚本 llm_api_launch.py 启动 LLM 服务
**!!!注意:**
**1.llm_api_launch.py脚本仅适用于linux和mac设备,win平台请使用wls;**
**2.加载非默认模型需要用命令行参数--model-path-address指定指定模型不会读取model_config.py配置;**
**!!!**
在项目根目录下,执行 [server/llm_api_launch.py](server/llm_api.py) 脚本启动 **LLM 模型**服务:
```shell
@ -274,7 +284,7 @@ $ python server/llm_api_launch.py --model-path-addresss model1@host1@port1 model
$ python server/llm_api_launch.py --gpus 0,1 --num-gpus 2 --max-gpu-memory 10GiB
```
以如上方式启动LLM服务会以nohup命令在后台运行 fastchat 服务,如需停止服务,可以运行如下命令,但该脚本**仅适用于linux和mac平台**
以如上方式启动LLM服务会以nohup命令在后台运行 FastChat 服务,如需停止服务,可以运行如下命令:
```shell
$ python server/llm_api_shutdown.py --serve all
@ -349,11 +359,17 @@ $ streamlit run webui.py --server.port 666
---
### 6 一键启动
### 6. 一键启动
#### 6.1 api服务一键启动脚本
⚠️ **注意:**
新增api一键启动脚本可一键开启fastchat后台服务及本项目提供的langchain api服务,调用示例:
**1. 一键启动脚本仅适用于 Linux 和 Mac 设备, Winodws 平台请使用 WLS;**
**2. 加载非默认模型需要用命令行参数 `--model-path-address` 指定指定模型,不会读取 `model_config.py` 配置。**
#### 6.1 API 服务一键启动脚本
新增 API 一键启动脚本,可一键开启 FastChat 后台服务及本项目提供的 API 服务,调用示例:
调用默认模型:
@ -372,32 +388,49 @@ $ python server/api_allinone.py --model-path-address model1@host1@port1 model2@h
```shell
python server/api_allinone.py --model-path-address model@host@port --num-gpus 2 --gpus 0,1 --max-gpu-memory 10GiB
```
其他参数详见各脚本及fastchat服务说明。
其他参数详见各脚本及 FastChat 服务说明。
#### 6.2 webui一键启动脚本
加载本地模型:
```shell
$ python webui_allinone.py
```
调用远程api服务
调用远程 API 服务:
```shell
$ python webui_allinone.py --use-remote-api
```
后台运行webui服务
```shell
$ python webui_allinone.py --nohup
```
加载多个非默认模型:
```shell
$ python webui_allinone.py --model-path-address model1@host1@port1 model2@host2@port2
```
多卡启动:
```shell
$ python webui_alline.py --model-path-address model@host@port --num-gpus 2 --gpus 0,1 --max-gpu-memory 10GiB
```
python webui_alline.py --model-path-address model@host@port --num-gpus 2 --gpus 0,1 --max-gpu-memory 10GiB
```
其他参数详见各脚本及fastchat服务说明。
上述两个一键启动脚本会后台运行多个服务,如要停止所有服务,可使用 `shutdown_all.sh` 脚本:
```shell
bash shutdown_all.sh
```
## 常见问题
参见 [常见问题](docs/FAQ.md)。

View File

@ -46,7 +46,7 @@ llm_model_dict = {
"chatglm-6b-int4": {
"local_model_path": "THUDM/chatglm-6b-int4",
"api_base_url": "http://localhost:8001/v1", # "name"修改为fastchat服务中的"api_base_url"
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
"api_key": "EMPTY"
},
@ -64,7 +64,7 @@ llm_model_dict = {
"vicuna-13b-hf": {
"local_model_path": "",
"api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url"
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
"api_key": "EMPTY"
},
@ -85,6 +85,7 @@ llm_model_dict = {
},
}
# LLM 名称
LLM_MODEL = "chatglm2-6b"

View File

@ -1,11 +1,10 @@
langchain==0.0.257
openai
sentence_transformers
fschat==0.2.20
transformers
fschat==0.2.24
transformers>=4.31.0
torch~=2.0.0
fastapi~=0.99.1
fastapi-offline
nltk~=3.8.1
uvicorn~=0.23.1
starlette~=0.27.0

View File

@ -1,11 +1,10 @@
langchain==0.0.257
openai
sentence_transformers
fschat==0.2.20
transformers
fschat==0.2.24
transformers>=4.31.0
torch~=2.0.0
fastapi~=0.99.1
fastapi-offline
nltk~=3.8.1
uvicorn~=0.23.1
starlette~=0.27.0

View File

@ -7,15 +7,17 @@ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN
import argparse
import uvicorn
from server.utils import FastAPIOffline as FastAPI
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse
from server.chat import (chat, knowledge_base_chat, openai_chat,
search_engine_chat)
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
update_doc, download_doc, recreate_vector_store)
from server.utils import BaseResponse, ListResponse
update_doc, download_doc, recreate_vector_store,
search_docs, DocumentWithScore)
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
from typing import List
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
@ -25,7 +27,8 @@ async def document():
def create_app():
app = FastAPI()
app = FastAPI(title="Langchain-Chatchat API Server")
MakeFastAPIOffline(app)
# Add CORS middleware to allow all origins
# 在config.py中设置OPEN_DOMAIN=True允许跨域
# set OPEN_DOMAIN=True in config.py to allow cross-domain
@ -83,6 +86,12 @@ def create_app():
summary="获取知识库内的文件列表"
)(list_docs)
app.post("/knowledge_base/search_docs",
tags=["Knowledge Base Management"],
response_model=List[DocumentWithScore],
summary="搜索知识库"
)(search_docs)
app.post("/knowledge_base/upload_doc",
tags=["Knowledge Base Management"],
response_model=BaseResponse,

View File

@ -22,9 +22,7 @@ parser.add_argument("--api-host", type=str, default="0.0.0.0")
parser.add_argument("--api-port", type=int, default=7861)
parser.add_argument("--ssl_keyfile", type=str)
parser.add_argument("--ssl_certfile", type=str)
# 初始化消息
args = parser.parse_args()
args_dict = vars(args)
api_args = ["api-host","api-port","ssl_keyfile","ssl_certfile"]
@ -41,6 +39,11 @@ def run_api(host, port, **kwargs):
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
print("Luanching api_allinoneit would take a while, please be patient...")
print("正在启动api_allinoneLLM服务启动约3-10分钟请耐心等待...")
# 初始化消息
args = parser.parse_args()
args_dict = vars(args)
launch_all(args=args,controller_args=controller_args,worker_args=worker_args,server_args=server_args)
run_api(
host=args.api_host,
@ -48,3 +51,5 @@ if __name__ == "__main__":
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
)
print("Luanching api_allinone done.")
print("api_allinone启动完毕.")

View File

@ -1,26 +1,27 @@
from fastapi import Body, Request
from fastapi.responses import StreamingResponse
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
VECTOR_SEARCH_TOP_K)
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
from server.chat.utils import wrap_done
from server.utils import BaseResponse
from langchain.chat_models import ChatOpenAI
from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
from typing import AsyncIterable, List, Optional
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional
from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBService, KBServiceFactory
import json
import os
from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_docs
def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=1),
history: List[History] = Body([],
description="历史对话",
examples=[[
@ -53,7 +54,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL
)
docs = kb.search_docs(query, top_k)
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
context = "\n".join([doc.page_content for doc in docs])
chat_prompt = ChatPromptTemplate.from_messages(

View File

@ -1,7 +1,7 @@
from fastapi.responses import StreamingResponse
from typing import List
import openai
from configs.model_config import llm_model_dict, LLM_MODEL
from configs.model_config import llm_model_dict, LLM_MODEL, logger
from pydantic import BaseModel
@ -33,19 +33,23 @@ async def openai_chat(msg: OpenAiChatMsgIn):
data = msg.dict()
data["streaming"] = True
data.pop("stream")
response = openai.ChatCompletion.create(**data)
if msg.stream:
for chunk in response.choices[0].message.content:
print(chunk)
yield chunk
else:
answer = ""
for chunk in response.choices[0].message.content:
answer += chunk
print(answer)
yield(answer)
try:
response = openai.ChatCompletion.create(**data)
if msg.stream:
for chunk in response.choices[0].message.content:
print(chunk)
yield chunk
else:
answer = ""
for chunk in response.choices[0].message.content:
answer += chunk
print(answer)
yield(answer)
except Exception as e:
print(type(e))
logger.error(e)
return StreamingResponse(
get_response(msg),
media_type='text/event-stream',

View File

@ -27,7 +27,7 @@ def add_doc_to_db(session, kb_file: KnowledgeFile):
file_ext=kb_file.ext,
kb_name=kb_file.kb_name,
document_loader_name=kb_file.document_loader_name,
text_splitter_name=kb_file.text_splitter_name,
text_splitter_name=kb_file.text_splitter_name or "SpacyTextSplitter",
)
kb.file_count += 1
session.add(new_file)

View File

@ -1,13 +1,32 @@
import os
import urllib
from fastapi import File, Form, Body, Query, UploadFile
from configs.model_config import DEFAULT_VS_TYPE, EMBEDDING_MODEL
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile
from fastapi.responses import StreamingResponse, FileResponse
import json
from server.knowledge_base.kb_service.base import KBServiceFactory
from typing import List
from typing import List, Dict
from langchain.docstore.document import Document
class DocumentWithScore(Document):
score: float = None
def search_docs(query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=1),
) -> List[DocumentWithScore]:
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}", "docs": []}
docs = kb.search_docs(query, top_k, score_threshold)
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
return data
async def list_docs(

View File

@ -13,7 +13,7 @@ from server.db.repository.knowledge_file_repository import (
list_docs_from_db, get_file_detail, delete_file_from_db
)
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K,
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
EMBEDDING_DEVICE, EMBEDDING_MODEL)
from server.knowledge_base.utils import (
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
@ -112,9 +112,10 @@ class KBService(ABC):
def search_docs(self,
query: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
):
embeddings = self._load_embeddings()
docs = self.do_search(query, top_k, embeddings)
docs = self.do_search(query, top_k, score_threshold, embeddings)
return docs
@abstractmethod

View File

@ -81,12 +81,13 @@ class FaissKBService(KBService):
def do_search(self,
query: str,
top_k: int,
embeddings: Embeddings,
score_threshold: float = SCORE_THRESHOLD,
embeddings: Embeddings = None,
) -> List[Document]:
search_index = load_vector_store(self.kb_name,
embeddings=embeddings,
tick=_VECTOR_STORE_TICKS.get(self.kb_name))
docs = search_index.similarity_search(query, k=top_k, score_threshold=SCORE_THRESHOLD)
docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
return docs
def do_add_doc(self,

View File

@ -45,7 +45,8 @@ class MilvusKBService(KBService):
def do_drop_kb(self):
self.milvus.col.drop()
def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]:
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
# todo: support score threshold
self._load_milvus(embeddings=embeddings)
return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD)

View File

@ -43,7 +43,8 @@ class PGKBService(KBService):
'''))
connect.commit()
def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]:
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
# todo: support score threshold
self._load_pg_vector(embeddings=embeddings)
return self.pg_vector.similarity_search(query, top_k)

View File

@ -102,6 +102,7 @@ class KnowledgeFile:
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
)
self.text_splitter_name = "SpacyTextSplitter"
else:
text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, self.text_splitter_name)

View File

@ -1,16 +1,18 @@
from multiprocessing import Process, Queue
import multiprocessing as mp
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger
from server.utils import MakeFastAPIOffline
host_ip = "0.0.0.0"
controller_port = 20001
model_worker_port = 20002
openai_api_port = 8888
base_url = "http://127.0.0.1:{}"
queue = Queue()
def set_httpx_timeout(timeout=60.0):
@ -30,33 +32,50 @@ def create_controller_app(
controller = Controller(dispatch_method)
sys.modules["fastchat.serve.controller"].controller = controller
MakeFastAPIOffline(app)
app.title = "FastChat Controller"
return app
def create_model_worker_app(
worker_address=base_url.format(model_worker_port),
controller_address=base_url.format(controller_port),
model_path=llm_model_dict[LLM_MODEL].get("local_model_path"),
model_names=[LLM_MODEL],
device=LLM_DEVICE,
gpus=None,
max_gpu_memory="20GiB",
load_8bit=False,
cpu_offloading=None,
gptq_ckpt=None,
gptq_wbits=16,
gptq_groupsize=-1,
gptq_act_order=None,
gpus=None,
num_gpus=1,
max_gpu_memory="20GiB",
cpu_offloading=None,
worker_address=base_url.format(model_worker_port),
controller_address=base_url.format(controller_port),
gptq_act_order=False,
awq_ckpt=None,
awq_wbits=16,
awq_groupsize=-1,
model_names=[LLM_MODEL],
num_gpus=1, # not in fastchat
conv_template=None,
limit_worker_concurrency=5,
stream_interval=2,
no_register=False,
):
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import app, GptqConfig, ModelWorker, worker_id
from fastchat.serve import model_worker
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
import argparse
import threading
import fastchat.serve.model_worker
# workaround to make program exit with Ctrl+c
# it should be deleted after pr is merged by fastchat
def _new_init_heart_beat(self):
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
)
self.heart_beat_thread.start()
ModelWorker.init_heart_beat = _new_init_heart_beat
parser = argparse.ArgumentParser()
args = parser.parse_args()
@ -68,12 +87,16 @@ def create_model_worker_app(
args.gptq_wbits = gptq_wbits
args.gptq_groupsize = gptq_groupsize
args.gptq_act_order = gptq_act_order
args.awq_ckpt = awq_ckpt
args.awq_wbits = awq_wbits
args.awq_groupsize = awq_groupsize
args.gpus = gpus
args.num_gpus = num_gpus
args.max_gpu_memory = max_gpu_memory
args.cpu_offloading = cpu_offloading
args.worker_address = worker_address
args.controller_address = controller_address
args.conv_template = conv_template
args.limit_worker_concurrency = limit_worker_concurrency
args.stream_interval = stream_interval
args.no_register = no_register
@ -95,6 +118,12 @@ def create_model_worker_app(
groupsize=args.gptq_groupsize,
act_order=args.gptq_act_order,
)
awq_config = AWQConfig(
ckpt=args.awq_ckpt or args.model_path,
wbits=args.awq_wbits,
groupsize=args.awq_groupsize,
)
# torch.multiprocessing.set_start_method('spawn')
worker = ModelWorker(
controller_addr=args.controller_address,
@ -110,19 +139,21 @@ def create_model_worker_app(
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
)
sys.modules["fastchat.serve.model_worker"].worker = worker
sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
MakeFastAPIOffline(app)
app.title = f"FastChat LLM Server ({LLM_MODEL})"
return app
def create_openai_api_app(
host=host_ip,
port=openai_api_port,
controller_address=base_url.format(controller_port),
api_keys=[],
):
@ -141,6 +172,8 @@ def create_openai_api_app(
app_settings.controller_address = controller_address
app_settings.api_keys = api_keys
MakeFastAPIOffline(app)
app.title = "FastChat OpeanAI API Server"
return app
@ -193,6 +226,8 @@ def run_openai_api(q):
if __name__ == "__main__":
mp.set_start_method("spawn")
queue = Queue()
logger.info(llm_model_dict[LLM_MODEL])
model_path = llm_model_dict[LLM_MODEL]["local_model_path"]
@ -209,15 +244,14 @@ if __name__ == "__main__":
)
controller_process.start()
# cuda 没办法用在fork的多进程中
# model_worker_process = Process(
# target=run_model_worker,
# name=f"model_worker({os.getpid()})",
# args=(queue,),
# # kwargs={"load_8bit": True},
# daemon=True,
# )
# model_worker_process.start()
model_worker_process = Process(
target=run_model_worker,
name=f"model_worker({os.getpid()})",
args=(queue,),
# kwargs={"load_8bit": True},
daemon=True,
)
model_worker_process.start()
openai_api_process = Process(
target=run_openai_api,
@ -227,11 +261,14 @@ if __name__ == "__main__":
)
openai_api_process.start()
run_model_worker(queue)
controller_process.join()
# model_worker_process.join()
openai_api_process.join()
try:
model_worker_process.join()
controller_process.join()
openai_api_process.join()
except KeyboardInterrupt:
model_worker_process.terminate()
controller_process.terminate()
openai_api_process.terminate()
# 服务启动后接口调用示例:
# import openai

View File

@ -159,17 +159,7 @@ server_args = ["server-host", "server-port", "allow-credentials", "api-keys",
"controller-address"
]
args = parser.parse_args()
# 必须要加http//:否则InvalidSchema: No connection adapters were found
args = argparse.Namespace(**vars(args),
**{"controller-address": f"http://{args.controller_host}:{str(args.controller_port)}"})
if args.gpus:
if len(args.gpus.split(",")) < args.num_gpus:
raise ValueError(
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
# 0,controller, model_worker, openai_api_server
# 1, 命令行选项
@ -211,12 +201,13 @@ def string_args(args, args_list):
return args_str
def launch_worker(item,args=args,worker_args=worker_args):
def launch_worker(item,args,worker_args=worker_args):
log_name = item.split("/")[-1].split("\\")[-1].replace("-", "_").replace("@", "_").replace(".", "_")
# 先分割model-path-address,在传到string_args中分析参数
args.model_path, args.worker_host, args.worker_port = item.split("@")
args.worker_address = f"http://{args.worker_host}:{args.worker_port}"
print("*" * 80)
print(f"如长时间未启动,请到{LOG_PATH}{log_name}.log下查看日志")
worker_str_args = string_args(args, worker_args)
print(worker_str_args)
worker_sh = base_launch_sh.format("model_worker", worker_str_args, LOG_PATH, f"worker_{log_name}")
@ -225,30 +216,44 @@ def launch_worker(item,args=args,worker_args=worker_args):
subprocess.run(worker_check_sh, shell=True, check=True)
def launch_all(args=args,
def launch_all(args,
controller_args=controller_args,
worker_args=worker_args,
server_args=server_args
):
print(f"Launching llm service,logs are located in {LOG_PATH}...")
print(f"开始启动LLM服务,请到{LOG_PATH}下监控各模块日志...")
controller_str_args = string_args(args, controller_args)
controller_sh = base_launch_sh.format("controller", controller_str_args, LOG_PATH, "controller")
controller_check_sh = base_check_sh.format(LOG_PATH, "controller", "controller")
subprocess.run(controller_sh, shell=True, check=True)
subprocess.run(controller_check_sh, shell=True, check=True)
print(f"worker启动时间视设备不同而不同约需3-10分钟请耐心等待...")
if isinstance(args.model_path_address, str):
launch_worker(args.model_path_address,args=args,worker_args=worker_args)
else:
for idx, item in enumerate(args.model_path_address):
print(f"开始加载第{idx}个模型:{item}")
launch_worker(item)
launch_worker(item,args=args,worker_args=worker_args)
server_str_args = string_args(args, server_args)
server_sh = base_launch_sh.format("openai_api_server", server_str_args, LOG_PATH, "openai_api_server")
server_check_sh = base_check_sh.format(LOG_PATH, "openai_api_server", "openai_api_server")
subprocess.run(server_sh, shell=True, check=True)
subprocess.run(server_check_sh, shell=True, check=True)
print("Launching LLM service done!")
print("LLM服务启动完毕。")
if __name__ == "__main__":
launch_all()
args = parser.parse_args()
# 必须要加http//:否则InvalidSchema: No connection adapters were found
args = argparse.Namespace(**vars(args),
**{"controller-address": f"http://{args.controller_host}:{str(args.controller_port)}"})
if args.gpus:
if len(args.gpus.split(",")) < args.num_gpus:
raise ValueError(
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
launch_all(args=args)

BIN
server/static/favicon.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.1 KiB

View File

@ -2,14 +2,10 @@ import pydantic
from pydantic import BaseModel
from typing import List
import torch
from fastapi_offline import FastAPIOffline
import fastapi_offline
from fastapi import FastAPI
from pathlib import Path
import asyncio
# patch fastapi_offline to use local static assests
fastapi_offline.core._STATIC_PATH = Path(__file__).parent / "static"
from typing import Any, Optional
class BaseResponse(BaseModel):
@ -112,3 +108,81 @@ def iter_over_async(ait, loop):
if done:
break
yield obj
def MakeFastAPIOffline(
app: FastAPI,
static_dir = Path(__file__).parent / "static",
static_url = "/static-offline-docs",
docs_url: Optional[str] = "/docs",
redoc_url: Optional[str] = "/redoc",
) -> None:
"""patch the FastAPI obj that doesn't rely on CDN for the documentation page"""
from fastapi import Request
from fastapi.openapi.docs import (
get_redoc_html,
get_swagger_ui_html,
get_swagger_ui_oauth2_redirect_html,
)
from fastapi.staticfiles import StaticFiles
from starlette.responses import HTMLResponse
openapi_url = app.openapi_url
swagger_ui_oauth2_redirect_url = app.swagger_ui_oauth2_redirect_url
def remove_route(url: str) -> None:
'''
remove original route from app
'''
index = None
for i, r in enumerate(app.routes):
if r.path.lower() == url.lower():
index = i
break
if isinstance(index, int):
app.routes.pop(i)
# Set up static file mount
app.mount(
static_url,
StaticFiles(directory=Path(static_dir).as_posix()),
name="static-offline-docs",
)
if docs_url is not None:
remove_route(docs_url)
remove_route(swagger_ui_oauth2_redirect_url)
# Define the doc and redoc pages, pointing at the right files
@app.get(docs_url, include_in_schema=False)
async def custom_swagger_ui_html(request: Request) -> HTMLResponse:
root = request.scope.get("root_path")
favicon = f"{root}{static_url}/favicon.png"
return get_swagger_ui_html(
openapi_url=f"{root}{openapi_url}",
title=app.title + " - Swagger UI",
oauth2_redirect_url=swagger_ui_oauth2_redirect_url,
swagger_js_url=f"{root}{static_url}/swagger-ui-bundle.js",
swagger_css_url=f"{root}{static_url}/swagger-ui.css",
swagger_favicon_url=favicon,
)
@app.get(swagger_ui_oauth2_redirect_url, include_in_schema=False)
async def swagger_ui_redirect() -> HTMLResponse:
return get_swagger_ui_oauth2_redirect_html()
if redoc_url is not None:
remove_route(redoc_url)
@app.get(redoc_url, include_in_schema=False)
async def redoc_html(request: Request) -> HTMLResponse:
root = request.scope.get("root_path")
favicon = f"{root}{static_url}/favicon.png"
return get_redoc_html(
openapi_url=f"{root}{openapi_url}",
title=app.title + " - ReDoc",
redoc_js_url=f"{root}{static_url}/redoc.standalone.js",
with_google_fonts=False,
redoc_favicon_url=favicon,
)

1
shutdown_all.sh Normal file
View File

@ -0,0 +1 @@
ps -eo pid,user,cmd|grep -P 'server/api.py|webui.py|fastchat.serve'|grep -v grep|awk '{print $1}'|xargs kill -9

View File

@ -13,7 +13,11 @@ import os
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False)
if __name__ == "__main__":
st.set_page_config("Langchain-Chatchat WebUI", initial_sidebar_state="expanded")
st.set_page_config(
"Langchain-Chatchat WebUI",
os.path.join("img", "chatchat_icon_blue_square_v2.png"),
initial_sidebar_state="expanded",
)
if not chat_box.chat_inited:
st.toast(

View File

@ -34,27 +34,31 @@ parser.add_argument("--theme.secondaryBackgroundColor",type=str,default='"#f5f5f
parser.add_argument("--theme.textColor",type=str,default='"#000000"')
web_args = ["server.port","theme.base","theme.primaryColor","theme.secondaryBackgroundColor","theme.textColor"]
args = parser.parse_args()
def launch_api(args=args,args_list=api_args,log_name=None):
print("launch api ...")
def launch_api(args,args_list=api_args,log_name=None):
print("Launching api ...")
print("启动API服务...")
if not log_name:
log_name = f"{LOG_PATH}api_{args.api_host}_{args.api_port}"
print(f"logs on api are written in {log_name}")
print(f"API日志位于{log_name}下,如启动异常请查看日志")
args_str = string_args(args,args_list)
api_sh = "python server/{script} {args_str} >{log_name}.log 2>&1 &".format(
script="api.py",args_str=args_str,log_name=log_name)
subprocess.run(api_sh, shell=True, check=True)
print("launch api done!")
print("启动API服务完毕.")
def launch_webui(args=args,args_list=web_args,log_name=None):
def launch_webui(args,args_list=web_args,log_name=None):
print("Launching webui...")
print("启动webui服务...")
if not log_name:
log_name = f"{LOG_PATH}webui"
print(f"logs on api are written in {log_name}")
args_str = string_args(args,args_list)
if args.nohup:
print(f"logs on api are written in {log_name}")
print(f"webui服务日志位于{log_name}下,如启动异常请查看日志")
webui_sh = "streamlit run webui.py {args_str} >{log_name}.log 2>&1 &".format(
args_str=args_str,log_name=log_name)
else:
@ -62,12 +66,18 @@ def launch_webui(args=args,args_list=web_args,log_name=None):
args_str=args_str)
subprocess.run(webui_sh, shell=True, check=True)
print("launch webui done!")
print("启动webui服务完毕.")
if __name__ == "__main__":
print("Starting webui_allineone.py, it would take a while, please be patient....")
print(f"开始启动webui_allinone,启动LLM服务需要约3-10分钟请耐心等待如长时间未启动请到{LOG_PATH}下查看日志...")
args = parser.parse_args()
print("*"*80)
if not args.use_remote_api:
launch_all(args=args,controller_args=controller_args,worker_args=worker_args,server_args=server_args)
launch_api(args=args,args_list=api_args)
launch_webui(args=args,args_list=web_args)
print("Start webui_allinone.py done!")
print("感谢耐心等待启动webui_allinone完毕。")

View File

@ -76,7 +76,7 @@ def dialogue_page(api: ApiRequest):
key="selected_kb",
)
kb_top_k = st.number_input("匹配知识条数:", 1, 20, 3)
# score_threshold = st.slider("知识匹配分数阈值:", 0, 1, 0, disabled=True)
score_threshold = st.number_input("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01)
# chunk_content = st.checkbox("关联上下文", False, disabled=True)
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
elif dialogue_mode == "搜索引擎问答":
@ -98,6 +98,9 @@ def dialogue_page(api: ApiRequest):
text = ""
r = api.chat_chat(prompt, history)
for t in r:
if error_msg := check_error_msg(t): # check whether error occured
st.error(error_msg)
break
text += t
chat_box.update_msg(text)
chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标
@ -108,7 +111,9 @@ def dialogue_page(api: ApiRequest):
Markdown("...", in_expander=True, title="知识库匹配结果"),
])
text = ""
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, history):
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
text += d["answer"]
chat_box.update_msg(text, 0)
chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False)
@ -120,6 +125,8 @@ def dialogue_page(api: ApiRequest):
])
text = ""
for d in api.search_engine_chat(prompt, search_engine, se_top_k):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
text += d["answer"]
chat_box.update_msg(text, 0)
chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False)

View File

@ -49,21 +49,27 @@ def file_exists(kb: str, selected_rows: List) -> Tuple[str, str]:
def knowledge_base_page(api: ApiRequest):
try:
kb_list = get_kb_details()
kb_list = {x["kb_name"]: x for x in get_kb_details()}
except Exception as e:
st.error("获取知识库信息错误,请检查是否已按照 `README.md` 中 `4 知识库初始化与迁移` 步骤完成初始化或迁移,或是否为数据库连接错误。")
st.stop()
kb_names = [x["kb_name"] for x in kb_list]
kb_names = list(kb_list.keys())
if "selected_kb_name" in st.session_state and st.session_state["selected_kb_name"] in kb_names:
selected_kb_index = kb_names.index(st.session_state["selected_kb_name"])
else:
selected_kb_index = 0
def format_selected_kb(kb_name: str) -> str:
if kb := kb_list.get(kb_name):
return f"{kb_name} ({kb['vs_type']} @ {kb['embed_model']})"
else:
return kb_name
selected_kb = st.selectbox(
"请选择或新建知识库:",
kb_list + ["新建知识库"],
format_func=lambda s: f"{s['kb_name']} ({s['vs_type']} @ {s['embed_model']})" if type(s) != str else s,
kb_names + ["新建知识库"],
format_func=format_selected_kb,
index=selected_kb_index
)
@ -117,9 +123,7 @@ def knowledge_base_page(api: ApiRequest):
st.experimental_rerun()
elif selected_kb:
kb = selected_kb["kb_name"]
kb = selected_kb
# 上传文件
# sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True)

View File

@ -6,8 +6,10 @@ from configs.model_config import (
DEFAULT_VS_TYPE,
KB_ROOT_PATH,
LLM_MODEL,
SCORE_THRESHOLD,
VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K,
logger,
)
import httpx
import asyncio
@ -24,6 +26,7 @@ from configs.model_config import NLTK_DATA_PATH
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
def set_httpx_timeout(timeout=60.0):
'''
设置httpx默认timeout到60秒
@ -80,7 +83,8 @@ class ApiRequest:
return httpx.stream("GET", url, params=params, **kwargs)
else:
return httpx.get(url, params=params, **kwargs)
except:
except Exception as e:
logger.error(e)
retry -= 1
async def aget(
@ -100,7 +104,8 @@ class ApiRequest:
return await client.stream("GET", url, params=params, **kwargs)
else:
return await client.get(url, params=params, **kwargs)
except:
except Exception as e:
logger.error(e)
retry -= 1
def post(
@ -121,7 +126,8 @@ class ApiRequest:
return httpx.stream("POST", url, data=data, json=json, **kwargs)
else:
return httpx.post(url, data=data, json=json, **kwargs)
except:
except Exception as e:
logger.error(e)
retry -= 1
async def apost(
@ -142,7 +148,8 @@ class ApiRequest:
return await client.stream("POST", url, data=data, json=json, **kwargs)
else:
return await client.post(url, data=data, json=json, **kwargs)
except:
except Exception as e:
logger.error(e)
retry -= 1
def delete(
@ -162,7 +169,8 @@ class ApiRequest:
return httpx.stream("DELETE", url, data=data, json=json, **kwargs)
else:
return httpx.delete(url, data=data, json=json, **kwargs)
except:
except Exception as e:
logger.error(e)
retry -= 1
async def adelete(
@ -183,7 +191,8 @@ class ApiRequest:
return await client.stream("DELETE", url, data=data, json=json, **kwargs)
else:
return await client.delete(url, data=data, json=json, **kwargs)
except:
except Exception as e:
logger.error(e)
retry -= 1
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
@ -195,11 +204,14 @@ class ApiRequest:
except:
loop = asyncio.new_event_loop()
for chunk in iter_over_async(response.body_iterator, loop):
if as_json and chunk:
yield json.loads(chunk)
elif chunk.strip():
yield chunk
try:
for chunk in iter_over_async(response.body_iterator, loop):
if as_json and chunk:
yield json.loads(chunk)
elif chunk.strip():
yield chunk
except Exception as e:
logger.error(e)
def _httpx_stream2generator(
self,
@ -209,12 +221,26 @@ class ApiRequest:
'''
将httpx.stream返回的GeneratorContextManager转化为普通生成器
'''
with response as r:
for chunk in r.iter_text(None):
if as_json and chunk:
yield json.loads(chunk)
elif chunk.strip():
yield chunk
try:
with response as r:
for chunk in r.iter_text(None):
if as_json and chunk:
yield json.loads(chunk)
elif chunk.strip():
yield chunk
except httpx.ConnectError as e:
msg = f"无法连接API服务器请确认已执行python server\\api.py"
logger.error(msg)
logger.error(e)
yield {"code": 500, "errorMsg": msg}
except httpx.ReadTimeout as e:
msg = f"API通信超时请确认已启动FastChat与API服务详见RADME '5. 启动 API 服务或 Web UI'"
logger.error(msg)
logger.error(e)
yield {"code": 500, "errorMsg": msg}
except Exception as e:
logger.error(e)
yield {"code": 500, "errorMsg": str(e)}
# 对话相关操作
@ -287,6 +313,7 @@ class ApiRequest:
query: str,
knowledge_base_name: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [],
stream: bool = True,
no_remote_api: bool = None,
@ -301,6 +328,7 @@ class ApiRequest:
"query": query,
"knowledge_base_name": knowledge_base_name,
"top_k": top_k,
"score_threshold": score_threshold,
"history": history,
"stream": stream,
"local_doc_url": no_remote_api,
@ -353,6 +381,21 @@ class ApiRequest:
# 知识库相关操作
def _check_httpx_json_response(
self,
response: httpx.Response,
errorMsg: str = f"无法连接API服务器请确认已执行python server\\api.py",
) -> Dict:
'''
check whether httpx returns correct data with normal Response.
error in api with streaming support was checked in _httpx_stream2enerator
'''
try:
return response.json()
except Exception as e:
logger.error(e)
return {"code": 500, "errorMsg": errorMsg or str(e)}
def list_knowledge_bases(
self,
no_remote_api: bool = None,
@ -369,7 +412,8 @@ class ApiRequest:
return response.data
else:
response = self.get("/knowledge_base/list_knowledge_bases")
return response.json().get("data")
data = self._check_httpx_json_response(response)
return data.get("data", [])
def create_knowledge_base(
self,
@ -399,7 +443,7 @@ class ApiRequest:
"/knowledge_base/create_knowledge_base",
json=data,
)
return response.json()
return self._check_httpx_json_response(response)
def delete_knowledge_base(
self,
@ -421,7 +465,7 @@ class ApiRequest:
"/knowledge_base/delete_knowledge_base",
json=f"{knowledge_base_name}",
)
return response.json()
return self._check_httpx_json_response(response)
def list_kb_docs(
self,
@ -443,7 +487,8 @@ class ApiRequest:
"/knowledge_base/list_docs",
params={"knowledge_base_name": knowledge_base_name}
)
return response.json().get("data")
data = self._check_httpx_json_response(response)
return data.get("data", [])
def upload_kb_doc(
self,
@ -487,7 +532,7 @@ class ApiRequest:
data={"knowledge_base_name": knowledge_base_name, "override": override},
files={"file": (filename, file)},
)
return response.json()
return self._check_httpx_json_response(response)
def delete_kb_doc(
self,
@ -517,7 +562,7 @@ class ApiRequest:
"/knowledge_base/delete_doc",
json=data,
)
return response.json()
return self._check_httpx_json_response(response)
def update_kb_doc(
self,
@ -540,7 +585,7 @@ class ApiRequest:
"/knowledge_base/update_doc",
json={"knowledge_base_name": knowledge_base_name, "file_name": file_name},
)
return response.json()
return self._check_httpx_json_response(response)
def recreate_vector_store(
self,
@ -572,10 +617,20 @@ class ApiRequest:
"/knowledge_base/recreate_vector_store",
json=data,
stream=True,
timeout=False,
)
return self._httpx_stream2generator(response, as_json=True)
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
'''
return error message if error occured when requests API
'''
if isinstance(data, dict) and key in data:
return data[key]
return ""
if __name__ == "__main__":
api = ApiRequest(no_remote_api=True)