修改Embeddings和FAISS缓存加载方式,知识库相关API接口支持多线程并发 (#1434)

* 修改Embeddings和FAISS缓存加载方式,支持多线程,支持内存FAISS

* 知识库相关API接口支持多线程并发

* 根据新的API接口调整ApiRequest和测试用例

* 删除webui.py失效的启动说明
This commit is contained in:
liunux4odoo 2023-09-11 20:41:41 +08:00 committed by GitHub
parent d0e654d847
commit 22ff073309
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 497 additions and 530 deletions

View File

@ -4,12 +4,11 @@ import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import LLM_MODEL, NLTK_DATA_PATH
from configs.server_config import OPEN_CROSS_DOMAIN, HTTPX_DEFAULT_TIMEOUT
from configs import VERSION, logger, log_verbose
from configs import VERSION
from configs.model_config import NLTK_DATA_PATH
from configs.server_config import OPEN_CROSS_DOMAIN
import argparse
import uvicorn
from fastapi import Body
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse
from server.chat import (chat, knowledge_base_chat, openai_chat,
@ -18,8 +17,8 @@ from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
update_docs, download_doc, recreate_vector_store,
search_docs, DocumentWithScore)
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address
import httpx
from server.llm_api import list_llm_models, change_llm_model, stop_llm_model
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
from typing import List
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
@ -126,79 +125,20 @@ def create_app():
)(recreate_vector_store)
# LLM模型相关接口
@app.post("/llm_model/list_models",
app.post("/llm_model/list_models",
tags=["LLM Model Management"],
summary="列出当前已加载的模型")
def list_models(
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
) -> BaseResponse:
'''
从fastchat controller获取已加载模型列表
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(controller_address + "/list_models")
return BaseResponse(data=r.json()["models"])
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
return BaseResponse(
code=500,
data=[],
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
summary="列出当前已加载的模型",
)(list_llm_models)
@app.post("/llm_model/stop",
app.post("/llm_model/stop",
tags=["LLM Model Management"],
summary="停止指定的LLM模型Model Worker)",
)
def stop_llm_model(
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
) -> BaseResponse:
'''
向fastchat controller请求停止某个LLM模型
注意由于Fastchat的实现方式实际上是把LLM模型所在的model_worker停掉
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(
controller_address + "/release_worker",
json={"model_name": model_name},
)
return r.json()
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
return BaseResponse(
code=500,
msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}")
)(stop_llm_model)
@app.post("/llm_model/change",
app.post("/llm_model/change",
tags=["LLM Model Management"],
summary="切换指定的LLM模型Model Worker)",
)
def change_llm_model(
model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]),
new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
):
'''
向fastchat controller请求切换LLM模型
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(
controller_address + "/release_worker",
json={"model_name": model_name, "new_model_name": new_model_name},
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
)
return r.json()
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
return BaseResponse(
code=500,
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
)(change_llm_model)
return app

View File

@ -12,10 +12,10 @@ def list_kbs():
return ListResponse(data=list_kbs_from_db())
async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
vector_store_type: str = Body("faiss"),
embed_model: str = Body(EMBEDDING_MODEL),
) -> BaseResponse:
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
vector_store_type: str = Body("faiss"),
embed_model: str = Body(EMBEDDING_MODEL),
) -> BaseResponse:
# Create selected knowledge base
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -38,8 +38,8 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
async def delete_kb(
knowledge_base_name: str = Body(..., examples=["samples"])
def delete_kb(
knowledge_base_name: str = Body(..., examples=["samples"])
) -> BaseResponse:
# Delete selected knowledge base
if not validate_kb_name(knowledge_base_name):

View File

@ -0,0 +1,137 @@
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
import threading
from configs.model_config import (CACHED_VS_NUM, EMBEDDING_MODEL, CHUNK_SIZE,
embedding_model_dict, logger, log_verbose)
from server.utils import embedding_device
from contextlib import contextmanager
from collections import OrderedDict
from typing import List, Any, Union, Tuple
class ThreadSafeObject:
def __init__(self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = None):
self._obj = obj
self._key = key
self._pool = pool
self._lock = threading.RLock()
self._loaded = threading.Event()
def __repr__(self) -> str:
cls = type(self).__name__
return f"<{cls}: key: {self._key}, obj: {self._obj}>"
@contextmanager
def acquire(self, owner: str = "", msg: str = ""):
owner = owner or f"thread {threading.get_native_id()}"
try:
self._lock.acquire()
if self._pool is not None:
self._pool._cache.move_to_end(self._key)
if log_verbose:
logger.info(f"{owner} 开始操作:{self._key}{msg}")
yield self._obj
finally:
if log_verbose:
logger.info(f"{owner} 结束操作:{self._key}{msg}")
self._lock.release()
def start_loading(self):
self._loaded.clear()
def finish_loading(self):
self._loaded.set()
def wait_for_loading(self):
self._loaded.wait()
@property
def obj(self):
return self._obj
@obj.setter
def obj(self, val: Any):
self._obj = val
class CachePool:
def __init__(self, cache_num: int = -1):
self._cache_num = cache_num
self._cache = OrderedDict()
self.atomic = threading.RLock()
def keys(self) -> List[str]:
return list(self._cache.keys())
def _check_count(self):
if isinstance(self._cache_num, int) and self._cache_num > 0:
while len(self._cache) > self._cache_num:
self._cache.popitem(last=False)
def get(self, key: str) -> ThreadSafeObject:
if cache := self._cache.get(key):
cache.wait_for_loading()
return cache
def set(self, key: str, obj: ThreadSafeObject) -> ThreadSafeObject:
self._cache[key] = obj
self._check_count()
return obj
def pop(self, key: str = None) -> ThreadSafeObject:
if key is None:
return self._cache.popitem(last=False)
else:
return self._cache.pop(key, None)
def acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = ""):
cache = self.get(key)
if cache is None:
raise RuntimeError(f"请求的资源 {key} 不存在")
elif isinstance(cache, ThreadSafeObject):
self._cache.move_to_end(key)
return cache.acquire(owner=owner, msg=msg)
else:
return cache
def load_kb_embeddings(self, kb_name: str=None, embed_device: str = embedding_device()) -> Embeddings:
from server.db.repository.knowledge_base_repository import get_kb_detail
kb_detail = get_kb_detail(kb_name=kb_name)
print(kb_detail)
embed_model = kb_detail.get("embed_model", EMBEDDING_MODEL)
return embeddings_pool.load_embeddings(model=embed_model, device=embed_device)
class EmbeddingsPool(CachePool):
def load_embeddings(self, model: str, device: str) -> Embeddings:
self.atomic.acquire()
model = model or EMBEDDING_MODEL
device = device or embedding_device()
key = (model, device)
if not self.get(key):
item = ThreadSafeObject(key, pool=self)
self.set(key, item)
with item.acquire(msg="初始化"):
self.atomic.release()
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE)
elif 'bge-' in model:
embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model],
model_kwargs={'device': device},
query_instruction="为这个句子生成表示以用于检索相关文章:")
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
embeddings.query_instruction = ""
else:
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device})
item.obj = embeddings
item.finish_loading()
else:
self.atomic.release()
return self.get(key).obj
embeddings_pool = EmbeddingsPool(cache_num=1)

View File

@ -0,0 +1,157 @@
from server.knowledge_base.kb_cache.base import *
from server.knowledge_base.utils import get_vs_path
from langchain.vectorstores import FAISS
import os
class ThreadSafeFaiss(ThreadSafeObject):
def __repr__(self) -> str:
cls = type(self).__name__
return f"<{cls}: key: {self._key}, obj: {self._obj}, docs_count: {self.docs_count()}>"
def docs_count(self) -> int:
return len(self._obj.docstore._dict)
def save(self, path: str, create_path: bool = True):
with self.acquire():
if not os.path.isdir(path) and create_path:
os.makedirs(path)
ret = self._obj.save_local(path)
logger.info(f"已将向量库 {self._key} 保存到磁盘")
return ret
def clear(self):
ret = []
with self.acquire():
ids = list(self._obj.docstore._dict.keys())
if ids:
ret = self._obj.delete(ids)
assert len(self._obj.docstore._dict) == 0
logger.info(f"已将向量库 {self._key} 清空")
return ret
class _FaissPool(CachePool):
def new_vector_store(
self,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> FAISS:
embeddings = embeddings_pool.load_embeddings(embed_model, embed_device)
# create an empty vector store
doc = Document(page_content="init", metadata={})
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True)
ids = list(vector_store.docstore._dict.keys())
vector_store.delete(ids)
return vector_store
def save_vector_store(self, kb_name: str, path: str=None):
if cache := self.get(kb_name):
return cache.save(path)
def unload_vector_store(self, kb_name: str):
if cache := self.get(kb_name):
self.pop(kb_name)
logger.info(f"成功释放向量库:{kb_name}")
class KBFaissPool(_FaissPool):
def load_vector_store(
self,
kb_name: str,
create: bool = True,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> ThreadSafeFaiss:
self.atomic.acquire()
cache = self.get(kb_name)
if cache is None:
item = ThreadSafeFaiss(kb_name, pool=self)
self.set(kb_name, item)
with item.acquire(msg="初始化"):
self.atomic.release()
logger.info(f"loading vector store in '{kb_name}' from disk.")
vs_path = get_vs_path(kb_name)
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device)
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
elif create:
# create an empty vector store
if not os.path.exists(vs_path):
os.makedirs(vs_path)
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
vector_store.save_local(vs_path)
else:
raise RuntimeError(f"knowledge base {kb_name} not exist.")
item.obj = vector_store
item.finish_loading()
else:
self.atomic.release()
return self.get(kb_name)
class MemoFaissPool(_FaissPool):
def load_vector_store(
self,
kb_name: str,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> ThreadSafeFaiss:
self.atomic.acquire()
cache = self.get(kb_name)
if cache is None:
item = ThreadSafeFaiss(kb_name, pool=self)
self.set(kb_name, item)
with item.acquire(msg="初始化"):
self.atomic.release()
logger.info(f"loading vector store in '{kb_name}' to memory.")
# create an empty vector store
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
item.obj = vector_store
item.finish_loading()
else:
self.atomic.release()
return self.get(kb_name)
kb_faiss_pool = KBFaissPool(cache_num=CACHED_VS_NUM)
memo_faiss_pool = MemoFaissPool()
if __name__ == "__main__":
import time, random
from pprint import pprint
kb_names = ["vs1", "vs2", "vs3"]
# for name in kb_names:
# memo_faiss_pool.load_vector_store(name)
def worker(vs_name: str, name: str):
vs_name = "samples"
time.sleep(random.randint(1, 5))
embeddings = embeddings_pool.load_embeddings()
r = random.randint(1, 3)
with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs:
if r == 1: # add docs
ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings)
pprint(ids)
elif r == 2: # search docs
docs = vs.similarity_search_with_score(f"{name}", top_k=3, score_threshold=1.0)
pprint(docs)
if r == 3: # delete docs
logger.warning(f"清除 {vs_name} by {name}")
kb_faiss_pool.get(vs_name).clear()
threads = []
for n in range(1, 30):
t = threading.Thread(target=worker,
kwargs={"vs_name": random.choice(kb_names), "name": f"worker {n}"},
daemon=True)
t.start()
threads.append(t)
for t in threads:
t.join()

View File

@ -117,13 +117,13 @@ def _save_files_in_thread(files: List[UploadFile],
# yield json.dumps(result, ensure_ascii=False)
async def upload_docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
override: bool = Form(False, description="覆盖已有文件"),
to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
docs: Json = Form({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
def upload_docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
override: bool = Form(False, description="覆盖已有文件"),
to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
docs: Json = Form({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
'''
API接口上传文件/或向量化
'''
@ -148,7 +148,7 @@ async def upload_docs(files: List[UploadFile] = File(..., description="上传文
# 对保存的文件进行向量化
if to_vector_store:
result = await update_docs(
result = update_docs(
knowledge_base_name=knowledge_base_name,
file_names=file_names,
override_custom_docs=True,
@ -162,11 +162,11 @@ async def upload_docs(files: List[UploadFile] = File(..., description="上传文
return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files})
async def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]),
file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]),
delete_content: bool = Body(False),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]),
file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]),
delete_content: bool = Body(False),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -196,12 +196,12 @@ async def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"])
return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files})
async def update_docs(
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=["file_name"]),
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
docs: Json = Body({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
def update_docs(
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=["file_name"]),
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
docs: Json = Body({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
'''
更新知识库文档
@ -302,11 +302,11 @@ def download_doc(
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
async def recreate_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]),
allow_empty_kb: bool = Body(True),
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
def recreate_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]),
allow_empty_kb: bool = Body(True),
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
):
'''
recreate vector store from the content.

View File

@ -146,7 +146,6 @@ class KBService(ABC):
docs = self.do_search(query, top_k, score_threshold, embeddings)
return docs
# TODO: milvus/pg需要实现该方法
def get_doc_by_id(self, id: str) -> Optional[Document]:
return None

View File

@ -3,62 +3,16 @@ import shutil
from configs.model_config import (
KB_ROOT_PATH,
CACHED_VS_NUM,
EMBEDDING_MODEL,
SCORE_THRESHOLD,
logger, log_verbose,
)
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from functools import lru_cache
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
from langchain.vectorstores import FAISS
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
from server.knowledge_base.utils import KnowledgeFile
from langchain.embeddings.base import Embeddings
from typing import List, Dict, Optional
from langchain.docstore.document import Document
from server.utils import torch_gc, embedding_device
_VECTOR_STORE_TICKS = {}
@lru_cache(CACHED_VS_NUM)
def load_faiss_vector_store(
knowledge_base_name: str,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
embeddings: Embeddings = None,
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
) -> FAISS:
logger.info(f"loading vector store in '{knowledge_base_name}'.")
vs_path = get_vs_path(knowledge_base_name)
if embeddings is None:
embeddings = load_embeddings(embed_model, embed_device)
if not os.path.exists(vs_path):
os.makedirs(vs_path)
if "index.faiss" in os.listdir(vs_path):
search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
else:
# create an empty vector store
doc = Document(page_content="init", metadata={})
search_index = FAISS.from_documents([doc], embeddings, normalize_L2=True)
ids = [k for k, v in search_index.docstore._dict.items()]
search_index.delete(ids)
search_index.save_local(vs_path)
if tick == 0: # vector store is loaded first time
_VECTOR_STORE_TICKS[knowledge_base_name] = 0
return search_index
def refresh_vs_cache(kb_name: str):
"""
make vector store cache refreshed when next loading
"""
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
logger.info(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}")
from server.utils import torch_gc
class FaissKBService(KBService):
@ -74,24 +28,15 @@ class FaissKBService(KBService):
def get_kb_path(self):
return os.path.join(KB_ROOT_PATH, self.kb_name)
def load_vector_store(self) -> FAISS:
return load_faiss_vector_store(
knowledge_base_name=self.kb_name,
embed_model=self.embed_model,
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0),
)
def load_vector_store(self) -> ThreadSafeFaiss:
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name, embed_model=self.embed_model)
def save_vector_store(self, vector_store: FAISS = None):
vector_store = vector_store or self.load_vector_store()
vector_store.save_local(self.vs_path)
return vector_store
def refresh_vs_cache(self):
refresh_vs_cache(self.kb_name)
def save_vector_store(self):
self.load_vector_store().save(self.vs_path)
def get_doc_by_id(self, id: str) -> Optional[Document]:
vector_store = self.load_vector_store()
return vector_store.docstore._dict.get(id)
with self.load_vector_store().acquire() as vs:
return vs.docstore._dict.get(id)
def do_init(self):
self.kb_path = self.get_kb_path()
@ -112,43 +57,38 @@ class FaissKBService(KBService):
score_threshold: float = SCORE_THRESHOLD,
embeddings: Embeddings = None,
) -> List[Document]:
search_index = self.load_vector_store()
docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
with self.load_vector_store().acquire() as vs:
docs = vs.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
return docs
def do_add_doc(self,
docs: List[Document],
**kwargs,
) -> List[Dict]:
vector_store = self.load_vector_store()
ids = vector_store.add_documents(docs)
with self.load_vector_store().acquire() as vs:
ids = vs.add_documents(docs)
if not kwargs.get("not_refresh_vs_cache"):
vs.save_local(self.vs_path)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
torch_gc()
if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path)
self.refresh_vs_cache()
return doc_infos
def do_delete_doc(self,
kb_file: KnowledgeFile,
**kwargs):
vector_store = self.load_vector_store()
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata.get("source") == kb_file.filepath]
if len(ids) == 0:
return None
vector_store.delete(ids)
if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path)
self.refresh_vs_cache()
return vector_store
with self.load_vector_store().acquire() as vs:
ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source") == kb_file.filepath]
if len(ids) > 0:
vs.delete(ids)
if not kwargs.get("not_refresh_vs_cache"):
vs.save_local(self.vs_path)
return ids
def do_clear_vs(self):
with kb_faiss_pool.atomic:
kb_faiss_pool.pop(self.kb_name)
shutil.rmtree(self.vs_path)
os.makedirs(self.vs_path)
self.refresh_vs_cache()
def exist_doc(self, file_name: str):
if super().exist_doc(file_name):

View File

@ -1,7 +1,6 @@
from configs.model_config import EMBEDDING_MODEL, DEFAULT_VS_TYPE, logger, log_verbose
from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder,
list_files_from_folder, run_in_thread_pool,
files2docs_in_thread,
list_files_from_folder,files2docs_in_thread,
KnowledgeFile,)
from server.knowledge_base.kb_service.base import KBServiceFactory, SupportedVSType
from server.db.repository.knowledge_file_repository import add_file_to_db
@ -72,7 +71,6 @@ def folder2db(
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache()
elif mode == "fill_info_only":
files = list_files_from_folder(kb_name)
kb_files = file_to_kbfile(kb_name, files)
@ -89,7 +87,6 @@ def folder2db(
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache()
elif mode == "increament":
db_files = kb.list_files()
folder_files = list_files_from_folder(kb_name)
@ -107,7 +104,6 @@ def folder2db(
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache()
else:
print(f"unspported migrate mode: {mode}")
@ -139,7 +135,6 @@ def prune_db_files(kb_name: str):
kb.delete_doc(kb_file, not_refresh_vs_cache=True)
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache()
return kb_files
def prune_folder_files(kb_name: str):

View File

@ -4,13 +4,13 @@ from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceBgeEmbeddings
from configs.model_config import (
embedding_model_dict,
EMBEDDING_MODEL,
KB_ROOT_PATH,
CHUNK_SIZE,
OVERLAP_SIZE,
ZH_TITLE_ENHANCE,
logger, log_verbose,
)
from functools import lru_cache
import importlib
from text_splitter import zh_title_enhance
import langchain.document_loaders
@ -19,25 +19,11 @@ from langchain.text_splitter import TextSplitter
from pathlib import Path
import json
from concurrent.futures import ThreadPoolExecutor
from server.utils import run_in_thread_pool
from server.utils import run_in_thread_pool, embedding_device
import io
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
# make HuggingFaceEmbeddings hashable
def _embeddings_hash(self):
if isinstance(self, HuggingFaceEmbeddings):
return hash(self.model_name)
elif isinstance(self, HuggingFaceBgeEmbeddings):
return hash(self.model_name)
elif isinstance(self, OpenAIEmbeddings):
return hash(self.model)
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
OpenAIEmbeddings.__hash__ = _embeddings_hash
HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash
def validate_kb_name(knowledge_base_id: str) -> bool:
# 检查是否包含预期外的字符或路径攻击关键字
if "../" in knowledge_base_id:
@ -72,19 +58,12 @@ def list_files_from_folder(kb_name: str):
if os.path.isfile(os.path.join(doc_path, file))]
@lru_cache(1)
def load_embeddings(model: str, device: str):
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE)
elif 'bge-' in model:
embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model],
model_kwargs={'device': device},
query_instruction="为这个句子生成表示以用于检索相关文章:")
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
embeddings.query_instruction = ""
else:
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device})
return embeddings
def load_embeddings(model: str = EMBEDDING_MODEL, device: str = embedding_device()):
'''
从缓存中加载embeddings可以避免多线程时竞争加载
'''
from server.knowledge_base.kb_cache.base import embeddings_pool
return embeddings_pool.load_embeddings(model=model, device=device)
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],

View File

@ -1,279 +1,70 @@
from multiprocessing import Process, Queue
import multiprocessing as mp
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import llm_model_dict, LLM_MODEL, LOG_PATH, logger, log_verbose
from server.utils import MakeFastAPIOffline, set_httpx_timeout, llm_device
from fastapi import Body
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
from server.utils import BaseResponse, fschat_controller_address
import httpx
host_ip = "0.0.0.0"
controller_port = 20001
model_worker_port = 20002
openai_api_port = 8888
base_url = "http://127.0.0.1:{}"
def list_llm_models(
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
) -> BaseResponse:
'''
从fastchat controller获取已加载模型列表
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(controller_address + "/list_models")
return BaseResponse(data=r.json()["models"])
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
return BaseResponse(
code=500,
data=[],
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
def create_controller_app(
dispatch_method="shortest_queue",
def stop_llm_model(
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
) -> BaseResponse:
'''
向fastchat controller请求停止某个LLM模型
注意由于Fastchat的实现方式实际上是把LLM模型所在的model_worker停掉
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(
controller_address + "/release_worker",
json={"model_name": model_name},
)
return r.json()
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
return BaseResponse(
code=500,
msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}")
def change_llm_model(
model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]),
new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
):
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.controller import app, Controller
controller = Controller(dispatch_method)
sys.modules["fastchat.serve.controller"].controller = controller
MakeFastAPIOffline(app)
app.title = "FastChat Controller"
return app
def create_model_worker_app(
worker_address=base_url.format(model_worker_port),
controller_address=base_url.format(controller_port),
model_path=llm_model_dict[LLM_MODEL].get("local_model_path"),
device=llm_device(),
gpus=None,
max_gpu_memory="20GiB",
load_8bit=False,
cpu_offloading=None,
gptq_ckpt=None,
gptq_wbits=16,
gptq_groupsize=-1,
gptq_act_order=False,
awq_ckpt=None,
awq_wbits=16,
awq_groupsize=-1,
model_names=[LLM_MODEL],
num_gpus=1, # not in fastchat
conv_template=None,
limit_worker_concurrency=5,
stream_interval=2,
no_register=False,
):
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
import argparse
import threading
import fastchat.serve.model_worker
# workaround to make program exit with Ctrl+c
# it should be deleted after pr is merged by fastchat
def _new_init_heart_beat(self):
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
'''
向fastchat controller请求切换LLM模型
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(
controller_address + "/release_worker",
json={"model_name": model_name, "new_model_name": new_model_name},
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
)
self.heart_beat_thread.start()
ModelWorker.init_heart_beat = _new_init_heart_beat
parser = argparse.ArgumentParser()
args = parser.parse_args()
args.model_path = model_path
args.model_names = model_names
args.device = device
args.load_8bit = load_8bit
args.gptq_ckpt = gptq_ckpt
args.gptq_wbits = gptq_wbits
args.gptq_groupsize = gptq_groupsize
args.gptq_act_order = gptq_act_order
args.awq_ckpt = awq_ckpt
args.awq_wbits = awq_wbits
args.awq_groupsize = awq_groupsize
args.gpus = gpus
args.num_gpus = num_gpus
args.max_gpu_memory = max_gpu_memory
args.cpu_offloading = cpu_offloading
args.worker_address = worker_address
args.controller_address = controller_address
args.conv_template = conv_template
args.limit_worker_concurrency = limit_worker_concurrency
args.stream_interval = stream_interval
args.no_register = no_register
if args.gpus:
if len(args.gpus.split(",")) < args.num_gpus:
raise ValueError(
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
if gpus and num_gpus is None:
num_gpus = len(gpus.split(','))
args.num_gpus = num_gpus
gptq_config = GptqConfig(
ckpt=gptq_ckpt or model_path,
wbits=args.gptq_wbits,
groupsize=args.gptq_groupsize,
act_order=args.gptq_act_order,
)
awq_config = AWQConfig(
ckpt=args.awq_ckpt or args.model_path,
wbits=args.awq_wbits,
groupsize=args.awq_groupsize,
)
# torch.multiprocessing.set_start_method('spawn')
worker = ModelWorker(
controller_addr=args.controller_address,
worker_addr=args.worker_address,
worker_id=worker_id,
model_path=args.model_path,
model_names=args.model_names,
limit_worker_concurrency=args.limit_worker_concurrency,
no_register=args.no_register,
device=args.device,
num_gpus=args.num_gpus,
max_gpu_memory=args.max_gpu_memory,
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
)
sys.modules["fastchat.serve.model_worker"].worker = worker
sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
MakeFastAPIOffline(app)
app.title = f"FastChat LLM Server ({LLM_MODEL})"
return app
def create_openai_api_app(
controller_address=base_url.format(controller_port),
api_keys=[],
):
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
app_settings.controller_address = controller_address
app_settings.api_keys = api_keys
MakeFastAPIOffline(app)
app.title = "FastChat OpeanAI API Server"
return app
def run_controller(q):
import uvicorn
app = create_controller_app()
@app.on_event("startup")
async def on_startup():
set_httpx_timeout()
q.put(1)
uvicorn.run(app, host=host_ip, port=controller_port)
def run_model_worker(q, *args, **kwargs):
import uvicorn
app = create_model_worker_app(*args, **kwargs)
@app.on_event("startup")
async def on_startup():
set_httpx_timeout()
while True:
no = q.get()
if no != 1:
q.put(no)
else:
break
q.put(2)
uvicorn.run(app, host=host_ip, port=model_worker_port)
def run_openai_api(q):
import uvicorn
app = create_openai_api_app()
@app.on_event("startup")
async def on_startup():
set_httpx_timeout()
while True:
no = q.get()
if no != 2:
q.put(no)
else:
break
q.put(3)
uvicorn.run(app, host=host_ip, port=openai_api_port)
if __name__ == "__main__":
mp.set_start_method("spawn")
queue = Queue()
logger.info(llm_model_dict[LLM_MODEL])
model_path = llm_model_dict[LLM_MODEL]["local_model_path"]
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
if not model_path:
logger.error("local_model_path 不能为空")
else:
controller_process = Process(
target=run_controller,
name=f"controller({os.getpid()})",
args=(queue,),
daemon=True,
)
controller_process.start()
model_worker_process = Process(
target=run_model_worker,
name=f"model_worker({os.getpid()})",
args=(queue,),
# kwargs={"load_8bit": True},
daemon=True,
)
model_worker_process.start()
openai_api_process = Process(
target=run_openai_api,
name=f"openai_api({os.getpid()})",
args=(queue,),
daemon=True,
)
openai_api_process.start()
try:
model_worker_process.join()
controller_process.join()
openai_api_process.join()
except KeyboardInterrupt:
model_worker_process.terminate()
controller_process.terminate()
openai_api_process.terminate()
# 服务启动后接口调用示例:
# import openai
# openai.api_key = "EMPTY" # Not support yet
# openai.api_base = "http://localhost:8888/v1"
# model = "chatglm2-6b"
# # create a chat completion
# completion = openai.ChatCompletion.create(
# model=model,
# messages=[{"role": "user", "content": "Hello! What is your name?"}]
# )
# # print the completion
# print(completion.choices[0].message.content)
return r.json()
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
return BaseResponse(
code=500,
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")

View File

@ -12,7 +12,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Literal, Optional, Callable, Generator, Dict, Any
thread_pool = ThreadPoolExecutor()
thread_pool = ThreadPoolExecutor(os.cpu_count())
class BaseResponse(BaseModel):

View File

@ -14,7 +14,7 @@ from pprint import pprint
api_base_url = api_address()
api: ApiRequest = ApiRequest(api_base_url)
api: ApiRequest = ApiRequest(api_base_url, no_remote_api=True)
kb = "kb_for_api_test"
@ -84,7 +84,7 @@ def test_upload_docs():
print(f"\n尝试重新上传知识文件, 覆盖自定义docs")
docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]}
data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)}
data = {"knowledge_base_name": kb, "override": True, "docs": docs}
data = api.upload_kb_docs(files, **data)
pprint(data)
assert data["code"] == 200

