Merge remote-tracking branch 'origin/dev' into dev
This commit is contained in:
commit
ab4c8d2e5d
|
|
@ -3,7 +3,7 @@
|
||||||
logs
|
logs
|
||||||
.idea/
|
.idea/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
knowledge_base/
|
/knowledge_base/
|
||||||
configs/*.py
|
configs/*.py
|
||||||
.vscode/
|
.vscode/
|
||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import torch
|
|
||||||
# 日志格式
|
# 日志格式
|
||||||
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
@ -32,8 +31,8 @@ embedding_model_dict = {
|
||||||
# 选用的 Embedding 名称
|
# 选用的 Embedding 名称
|
||||||
EMBEDDING_MODEL = "m3e-base"
|
EMBEDDING_MODEL = "m3e-base"
|
||||||
|
|
||||||
# Embedding 模型运行设备
|
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
|
||||||
EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
EMBEDDING_DEVICE = "auto"
|
||||||
|
|
||||||
|
|
||||||
llm_model_dict = {
|
llm_model_dict = {
|
||||||
|
|
@ -77,15 +76,14 @@ llm_model_dict = {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# LLM 名称
|
# LLM 名称
|
||||||
LLM_MODEL = "chatglm2-6b"
|
LLM_MODEL = "chatglm2-6b"
|
||||||
|
|
||||||
# 历史对话轮数
|
# 历史对话轮数
|
||||||
HISTORY_LEN = 3
|
HISTORY_LEN = 3
|
||||||
|
|
||||||
# LLM 运行设备
|
# LLM 运行设备。可选项同Embedding 运行设备。
|
||||||
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
LLM_DEVICE = "auto"
|
||||||
|
|
||||||
# 日志存储路径
|
# 日志存储路径
|
||||||
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
|
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
from .mypdfloader import RapidOCRPDFLoader
|
||||||
|
from .myimgloader import RapidOCRLoader
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
from typing import List
|
||||||
|
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||||
|
|
||||||
|
|
||||||
|
class RapidOCRLoader(UnstructuredFileLoader):
|
||||||
|
def _get_elements(self) -> List:
|
||||||
|
def img2text(filepath):
|
||||||
|
from rapidocr_onnxruntime import RapidOCR
|
||||||
|
resp = ""
|
||||||
|
ocr = RapidOCR()
|
||||||
|
result, _ = ocr(filepath)
|
||||||
|
if result:
|
||||||
|
ocr_result = [line[1] for line in result]
|
||||||
|
resp += "\n".join(ocr_result)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
text = img2text(self.file_path)
|
||||||
|
from unstructured.partition.text import partition_text
|
||||||
|
return partition_text(text=text, **self.unstructured_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
loader = RapidOCRLoader(file_path="../tests/samples/ocr_test.jpg")
|
||||||
|
docs = loader.load()
|
||||||
|
print(docs)
|
||||||
|
|
@ -0,0 +1,37 @@
|
||||||
|
from typing import List
|
||||||
|
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||||
|
|
||||||
|
|
||||||
|
class RapidOCRPDFLoader(UnstructuredFileLoader):
|
||||||
|
def _get_elements(self) -> List:
|
||||||
|
def pdf2text(filepath):
|
||||||
|
import fitz
|
||||||
|
from rapidocr_onnxruntime import RapidOCR
|
||||||
|
import numpy as np
|
||||||
|
ocr = RapidOCR()
|
||||||
|
doc = fitz.open(filepath)
|
||||||
|
resp = ""
|
||||||
|
for page in doc:
|
||||||
|
# TODO: 依据文本与图片顺序调整处理方式
|
||||||
|
text = page.get_text("")
|
||||||
|
resp += text + "\n"
|
||||||
|
|
||||||
|
img_list = page.get_images()
|
||||||
|
for img in img_list:
|
||||||
|
pix = fitz.Pixmap(doc, img[0])
|
||||||
|
img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1)
|
||||||
|
result, _ = ocr(img_array)
|
||||||
|
if result:
|
||||||
|
ocr_result = [line[1] for line in result]
|
||||||
|
resp += "\n".join(ocr_result)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
text = pdf2text(self.file_path)
|
||||||
|
from unstructured.partition.text import partition_text
|
||||||
|
return partition_text(text=text, **self.unstructured_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
loader = RapidOCRPDFLoader(file_path="../tests/samples/ocr_test.pdf")
|
||||||
|
docs = loader.load()
|
||||||
|
print(docs)
|
||||||
|
|
@ -15,6 +15,8 @@ SQLAlchemy==2.0.19
|
||||||
faiss-cpu
|
faiss-cpu
|
||||||
accelerate
|
accelerate
|
||||||
spacy
|
spacy
|
||||||
|
PyMuPDF==1.22.5
|
||||||
|
rapidocr_onnxruntime>=1.3.1
|
||||||
|
|
||||||
# uncomment libs if you want to use corresponding vector store
|
# uncomment libs if you want to use corresponding vector store
|
||||||
# pymilvus==2.1.3 # requires milvus==2.1.3
|
# pymilvus==2.1.3 # requires milvus==2.1.3
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,8 @@ faiss-cpu
|
||||||
nltk
|
nltk
|
||||||
accelerate
|
accelerate
|
||||||
spacy
|
spacy
|
||||||
|
PyMuPDF==1.22.5
|
||||||
|
rapidocr_onnxruntime>=1.3.1
|
||||||
|
|
||||||
# uncomment libs if you want to use corresponding vector store
|
# uncomment libs if you want to use corresponding vector store
|
||||||
# pymilvus==2.1.3 # requires milvus==2.1.3
|
# pymilvus==2.1.3 # requires milvus==2.1.3
|
||||||
|
|
|
||||||
|
|
@ -18,11 +18,12 @@ from server.db.repository.knowledge_file_repository import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||||
EMBEDDING_DEVICE, EMBEDDING_MODEL)
|
EMBEDDING_MODEL)
|
||||||
from server.knowledge_base.utils import (
|
from server.knowledge_base.utils import (
|
||||||
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
|
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
|
||||||
list_kbs_from_folder, list_files_from_folder,
|
list_kbs_from_folder, list_files_from_folder,
|
||||||
)
|
)
|
||||||
|
from server.utils import embedding_device
|
||||||
from typing import List, Union, Dict
|
from typing import List, Union, Dict
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -45,7 +46,7 @@ class KBService(ABC):
|
||||||
self.doc_path = get_doc_path(self.kb_name)
|
self.doc_path = get_doc_path(self.kb_name)
|
||||||
self.do_init()
|
self.do_init()
|
||||||
|
|
||||||
def _load_embeddings(self, embed_device: str = EMBEDDING_DEVICE) -> Embeddings:
|
def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
|
||||||
return load_embeddings(self.embed_model, embed_device)
|
return load_embeddings(self.embed_model, embed_device)
|
||||||
|
|
||||||
def create_kb(self):
|
def create_kb(self):
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ from configs.model_config import (
|
||||||
KB_ROOT_PATH,
|
KB_ROOT_PATH,
|
||||||
CACHED_VS_NUM,
|
CACHED_VS_NUM,
|
||||||
EMBEDDING_MODEL,
|
EMBEDDING_MODEL,
|
||||||
EMBEDDING_DEVICE,
|
|
||||||
SCORE_THRESHOLD
|
SCORE_THRESHOLD
|
||||||
)
|
)
|
||||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||||
|
|
@ -15,7 +14,7 @@ from langchain.vectorstores import FAISS
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from typing import List
|
from typing import List
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from server.utils import torch_gc
|
from server.utils import torch_gc, embedding_device
|
||||||
|
|
||||||
|
|
||||||
_VECTOR_STORE_TICKS = {}
|
_VECTOR_STORE_TICKS = {}
|
||||||
|
|
@ -25,10 +24,10 @@ _VECTOR_STORE_TICKS = {}
|
||||||
def load_faiss_vector_store(
|
def load_faiss_vector_store(
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
embed_device: str = EMBEDDING_DEVICE,
|
embed_device: str = embedding_device(),
|
||||||
embeddings: Embeddings = None,
|
embeddings: Embeddings = None,
|
||||||
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
|
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
|
||||||
):
|
) -> FAISS:
|
||||||
print(f"loading vector store in '{knowledge_base_name}'.")
|
print(f"loading vector store in '{knowledge_base_name}'.")
|
||||||
vs_path = get_vs_path(knowledge_base_name)
|
vs_path = get_vs_path(knowledge_base_name)
|
||||||
if embeddings is None:
|
if embeddings is None:
|
||||||
|
|
@ -74,13 +73,18 @@ class FaissKBService(KBService):
|
||||||
def get_kb_path(self):
|
def get_kb_path(self):
|
||||||
return os.path.join(KB_ROOT_PATH, self.kb_name)
|
return os.path.join(KB_ROOT_PATH, self.kb_name)
|
||||||
|
|
||||||
def load_vector_store(self):
|
def load_vector_store(self) -> FAISS:
|
||||||
return load_faiss_vector_store(
|
return load_faiss_vector_store(
|
||||||
knowledge_base_name=self.kb_name,
|
knowledge_base_name=self.kb_name,
|
||||||
embed_model=self.embed_model,
|
embed_model=self.embed_model,
|
||||||
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0),
|
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def save_vector_store(self, vector_store: FAISS = None):
|
||||||
|
vector_store = vector_store or self.load_vector_store()
|
||||||
|
vector_store.save_local(self.vs_path)
|
||||||
|
return vector_store
|
||||||
|
|
||||||
def refresh_vs_cache(self):
|
def refresh_vs_cache(self):
|
||||||
refresh_vs_cache(self.kb_name)
|
refresh_vs_cache(self.kb_name)
|
||||||
|
|
||||||
|
|
@ -117,11 +121,11 @@ class FaissKBService(KBService):
|
||||||
if not kwargs.get("not_refresh_vs_cache"):
|
if not kwargs.get("not_refresh_vs_cache"):
|
||||||
vector_store.save_local(self.vs_path)
|
vector_store.save_local(self.vs_path)
|
||||||
self.refresh_vs_cache()
|
self.refresh_vs_cache()
|
||||||
|
return vector_store
|
||||||
|
|
||||||
def do_delete_doc(self,
|
def do_delete_doc(self,
|
||||||
kb_file: KnowledgeFile,
|
kb_file: KnowledgeFile,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
embeddings = self._load_embeddings()
|
|
||||||
vector_store = self.load_vector_store()
|
vector_store = self.load_vector_store()
|
||||||
|
|
||||||
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
||||||
|
|
@ -133,7 +137,7 @@ class FaissKBService(KBService):
|
||||||
vector_store.save_local(self.vs_path)
|
vector_store.save_local(self.vs_path)
|
||||||
self.refresh_vs_cache()
|
self.refresh_vs_cache()
|
||||||
|
|
||||||
return True
|
return vector_store
|
||||||
|
|
||||||
def do_clear_vs(self):
|
def do_clear_vs(self):
|
||||||
shutil.rmtree(self.vs_path)
|
shutil.rmtree(self.vs_path)
|
||||||
|
|
|
||||||
|
|
@ -6,16 +6,16 @@ from langchain.vectorstores import PGVector
|
||||||
from langchain.vectorstores.pgvector import DistanceStrategy
|
from langchain.vectorstores.pgvector import DistanceStrategy
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
|
|
||||||
from configs.model_config import EMBEDDING_DEVICE, kbs_config
|
from configs.model_config import kbs_config
|
||||||
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
|
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
|
||||||
score_threshold_process
|
score_threshold_process
|
||||||
from server.knowledge_base.utils import load_embeddings, KnowledgeFile
|
from server.knowledge_base.utils import load_embeddings, KnowledgeFile
|
||||||
|
from server.utils import embedding_device as get_embedding_device
|
||||||
|
|
||||||
class PGKBService(KBService):
|
class PGKBService(KBService):
|
||||||
pg_vector: PGVector
|
pg_vector: PGVector
|
||||||
|
|
||||||
def _load_pg_vector(self, embedding_device: str = EMBEDDING_DEVICE, embeddings: Embeddings = None):
|
def _load_pg_vector(self, embedding_device: str = get_embedding_device(), embeddings: Embeddings = None):
|
||||||
_embeddings = embeddings
|
_embeddings = embeddings
|
||||||
if _embeddings is None:
|
if _embeddings is None:
|
||||||
_embeddings = load_embeddings(self.embed_model, embedding_device)
|
_embeddings = load_embeddings(self.embed_model, embedding_device)
|
||||||
|
|
|
||||||
|
|
@ -69,6 +69,7 @@ def folder2db(
|
||||||
print(result)
|
print(result)
|
||||||
|
|
||||||
if kb.vs_type() == SupportedVSType.FAISS:
|
if kb.vs_type() == SupportedVSType.FAISS:
|
||||||
|
kb.save_vector_store()
|
||||||
kb.refresh_vs_cache()
|
kb.refresh_vs_cache()
|
||||||
elif mode == "fill_info_only":
|
elif mode == "fill_info_only":
|
||||||
files = list_files_from_folder(kb_name)
|
files = list_files_from_folder(kb_name)
|
||||||
|
|
@ -85,6 +86,7 @@ def folder2db(
|
||||||
kb.update_doc(kb_file, not_refresh_vs_cache=True)
|
kb.update_doc(kb_file, not_refresh_vs_cache=True)
|
||||||
|
|
||||||
if kb.vs_type() == SupportedVSType.FAISS:
|
if kb.vs_type() == SupportedVSType.FAISS:
|
||||||
|
kb.save_vector_store()
|
||||||
kb.refresh_vs_cache()
|
kb.refresh_vs_cache()
|
||||||
elif mode == "increament":
|
elif mode == "increament":
|
||||||
db_files = kb.list_files()
|
db_files = kb.list_files()
|
||||||
|
|
@ -102,6 +104,7 @@ def folder2db(
|
||||||
print(result)
|
print(result)
|
||||||
|
|
||||||
if kb.vs_type() == SupportedVSType.FAISS:
|
if kb.vs_type() == SupportedVSType.FAISS:
|
||||||
|
kb.save_vector_store()
|
||||||
kb.refresh_vs_cache()
|
kb.refresh_vs_cache()
|
||||||
else:
|
else:
|
||||||
print(f"unspported migrate mode: {mode}")
|
print(f"unspported migrate mode: {mode}")
|
||||||
|
|
@ -131,7 +134,10 @@ def prune_db_files(kb_name: str):
|
||||||
files = list(set(files_in_db) - set(files_in_folder))
|
files = list(set(files_in_db) - set(files_in_folder))
|
||||||
kb_files = file_to_kbfile(kb_name, files)
|
kb_files = file_to_kbfile(kb_name, files)
|
||||||
for kb_file in kb_files:
|
for kb_file in kb_files:
|
||||||
kb.delete_doc(kb_file)
|
kb.delete_doc(kb_file, not_refresh_vs_cache=True)
|
||||||
|
if kb.vs_type() == SupportedVSType.FAISS:
|
||||||
|
kb.save_vector_store()
|
||||||
|
kb.refresh_vs_cache()
|
||||||
return kb_files
|
return kb_files
|
||||||
|
|
||||||
def prune_folder_files(kb_name: str):
|
def prune_folder_files(kb_name: str):
|
||||||
|
|
|
||||||
|
|
@ -87,7 +87,8 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
|
||||||
"UnstructuredMarkdownLoader": ['.md'],
|
"UnstructuredMarkdownLoader": ['.md'],
|
||||||
"CustomJSONLoader": [".json"],
|
"CustomJSONLoader": [".json"],
|
||||||
"CSVLoader": [".csv"],
|
"CSVLoader": [".csv"],
|
||||||
"PyPDFLoader": [".pdf"],
|
"RapidOCRPDFLoader": [".pdf"],
|
||||||
|
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
|
||||||
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
|
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
|
||||||
'.rtf', '.txt', '.xml',
|
'.rtf', '.txt', '.xml',
|
||||||
'.doc', '.docx', '.epub', '.odt',
|
'.doc', '.docx', '.epub', '.odt',
|
||||||
|
|
@ -196,7 +197,10 @@ class KnowledgeFile:
|
||||||
|
|
||||||
print(f"{self.document_loader_name} used for {self.filepath}")
|
print(f"{self.document_loader_name} used for {self.filepath}")
|
||||||
try:
|
try:
|
||||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
if self.document_loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
|
||||||
|
document_loaders_module = importlib.import_module('document_loaders')
|
||||||
|
else:
|
||||||
|
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||||
DocumentLoader = getattr(document_loaders_module, self.document_loader_name)
|
DocumentLoader = getattr(document_loaders_module, self.document_loader_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,8 @@ import sys
|
||||||
import os
|
import os
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||||
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger
|
from configs.model_config import llm_model_dict, LLM_MODEL, LOG_PATH, logger
|
||||||
from server.utils import MakeFastAPIOffline, set_httpx_timeout
|
from server.utils import MakeFastAPIOffline, set_httpx_timeout, llm_device
|
||||||
|
|
||||||
|
|
||||||
host_ip = "0.0.0.0"
|
host_ip = "0.0.0.0"
|
||||||
|
|
@ -34,7 +34,7 @@ def create_model_worker_app(
|
||||||
worker_address=base_url.format(model_worker_port),
|
worker_address=base_url.format(model_worker_port),
|
||||||
controller_address=base_url.format(controller_port),
|
controller_address=base_url.format(controller_port),
|
||||||
model_path=llm_model_dict[LLM_MODEL].get("local_model_path"),
|
model_path=llm_model_dict[LLM_MODEL].get("local_model_path"),
|
||||||
device=LLM_DEVICE,
|
device=llm_device(),
|
||||||
gpus=None,
|
gpus=None,
|
||||||
max_gpu_memory="20GiB",
|
max_gpu_memory="20GiB",
|
||||||
load_8bit=False,
|
load_8bit=False,
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,8 @@ import torch
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import asyncio
|
import asyncio
|
||||||
from configs.model_config import LLM_MODEL
|
from configs.model_config import LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE
|
||||||
from typing import Any, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
|
||||||
class BaseResponse(BaseModel):
|
class BaseResponse(BaseModel):
|
||||||
|
|
@ -201,6 +201,7 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
|
||||||
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
||||||
config.update(llm_model_dict.get(model_name, {}))
|
config.update(llm_model_dict.get(model_name, {}))
|
||||||
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
|
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
|
||||||
|
config["device"] = llm_device(config.get("device"))
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -256,3 +257,28 @@ def set_httpx_timeout(timeout: float = None):
|
||||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
|
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
|
||||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
|
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
|
||||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
|
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
|
||||||
|
|
||||||
|
|
||||||
|
# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch
|
||||||
|
def detect_device() -> Literal["cuda", "mps", "cpu"]:
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return "cuda"
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
return "mps"
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def llm_device(device: str = LLM_DEVICE) -> Literal["cuda", "mps", "cpu"]:
|
||||||
|
if device not in ["cuda", "mps", "cpu"]:
|
||||||
|
device = detect_device()
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", "cpu"]:
|
||||||
|
if device not in ["cuda", "mps", "cpu"]:
|
||||||
|
device = detect_device()
|
||||||
|
return device
|
||||||
|
|
|
||||||
96
startup.py
96
startup.py
|
|
@ -14,12 +14,13 @@ except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||||
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, \
|
from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
|
||||||
logger
|
logger
|
||||||
from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS,
|
from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS,
|
||||||
FSCHAT_OPENAI_API, )
|
FSCHAT_OPENAI_API, )
|
||||||
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
||||||
fschat_openai_api_address, set_httpx_timeout)
|
fschat_openai_api_address, set_httpx_timeout,
|
||||||
|
llm_device, embedding_device, get_model_worker_config)
|
||||||
from server.utils import MakeFastAPIOffline, FastAPI
|
from server.utils import MakeFastAPIOffline, FastAPI
|
||||||
import argparse
|
import argparse
|
||||||
from typing import Tuple, List
|
from typing import Tuple, List
|
||||||
|
|
@ -28,10 +29,12 @@ from configs import VERSION
|
||||||
|
|
||||||
def create_controller_app(
|
def create_controller_app(
|
||||||
dispatch_method: str,
|
dispatch_method: str,
|
||||||
|
log_level: str = "INFO",
|
||||||
) -> FastAPI:
|
) -> FastAPI:
|
||||||
import fastchat.constants
|
import fastchat.constants
|
||||||
fastchat.constants.LOGDIR = LOG_PATH
|
fastchat.constants.LOGDIR = LOG_PATH
|
||||||
from fastchat.serve.controller import app, Controller
|
from fastchat.serve.controller import app, Controller, logger
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
controller = Controller(dispatch_method)
|
controller = Controller(dispatch_method)
|
||||||
sys.modules["fastchat.serve.controller"].controller = controller
|
sys.modules["fastchat.serve.controller"].controller = controller
|
||||||
|
|
@ -41,13 +44,14 @@ def create_controller_app(
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
def create_model_worker_app(**kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]:
|
def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]:
|
||||||
import fastchat.constants
|
import fastchat.constants
|
||||||
fastchat.constants.LOGDIR = LOG_PATH
|
fastchat.constants.LOGDIR = LOG_PATH
|
||||||
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
|
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id, logger
|
||||||
import argparse
|
import argparse
|
||||||
import threading
|
import threading
|
||||||
import fastchat.serve.model_worker
|
import fastchat.serve.model_worker
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
# workaround to make program exit with Ctrl+c
|
# workaround to make program exit with Ctrl+c
|
||||||
# it should be deleted after pr is merged by fastchat
|
# it should be deleted after pr is merged by fastchat
|
||||||
|
|
@ -136,10 +140,14 @@ def create_model_worker_app(**kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]
|
||||||
def create_openai_api_app(
|
def create_openai_api_app(
|
||||||
controller_address: str,
|
controller_address: str,
|
||||||
api_keys: List = [],
|
api_keys: List = [],
|
||||||
|
log_level: str = "INFO",
|
||||||
) -> FastAPI:
|
) -> FastAPI:
|
||||||
import fastchat.constants
|
import fastchat.constants
|
||||||
fastchat.constants.LOGDIR = LOG_PATH
|
fastchat.constants.LOGDIR = LOG_PATH
|
||||||
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
|
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
|
||||||
|
from fastchat.utils import build_logger
|
||||||
|
logger = build_logger("openai_api", "openai_api.log")
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
|
|
@ -149,6 +157,7 @@ def create_openai_api_app(
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sys.modules["fastchat.serve.openai_api_server"].logger = logger
|
||||||
app_settings.controller_address = controller_address
|
app_settings.controller_address = controller_address
|
||||||
app_settings.api_keys = api_keys
|
app_settings.api_keys = api_keys
|
||||||
|
|
||||||
|
|
@ -158,6 +167,9 @@ def create_openai_api_app(
|
||||||
|
|
||||||
|
|
||||||
def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
|
def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
|
||||||
|
if q is None or not isinstance(run_seq, int):
|
||||||
|
return
|
||||||
|
|
||||||
if run_seq == 1:
|
if run_seq == 1:
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def on_startup():
|
async def on_startup():
|
||||||
|
|
@ -176,15 +188,22 @@ def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
|
||||||
q.put(run_seq)
|
q.put(run_seq)
|
||||||
|
|
||||||
|
|
||||||
def run_controller(q: Queue, run_seq: int = 1):
|
def run_controller(q: Queue, run_seq: int = 1, log_level: str ="INFO"):
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
import sys
|
||||||
|
|
||||||
app = create_controller_app(FSCHAT_CONTROLLER.get("dispatch_method"))
|
app = create_controller_app(
|
||||||
|
dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"),
|
||||||
|
log_level=log_level,
|
||||||
|
)
|
||||||
_set_app_seq(app, q, run_seq)
|
_set_app_seq(app, q, run_seq)
|
||||||
|
|
||||||
host = FSCHAT_CONTROLLER["host"]
|
host = FSCHAT_CONTROLLER["host"]
|
||||||
port = FSCHAT_CONTROLLER["port"]
|
port = FSCHAT_CONTROLLER["port"]
|
||||||
uvicorn.run(app, host=host, port=port)
|
if log_level == "ERROR":
|
||||||
|
sys.stdout = sys.__stdout__
|
||||||
|
sys.stderr = sys.__stderr__
|
||||||
|
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
|
||||||
|
|
||||||
|
|
||||||
def run_model_worker(
|
def run_model_worker(
|
||||||
|
|
@ -192,10 +211,12 @@ def run_model_worker(
|
||||||
controller_address: str = "",
|
controller_address: str = "",
|
||||||
q: Queue = None,
|
q: Queue = None,
|
||||||
run_seq: int = 2,
|
run_seq: int = 2,
|
||||||
|
log_level: str ="INFO",
|
||||||
):
|
):
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
import sys
|
||||||
|
|
||||||
kwargs = FSCHAT_MODEL_WORKERS[model_name].copy()
|
kwargs = get_model_worker_config(model_name)
|
||||||
host = kwargs.pop("host")
|
host = kwargs.pop("host")
|
||||||
port = kwargs.pop("port")
|
port = kwargs.pop("port")
|
||||||
model_path = llm_model_dict[model_name].get("local_model_path", "")
|
model_path = llm_model_dict[model_name].get("local_model_path", "")
|
||||||
|
|
@ -204,21 +225,28 @@ def run_model_worker(
|
||||||
kwargs["controller_address"] = controller_address or fschat_controller_address()
|
kwargs["controller_address"] = controller_address or fschat_controller_address()
|
||||||
kwargs["worker_address"] = fschat_model_worker_address()
|
kwargs["worker_address"] = fschat_model_worker_address()
|
||||||
|
|
||||||
app = create_model_worker_app(**kwargs)
|
app = create_model_worker_app(log_level=log_level, **kwargs)
|
||||||
_set_app_seq(app, q, run_seq)
|
_set_app_seq(app, q, run_seq)
|
||||||
|
if log_level == "ERROR":
|
||||||
|
sys.stdout = sys.__stdout__
|
||||||
|
sys.stderr = sys.__stderr__
|
||||||
|
|
||||||
uvicorn.run(app, host=host, port=port)
|
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
|
||||||
|
|
||||||
|
|
||||||
def run_openai_api(q: Queue, run_seq: int = 3):
|
def run_openai_api(q: Queue, run_seq: int = 3, log_level: str = "INFO"):
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
import sys
|
||||||
|
|
||||||
controller_addr = fschat_controller_address()
|
controller_addr = fschat_controller_address()
|
||||||
app = create_openai_api_app(controller_addr) # todo: not support keys yet.
|
app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet.
|
||||||
_set_app_seq(app, q, run_seq)
|
_set_app_seq(app, q, run_seq)
|
||||||
|
|
||||||
host = FSCHAT_OPENAI_API["host"]
|
host = FSCHAT_OPENAI_API["host"]
|
||||||
port = FSCHAT_OPENAI_API["port"]
|
port = FSCHAT_OPENAI_API["port"]
|
||||||
|
if log_level == "ERROR":
|
||||||
|
sys.stdout = sys.__stdout__
|
||||||
|
sys.stderr = sys.__stderr__
|
||||||
uvicorn.run(app, host=host, port=port)
|
uvicorn.run(app, host=host, port=port)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -238,13 +266,15 @@ def run_api_server(q: Queue, run_seq: int = 4):
|
||||||
def run_webui(q: Queue, run_seq: int = 5):
|
def run_webui(q: Queue, run_seq: int = 5):
|
||||||
host = WEBUI_SERVER["host"]
|
host = WEBUI_SERVER["host"]
|
||||||
port = WEBUI_SERVER["port"]
|
port = WEBUI_SERVER["port"]
|
||||||
while True:
|
|
||||||
no = q.get()
|
if q is not None and isinstance(run_seq, int):
|
||||||
if no != run_seq - 1:
|
while True:
|
||||||
q.put(no)
|
no = q.get()
|
||||||
else:
|
if no != run_seq - 1:
|
||||||
break
|
q.put(no)
|
||||||
q.put(run_seq)
|
else:
|
||||||
|
break
|
||||||
|
q.put(run_seq)
|
||||||
p = subprocess.Popen(["streamlit", "run", "webui.py",
|
p = subprocess.Popen(["streamlit", "run", "webui.py",
|
||||||
"--server.address", host,
|
"--server.address", host,
|
||||||
"--server.port", str(port)])
|
"--server.port", str(port)])
|
||||||
|
|
@ -314,11 +344,18 @@ def parse_args() -> argparse.ArgumentParser:
|
||||||
help="run webui.py server",
|
help="run webui.py server",
|
||||||
dest="webui",
|
dest="webui",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-q",
|
||||||
|
"--quiet",
|
||||||
|
action="store_true",
|
||||||
|
help="减少fastchat服务log信息",
|
||||||
|
dest="quiet",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args, parser
|
return args, parser
|
||||||
|
|
||||||
|
|
||||||
def dump_server_info(after_start=False):
|
def dump_server_info(after_start=False, args=None):
|
||||||
import platform
|
import platform
|
||||||
import langchain
|
import langchain
|
||||||
import fastchat
|
import fastchat
|
||||||
|
|
@ -331,9 +368,9 @@ def dump_server_info(after_start=False):
|
||||||
print(f"项目版本:{VERSION}")
|
print(f"项目版本:{VERSION}")
|
||||||
print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
|
print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
|
||||||
print("\n")
|
print("\n")
|
||||||
print(f"当前LLM模型:{LLM_MODEL} @ {LLM_DEVICE}")
|
print(f"当前LLM模型:{LLM_MODEL} @ {llm_device()}")
|
||||||
pprint(llm_model_dict[LLM_MODEL])
|
pprint(llm_model_dict[LLM_MODEL])
|
||||||
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {EMBEDDING_DEVICE}")
|
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}")
|
||||||
if after_start:
|
if after_start:
|
||||||
print("\n")
|
print("\n")
|
||||||
print(f"服务端运行信息:")
|
print(f"服务端运行信息:")
|
||||||
|
|
@ -354,6 +391,7 @@ if __name__ == "__main__":
|
||||||
mp.set_start_method("spawn")
|
mp.set_start_method("spawn")
|
||||||
queue = Queue()
|
queue = Queue()
|
||||||
args, parser = parse_args()
|
args, parser = parse_args()
|
||||||
|
|
||||||
if args.all_webui:
|
if args.all_webui:
|
||||||
args.openai_api = True
|
args.openai_api = True
|
||||||
args.model_worker = True
|
args.model_worker = True
|
||||||
|
|
@ -372,19 +410,23 @@ if __name__ == "__main__":
|
||||||
args.api = False
|
args.api = False
|
||||||
args.webui = False
|
args.webui = False
|
||||||
|
|
||||||
dump_server_info()
|
dump_server_info(args=args)
|
||||||
|
|
||||||
if len(sys.argv) > 1:
|
if len(sys.argv) > 1:
|
||||||
logger.info(f"正在启动服务:")
|
logger.info(f"正在启动服务:")
|
||||||
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
|
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
|
||||||
|
|
||||||
processes = {}
|
processes = {}
|
||||||
|
if args.quiet:
|
||||||
|
log_level = "ERROR"
|
||||||
|
else:
|
||||||
|
log_level = "INFO"
|
||||||
|
|
||||||
if args.openai_api:
|
if args.openai_api:
|
||||||
process = Process(
|
process = Process(
|
||||||
target=run_controller,
|
target=run_controller,
|
||||||
name=f"controller({os.getpid()})",
|
name=f"controller({os.getpid()})",
|
||||||
args=(queue, len(processes) + 1),
|
args=(queue, len(processes) + 1, log_level),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
process.start()
|
process.start()
|
||||||
|
|
@ -405,7 +447,7 @@ if __name__ == "__main__":
|
||||||
process = Process(
|
process = Process(
|
||||||
target=run_model_worker,
|
target=run_model_worker,
|
||||||
name=f"model_worker({os.getpid()})",
|
name=f"model_worker({os.getpid()})",
|
||||||
args=(args.model_name, args.controller_address, queue, len(processes) + 1),
|
args=(args.model_name, args.controller_address, queue, len(processes) + 1, log_level),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
process.start()
|
process.start()
|
||||||
|
|
@ -440,7 +482,7 @@ if __name__ == "__main__":
|
||||||
no = queue.get()
|
no = queue.get()
|
||||||
if no == len(processes):
|
if no == len(processes):
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
dump_server_info(True)
|
dump_server_info(after_start=True, args=args)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
queue.put(no)
|
queue.put(no)
|
||||||
|
|
|
||||||
Binary file not shown.
|
After Width: | Height: | Size: 7.9 KiB |
Binary file not shown.
4
webui.py
4
webui.py
|
|
@ -10,8 +10,10 @@ from streamlit_option_menu import option_menu
|
||||||
from webui_pages import *
|
from webui_pages import *
|
||||||
import os
|
import os
|
||||||
from configs import VERSION
|
from configs import VERSION
|
||||||
|
from server.utils import api_address
|
||||||
|
|
||||||
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False)
|
|
||||||
|
api = ApiRequest(base_url=api_address())
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
st.set_page_config(
|
st.set_page_config(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue