Merge remote-tracking branch 'origin/dev' into dev

This commit is contained in:
zqt 2023-09-01 18:10:32 +08:00
commit ab4c8d2e5d
18 changed files with 207 additions and 56 deletions

2
.gitignore vendored
View File

@ -3,7 +3,7 @@
logs logs
.idea/ .idea/
__pycache__/ __pycache__/
knowledge_base/ /knowledge_base/
configs/*.py configs/*.py
.vscode/ .vscode/
.pytest_cache/ .pytest_cache/

View File

@ -1,6 +1,5 @@
import os import os
import logging import logging
import torch
# 日志格式 # 日志格式
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
logger = logging.getLogger() logger = logging.getLogger()
@ -32,8 +31,8 @@ embedding_model_dict = {
# 选用的 Embedding 名称 # 选用的 Embedding 名称
EMBEDDING_MODEL = "m3e-base" EMBEDDING_MODEL = "m3e-base"
# Embedding 模型运行设备 # Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" EMBEDDING_DEVICE = "auto"
llm_model_dict = { llm_model_dict = {
@ -77,15 +76,14 @@ llm_model_dict = {
}, },
} }
# LLM 名称 # LLM 名称
LLM_MODEL = "chatglm2-6b" LLM_MODEL = "chatglm2-6b"
# 历史对话轮数 # 历史对话轮数
HISTORY_LEN = 3 HISTORY_LEN = 3
# LLM 运行设备 # LLM 运行设备。可选项同Embedding 运行设备。
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" LLM_DEVICE = "auto"
# 日志存储路径 # 日志存储路径
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs") LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")

View File

@ -0,0 +1,2 @@
from .mypdfloader import RapidOCRPDFLoader
from .myimgloader import RapidOCRLoader

View File

@ -0,0 +1,25 @@
from typing import List
from langchain.document_loaders.unstructured import UnstructuredFileLoader
class RapidOCRLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def img2text(filepath):
from rapidocr_onnxruntime import RapidOCR
resp = ""
ocr = RapidOCR()
result, _ = ocr(filepath)
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
return resp
text = img2text(self.file_path)
from unstructured.partition.text import partition_text
return partition_text(text=text, **self.unstructured_kwargs)
if __name__ == "__main__":
loader = RapidOCRLoader(file_path="../tests/samples/ocr_test.jpg")
docs = loader.load()
print(docs)

View File

@ -0,0 +1,37 @@
from typing import List
from langchain.document_loaders.unstructured import UnstructuredFileLoader
class RapidOCRPDFLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def pdf2text(filepath):
import fitz
from rapidocr_onnxruntime import RapidOCR
import numpy as np
ocr = RapidOCR()
doc = fitz.open(filepath)
resp = ""
for page in doc:
# TODO: 依据文本与图片顺序调整处理方式
text = page.get_text("")
resp += text + "\n"
img_list = page.get_images()
for img in img_list:
pix = fitz.Pixmap(doc, img[0])
img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1)
result, _ = ocr(img_array)
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
return resp
text = pdf2text(self.file_path)
from unstructured.partition.text import partition_text
return partition_text(text=text, **self.unstructured_kwargs)
if __name__ == "__main__":
loader = RapidOCRPDFLoader(file_path="../tests/samples/ocr_test.pdf")
docs = loader.load()
print(docs)

View File

@ -15,6 +15,8 @@ SQLAlchemy==2.0.19
faiss-cpu faiss-cpu
accelerate accelerate
spacy spacy
PyMuPDF==1.22.5
rapidocr_onnxruntime>=1.3.1
# uncomment libs if you want to use corresponding vector store # uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3 # pymilvus==2.1.3 # requires milvus==2.1.3

View File

@ -16,6 +16,8 @@ faiss-cpu
nltk nltk
accelerate accelerate
spacy spacy
PyMuPDF==1.22.5
rapidocr_onnxruntime>=1.3.1
# uncomment libs if you want to use corresponding vector store # uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3 # pymilvus==2.1.3 # requires milvus==2.1.3

View File

@ -18,11 +18,12 @@ from server.db.repository.knowledge_file_repository import (
) )
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
EMBEDDING_DEVICE, EMBEDDING_MODEL) EMBEDDING_MODEL)
from server.knowledge_base.utils import ( from server.knowledge_base.utils import (
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile, get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
list_kbs_from_folder, list_files_from_folder, list_kbs_from_folder, list_files_from_folder,
) )
from server.utils import embedding_device
from typing import List, Union, Dict from typing import List, Union, Dict
@ -45,7 +46,7 @@ class KBService(ABC):
self.doc_path = get_doc_path(self.kb_name) self.doc_path = get_doc_path(self.kb_name)
self.do_init() self.do_init()
def _load_embeddings(self, embed_device: str = EMBEDDING_DEVICE) -> Embeddings: def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
return load_embeddings(self.embed_model, embed_device) return load_embeddings(self.embed_model, embed_device)
def create_kb(self): def create_kb(self):

View File

@ -5,7 +5,6 @@ from configs.model_config import (
KB_ROOT_PATH, KB_ROOT_PATH,
CACHED_VS_NUM, CACHED_VS_NUM,
EMBEDDING_MODEL, EMBEDDING_MODEL,
EMBEDDING_DEVICE,
SCORE_THRESHOLD SCORE_THRESHOLD
) )
from server.knowledge_base.kb_service.base import KBService, SupportedVSType from server.knowledge_base.kb_service.base import KBService, SupportedVSType
@ -15,7 +14,7 @@ from langchain.vectorstores import FAISS
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from typing import List from typing import List
from langchain.docstore.document import Document from langchain.docstore.document import Document
from server.utils import torch_gc from server.utils import torch_gc, embedding_device
_VECTOR_STORE_TICKS = {} _VECTOR_STORE_TICKS = {}
@ -25,10 +24,10 @@ _VECTOR_STORE_TICKS = {}
def load_faiss_vector_store( def load_faiss_vector_store(
knowledge_base_name: str, knowledge_base_name: str,
embed_model: str = EMBEDDING_MODEL, embed_model: str = EMBEDDING_MODEL,
embed_device: str = EMBEDDING_DEVICE, embed_device: str = embedding_device(),
embeddings: Embeddings = None, embeddings: Embeddings = None,
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed. tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
): ) -> FAISS:
print(f"loading vector store in '{knowledge_base_name}'.") print(f"loading vector store in '{knowledge_base_name}'.")
vs_path = get_vs_path(knowledge_base_name) vs_path = get_vs_path(knowledge_base_name)
if embeddings is None: if embeddings is None:
@ -74,13 +73,18 @@ class FaissKBService(KBService):
def get_kb_path(self): def get_kb_path(self):
return os.path.join(KB_ROOT_PATH, self.kb_name) return os.path.join(KB_ROOT_PATH, self.kb_name)
def load_vector_store(self): def load_vector_store(self) -> FAISS:
return load_faiss_vector_store( return load_faiss_vector_store(
knowledge_base_name=self.kb_name, knowledge_base_name=self.kb_name,
embed_model=self.embed_model, embed_model=self.embed_model,
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0), tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0),
) )
def save_vector_store(self, vector_store: FAISS = None):
vector_store = vector_store or self.load_vector_store()
vector_store.save_local(self.vs_path)
return vector_store
def refresh_vs_cache(self): def refresh_vs_cache(self):
refresh_vs_cache(self.kb_name) refresh_vs_cache(self.kb_name)
@ -117,11 +121,11 @@ class FaissKBService(KBService):
if not kwargs.get("not_refresh_vs_cache"): if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path) vector_store.save_local(self.vs_path)
self.refresh_vs_cache() self.refresh_vs_cache()
return vector_store
def do_delete_doc(self, def do_delete_doc(self,
kb_file: KnowledgeFile, kb_file: KnowledgeFile,
**kwargs): **kwargs):
embeddings = self._load_embeddings()
vector_store = self.load_vector_store() vector_store = self.load_vector_store()
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
@ -133,7 +137,7 @@ class FaissKBService(KBService):
vector_store.save_local(self.vs_path) vector_store.save_local(self.vs_path)
self.refresh_vs_cache() self.refresh_vs_cache()
return True return vector_store
def do_clear_vs(self): def do_clear_vs(self):
shutil.rmtree(self.vs_path) shutil.rmtree(self.vs_path)

View File

@ -6,16 +6,16 @@ from langchain.vectorstores import PGVector
from langchain.vectorstores.pgvector import DistanceStrategy from langchain.vectorstores.pgvector import DistanceStrategy
from sqlalchemy import text from sqlalchemy import text
from configs.model_config import EMBEDDING_DEVICE, kbs_config from configs.model_config import kbs_config
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \ from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
score_threshold_process score_threshold_process
from server.knowledge_base.utils import load_embeddings, KnowledgeFile from server.knowledge_base.utils import load_embeddings, KnowledgeFile
from server.utils import embedding_device as get_embedding_device
class PGKBService(KBService): class PGKBService(KBService):
pg_vector: PGVector pg_vector: PGVector
def _load_pg_vector(self, embedding_device: str = EMBEDDING_DEVICE, embeddings: Embeddings = None): def _load_pg_vector(self, embedding_device: str = get_embedding_device(), embeddings: Embeddings = None):
_embeddings = embeddings _embeddings = embeddings
if _embeddings is None: if _embeddings is None:
_embeddings = load_embeddings(self.embed_model, embedding_device) _embeddings = load_embeddings(self.embed_model, embedding_device)

View File

@ -69,6 +69,7 @@ def folder2db(
print(result) print(result)
if kb.vs_type() == SupportedVSType.FAISS: if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache() kb.refresh_vs_cache()
elif mode == "fill_info_only": elif mode == "fill_info_only":
files = list_files_from_folder(kb_name) files = list_files_from_folder(kb_name)
@ -85,6 +86,7 @@ def folder2db(
kb.update_doc(kb_file, not_refresh_vs_cache=True) kb.update_doc(kb_file, not_refresh_vs_cache=True)
if kb.vs_type() == SupportedVSType.FAISS: if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache() kb.refresh_vs_cache()
elif mode == "increament": elif mode == "increament":
db_files = kb.list_files() db_files = kb.list_files()
@ -102,6 +104,7 @@ def folder2db(
print(result) print(result)
if kb.vs_type() == SupportedVSType.FAISS: if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache() kb.refresh_vs_cache()
else: else:
print(f"unspported migrate mode: {mode}") print(f"unspported migrate mode: {mode}")
@ -131,7 +134,10 @@ def prune_db_files(kb_name: str):
files = list(set(files_in_db) - set(files_in_folder)) files = list(set(files_in_db) - set(files_in_folder))
kb_files = file_to_kbfile(kb_name, files) kb_files = file_to_kbfile(kb_name, files)
for kb_file in kb_files: for kb_file in kb_files:
kb.delete_doc(kb_file) kb.delete_doc(kb_file, not_refresh_vs_cache=True)
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache()
return kb_files return kb_files
def prune_folder_files(kb_name: str): def prune_folder_files(kb_name: str):

View File

@ -87,7 +87,8 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
"UnstructuredMarkdownLoader": ['.md'], "UnstructuredMarkdownLoader": ['.md'],
"CustomJSONLoader": [".json"], "CustomJSONLoader": [".json"],
"CSVLoader": [".csv"], "CSVLoader": [".csv"],
"PyPDFLoader": [".pdf"], "RapidOCRPDFLoader": [".pdf"],
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
"UnstructuredFileLoader": ['.eml', '.msg', '.rst', "UnstructuredFileLoader": ['.eml', '.msg', '.rst',
'.rtf', '.txt', '.xml', '.rtf', '.txt', '.xml',
'.doc', '.docx', '.epub', '.odt', '.doc', '.docx', '.epub', '.odt',
@ -196,6 +197,9 @@ class KnowledgeFile:
print(f"{self.document_loader_name} used for {self.filepath}") print(f"{self.document_loader_name} used for {self.filepath}")
try: try:
if self.document_loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
document_loaders_module = importlib.import_module('document_loaders')
else:
document_loaders_module = importlib.import_module('langchain.document_loaders') document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, self.document_loader_name) DocumentLoader = getattr(document_loaders_module, self.document_loader_name)
except Exception as e: except Exception as e:

View File

@ -4,8 +4,8 @@ import sys
import os import os
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger from configs.model_config import llm_model_dict, LLM_MODEL, LOG_PATH, logger
from server.utils import MakeFastAPIOffline, set_httpx_timeout from server.utils import MakeFastAPIOffline, set_httpx_timeout, llm_device
host_ip = "0.0.0.0" host_ip = "0.0.0.0"
@ -34,7 +34,7 @@ def create_model_worker_app(
worker_address=base_url.format(model_worker_port), worker_address=base_url.format(model_worker_port),
controller_address=base_url.format(controller_port), controller_address=base_url.format(controller_port),
model_path=llm_model_dict[LLM_MODEL].get("local_model_path"), model_path=llm_model_dict[LLM_MODEL].get("local_model_path"),
device=LLM_DEVICE, device=llm_device(),
gpus=None, gpus=None,
max_gpu_memory="20GiB", max_gpu_memory="20GiB",
load_8bit=False, load_8bit=False,

View File

@ -5,8 +5,8 @@ import torch
from fastapi import FastAPI from fastapi import FastAPI
from pathlib import Path from pathlib import Path
import asyncio import asyncio
from configs.model_config import LLM_MODEL from configs.model_config import LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE
from typing import Any, Optional from typing import Literal, Optional
class BaseResponse(BaseModel): class BaseResponse(BaseModel):
@ -201,6 +201,7 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy() config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
config.update(llm_model_dict.get(model_name, {})) config.update(llm_model_dict.get(model_name, {}))
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {})) config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
config["device"] = llm_device(config.get("device"))
return config return config
@ -256,3 +257,28 @@ def set_httpx_timeout(timeout: float = None):
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
# 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch
def detect_device() -> Literal["cuda", "mps", "cpu"]:
try:
import torch
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
except:
pass
return "cpu"
def llm_device(device: str = LLM_DEVICE) -> Literal["cuda", "mps", "cpu"]:
if device not in ["cuda", "mps", "cpu"]:
device = detect_device()
return device
def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", "cpu"]:
if device not in ["cuda", "mps", "cpu"]:
device = detect_device()
return device

View File

@ -14,12 +14,13 @@ except:
pass pass
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, \ from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
logger logger
from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS, from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS,
FSCHAT_OPENAI_API, ) FSCHAT_OPENAI_API, )
from server.utils import (fschat_controller_address, fschat_model_worker_address, from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_timeout) fschat_openai_api_address, set_httpx_timeout,
llm_device, embedding_device, get_model_worker_config)
from server.utils import MakeFastAPIOffline, FastAPI from server.utils import MakeFastAPIOffline, FastAPI
import argparse import argparse
from typing import Tuple, List from typing import Tuple, List
@ -28,10 +29,12 @@ from configs import VERSION
def create_controller_app( def create_controller_app(
dispatch_method: str, dispatch_method: str,
log_level: str = "INFO",
) -> FastAPI: ) -> FastAPI:
import fastchat.constants import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.controller import app, Controller from fastchat.serve.controller import app, Controller, logger
logger.setLevel(log_level)
controller = Controller(dispatch_method) controller = Controller(dispatch_method)
sys.modules["fastchat.serve.controller"].controller = controller sys.modules["fastchat.serve.controller"].controller = controller
@ -41,13 +44,14 @@ def create_controller_app(
return app return app
def create_model_worker_app(**kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]: def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]:
import fastchat.constants import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id, logger
import argparse import argparse
import threading import threading
import fastchat.serve.model_worker import fastchat.serve.model_worker
logger.setLevel(log_level)
# workaround to make program exit with Ctrl+c # workaround to make program exit with Ctrl+c
# it should be deleted after pr is merged by fastchat # it should be deleted after pr is merged by fastchat
@ -136,10 +140,14 @@ def create_model_worker_app(**kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]
def create_openai_api_app( def create_openai_api_app(
controller_address: str, controller_address: str,
api_keys: List = [], api_keys: List = [],
log_level: str = "INFO",
) -> FastAPI: ) -> FastAPI:
import fastchat.constants import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
from fastchat.utils import build_logger
logger = build_logger("openai_api", "openai_api.log")
logger.setLevel(log_level)
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
@ -149,6 +157,7 @@ def create_openai_api_app(
allow_headers=["*"], allow_headers=["*"],
) )
sys.modules["fastchat.serve.openai_api_server"].logger = logger
app_settings.controller_address = controller_address app_settings.controller_address = controller_address
app_settings.api_keys = api_keys app_settings.api_keys = api_keys
@ -158,6 +167,9 @@ def create_openai_api_app(
def _set_app_seq(app: FastAPI, q: Queue, run_seq: int): def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
if q is None or not isinstance(run_seq, int):
return
if run_seq == 1: if run_seq == 1:
@app.on_event("startup") @app.on_event("startup")
async def on_startup(): async def on_startup():
@ -176,15 +188,22 @@ def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
q.put(run_seq) q.put(run_seq)
def run_controller(q: Queue, run_seq: int = 1): def run_controller(q: Queue, run_seq: int = 1, log_level: str ="INFO"):
import uvicorn import uvicorn
import sys
app = create_controller_app(FSCHAT_CONTROLLER.get("dispatch_method")) app = create_controller_app(
dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"),
log_level=log_level,
)
_set_app_seq(app, q, run_seq) _set_app_seq(app, q, run_seq)
host = FSCHAT_CONTROLLER["host"] host = FSCHAT_CONTROLLER["host"]
port = FSCHAT_CONTROLLER["port"] port = FSCHAT_CONTROLLER["port"]
uvicorn.run(app, host=host, port=port) if log_level == "ERROR":
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
def run_model_worker( def run_model_worker(
@ -192,10 +211,12 @@ def run_model_worker(
controller_address: str = "", controller_address: str = "",
q: Queue = None, q: Queue = None,
run_seq: int = 2, run_seq: int = 2,
log_level: str ="INFO",
): ):
import uvicorn import uvicorn
import sys
kwargs = FSCHAT_MODEL_WORKERS[model_name].copy() kwargs = get_model_worker_config(model_name)
host = kwargs.pop("host") host = kwargs.pop("host")
port = kwargs.pop("port") port = kwargs.pop("port")
model_path = llm_model_dict[model_name].get("local_model_path", "") model_path = llm_model_dict[model_name].get("local_model_path", "")
@ -204,21 +225,28 @@ def run_model_worker(
kwargs["controller_address"] = controller_address or fschat_controller_address() kwargs["controller_address"] = controller_address or fschat_controller_address()
kwargs["worker_address"] = fschat_model_worker_address() kwargs["worker_address"] = fschat_model_worker_address()
app = create_model_worker_app(**kwargs) app = create_model_worker_app(log_level=log_level, **kwargs)
_set_app_seq(app, q, run_seq) _set_app_seq(app, q, run_seq)
if log_level == "ERROR":
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
uvicorn.run(app, host=host, port=port) uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
def run_openai_api(q: Queue, run_seq: int = 3): def run_openai_api(q: Queue, run_seq: int = 3, log_level: str = "INFO"):
import uvicorn import uvicorn
import sys
controller_addr = fschat_controller_address() controller_addr = fschat_controller_address()
app = create_openai_api_app(controller_addr) # todo: not support keys yet. app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet.
_set_app_seq(app, q, run_seq) _set_app_seq(app, q, run_seq)
host = FSCHAT_OPENAI_API["host"] host = FSCHAT_OPENAI_API["host"]
port = FSCHAT_OPENAI_API["port"] port = FSCHAT_OPENAI_API["port"]
if log_level == "ERROR":
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
uvicorn.run(app, host=host, port=port) uvicorn.run(app, host=host, port=port)
@ -238,6 +266,8 @@ def run_api_server(q: Queue, run_seq: int = 4):
def run_webui(q: Queue, run_seq: int = 5): def run_webui(q: Queue, run_seq: int = 5):
host = WEBUI_SERVER["host"] host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"] port = WEBUI_SERVER["port"]
if q is not None and isinstance(run_seq, int):
while True: while True:
no = q.get() no = q.get()
if no != run_seq - 1: if no != run_seq - 1:
@ -314,11 +344,18 @@ def parse_args() -> argparse.ArgumentParser:
help="run webui.py server", help="run webui.py server",
dest="webui", dest="webui",
) )
parser.add_argument(
"-q",
"--quiet",
action="store_true",
help="减少fastchat服务log信息",
dest="quiet",
)
args = parser.parse_args() args = parser.parse_args()
return args, parser return args, parser
def dump_server_info(after_start=False): def dump_server_info(after_start=False, args=None):
import platform import platform
import langchain import langchain
import fastchat import fastchat
@ -331,9 +368,9 @@ def dump_server_info(after_start=False):
print(f"项目版本:{VERSION}") print(f"项目版本:{VERSION}")
print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}") print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}")
print("\n") print("\n")
print(f"当前LLM模型{LLM_MODEL} @ {LLM_DEVICE}") print(f"当前LLM模型{LLM_MODEL} @ {llm_device()}")
pprint(llm_model_dict[LLM_MODEL]) pprint(llm_model_dict[LLM_MODEL])
print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {EMBEDDING_DEVICE}") print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {embedding_device()}")
if after_start: if after_start:
print("\n") print("\n")
print(f"服务端运行信息:") print(f"服务端运行信息:")
@ -354,6 +391,7 @@ if __name__ == "__main__":
mp.set_start_method("spawn") mp.set_start_method("spawn")
queue = Queue() queue = Queue()
args, parser = parse_args() args, parser = parse_args()
if args.all_webui: if args.all_webui:
args.openai_api = True args.openai_api = True
args.model_worker = True args.model_worker = True
@ -372,19 +410,23 @@ if __name__ == "__main__":
args.api = False args.api = False
args.webui = False args.webui = False
dump_server_info() dump_server_info(args=args)
if len(sys.argv) > 1: if len(sys.argv) > 1:
logger.info(f"正在启动服务:") logger.info(f"正在启动服务:")
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}") logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
processes = {} processes = {}
if args.quiet:
log_level = "ERROR"
else:
log_level = "INFO"
if args.openai_api: if args.openai_api:
process = Process( process = Process(
target=run_controller, target=run_controller,
name=f"controller({os.getpid()})", name=f"controller({os.getpid()})",
args=(queue, len(processes) + 1), args=(queue, len(processes) + 1, log_level),
daemon=True, daemon=True,
) )
process.start() process.start()
@ -405,7 +447,7 @@ if __name__ == "__main__":
process = Process( process = Process(
target=run_model_worker, target=run_model_worker,
name=f"model_worker({os.getpid()})", name=f"model_worker({os.getpid()})",
args=(args.model_name, args.controller_address, queue, len(processes) + 1), args=(args.model_name, args.controller_address, queue, len(processes) + 1, log_level),
daemon=True, daemon=True,
) )
process.start() process.start()
@ -440,7 +482,7 @@ if __name__ == "__main__":
no = queue.get() no = queue.get()
if no == len(processes): if no == len(processes):
time.sleep(0.5) time.sleep(0.5)
dump_server_info(True) dump_server_info(after_start=True, args=args)
break break
else: else:
queue.put(no) queue.put(no)

BIN
tests/samples/ocr_test.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.9 KiB

BIN
tests/samples/ocr_test.pdf Normal file

Binary file not shown.

View File

@ -10,8 +10,10 @@ from streamlit_option_menu import option_menu
from webui_pages import * from webui_pages import *
import os import os
from configs import VERSION from configs import VERSION
from server.utils import api_address
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False)
api = ApiRequest(base_url=api_address())
if __name__ == "__main__": if __name__ == "__main__":
st.set_page_config( st.set_page_config(