View File

@ -5,8 +5,9 @@ from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from configs.server_config import api_address, FSCHAT_MODEL_WORKERS
from configs.server_config import FSCHAT_MODEL_WORKERS
from configs.model_config import LLM_MODEL, llm_model_dict
from server.utils import api_address
from pprint import pprint
import random

View File

@ -47,7 +47,6 @@ data = {
}
def test_chat_fastchat(api="/chat/fastchat"):
url = f"{api_base_url}{api}"
data2 = {

View File

@ -1,9 +1,3 @@
# 运行方式:
# 1. 安装必要的包pip install streamlit-option-menu streamlit-chatbox>=1.1.6
# 2. 运行本机fastchat服务python server\llm_api.py 或者 运行对应的sh文件
# 3. 运行API服务器python server/api.py。如果使用api = ApiRequest(no_remote_api=True),该步可以跳过。
# 4. 运行WEB UIstreamlit run webui.py --server.port 7860
import streamlit as st
from webui_pages.utils import *
from streamlit_option_menu import option_menu

View File

@ -20,6 +20,7 @@ from server.chat.openai_chat import OpenAiChatMsgIn
from fastapi.responses import StreamingResponse
import contextlib
import json
import os
from io import BytesIO
from server.utils import run_async, iter_over_async, set_httpx_timeout, api_address
@ -475,7 +476,7 @@ class ApiRequest:
if no_remote_api:
from server.knowledge_base.kb_api import create_kb
response = run_async(create_kb(**data))
response = create_kb(**data)
return response.dict()
else:
response = self.post(
@ -497,7 +498,7 @@ class ApiRequest:
if no_remote_api:
from server.knowledge_base.kb_api import delete_kb
response = run_async(delete_kb(knowledge_base_name))
response = delete_kb(knowledge_base_name)
return response.dict()
else:
response = self.post(
@ -584,7 +585,7 @@ class ApiRequest:
filename = filename or file.name
else: # a local path
file = Path(file).absolute().open("rb")
filename = filename or file.name
filename = filename or os.path.split(file.name)[-1]
return filename, file
files = [convert_file(file) for file in files]
@ -602,13 +603,13 @@ class ApiRequest:
from tempfile import SpooledTemporaryFile
upload_files = []
for file, filename in files:
for filename, file in files:
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
temp_file.write(file.read())
temp_file.seek(0)
upload_files.append(UploadFile(file=temp_file, filename=filename))
response = run_async(upload_docs(upload_files, **data))
response = upload_docs(upload_files, **data)
return response.dict()
else:
if isinstance(data["docs"], dict):
@ -643,7 +644,7 @@ class ApiRequest:
if no_remote_api:
from server.knowledge_base.kb_doc_api import delete_docs
response = run_async(delete_docs(**data))
response = delete_docs(**data)
return response.dict()
else:
response = self.post(
@ -676,7 +677,7 @@ class ApiRequest:
}
if no_remote_api:
from server.knowledge_base.kb_doc_api import update_docs
response = run_async(update_docs(**data))
response = update_docs(**data)
return response.dict()
else:
if isinstance(data["docs"], dict):
@ -710,7 +711,7 @@ class ApiRequest:
if no_remote_api:
from server.knowledge_base.kb_doc_api import recreate_vector_store
response = run_async(recreate_vector_store(**data))
response = recreate_vector_store(**data)
return self._fastapi_stream2generator(response, as_json=True)
else:
response = self.post(
@ -721,14 +722,30 @@ class ApiRequest:
)
return self._httpx_stream2generator(response, as_json=True)
def list_running_models(self, controller_address: str = None):
# LLM模型相关操作
def list_running_models(
self,
controller_address: str = None,
no_remote_api: bool = None,
):
'''
获取Fastchat中正运行的模型列表
'''
r = self.post(
"/llm_model/list_models",
)
return r.json().get("data", [])
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = {
"controller_address": controller_address,
}
if no_remote_api:
from server.llm_api import list_llm_models
return list_llm_models(**data).data
else:
r = self.post(
"/llm_model/list_models",
json=data,
)
return r.json().get("data", [])
def list_config_models(self):
'''
@ -740,30 +757,43 @@ class ApiRequest:
self,
model_name: str,
controller_address: str = None,
no_remote_api: bool = None,
):
'''
停止某个LLM模型
注意由于Fastchat的实现方式实际上是把LLM模型所在的model_worker停掉
'''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = {
"model_name": model_name,
"controller_address": controller_address,
}
r = self.post(
"/llm_model/stop",
json=data,
)
return r.json()
if no_remote_api:
from server.llm_api import stop_llm_model
return stop_llm_model(**data).dict()
else:
r = self.post(
"/llm_model/stop",
json=data,
)
return r.json()
def change_llm_model(
self,
model_name: str,
new_model_name: str,
controller_address: str = None,
no_remote_api: bool = None,
):
'''
向fastchat controller请求切换LLM模型
'''
if no_remote_api is None:
no_remote_api = self.no_remote_api
if not model_name or not new_model_name:
return
@ -792,12 +822,17 @@ class ApiRequest:
"new_model_name": new_model_name,
"controller_address": controller_address,
}
r = self.post(
"/llm_model/change",
json=data,
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
)
return r.json()
if no_remote_api:
from server.llm_api import change_llm_model
return change_llm_model(**data).dict()
else:
r = self.post(
"/llm_model/change",
json=data,
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
)
return r.json()
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str: