修改Embeddings和FAISS缓存加载方式,知识库相关API接口支持多线程并发 (#1434)
* 修改Embeddings和FAISS缓存加载方式,支持多线程,支持内存FAISS * 知识库相关API接口支持多线程并发 * 根据新的API接口调整ApiRequest和测试用例 * 删除webui.py失效的启动说明
This commit is contained in:
parent
d0e654d847
commit
22ff073309
|
|
@ -4,12 +4,11 @@ import os
|
|||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from configs.model_config import LLM_MODEL, NLTK_DATA_PATH
|
||||
from configs.server_config import OPEN_CROSS_DOMAIN, HTTPX_DEFAULT_TIMEOUT
|
||||
from configs import VERSION, logger, log_verbose
|
||||
from configs import VERSION
|
||||
from configs.model_config import NLTK_DATA_PATH
|
||||
from configs.server_config import OPEN_CROSS_DOMAIN
|
||||
import argparse
|
||||
import uvicorn
|
||||
from fastapi import Body
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.responses import RedirectResponse
|
||||
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
||||
|
|
@ -18,8 +17,8 @@ from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
|||
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||||
update_docs, download_doc, recreate_vector_store,
|
||||
search_docs, DocumentWithScore)
|
||||
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address
|
||||
import httpx
|
||||
from server.llm_api import list_llm_models, change_llm_model, stop_llm_model
|
||||
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
|
||||
from typing import List
|
||||
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
|
|
@ -126,79 +125,20 @@ def create_app():
|
|||
)(recreate_vector_store)
|
||||
|
||||
# LLM模型相关接口
|
||||
@app.post("/llm_model/list_models",
|
||||
app.post("/llm_model/list_models",
|
||||
tags=["LLM Model Management"],
|
||||
summary="列出当前已加载的模型")
|
||||
def list_models(
|
||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
从fastchat controller获取已加载模型列表
|
||||
'''
|
||||
try:
|
||||
controller_address = controller_address or fschat_controller_address()
|
||||
r = httpx.post(controller_address + "/list_models")
|
||||
return BaseResponse(data=r.json()["models"])
|
||||
except Exception as e:
|
||||
logger.error(f'{e.__class__.__name__}: {e}',
|
||||
exc_info=e if log_verbose else None)
|
||||
return BaseResponse(
|
||||
code=500,
|
||||
data=[],
|
||||
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
|
||||
summary="列出当前已加载的模型",
|
||||
)(list_llm_models)
|
||||
|
||||
@app.post("/llm_model/stop",
|
||||
app.post("/llm_model/stop",
|
||||
tags=["LLM Model Management"],
|
||||
summary="停止指定的LLM模型(Model Worker)",
|
||||
)
|
||||
def stop_llm_model(
|
||||
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
|
||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
向fastchat controller请求停止某个LLM模型。
|
||||
注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。
|
||||
'''
|
||||
try:
|
||||
controller_address = controller_address or fschat_controller_address()
|
||||
r = httpx.post(
|
||||
controller_address + "/release_worker",
|
||||
json={"model_name": model_name},
|
||||
)
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
logger.error(f'{e.__class__.__name__}: {e}',
|
||||
exc_info=e if log_verbose else None)
|
||||
return BaseResponse(
|
||||
code=500,
|
||||
msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}")
|
||||
)(stop_llm_model)
|
||||
|
||||
@app.post("/llm_model/change",
|
||||
app.post("/llm_model/change",
|
||||
tags=["LLM Model Management"],
|
||||
summary="切换指定的LLM模型(Model Worker)",
|
||||
)
|
||||
def change_llm_model(
|
||||
model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]),
|
||||
new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]),
|
||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
||||
):
|
||||
'''
|
||||
向fastchat controller请求切换LLM模型。
|
||||
'''
|
||||
try:
|
||||
controller_address = controller_address or fschat_controller_address()
|
||||
r = httpx.post(
|
||||
controller_address + "/release_worker",
|
||||
json={"model_name": model_name, "new_model_name": new_model_name},
|
||||
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
|
||||
)
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
logger.error(f'{e.__class__.__name__}: {e}',
|
||||
exc_info=e if log_verbose else None)
|
||||
return BaseResponse(
|
||||
code=500,
|
||||
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
|
||||
)(change_llm_model)
|
||||
|
||||
return app
|
||||
|
||||
|
|
|
|||
|
|
@ -12,10 +12,10 @@ def list_kbs():
|
|||
return ListResponse(data=list_kbs_from_db())
|
||||
|
||||
|
||||
async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
vector_store_type: str = Body("faiss"),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
) -> BaseResponse:
|
||||
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
vector_store_type: str = Body("faiss"),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
) -> BaseResponse:
|
||||
# Create selected knowledge base
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
|
@ -38,8 +38,8 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
|||
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
|
||||
|
||||
|
||||
async def delete_kb(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"])
|
||||
def delete_kb(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"])
|
||||
) -> BaseResponse:
|
||||
# Delete selected knowledge base
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,137 @@
|
|||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
import threading
|
||||
from configs.model_config import (CACHED_VS_NUM, EMBEDDING_MODEL, CHUNK_SIZE,
|
||||
embedding_model_dict, logger, log_verbose)
|
||||
from server.utils import embedding_device
|
||||
from contextlib import contextmanager
|
||||
from collections import OrderedDict
|
||||
from typing import List, Any, Union, Tuple
|
||||
|
||||
|
||||
class ThreadSafeObject:
|
||||
def __init__(self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = None):
|
||||
self._obj = obj
|
||||
self._key = key
|
||||
self._pool = pool
|
||||
self._lock = threading.RLock()
|
||||
self._loaded = threading.Event()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
cls = type(self).__name__
|
||||
return f"<{cls}: key: {self._key}, obj: {self._obj}>"
|
||||
|
||||
@contextmanager
|
||||
def acquire(self, owner: str = "", msg: str = ""):
|
||||
owner = owner or f"thread {threading.get_native_id()}"
|
||||
try:
|
||||
self._lock.acquire()
|
||||
if self._pool is not None:
|
||||
self._pool._cache.move_to_end(self._key)
|
||||
if log_verbose:
|
||||
logger.info(f"{owner} 开始操作:{self._key}。{msg}")
|
||||
yield self._obj
|
||||
finally:
|
||||
if log_verbose:
|
||||
logger.info(f"{owner} 结束操作:{self._key}。{msg}")
|
||||
self._lock.release()
|
||||
|
||||
def start_loading(self):
|
||||
self._loaded.clear()
|
||||
|
||||
def finish_loading(self):
|
||||
self._loaded.set()
|
||||
|
||||
def wait_for_loading(self):
|
||||
self._loaded.wait()
|
||||
|
||||
@property
|
||||
def obj(self):
|
||||
return self._obj
|
||||
|
||||
@obj.setter
|
||||
def obj(self, val: Any):
|
||||
self._obj = val
|
||||
|
||||
|
||||
class CachePool:
|
||||
def __init__(self, cache_num: int = -1):
|
||||
self._cache_num = cache_num
|
||||
self._cache = OrderedDict()
|
||||
self.atomic = threading.RLock()
|
||||
|
||||
def keys(self) -> List[str]:
|
||||
return list(self._cache.keys())
|
||||
|
||||
def _check_count(self):
|
||||
if isinstance(self._cache_num, int) and self._cache_num > 0:
|
||||
while len(self._cache) > self._cache_num:
|
||||
self._cache.popitem(last=False)
|
||||
|
||||
def get(self, key: str) -> ThreadSafeObject:
|
||||
if cache := self._cache.get(key):
|
||||
cache.wait_for_loading()
|
||||
return cache
|
||||
|
||||
def set(self, key: str, obj: ThreadSafeObject) -> ThreadSafeObject:
|
||||
self._cache[key] = obj
|
||||
self._check_count()
|
||||
return obj
|
||||
|
||||
def pop(self, key: str = None) -> ThreadSafeObject:
|
||||
if key is None:
|
||||
return self._cache.popitem(last=False)
|
||||
else:
|
||||
return self._cache.pop(key, None)
|
||||
|
||||
def acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = ""):
|
||||
cache = self.get(key)
|
||||
if cache is None:
|
||||
raise RuntimeError(f"请求的资源 {key} 不存在")
|
||||
elif isinstance(cache, ThreadSafeObject):
|
||||
self._cache.move_to_end(key)
|
||||
return cache.acquire(owner=owner, msg=msg)
|
||||
else:
|
||||
return cache
|
||||
|
||||
def load_kb_embeddings(self, kb_name: str=None, embed_device: str = embedding_device()) -> Embeddings:
|
||||
from server.db.repository.knowledge_base_repository import get_kb_detail
|
||||
|
||||
kb_detail = get_kb_detail(kb_name=kb_name)
|
||||
print(kb_detail)
|
||||
embed_model = kb_detail.get("embed_model", EMBEDDING_MODEL)
|
||||
return embeddings_pool.load_embeddings(model=embed_model, device=embed_device)
|
||||
|
||||
|
||||
class EmbeddingsPool(CachePool):
|
||||
def load_embeddings(self, model: str, device: str) -> Embeddings:
|
||||
self.atomic.acquire()
|
||||
model = model or EMBEDDING_MODEL
|
||||
device = device or embedding_device()
|
||||
key = (model, device)
|
||||
if not self.get(key):
|
||||
item = ThreadSafeObject(key, pool=self)
|
||||
self.set(key, item)
|
||||
with item.acquire(msg="初始化"):
|
||||
self.atomic.release()
|
||||
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
|
||||
embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE)
|
||||
elif 'bge-' in model:
|
||||
embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model],
|
||||
model_kwargs={'device': device},
|
||||
query_instruction="为这个句子生成表示以用于检索相关文章:")
|
||||
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
|
||||
embeddings.query_instruction = ""
|
||||
else:
|
||||
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device})
|
||||
item.obj = embeddings
|
||||
item.finish_loading()
|
||||
else:
|
||||
self.atomic.release()
|
||||
return self.get(key).obj
|
||||
|
||||
|
||||
embeddings_pool = EmbeddingsPool(cache_num=1)
|
||||
|
|
@ -0,0 +1,157 @@
|
|||
from server.knowledge_base.kb_cache.base import *
|
||||
from server.knowledge_base.utils import get_vs_path
|
||||
from langchain.vectorstores import FAISS
|
||||
import os
|
||||
|
||||
|
||||
class ThreadSafeFaiss(ThreadSafeObject):
|
||||
def __repr__(self) -> str:
|
||||
cls = type(self).__name__
|
||||
return f"<{cls}: key: {self._key}, obj: {self._obj}, docs_count: {self.docs_count()}>"
|
||||
|
||||
def docs_count(self) -> int:
|
||||
return len(self._obj.docstore._dict)
|
||||
|
||||
def save(self, path: str, create_path: bool = True):
|
||||
with self.acquire():
|
||||
if not os.path.isdir(path) and create_path:
|
||||
os.makedirs(path)
|
||||
ret = self._obj.save_local(path)
|
||||
logger.info(f"已将向量库 {self._key} 保存到磁盘")
|
||||
return ret
|
||||
|
||||
def clear(self):
|
||||
ret = []
|
||||
with self.acquire():
|
||||
ids = list(self._obj.docstore._dict.keys())
|
||||
if ids:
|
||||
ret = self._obj.delete(ids)
|
||||
assert len(self._obj.docstore._dict) == 0
|
||||
logger.info(f"已将向量库 {self._key} 清空")
|
||||
return ret
|
||||
|
||||
|
||||
class _FaissPool(CachePool):
|
||||
def new_vector_store(
|
||||
self,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_device: str = embedding_device(),
|
||||
) -> FAISS:
|
||||
embeddings = embeddings_pool.load_embeddings(embed_model, embed_device)
|
||||
|
||||
# create an empty vector store
|
||||
doc = Document(page_content="init", metadata={})
|
||||
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True)
|
||||
ids = list(vector_store.docstore._dict.keys())
|
||||
vector_store.delete(ids)
|
||||
return vector_store
|
||||
|
||||
def save_vector_store(self, kb_name: str, path: str=None):
|
||||
if cache := self.get(kb_name):
|
||||
return cache.save(path)
|
||||
|
||||
def unload_vector_store(self, kb_name: str):
|
||||
if cache := self.get(kb_name):
|
||||
self.pop(kb_name)
|
||||
logger.info(f"成功释放向量库:{kb_name}")
|
||||
|
||||
|
||||
class KBFaissPool(_FaissPool):
|
||||
def load_vector_store(
|
||||
self,
|
||||
kb_name: str,
|
||||
create: bool = True,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_device: str = embedding_device(),
|
||||
) -> ThreadSafeFaiss:
|
||||
self.atomic.acquire()
|
||||
cache = self.get(kb_name)
|
||||
if cache is None:
|
||||
item = ThreadSafeFaiss(kb_name, pool=self)
|
||||
self.set(kb_name, item)
|
||||
with item.acquire(msg="初始化"):
|
||||
self.atomic.release()
|
||||
logger.info(f"loading vector store in '{kb_name}' from disk.")
|
||||
vs_path = get_vs_path(kb_name)
|
||||
|
||||
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
|
||||
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device)
|
||||
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
|
||||
elif create:
|
||||
# create an empty vector store
|
||||
if not os.path.exists(vs_path):
|
||||
os.makedirs(vs_path)
|
||||
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
|
||||
vector_store.save_local(vs_path)
|
||||
else:
|
||||
raise RuntimeError(f"knowledge base {kb_name} not exist.")
|
||||
item.obj = vector_store
|
||||
item.finish_loading()
|
||||
else:
|
||||
self.atomic.release()
|
||||
return self.get(kb_name)
|
||||
|
||||
|
||||
class MemoFaissPool(_FaissPool):
|
||||
def load_vector_store(
|
||||
self,
|
||||
kb_name: str,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_device: str = embedding_device(),
|
||||
) -> ThreadSafeFaiss:
|
||||
self.atomic.acquire()
|
||||
cache = self.get(kb_name)
|
||||
if cache is None:
|
||||
item = ThreadSafeFaiss(kb_name, pool=self)
|
||||
self.set(kb_name, item)
|
||||
with item.acquire(msg="初始化"):
|
||||
self.atomic.release()
|
||||
logger.info(f"loading vector store in '{kb_name}' to memory.")
|
||||
# create an empty vector store
|
||||
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
|
||||
item.obj = vector_store
|
||||
item.finish_loading()
|
||||
else:
|
||||
self.atomic.release()
|
||||
return self.get(kb_name)
|
||||
|
||||
|
||||
kb_faiss_pool = KBFaissPool(cache_num=CACHED_VS_NUM)
|
||||
memo_faiss_pool = MemoFaissPool()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time, random
|
||||
from pprint import pprint
|
||||
|
||||
kb_names = ["vs1", "vs2", "vs3"]
|
||||
# for name in kb_names:
|
||||
# memo_faiss_pool.load_vector_store(name)
|
||||
|
||||
def worker(vs_name: str, name: str):
|
||||
vs_name = "samples"
|
||||
time.sleep(random.randint(1, 5))
|
||||
embeddings = embeddings_pool.load_embeddings()
|
||||
r = random.randint(1, 3)
|
||||
|
||||
with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs:
|
||||
if r == 1: # add docs
|
||||
ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings)
|
||||
pprint(ids)
|
||||
elif r == 2: # search docs
|
||||
docs = vs.similarity_search_with_score(f"{name}", top_k=3, score_threshold=1.0)
|
||||
pprint(docs)
|
||||
if r == 3: # delete docs
|
||||
logger.warning(f"清除 {vs_name} by {name}")
|
||||
kb_faiss_pool.get(vs_name).clear()
|
||||
|
||||
threads = []
|
||||
for n in range(1, 30):
|
||||
t = threading.Thread(target=worker,
|
||||
kwargs={"vs_name": random.choice(kb_names), "name": f"worker {n}"},
|
||||
daemon=True)
|
||||
t.start()
|
||||
threads.append(t)
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
|
@ -117,13 +117,13 @@ def _save_files_in_thread(files: List[UploadFile],
|
|||
# yield json.dumps(result, ensure_ascii=False)
|
||||
|
||||
|
||||
async def upload_docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||||
knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
|
||||
override: bool = Form(False, description="覆盖已有文件"),
|
||||
to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
|
||||
docs: Json = Form({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
|
||||
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
def upload_docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||||
knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
|
||||
override: bool = Form(False, description="覆盖已有文件"),
|
||||
to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
|
||||
docs: Json = Form({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
|
||||
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
API接口:上传文件,并/或向量化
|
||||
'''
|
||||
|
|
@ -148,7 +148,7 @@ async def upload_docs(files: List[UploadFile] = File(..., description="上传文
|
|||
|
||||
# 对保存的文件进行向量化
|
||||
if to_vector_store:
|
||||
result = await update_docs(
|
||||
result = update_docs(
|
||||
knowledge_base_name=knowledge_base_name,
|
||||
file_names=file_names,
|
||||
override_custom_docs=True,
|
||||
|
|
@ -162,11 +162,11 @@ async def upload_docs(files: List[UploadFile] = File(..., description="上传文
|
|||
return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files})
|
||||
|
||||
|
||||
async def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]),
|
||||
delete_content: bool = Body(False),
|
||||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]),
|
||||
delete_content: bool = Body(False),
|
||||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
|
|
@ -196,12 +196,12 @@ async def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"])
|
|||
return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files})
|
||||
|
||||
|
||||
async def update_docs(
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=["file_name"]),
|
||||
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
|
||||
docs: Json = Body({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
|
||||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||
def update_docs(
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=["file_name"]),
|
||||
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
|
||||
docs: Json = Body({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
|
||||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
更新知识库文档
|
||||
|
|
@ -302,11 +302,11 @@ def download_doc(
|
|||
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
|
||||
|
||||
|
||||
async def recreate_vector_store(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
allow_empty_kb: bool = Body(True),
|
||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
def recreate_vector_store(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
allow_empty_kb: bool = Body(True),
|
||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
):
|
||||
'''
|
||||
recreate vector store from the content.
|
||||
|
|
|
|||
|
|
@ -146,7 +146,6 @@ class KBService(ABC):
|
|||
docs = self.do_search(query, top_k, score_threshold, embeddings)
|
||||
return docs
|
||||
|
||||
# TODO: milvus/pg需要实现该方法
|
||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -3,62 +3,16 @@ import shutil
|
|||
|
||||
from configs.model_config import (
|
||||
KB_ROOT_PATH,
|
||||
CACHED_VS_NUM,
|
||||
EMBEDDING_MODEL,
|
||||
SCORE_THRESHOLD,
|
||||
logger, log_verbose,
|
||||
)
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||
from functools import lru_cache
|
||||
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
|
||||
from langchain.vectorstores import FAISS
|
||||
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from typing import List, Dict, Optional
|
||||
from langchain.docstore.document import Document
|
||||
from server.utils import torch_gc, embedding_device
|
||||
|
||||
|
||||
_VECTOR_STORE_TICKS = {}
|
||||
|
||||
|
||||
@lru_cache(CACHED_VS_NUM)
|
||||
def load_faiss_vector_store(
|
||||
knowledge_base_name: str,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_device: str = embedding_device(),
|
||||
embeddings: Embeddings = None,
|
||||
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
|
||||
) -> FAISS:
|
||||
logger.info(f"loading vector store in '{knowledge_base_name}'.")
|
||||
vs_path = get_vs_path(knowledge_base_name)
|
||||
if embeddings is None:
|
||||
embeddings = load_embeddings(embed_model, embed_device)
|
||||
|
||||
if not os.path.exists(vs_path):
|
||||
os.makedirs(vs_path)
|
||||
|
||||
if "index.faiss" in os.listdir(vs_path):
|
||||
search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
|
||||
else:
|
||||
# create an empty vector store
|
||||
doc = Document(page_content="init", metadata={})
|
||||
search_index = FAISS.from_documents([doc], embeddings, normalize_L2=True)
|
||||
ids = [k for k, v in search_index.docstore._dict.items()]
|
||||
search_index.delete(ids)
|
||||
search_index.save_local(vs_path)
|
||||
|
||||
if tick == 0: # vector store is loaded first time
|
||||
_VECTOR_STORE_TICKS[knowledge_base_name] = 0
|
||||
|
||||
return search_index
|
||||
|
||||
|
||||
def refresh_vs_cache(kb_name: str):
|
||||
"""
|
||||
make vector store cache refreshed when next loading
|
||||
"""
|
||||
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
|
||||
logger.info(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}")
|
||||
from server.utils import torch_gc
|
||||
|
||||
|
||||
class FaissKBService(KBService):
|
||||
|
|
@ -74,24 +28,15 @@ class FaissKBService(KBService):
|
|||
def get_kb_path(self):
|
||||
return os.path.join(KB_ROOT_PATH, self.kb_name)
|
||||
|
||||
def load_vector_store(self) -> FAISS:
|
||||
return load_faiss_vector_store(
|
||||
knowledge_base_name=self.kb_name,
|
||||
embed_model=self.embed_model,
|
||||
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0),
|
||||
)
|
||||
def load_vector_store(self) -> ThreadSafeFaiss:
|
||||
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name, embed_model=self.embed_model)
|
||||
|
||||
def save_vector_store(self, vector_store: FAISS = None):
|
||||
vector_store = vector_store or self.load_vector_store()
|
||||
vector_store.save_local(self.vs_path)
|
||||
return vector_store
|
||||
|
||||
def refresh_vs_cache(self):
|
||||
refresh_vs_cache(self.kb_name)
|
||||
def save_vector_store(self):
|
||||
self.load_vector_store().save(self.vs_path)
|
||||
|
||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||
vector_store = self.load_vector_store()
|
||||
return vector_store.docstore._dict.get(id)
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
return vs.docstore._dict.get(id)
|
||||
|
||||
def do_init(self):
|
||||
self.kb_path = self.get_kb_path()
|
||||
|
|
@ -112,43 +57,38 @@ class FaissKBService(KBService):
|
|||
score_threshold: float = SCORE_THRESHOLD,
|
||||
embeddings: Embeddings = None,
|
||||
) -> List[Document]:
|
||||
search_index = self.load_vector_store()
|
||||
docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
docs = vs.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
|
||||
return docs
|
||||
|
||||
def do_add_doc(self,
|
||||
docs: List[Document],
|
||||
**kwargs,
|
||||
) -> List[Dict]:
|
||||
vector_store = self.load_vector_store()
|
||||
ids = vector_store.add_documents(docs)
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
ids = vs.add_documents(docs)
|
||||
if not kwargs.get("not_refresh_vs_cache"):
|
||||
vs.save_local(self.vs_path)
|
||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||
torch_gc()
|
||||
if not kwargs.get("not_refresh_vs_cache"):
|
||||
vector_store.save_local(self.vs_path)
|
||||
self.refresh_vs_cache()
|
||||
return doc_infos
|
||||
|
||||
def do_delete_doc(self,
|
||||
kb_file: KnowledgeFile,
|
||||
**kwargs):
|
||||
vector_store = self.load_vector_store()
|
||||
|
||||
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata.get("source") == kb_file.filepath]
|
||||
if len(ids) == 0:
|
||||
return None
|
||||
|
||||
vector_store.delete(ids)
|
||||
if not kwargs.get("not_refresh_vs_cache"):
|
||||
vector_store.save_local(self.vs_path)
|
||||
self.refresh_vs_cache()
|
||||
|
||||
return vector_store
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source") == kb_file.filepath]
|
||||
if len(ids) > 0:
|
||||
vs.delete(ids)
|
||||
if not kwargs.get("not_refresh_vs_cache"):
|
||||
vs.save_local(self.vs_path)
|
||||
return ids
|
||||
|
||||
def do_clear_vs(self):
|
||||
with kb_faiss_pool.atomic:
|
||||
kb_faiss_pool.pop(self.kb_name)
|
||||
shutil.rmtree(self.vs_path)
|
||||
os.makedirs(self.vs_path)
|
||||
self.refresh_vs_cache()
|
||||
|
||||
def exist_doc(self, file_name: str):
|
||||
if super().exist_doc(file_name):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from configs.model_config import EMBEDDING_MODEL, DEFAULT_VS_TYPE, logger, log_verbose
|
||||
from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder,
|
||||
list_files_from_folder, run_in_thread_pool,
|
||||
files2docs_in_thread,
|
||||
list_files_from_folder,files2docs_in_thread,
|
||||
KnowledgeFile,)
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory, SupportedVSType
|
||||
from server.db.repository.knowledge_file_repository import add_file_to_db
|
||||
|
|
@ -72,7 +71,6 @@ def folder2db(
|
|||
|
||||
if kb.vs_type() == SupportedVSType.FAISS:
|
||||
kb.save_vector_store()
|
||||
kb.refresh_vs_cache()
|
||||
elif mode == "fill_info_only":
|
||||
files = list_files_from_folder(kb_name)
|
||||
kb_files = file_to_kbfile(kb_name, files)
|
||||
|
|
@ -89,7 +87,6 @@ def folder2db(
|
|||
|
||||
if kb.vs_type() == SupportedVSType.FAISS:
|
||||
kb.save_vector_store()
|
||||
kb.refresh_vs_cache()
|
||||
elif mode == "increament":
|
||||
db_files = kb.list_files()
|
||||
folder_files = list_files_from_folder(kb_name)
|
||||
|
|
@ -107,7 +104,6 @@ def folder2db(
|
|||
|
||||
if kb.vs_type() == SupportedVSType.FAISS:
|
||||
kb.save_vector_store()
|
||||
kb.refresh_vs_cache()
|
||||
else:
|
||||
print(f"unspported migrate mode: {mode}")
|
||||
|
||||
|
|
@ -139,7 +135,6 @@ def prune_db_files(kb_name: str):
|
|||
kb.delete_doc(kb_file, not_refresh_vs_cache=True)
|
||||
if kb.vs_type() == SupportedVSType.FAISS:
|
||||
kb.save_vector_store()
|
||||
kb.refresh_vs_cache()
|
||||
return kb_files
|
||||
|
||||
def prune_folder_files(kb_name: str):
|
||||
|
|
|
|||
|
|
@ -4,13 +4,13 @@ from langchain.embeddings.openai import OpenAIEmbeddings
|
|||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||
from configs.model_config import (
|
||||
embedding_model_dict,
|
||||
EMBEDDING_MODEL,
|
||||
KB_ROOT_PATH,
|
||||
CHUNK_SIZE,
|
||||
OVERLAP_SIZE,
|
||||
ZH_TITLE_ENHANCE,
|
||||
logger, log_verbose,
|
||||
)
|
||||
from functools import lru_cache
|
||||
import importlib
|
||||
from text_splitter import zh_title_enhance
|
||||
import langchain.document_loaders
|
||||
|
|
@ -19,25 +19,11 @@ from langchain.text_splitter import TextSplitter
|
|||
from pathlib import Path
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from server.utils import run_in_thread_pool
|
||||
from server.utils import run_in_thread_pool, embedding_device
|
||||
import io
|
||||
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
|
||||
|
||||
|
||||
# make HuggingFaceEmbeddings hashable
|
||||
def _embeddings_hash(self):
|
||||
if isinstance(self, HuggingFaceEmbeddings):
|
||||
return hash(self.model_name)
|
||||
elif isinstance(self, HuggingFaceBgeEmbeddings):
|
||||
return hash(self.model_name)
|
||||
elif isinstance(self, OpenAIEmbeddings):
|
||||
return hash(self.model)
|
||||
|
||||
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
|
||||
OpenAIEmbeddings.__hash__ = _embeddings_hash
|
||||
HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash
|
||||
|
||||
|
||||
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||
# 检查是否包含预期外的字符或路径攻击关键字
|
||||
if "../" in knowledge_base_id:
|
||||
|
|
@ -72,19 +58,12 @@ def list_files_from_folder(kb_name: str):
|
|||
if os.path.isfile(os.path.join(doc_path, file))]
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def load_embeddings(model: str, device: str):
|
||||
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
|
||||
embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE)
|
||||
elif 'bge-' in model:
|
||||
embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model],
|
||||
model_kwargs={'device': device},
|
||||
query_instruction="为这个句子生成表示以用于检索相关文章:")
|
||||
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
|
||||
embeddings.query_instruction = ""
|
||||
else:
|
||||
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device})
|
||||
return embeddings
|
||||
def load_embeddings(model: str = EMBEDDING_MODEL, device: str = embedding_device()):
|
||||
'''
|
||||
从缓存中加载embeddings,可以避免多线程时竞争加载。
|
||||
'''
|
||||
from server.knowledge_base.kb_cache.base import embeddings_pool
|
||||
return embeddings_pool.load_embeddings(model=model, device=device)
|
||||
|
||||
|
||||
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
|
||||
|
|
|
|||
|
|
@ -1,279 +1,70 @@
|
|||
from multiprocessing import Process, Queue
|
||||
import multiprocessing as mp
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
from configs.model_config import llm_model_dict, LLM_MODEL, LOG_PATH, logger, log_verbose
|
||||
from server.utils import MakeFastAPIOffline, set_httpx_timeout, llm_device
|
||||
from fastapi import Body
|
||||
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
|
||||
from server.utils import BaseResponse, fschat_controller_address
|
||||
import httpx
|
||||
|
||||
|
||||
host_ip = "0.0.0.0"
|
||||
controller_port = 20001
|
||||
model_worker_port = 20002
|
||||
openai_api_port = 8888
|
||||
base_url = "http://127.0.0.1:{}"
|
||||
def list_llm_models(
|
||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
从fastchat controller获取已加载模型列表
|
||||
'''
|
||||
try:
|
||||
controller_address = controller_address or fschat_controller_address()
|
||||
r = httpx.post(controller_address + "/list_models")
|
||||
return BaseResponse(data=r.json()["models"])
|
||||
except Exception as e:
|
||||
logger.error(f'{e.__class__.__name__}: {e}',
|
||||
exc_info=e if log_verbose else None)
|
||||
return BaseResponse(
|
||||
code=500,
|
||||
data=[],
|
||||
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
|
||||
|
||||
|
||||
def create_controller_app(
|
||||
dispatch_method="shortest_queue",
|
||||
def stop_llm_model(
|
||||
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
|
||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
向fastchat controller请求停止某个LLM模型。
|
||||
注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。
|
||||
'''
|
||||
try:
|
||||
controller_address = controller_address or fschat_controller_address()
|
||||
r = httpx.post(
|
||||
controller_address + "/release_worker",
|
||||
json={"model_name": model_name},
|
||||
)
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
logger.error(f'{e.__class__.__name__}: {e}',
|
||||
exc_info=e if log_verbose else None)
|
||||
return BaseResponse(
|
||||
code=500,
|
||||
msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}")
|
||||
|
||||
|
||||
def change_llm_model(
|
||||
model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]),
|
||||
new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]),
|
||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
||||
):
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.controller import app, Controller
|
||||
|
||||
controller = Controller(dispatch_method)
|
||||
sys.modules["fastchat.serve.controller"].controller = controller
|
||||
|
||||
MakeFastAPIOffline(app)
|
||||
app.title = "FastChat Controller"
|
||||
return app
|
||||
|
||||
|
||||
def create_model_worker_app(
|
||||
worker_address=base_url.format(model_worker_port),
|
||||
controller_address=base_url.format(controller_port),
|
||||
model_path=llm_model_dict[LLM_MODEL].get("local_model_path"),
|
||||
device=llm_device(),
|
||||
gpus=None,
|
||||
max_gpu_memory="20GiB",
|
||||
load_8bit=False,
|
||||
cpu_offloading=None,
|
||||
gptq_ckpt=None,
|
||||
gptq_wbits=16,
|
||||
gptq_groupsize=-1,
|
||||
gptq_act_order=False,
|
||||
awq_ckpt=None,
|
||||
awq_wbits=16,
|
||||
awq_groupsize=-1,
|
||||
model_names=[LLM_MODEL],
|
||||
num_gpus=1, # not in fastchat
|
||||
conv_template=None,
|
||||
limit_worker_concurrency=5,
|
||||
stream_interval=2,
|
||||
no_register=False,
|
||||
):
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
|
||||
import argparse
|
||||
import threading
|
||||
import fastchat.serve.model_worker
|
||||
|
||||
# workaround to make program exit with Ctrl+c
|
||||
# it should be deleted after pr is merged by fastchat
|
||||
def _new_init_heart_beat(self):
|
||||
self.register_to_controller()
|
||||
self.heart_beat_thread = threading.Thread(
|
||||
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
|
||||
'''
|
||||
向fastchat controller请求切换LLM模型。
|
||||
'''
|
||||
try:
|
||||
controller_address = controller_address or fschat_controller_address()
|
||||
r = httpx.post(
|
||||
controller_address + "/release_worker",
|
||||
json={"model_name": model_name, "new_model_name": new_model_name},
|
||||
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
|
||||
)
|
||||
self.heart_beat_thread.start()
|
||||
ModelWorker.init_heart_beat = _new_init_heart_beat
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
args = parser.parse_args()
|
||||
args.model_path = model_path
|
||||
args.model_names = model_names
|
||||
args.device = device
|
||||
args.load_8bit = load_8bit
|
||||
args.gptq_ckpt = gptq_ckpt
|
||||
args.gptq_wbits = gptq_wbits
|
||||
args.gptq_groupsize = gptq_groupsize
|
||||
args.gptq_act_order = gptq_act_order
|
||||
args.awq_ckpt = awq_ckpt
|
||||
args.awq_wbits = awq_wbits
|
||||
args.awq_groupsize = awq_groupsize
|
||||
args.gpus = gpus
|
||||
args.num_gpus = num_gpus
|
||||
args.max_gpu_memory = max_gpu_memory
|
||||
args.cpu_offloading = cpu_offloading
|
||||
args.worker_address = worker_address
|
||||
args.controller_address = controller_address
|
||||
args.conv_template = conv_template
|
||||
args.limit_worker_concurrency = limit_worker_concurrency
|
||||
args.stream_interval = stream_interval
|
||||
args.no_register = no_register
|
||||
|
||||
if args.gpus:
|
||||
if len(args.gpus.split(",")) < args.num_gpus:
|
||||
raise ValueError(
|
||||
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
|
||||
)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
||||
|
||||
if gpus and num_gpus is None:
|
||||
num_gpus = len(gpus.split(','))
|
||||
args.num_gpus = num_gpus
|
||||
|
||||
gptq_config = GptqConfig(
|
||||
ckpt=gptq_ckpt or model_path,
|
||||
wbits=args.gptq_wbits,
|
||||
groupsize=args.gptq_groupsize,
|
||||
act_order=args.gptq_act_order,
|
||||
)
|
||||
awq_config = AWQConfig(
|
||||
ckpt=args.awq_ckpt or args.model_path,
|
||||
wbits=args.awq_wbits,
|
||||
groupsize=args.awq_groupsize,
|
||||
)
|
||||
|
||||
# torch.multiprocessing.set_start_method('spawn')
|
||||
worker = ModelWorker(
|
||||
controller_addr=args.controller_address,
|
||||
worker_addr=args.worker_address,
|
||||
worker_id=worker_id,
|
||||
model_path=args.model_path,
|
||||
model_names=args.model_names,
|
||||
limit_worker_concurrency=args.limit_worker_concurrency,
|
||||
no_register=args.no_register,
|
||||
device=args.device,
|
||||
num_gpus=args.num_gpus,
|
||||
max_gpu_memory=args.max_gpu_memory,
|
||||
load_8bit=args.load_8bit,
|
||||
cpu_offloading=args.cpu_offloading,
|
||||
gptq_config=gptq_config,
|
||||
awq_config=awq_config,
|
||||
stream_interval=args.stream_interval,
|
||||
conv_template=args.conv_template,
|
||||
)
|
||||
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
sys.modules["fastchat.serve.model_worker"].args = args
|
||||
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
||||
|
||||
MakeFastAPIOffline(app)
|
||||
app.title = f"FastChat LLM Server ({LLM_MODEL})"
|
||||
return app
|
||||
|
||||
|
||||
def create_openai_api_app(
|
||||
controller_address=base_url.format(controller_port),
|
||||
api_keys=[],
|
||||
):
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_credentials=True,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app_settings.controller_address = controller_address
|
||||
app_settings.api_keys = api_keys
|
||||
|
||||
MakeFastAPIOffline(app)
|
||||
app.title = "FastChat OpeanAI API Server"
|
||||
return app
|
||||
|
||||
|
||||
def run_controller(q):
|
||||
import uvicorn
|
||||
app = create_controller_app()
|
||||
|
||||
@app.on_event("startup")
|
||||
async def on_startup():
|
||||
set_httpx_timeout()
|
||||
q.put(1)
|
||||
|
||||
uvicorn.run(app, host=host_ip, port=controller_port)
|
||||
|
||||
|
||||
def run_model_worker(q, *args, **kwargs):
|
||||
import uvicorn
|
||||
app = create_model_worker_app(*args, **kwargs)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def on_startup():
|
||||
set_httpx_timeout()
|
||||
while True:
|
||||
no = q.get()
|
||||
if no != 1:
|
||||
q.put(no)
|
||||
else:
|
||||
break
|
||||
q.put(2)
|
||||
|
||||
uvicorn.run(app, host=host_ip, port=model_worker_port)
|
||||
|
||||
|
||||
def run_openai_api(q):
|
||||
import uvicorn
|
||||
app = create_openai_api_app()
|
||||
|
||||
@app.on_event("startup")
|
||||
async def on_startup():
|
||||
set_httpx_timeout()
|
||||
while True:
|
||||
no = q.get()
|
||||
if no != 2:
|
||||
q.put(no)
|
||||
else:
|
||||
break
|
||||
q.put(3)
|
||||
|
||||
uvicorn.run(app, host=host_ip, port=openai_api_port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mp.set_start_method("spawn")
|
||||
queue = Queue()
|
||||
logger.info(llm_model_dict[LLM_MODEL])
|
||||
model_path = llm_model_dict[LLM_MODEL]["local_model_path"]
|
||||
|
||||
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
|
||||
|
||||
if not model_path:
|
||||
logger.error("local_model_path 不能为空")
|
||||
else:
|
||||
controller_process = Process(
|
||||
target=run_controller,
|
||||
name=f"controller({os.getpid()})",
|
||||
args=(queue,),
|
||||
daemon=True,
|
||||
)
|
||||
controller_process.start()
|
||||
|
||||
model_worker_process = Process(
|
||||
target=run_model_worker,
|
||||
name=f"model_worker({os.getpid()})",
|
||||
args=(queue,),
|
||||
# kwargs={"load_8bit": True},
|
||||
daemon=True,
|
||||
)
|
||||
model_worker_process.start()
|
||||
|
||||
openai_api_process = Process(
|
||||
target=run_openai_api,
|
||||
name=f"openai_api({os.getpid()})",
|
||||
args=(queue,),
|
||||
daemon=True,
|
||||
)
|
||||
openai_api_process.start()
|
||||
|
||||
try:
|
||||
model_worker_process.join()
|
||||
controller_process.join()
|
||||
openai_api_process.join()
|
||||
except KeyboardInterrupt:
|
||||
model_worker_process.terminate()
|
||||
controller_process.terminate()
|
||||
openai_api_process.terminate()
|
||||
|
||||
# 服务启动后接口调用示例:
|
||||
# import openai
|
||||
# openai.api_key = "EMPTY" # Not support yet
|
||||
# openai.api_base = "http://localhost:8888/v1"
|
||||
|
||||
# model = "chatglm2-6b"
|
||||
|
||||
# # create a chat completion
|
||||
# completion = openai.ChatCompletion.create(
|
||||
# model=model,
|
||||
# messages=[{"role": "user", "content": "Hello! What is your name?"}]
|
||||
# )
|
||||
# # print the completion
|
||||
# print(completion.choices[0].message.content)
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
logger.error(f'{e.__class__.__name__}: {e}',
|
||||
exc_info=e if log_verbose else None)
|
||||
return BaseResponse(
|
||||
code=500,
|
||||
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
|||
from typing import Literal, Optional, Callable, Generator, Dict, Any
|
||||
|
||||
|
||||
thread_pool = ThreadPoolExecutor()
|
||||
thread_pool = ThreadPoolExecutor(os.cpu_count())
|
||||
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from pprint import pprint
|
|||
|
||||
|
||||
api_base_url = api_address()
|
||||
api: ApiRequest = ApiRequest(api_base_url)
|
||||
api: ApiRequest = ApiRequest(api_base_url, no_remote_api=True)
|
||||
|
||||
|
||||
kb = "kb_for_api_test"
|
||||
|
|
@ -84,7 +84,7 @@ def test_upload_docs():
|
|||
|
||||
print(f"\n尝试重新上传知识文件, 覆盖,自定义docs")
|
||||
docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]}
|
||||
data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)}
|
||||
data = {"knowledge_base_name": kb, "override": True, "docs": docs}
|
||||
data = api.upload_kb_docs(files, **data)
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
|
|
|
|||
|
|
@ -5,8 +5,9 @@ from pathlib import Path
|
|||
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from configs.server_config import api_address, FSCHAT_MODEL_WORKERS
|
||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||
from configs.model_config import LLM_MODEL, llm_model_dict
|
||||
from server.utils import api_address
|
||||
|
||||
from pprint import pprint
|
||||
import random
|
||||
|
|
|
|||
|
|
@ -47,7 +47,6 @@ data = {
|
|||
}
|
||||
|
||||
|
||||
|
||||
def test_chat_fastchat(api="/chat/fastchat"):
|
||||
url = f"{api_base_url}{api}"
|
||||
data2 = {
|
||||
|
|
|
|||
6
webui.py
6
webui.py
|
|
@ -1,9 +1,3 @@
|
|||
# 运行方式:
|
||||
# 1. 安装必要的包:pip install streamlit-option-menu streamlit-chatbox>=1.1.6
|
||||
# 2. 运行本机fastchat服务:python server\llm_api.py 或者 运行对应的sh文件
|
||||
# 3. 运行API服务器:python server/api.py。如果使用api = ApiRequest(no_remote_api=True),该步可以跳过。
|
||||
# 4. 运行WEB UI:streamlit run webui.py --server.port 7860
|
||||
|
||||
import streamlit as st
|
||||
from webui_pages.utils import *
|
||||
from streamlit_option_menu import option_menu
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from server.chat.openai_chat import OpenAiChatMsgIn
|
|||
from fastapi.responses import StreamingResponse
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
from io import BytesIO
|
||||
from server.utils import run_async, iter_over_async, set_httpx_timeout, api_address
|
||||
|
||||
|
|
@ -475,7 +476,7 @@ class ApiRequest:
|
|||
|
||||
if no_remote_api:
|
||||
from server.knowledge_base.kb_api import create_kb
|
||||
response = run_async(create_kb(**data))
|
||||
response = create_kb(**data)
|
||||
return response.dict()
|
||||
else:
|
||||
response = self.post(
|
||||
|
|
@ -497,7 +498,7 @@ class ApiRequest:
|
|||
|
||||
if no_remote_api:
|
||||
from server.knowledge_base.kb_api import delete_kb
|
||||
response = run_async(delete_kb(knowledge_base_name))
|
||||
response = delete_kb(knowledge_base_name)
|
||||
return response.dict()
|
||||
else:
|
||||
response = self.post(
|
||||
|
|
@ -584,7 +585,7 @@ class ApiRequest:
|
|||
filename = filename or file.name
|
||||
else: # a local path
|
||||
file = Path(file).absolute().open("rb")
|
||||
filename = filename or file.name
|
||||
filename = filename or os.path.split(file.name)[-1]
|
||||
return filename, file
|
||||
|
||||
files = [convert_file(file) for file in files]
|
||||
|
|
@ -602,13 +603,13 @@ class ApiRequest:
|
|||
from tempfile import SpooledTemporaryFile
|
||||
|
||||
upload_files = []
|
||||
for file, filename in files:
|
||||
for filename, file in files:
|
||||
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
|
||||
temp_file.write(file.read())
|
||||
temp_file.seek(0)
|
||||
upload_files.append(UploadFile(file=temp_file, filename=filename))
|
||||
|
||||
response = run_async(upload_docs(upload_files, **data))
|
||||
response = upload_docs(upload_files, **data)
|
||||
return response.dict()
|
||||
else:
|
||||
if isinstance(data["docs"], dict):
|
||||
|
|
@ -643,7 +644,7 @@ class ApiRequest:
|
|||
|
||||
if no_remote_api:
|
||||
from server.knowledge_base.kb_doc_api import delete_docs
|
||||
response = run_async(delete_docs(**data))
|
||||
response = delete_docs(**data)
|
||||
return response.dict()
|
||||
else:
|
||||
response = self.post(
|
||||
|
|
@ -676,7 +677,7 @@ class ApiRequest:
|
|||
}
|
||||
if no_remote_api:
|
||||
from server.knowledge_base.kb_doc_api import update_docs
|
||||
response = run_async(update_docs(**data))
|
||||
response = update_docs(**data)
|
||||
return response.dict()
|
||||
else:
|
||||
if isinstance(data["docs"], dict):
|
||||
|
|
@ -710,7 +711,7 @@ class ApiRequest:
|
|||
|
||||
if no_remote_api:
|
||||
from server.knowledge_base.kb_doc_api import recreate_vector_store
|
||||
response = run_async(recreate_vector_store(**data))
|
||||
response = recreate_vector_store(**data)
|
||||
return self._fastapi_stream2generator(response, as_json=True)
|
||||
else:
|
||||
response = self.post(
|
||||
|
|
@ -721,14 +722,30 @@ class ApiRequest:
|
|||
)
|
||||
return self._httpx_stream2generator(response, as_json=True)
|
||||
|
||||
def list_running_models(self, controller_address: str = None):
|
||||
# LLM模型相关操作
|
||||
def list_running_models(
|
||||
self,
|
||||
controller_address: str = None,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
获取Fastchat中正运行的模型列表
|
||||
'''
|
||||
r = self.post(
|
||||
"/llm_model/list_models",
|
||||
)
|
||||
return r.json().get("data", [])
|
||||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
data = {
|
||||
"controller_address": controller_address,
|
||||
}
|
||||
if no_remote_api:
|
||||
from server.llm_api import list_llm_models
|
||||
return list_llm_models(**data).data
|
||||
else:
|
||||
r = self.post(
|
||||
"/llm_model/list_models",
|
||||
json=data,
|
||||
)
|
||||
return r.json().get("data", [])
|
||||
|
||||
def list_config_models(self):
|
||||
'''
|
||||
|
|
@ -740,30 +757,43 @@ class ApiRequest:
|
|||
self,
|
||||
model_name: str,
|
||||
controller_address: str = None,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
停止某个LLM模型。
|
||||
注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。
|
||||
'''
|
||||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
data = {
|
||||
"model_name": model_name,
|
||||
"controller_address": controller_address,
|
||||
}
|
||||
r = self.post(
|
||||
"/llm_model/stop",
|
||||
json=data,
|
||||
)
|
||||
return r.json()
|
||||
|
||||
if no_remote_api:
|
||||
from server.llm_api import stop_llm_model
|
||||
return stop_llm_model(**data).dict()
|
||||
else:
|
||||
r = self.post(
|
||||
"/llm_model/stop",
|
||||
json=data,
|
||||
)
|
||||
return r.json()
|
||||
|
||||
def change_llm_model(
|
||||
self,
|
||||
model_name: str,
|
||||
new_model_name: str,
|
||||
controller_address: str = None,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
向fastchat controller请求切换LLM模型。
|
||||
'''
|
||||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
if not model_name or not new_model_name:
|
||||
return
|
||||
|
||||
|
|
@ -792,12 +822,17 @@ class ApiRequest:
|
|||
"new_model_name": new_model_name,
|
||||
"controller_address": controller_address,
|
||||
}
|
||||
r = self.post(
|
||||
"/llm_model/change",
|
||||
json=data,
|
||||
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
|
||||
)
|
||||
return r.json()
|
||||
|
||||
if no_remote_api:
|
||||
from server.llm_api import change_llm_model
|
||||
return change_llm_model(**data).dict()
|
||||
else:
|
||||
r = self.post(
|
||||
"/llm_model/change",
|
||||
json=data,
|
||||
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
|
||||
)
|
||||
return r.json()
|
||||
|
||||
|
||||
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
|
||||
|
|
|
|||
Loading…
Reference in New Issue