diff --git a/.gitignore b/.gitignore index c4178a9..a7ef90f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,7 @@ logs .idea/ __pycache__/ -knowledge_base/ +/knowledge_base/ configs/*.py .vscode/ .pytest_cache/ diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 4169adb..53097db 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -1,6 +1,5 @@ import os import logging -import torch # 日志格式 LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" logger = logging.getLogger() @@ -32,8 +31,8 @@ embedding_model_dict = { # 选用的 Embedding 名称 EMBEDDING_MODEL = "m3e-base" -# Embedding 模型运行设备 -EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" +# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。 +EMBEDDING_DEVICE = "auto" llm_model_dict = { @@ -77,15 +76,14 @@ llm_model_dict = { }, } - # LLM 名称 LLM_MODEL = "chatglm2-6b" # 历史对话轮数 HISTORY_LEN = 3 -# LLM 运行设备 -LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" +# LLM 运行设备。可选项同Embedding 运行设备。 +LLM_DEVICE = "auto" # 日志存储路径 LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs") @@ -167,4 +165,4 @@ BING_SUBSCRIPTION_KEY = "" # 是否开启中文标题加强,以及标题增强的相关配置 # 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记; # 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。 -ZH_TITLE_ENHANCE = False \ No newline at end of file +ZH_TITLE_ENHANCE = False diff --git a/document_loaders/__init__.py b/document_loaders/__init__.py new file mode 100644 index 0000000..a4d6b28 --- /dev/null +++ b/document_loaders/__init__.py @@ -0,0 +1,2 @@ +from .mypdfloader import RapidOCRPDFLoader +from .myimgloader import RapidOCRLoader \ No newline at end of file diff --git a/document_loaders/myimgloader.py b/document_loaders/myimgloader.py new file mode 100644 index 0000000..8648192 --- /dev/null +++ b/document_loaders/myimgloader.py @@ -0,0 +1,25 @@ +from typing import List +from langchain.document_loaders.unstructured import UnstructuredFileLoader + + +class RapidOCRLoader(UnstructuredFileLoader): + def _get_elements(self) -> List: + def img2text(filepath): + from rapidocr_onnxruntime import RapidOCR + resp = "" + ocr = RapidOCR() + result, _ = ocr(filepath) + if result: + ocr_result = [line[1] for line in result] + resp += "\n".join(ocr_result) + return resp + + text = img2text(self.file_path) + from unstructured.partition.text import partition_text + return partition_text(text=text, **self.unstructured_kwargs) + + +if __name__ == "__main__": + loader = RapidOCRLoader(file_path="../tests/samples/ocr_test.jpg") + docs = loader.load() + print(docs) diff --git a/document_loaders/mypdfloader.py b/document_loaders/mypdfloader.py new file mode 100644 index 0000000..71e063d --- /dev/null +++ b/document_loaders/mypdfloader.py @@ -0,0 +1,37 @@ +from typing import List +from langchain.document_loaders.unstructured import UnstructuredFileLoader + + +class RapidOCRPDFLoader(UnstructuredFileLoader): + def _get_elements(self) -> List: + def pdf2text(filepath): + import fitz + from rapidocr_onnxruntime import RapidOCR + import numpy as np + ocr = RapidOCR() + doc = fitz.open(filepath) + resp = "" + for page in doc: + # TODO: 依据文本与图片顺序调整处理方式 + text = page.get_text("") + resp += text + "\n" + + img_list = page.get_images() + for img in img_list: + pix = fitz.Pixmap(doc, img[0]) + img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1) + result, _ = ocr(img_array) + if result: + ocr_result = [line[1] for line in result] + resp += "\n".join(ocr_result) + return resp + + text = pdf2text(self.file_path) + from unstructured.partition.text import partition_text + return partition_text(text=text, **self.unstructured_kwargs) + + +if __name__ == "__main__": + loader = RapidOCRPDFLoader(file_path="../tests/samples/ocr_test.pdf") + docs = loader.load() + print(docs) diff --git a/requirements.txt b/requirements.txt index e40f665..4271f3a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,8 @@ SQLAlchemy==2.0.19 faiss-cpu accelerate spacy +PyMuPDF==1.22.5 +rapidocr_onnxruntime>=1.3.1 # uncomment libs if you want to use corresponding vector store # pymilvus==2.1.3 # requires milvus==2.1.3 diff --git a/requirements_api.txt b/requirements_api.txt index 58dbc0c..bdecf3c 100644 --- a/requirements_api.txt +++ b/requirements_api.txt @@ -16,6 +16,8 @@ faiss-cpu nltk accelerate spacy +PyMuPDF==1.22.5 +rapidocr_onnxruntime>=1.3.1 # uncomment libs if you want to use corresponding vector store # pymilvus==2.1.3 # requires milvus==2.1.3 diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index ca8d1ae..79b1518 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -18,11 +18,12 @@ from server.db.repository.knowledge_file_repository import ( ) from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, - EMBEDDING_DEVICE, EMBEDDING_MODEL) + EMBEDDING_MODEL) from server.knowledge_base.utils import ( get_kb_path, get_doc_path, load_embeddings, KnowledgeFile, list_kbs_from_folder, list_files_from_folder, ) +from server.utils import embedding_device from typing import List, Union, Dict @@ -45,7 +46,7 @@ class KBService(ABC): self.doc_path = get_doc_path(self.kb_name) self.do_init() - def _load_embeddings(self, embed_device: str = EMBEDDING_DEVICE) -> Embeddings: + def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings: return load_embeddings(self.embed_model, embed_device) def create_kb(self): diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 3601b57..f17b2da 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -5,7 +5,6 @@ from configs.model_config import ( KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_MODEL, - EMBEDDING_DEVICE, SCORE_THRESHOLD ) from server.knowledge_base.kb_service.base import KBService, SupportedVSType @@ -15,7 +14,7 @@ from langchain.vectorstores import FAISS from langchain.embeddings.base import Embeddings from typing import List from langchain.docstore.document import Document -from server.utils import torch_gc +from server.utils import torch_gc, embedding_device _VECTOR_STORE_TICKS = {} @@ -25,10 +24,10 @@ _VECTOR_STORE_TICKS = {} def load_faiss_vector_store( knowledge_base_name: str, embed_model: str = EMBEDDING_MODEL, - embed_device: str = EMBEDDING_DEVICE, + embed_device: str = embedding_device(), embeddings: Embeddings = None, tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed. -): +) -> FAISS: print(f"loading vector store in '{knowledge_base_name}'.") vs_path = get_vs_path(knowledge_base_name) if embeddings is None: @@ -74,13 +73,18 @@ class FaissKBService(KBService): def get_kb_path(self): return os.path.join(KB_ROOT_PATH, self.kb_name) - def load_vector_store(self): + def load_vector_store(self) -> FAISS: return load_faiss_vector_store( knowledge_base_name=self.kb_name, embed_model=self.embed_model, tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0), ) + def save_vector_store(self, vector_store: FAISS = None): + vector_store = vector_store or self.load_vector_store() + vector_store.save_local(self.vs_path) + return vector_store + def refresh_vs_cache(self): refresh_vs_cache(self.kb_name) @@ -117,11 +121,11 @@ class FaissKBService(KBService): if not kwargs.get("not_refresh_vs_cache"): vector_store.save_local(self.vs_path) self.refresh_vs_cache() + return vector_store def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): - embeddings = self._load_embeddings() vector_store = self.load_vector_store() ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] @@ -133,7 +137,7 @@ class FaissKBService(KBService): vector_store.save_local(self.vs_path) self.refresh_vs_cache() - return True + return vector_store def do_clear_vs(self): shutil.rmtree(self.vs_path) diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index 4b52625..3e3dd52 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -6,16 +6,16 @@ from langchain.vectorstores import PGVector from langchain.vectorstores.pgvector import DistanceStrategy from sqlalchemy import text -from configs.model_config import EMBEDDING_DEVICE, kbs_config +from configs.model_config import kbs_config from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \ score_threshold_process from server.knowledge_base.utils import load_embeddings, KnowledgeFile - +from server.utils import embedding_device as get_embedding_device class PGKBService(KBService): pg_vector: PGVector - def _load_pg_vector(self, embedding_device: str = EMBEDDING_DEVICE, embeddings: Embeddings = None): + def _load_pg_vector(self, embedding_device: str = get_embedding_device(), embeddings: Embeddings = None): _embeddings = embeddings if _embeddings is None: _embeddings = load_embeddings(self.embed_model, embedding_device) diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index af506e2..4285b79 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -69,6 +69,7 @@ def folder2db( print(result) if kb.vs_type() == SupportedVSType.FAISS: + kb.save_vector_store() kb.refresh_vs_cache() elif mode == "fill_info_only": files = list_files_from_folder(kb_name) @@ -85,6 +86,7 @@ def folder2db( kb.update_doc(kb_file, not_refresh_vs_cache=True) if kb.vs_type() == SupportedVSType.FAISS: + kb.save_vector_store() kb.refresh_vs_cache() elif mode == "increament": db_files = kb.list_files() @@ -102,6 +104,7 @@ def folder2db( print(result) if kb.vs_type() == SupportedVSType.FAISS: + kb.save_vector_store() kb.refresh_vs_cache() else: print(f"unspported migrate mode: {mode}") @@ -131,7 +134,10 @@ def prune_db_files(kb_name: str): files = list(set(files_in_db) - set(files_in_folder)) kb_files = file_to_kbfile(kb_name, files) for kb_file in kb_files: - kb.delete_doc(kb_file) + kb.delete_doc(kb_file, not_refresh_vs_cache=True) + if kb.vs_type() == SupportedVSType.FAISS: + kb.save_vector_store() + kb.refresh_vs_cache() return kb_files def prune_folder_files(kb_name: str): diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 8cab754..a8a9bcc 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -87,7 +87,8 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'], "UnstructuredMarkdownLoader": ['.md'], "CustomJSONLoader": [".json"], "CSVLoader": [".csv"], - "PyPDFLoader": [".pdf"], + "RapidOCRPDFLoader": [".pdf"], + "RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'], "UnstructuredFileLoader": ['.eml', '.msg', '.rst', '.rtf', '.txt', '.xml', '.doc', '.docx', '.epub', '.odt', @@ -196,7 +197,10 @@ class KnowledgeFile: print(f"{self.document_loader_name} used for {self.filepath}") try: - document_loaders_module = importlib.import_module('langchain.document_loaders') + if self.document_loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]: + document_loaders_module = importlib.import_module('document_loaders') + else: + document_loaders_module = importlib.import_module('langchain.document_loaders') DocumentLoader = getattr(document_loaders_module, self.document_loader_name) except Exception as e: print(e) diff --git a/server/llm_api.py b/server/llm_api.py index 7ef5891..d9667e4 100644 --- a/server/llm_api.py +++ b/server/llm_api.py @@ -4,8 +4,8 @@ import sys import os sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger -from server.utils import MakeFastAPIOffline, set_httpx_timeout +from configs.model_config import llm_model_dict, LLM_MODEL, LOG_PATH, logger +from server.utils import MakeFastAPIOffline, set_httpx_timeout, llm_device host_ip = "0.0.0.0" @@ -34,7 +34,7 @@ def create_model_worker_app( worker_address=base_url.format(model_worker_port), controller_address=base_url.format(controller_port), model_path=llm_model_dict[LLM_MODEL].get("local_model_path"), - device=LLM_DEVICE, + device=llm_device(), gpus=None, max_gpu_memory="20GiB", load_8bit=False, diff --git a/server/utils.py b/server/utils.py index 167b672..d716582 100644 --- a/server/utils.py +++ b/server/utils.py @@ -5,8 +5,8 @@ import torch from fastapi import FastAPI from pathlib import Path import asyncio -from configs.model_config import LLM_MODEL -from typing import Any, Optional +from configs.model_config import LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE +from typing import Literal, Optional class BaseResponse(BaseModel): @@ -201,6 +201,7 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict: config = FSCHAT_MODEL_WORKERS.get("default", {}).copy() config.update(llm_model_dict.get(model_name, {})) config.update(FSCHAT_MODEL_WORKERS.get(model_name, {})) + config["device"] = llm_device(config.get("device")) return config @@ -256,3 +257,28 @@ def set_httpx_timeout(timeout: float = None): httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout + + +# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch +def detect_device() -> Literal["cuda", "mps", "cpu"]: + try: + import torch + if torch.cuda.is_available(): + return "cuda" + if torch.backends.mps.is_available(): + return "mps" + except: + pass + return "cpu" + + +def llm_device(device: str = LLM_DEVICE) -> Literal["cuda", "mps", "cpu"]: + if device not in ["cuda", "mps", "cpu"]: + device = detect_device() + return device + + +def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", "cpu"]: + if device not in ["cuda", "mps", "cpu"]: + device = detect_device() + return device diff --git a/startup.py b/startup.py index 64a3bcc..5ef05ce 100644 --- a/startup.py +++ b/startup.py @@ -14,12 +14,13 @@ except: pass sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, \ +from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \ logger from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS, FSCHAT_OPENAI_API, ) from server.utils import (fschat_controller_address, fschat_model_worker_address, - fschat_openai_api_address, set_httpx_timeout) + fschat_openai_api_address, set_httpx_timeout, + llm_device, embedding_device, get_model_worker_config) from server.utils import MakeFastAPIOffline, FastAPI import argparse from typing import Tuple, List @@ -28,10 +29,12 @@ from configs import VERSION def create_controller_app( dispatch_method: str, + log_level: str = "INFO", ) -> FastAPI: import fastchat.constants fastchat.constants.LOGDIR = LOG_PATH - from fastchat.serve.controller import app, Controller + from fastchat.serve.controller import app, Controller, logger + logger.setLevel(log_level) controller = Controller(dispatch_method) sys.modules["fastchat.serve.controller"].controller = controller @@ -41,13 +44,14 @@ def create_controller_app( return app -def create_model_worker_app(**kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]: +def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]: import fastchat.constants fastchat.constants.LOGDIR = LOG_PATH - from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id + from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id, logger import argparse import threading import fastchat.serve.model_worker + logger.setLevel(log_level) # workaround to make program exit with Ctrl+c # it should be deleted after pr is merged by fastchat @@ -136,10 +140,14 @@ def create_model_worker_app(**kwargs) -> Tuple[argparse.ArgumentParser, FastAPI] def create_openai_api_app( controller_address: str, api_keys: List = [], + log_level: str = "INFO", ) -> FastAPI: import fastchat.constants fastchat.constants.LOGDIR = LOG_PATH from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings + from fastchat.utils import build_logger + logger = build_logger("openai_api", "openai_api.log") + logger.setLevel(log_level) app.add_middleware( CORSMiddleware, @@ -149,6 +157,7 @@ def create_openai_api_app( allow_headers=["*"], ) + sys.modules["fastchat.serve.openai_api_server"].logger = logger app_settings.controller_address = controller_address app_settings.api_keys = api_keys @@ -158,6 +167,9 @@ def create_openai_api_app( def _set_app_seq(app: FastAPI, q: Queue, run_seq: int): + if q is None or not isinstance(run_seq, int): + return + if run_seq == 1: @app.on_event("startup") async def on_startup(): @@ -176,15 +188,22 @@ def _set_app_seq(app: FastAPI, q: Queue, run_seq: int): q.put(run_seq) -def run_controller(q: Queue, run_seq: int = 1): +def run_controller(q: Queue, run_seq: int = 1, log_level: str ="INFO"): import uvicorn + import sys - app = create_controller_app(FSCHAT_CONTROLLER.get("dispatch_method")) + app = create_controller_app( + dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"), + log_level=log_level, + ) _set_app_seq(app, q, run_seq) host = FSCHAT_CONTROLLER["host"] port = FSCHAT_CONTROLLER["port"] - uvicorn.run(app, host=host, port=port) + if log_level == "ERROR": + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ + uvicorn.run(app, host=host, port=port, log_level=log_level.lower()) def run_model_worker( @@ -192,10 +211,12 @@ def run_model_worker( controller_address: str = "", q: Queue = None, run_seq: int = 2, + log_level: str ="INFO", ): import uvicorn + import sys - kwargs = FSCHAT_MODEL_WORKERS[model_name].copy() + kwargs = get_model_worker_config(model_name) host = kwargs.pop("host") port = kwargs.pop("port") model_path = llm_model_dict[model_name].get("local_model_path", "") @@ -204,21 +225,28 @@ def run_model_worker( kwargs["controller_address"] = controller_address or fschat_controller_address() kwargs["worker_address"] = fschat_model_worker_address() - app = create_model_worker_app(**kwargs) + app = create_model_worker_app(log_level=log_level, **kwargs) _set_app_seq(app, q, run_seq) + if log_level == "ERROR": + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ - uvicorn.run(app, host=host, port=port) + uvicorn.run(app, host=host, port=port, log_level=log_level.lower()) -def run_openai_api(q: Queue, run_seq: int = 3): +def run_openai_api(q: Queue, run_seq: int = 3, log_level: str = "INFO"): import uvicorn + import sys controller_addr = fschat_controller_address() - app = create_openai_api_app(controller_addr) # todo: not support keys yet. + app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet. _set_app_seq(app, q, run_seq) host = FSCHAT_OPENAI_API["host"] port = FSCHAT_OPENAI_API["port"] + if log_level == "ERROR": + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ uvicorn.run(app, host=host, port=port) @@ -238,13 +266,15 @@ def run_api_server(q: Queue, run_seq: int = 4): def run_webui(q: Queue, run_seq: int = 5): host = WEBUI_SERVER["host"] port = WEBUI_SERVER["port"] - while True: - no = q.get() - if no != run_seq - 1: - q.put(no) - else: - break - q.put(run_seq) + + if q is not None and isinstance(run_seq, int): + while True: + no = q.get() + if no != run_seq - 1: + q.put(no) + else: + break + q.put(run_seq) p = subprocess.Popen(["streamlit", "run", "webui.py", "--server.address", host, "--server.port", str(port)]) @@ -314,11 +344,18 @@ def parse_args() -> argparse.ArgumentParser: help="run webui.py server", dest="webui", ) + parser.add_argument( + "-q", + "--quiet", + action="store_true", + help="减少fastchat服务log信息", + dest="quiet", + ) args = parser.parse_args() return args, parser -def dump_server_info(after_start=False): +def dump_server_info(after_start=False, args=None): import platform import langchain import fastchat @@ -331,9 +368,9 @@ def dump_server_info(after_start=False): print(f"项目版本:{VERSION}") print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}") print("\n") - print(f"当前LLM模型:{LLM_MODEL} @ {LLM_DEVICE}") + print(f"当前LLM模型:{LLM_MODEL} @ {llm_device()}") pprint(llm_model_dict[LLM_MODEL]) - print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {EMBEDDING_DEVICE}") + print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}") if after_start: print("\n") print(f"服务端运行信息:") @@ -354,6 +391,7 @@ if __name__ == "__main__": mp.set_start_method("spawn") queue = Queue() args, parser = parse_args() + if args.all_webui: args.openai_api = True args.model_worker = True @@ -372,19 +410,23 @@ if __name__ == "__main__": args.api = False args.webui = False - dump_server_info() + dump_server_info(args=args) if len(sys.argv) > 1: logger.info(f"正在启动服务:") logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}") processes = {} + if args.quiet: + log_level = "ERROR" + else: + log_level = "INFO" if args.openai_api: process = Process( target=run_controller, name=f"controller({os.getpid()})", - args=(queue, len(processes) + 1), + args=(queue, len(processes) + 1, log_level), daemon=True, ) process.start() @@ -405,7 +447,7 @@ if __name__ == "__main__": process = Process( target=run_model_worker, name=f"model_worker({os.getpid()})", - args=(args.model_name, args.controller_address, queue, len(processes) + 1), + args=(args.model_name, args.controller_address, queue, len(processes) + 1, log_level), daemon=True, ) process.start() @@ -440,7 +482,7 @@ if __name__ == "__main__": no = queue.get() if no == len(processes): time.sleep(0.5) - dump_server_info(True) + dump_server_info(after_start=True, args=args) break else: queue.put(no) diff --git a/tests/samples/ocr_test.jpg b/tests/samples/ocr_test.jpg new file mode 100644 index 0000000..70c199b Binary files /dev/null and b/tests/samples/ocr_test.jpg differ diff --git a/tests/samples/ocr_test.pdf b/tests/samples/ocr_test.pdf new file mode 100644 index 0000000..3a137ad Binary files /dev/null and b/tests/samples/ocr_test.pdf differ diff --git a/webui.py b/webui.py index 58fc0e3..0cda9eb 100644 --- a/webui.py +++ b/webui.py @@ -10,8 +10,10 @@ from streamlit_option_menu import option_menu from webui_pages import * import os from configs import VERSION +from server.utils import api_address -api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False) + +api = ApiRequest(base_url=api_address()) if __name__ == "__main__": st.set_page_config(