支持在线 Embeddings, Lite 模式支持所有知识库相关功能 (#1924)

新功能:
- 支持在线 Embeddings:zhipu-api, qwen-api, minimax-api, qianfan-api
- API 增加 /other/embed_texts 接口
- init_database.py 增加 --embed-model 参数,可以指定使用的嵌入模型(本地或在线均可)
- 对于 FAISS 知识库,支持多向量库,默认位置:{KB_PATH}/vector_store/{embed_model}
- Lite 模式支持所有知识库相关功能。此模式下最主要的限制是:
  - 不能使用本地 LLM 和 Embeddings 模型
  - 知识库不支持 PDF 文件
- init_database.py 重建知识库时不再默认情况数据库表,增加 clear-tables 参数手动控制。
- API 和 WEBUI 中 score_threshold 参数范围改为 [0, 2],以更好的适应在线嵌入模型

问题修复:
- API 中 list_config_models 会删除 ONLINE_LLM_MODEL 中的敏感信息,导致第二轮API请求错误

开发者:
- 统一向量库的识别:以(kb_name,embed_model)为判断向量库唯一性的依据,避免 FAISS 知识库缓存加载逻辑错误
- KBServiceFactory.get_service_by_name 中添加 default_embed_model 参数,用于在构建新知识库时设置 embed_model
- 优化 kb_service 中 Embeddings 操作:
  - 统一加载接口: server.utils.load_embeddings,利用全局缓存避免各处 Embeddings 传参
  - 统一文本嵌入接口:server.knowledge_base.kb_service.base.[embed_texts, embed_documents]
- 重写 normalize 函数,去除对 scikit-learn/scipy 的依赖
This commit is contained in:
liunux4odoo 2023-10-31 14:26:50 +08:00 committed by GitHub
parent deed92169f
commit 65592a45c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 130 additions and 96 deletions

View File

@ -23,6 +23,11 @@ if __name__ == "__main__":
'''
)
)
parser.add_argument(
"--clear-tables",
action="store_true",
help=("drop the database tables before recreate vector stores")
)
parser.add_argument(
"-u",
"--update-in-db",
@ -74,7 +79,7 @@ if __name__ == "__main__":
"--embed-model",
type=str,
default=EMBEDDING_MODEL,
help=("specify knowledge base names to operate on. default is all folders exist in KB_ROOT_PATH.")
help=("specify embeddings model.")
)
if len(sys.argv) <= 1:
@ -84,9 +89,12 @@ if __name__ == "__main__":
start_time = datetime.now()
create_tables() # confirm tables exist
if args.recreate_vs:
if args.clear_tables:
reset_tables()
print("database talbes reseted")
if args.recreate_vs:
print("recreating all vector stores")
folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model)
elif args.update_in_db:

View File

@ -7,18 +7,19 @@ openai
# torchvision
# torchaudio
fastapi>=0.103.1
python-multipart
nltk~=3.8.1
uvicorn~=0.23.1
starlette~=0.27.0
pydantic~=1.10.11
# unstructured[all-docs]>=0.10.4
# python-magic-bin; sys_platform == 'win32'
unstructured[docx,csv]>=0.10.4 # add pdf if need
python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19
# faiss-cpu
faiss-cpu
# accelerate
# spacy
# PyMuPDF==1.22.5
# rapidocr_onnxruntime>=1.3.2
# PyMuPDF==1.22.5 # install if need pdf
# rapidocr_onnxruntime>=1.3.2 # install if need pdf
requests
pathlib

View File

@ -74,7 +74,6 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
)(search_engine_chat)
# 知识库相关接口
if run_mode != "lite":
mount_knowledge_routes(app)
# LLM模型相关接口

View File

@ -19,7 +19,7 @@ from server.knowledge_base.kb_doc_api import search_docs
async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=1),
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=2),
history: List[History] = Body([],
description="历史对话",
examples=[[

View File

@ -16,6 +16,7 @@ def embed_texts(
) -> BaseResponse:
'''
对文本进行向量化返回数据格式BaseResponse(data=List[List[float]])
TODO: 也许需要加入缓存机制减少 token 消耗
'''
try:
if embed_model in list_embed_models(): # 使用本地Embeddings模型

View File

@ -1,11 +1,8 @@
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.embeddings.base import Embeddings
import threading
from configs import (EMBEDDING_MODEL, CHUNK_SIZE,
logger, log_verbose)
from server.utils import embedding_device, get_model_path
from server.utils import embedding_device, get_model_path, list_online_embed_models
from contextlib import contextmanager
from collections import OrderedDict
from typing import List, Any, Union, Tuple
@ -100,12 +97,21 @@ class CachePool:
else:
return cache
def load_kb_embeddings(self, kb_name: str=None, embed_device: str = embedding_device()) -> Embeddings:
def load_kb_embeddings(
self,
kb_name: str,
embed_device: str = embedding_device(),
default_embed_model: str = EMBEDDING_MODEL,
) -> Embeddings:
from server.db.repository.knowledge_base_repository import get_kb_detail
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
kb_detail = get_kb_detail(kb_name=kb_name)
print(kb_detail)
embed_model = kb_detail.get("embed_model", EMBEDDING_MODEL)
kb_detail = get_kb_detail(kb_name)
embed_model = kb_detail.get("embed_model", default_embed_model)
if embed_model in list_online_embed_models():
return EmbeddingsFunAdapter(embed_model)
else:
return embeddings_pool.load_embeddings(model=embed_model, device=embed_device)
@ -121,10 +127,12 @@ class EmbeddingsPool(CachePool):
with item.acquire(msg="初始化"):
self.atomic.release()
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
from langchain.embeddings.openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings(model_name=model,
openai_api_key=get_model_path(model),
chunk_size=CHUNK_SIZE)
elif 'bge-' in model:
from langchain.embeddings import HuggingFaceBgeEmbeddings
if 'zh' in model:
# for chinese model
query_instruction = "为这个句子生成表示以用于检索相关文章:"
@ -140,6 +148,7 @@ class EmbeddingsPool(CachePool):
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
embeddings.query_instruction = ""
else:
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name=get_model_path(model), model_kwargs={'device': device})
item.obj = embeddings
item.finish_loading()

View File

@ -64,23 +64,24 @@ class KBFaissPool(_FaissPool):
def load_vector_store(
self,
kb_name: str,
vector_name: str = "vector_store",
vector_name: str = None,
create: bool = True,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> ThreadSafeFaiss:
self.atomic.acquire()
vector_name = vector_name or embed_model
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
if cache is None:
item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
self.set((kb_name, vector_name), item)
with item.acquire(msg="初始化"):
self.atomic.release()
logger.info(f"loading vector store in '{kb_name}/{vector_name}' from disk.")
logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.")
vs_path = get_vs_path(kb_name, vector_name)
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device)
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device, default_embed_model=embed_model)
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
elif create:
# create an empty vector store

View File

@ -6,7 +6,6 @@ import os
import numpy as np
from langchain.embeddings.base import Embeddings
from langchain.docstore.document import Document
from sklearn.preprocessing import normalize
from server.db.repository.knowledge_base_repository import (
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
@ -31,6 +30,16 @@ from server.embeddings_api import embed_texts
from server.embeddings_api import embed_documents
def normalize(embeddings: List[List[float]]) -> np.ndarray:
'''
sklearn.preprocessing.normalize 的替代使用 L2避免安装 scipy, scikit-learn
'''
norm = np.linalg.norm(embeddings, axis=1)
norm = np.reshape(norm, (norm.shape[0], 1))
norm = np.tile(norm, (1, len(embeddings[0])))
return np.divide(embeddings, norm)
class SupportedVSType:
FAISS = 'faiss'
MILVUS = 'milvus'
@ -52,6 +61,9 @@ class KBService(ABC):
self.doc_path = get_doc_path(self.kb_name)
self.do_init()
def __repr__(self) -> str:
return f"{self.kb_name} @ {self.embed_model}"
def save_vector_store(self):
'''
保存向量库:FAISS保存到磁盘milvus保存到数据库PGVector暂未支持
@ -209,7 +221,6 @@ class KBService(ABC):
query: str,
top_k: int,
score_threshold: float,
embeddings: Embeddings,
) -> List[Document]:
"""
搜索知识库子类实自己逻辑
@ -267,11 +278,15 @@ class KBServiceFactory:
return DefaultKBService(kb_name)
@staticmethod
def get_service_by_name(kb_name: str
def get_service_by_name(kb_name: str,
default_vs_type: SupportedVSType = SupportedVSType.FAISS,
default_embed_model: str = EMBEDDING_MODEL,
) -> KBService:
_, vs_type, embed_model = load_kb_from_db(kb_name)
if vs_type is None and os.path.isdir(get_kb_path(kb_name)): # faiss knowledge base not in db
vs_type = "faiss"
if vs_type is None: # faiss knowledge base not in db
vs_type = default_vs_type
if embed_model is None:
embed_model = default_embed_model
return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
@staticmethod
@ -357,7 +372,7 @@ class EmbeddingsFunAdapter(Embeddings):
def embed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = embed_texts(texts=texts, embed_model=self.embed_model, to_query=False).data
return normalize(embeddings)
return normalize(embeddings).tolist()
def embed_query(self, text: str) -> List[float]:
embeddings = embed_texts(texts=[text], embed_model=self.embed_model, to_query=True).data

View File

@ -1,14 +1,10 @@
import os
import shutil
from configs import (
KB_ROOT_PATH,
SCORE_THRESHOLD,
logger, log_verbose,
)
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from configs import SCORE_THRESHOLD
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
from server.knowledge_base.utils import KnowledgeFile
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
from server.utils import torch_gc
from langchain.docstore.document import Document
from typing import List, Dict, Optional
@ -17,16 +13,16 @@ from typing import List, Dict, Optional
class FaissKBService(KBService):
vs_path: str
kb_path: str
vector_name: str = "vector_store"
vector_name: str = None
def vs_type(self) -> str:
return SupportedVSType.FAISS
def get_vs_path(self):
return os.path.join(self.get_kb_path(), self.vector_name)
return get_vs_path(self.kb_name, self.vector_name)
def get_kb_path(self):
return os.path.join(KB_ROOT_PATH, self.kb_name)
return get_kb_path(self.kb_name)
def load_vector_store(self) -> ThreadSafeFaiss:
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
@ -41,6 +37,7 @@ class FaissKBService(KBService):
return vs.docstore._dict.get(id)
def do_init(self):
self.vector_name = self.vector_name or self.embed_model
self.kb_path = self.get_kb_path()
self.vs_path = self.get_vs_path()
@ -58,8 +55,10 @@ class FaissKBService(KBService):
top_k: int,
score_threshold: float = SCORE_THRESHOLD,
) -> List[Document]:
embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query)
with self.load_vector_store().acquire() as vs:
docs = vs.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
return docs
def do_add_doc(self,

View File

@ -59,7 +59,10 @@ class MilvusKBService(KBService):
def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_milvus()
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query)
docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k)
return score_threshold_process(score_threshold, top_k, docs)
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
# TODO: workaround for bug #10492 in langchain

View File

@ -53,8 +53,10 @@ class PGKBService(KBService):
def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_pg_vector()
return score_threshold_process(score_threshold, top_k,
self.pg_vector.similarity_search_with_score(query, top_k))
embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query)
docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k)
return score_threshold_process(score_threshold, top_k, docs)
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
ids = self.pg_vector.add_documents(docs)

View File

@ -59,7 +59,10 @@ class ZillizKBService(KBService):
def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_zilliz()
return score_threshold_process(score_threshold, top_k, self.zilliz.similarity_search_with_score(query, top_k))
embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query)
docs = self.zilliz.similarity_search_with_score_by_vector(embeddings, top_k)
return score_threshold_process(score_threshold, top_k, docs)
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
for doc in docs:

View File

@ -67,6 +67,7 @@ def folder2db(
kb_names = kb_names or list_kbs_from_folder()
for kb_name in kb_names:
kb = KBServiceFactory.get_service(kb_name, vs_type, embed_model)
if not kb.exists():
kb.create_kb()
# 清除向量库,从本地文件重建

View File

@ -39,7 +39,7 @@ def get_doc_path(knowledge_base_name: str):
def get_vs_path(knowledge_base_name: str, vector_name: str):
return os.path.join(get_kb_path(knowledge_base_name), vector_name)
return os.path.join(get_kb_path(knowledge_base_name), "vector_store", vector_name)
def get_file_path(knowledge_base_name: str, doc_name: str):

View File

@ -237,6 +237,7 @@ class ChatMessage(BaseModel):
def torch_gc():
try:
import torch
if torch.cuda.is_available():
# with torch.cuda.device(DEVICE):
@ -251,6 +252,8 @@ def torch_gc():
"以支持及时清理 torch 产生的内存占用。")
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
except Exception:
...
def run_async(cor):
@ -719,7 +722,7 @@ def list_online_embed_models() -> List[str]:
ret = []
for k, v in list_config_llm_models()["online"].items():
provider = v.get("provider")
if provider := v.get("provider"):
worker_class = getattr(model_workers, provider, None)
if worker_class is not None and worker_class.can_embedding():
ret.append(k)

View File

@ -2,6 +2,7 @@ import streamlit as st
from webui_pages.utils import *
from streamlit_option_menu import option_menu
from webui_pages.dialogue.dialogue import dialogue_page, chat_box
from webui_pages.knowledge_base.knowledge_base import knowledge_base_page
import os
import sys
from configs import VERSION
@ -29,15 +30,11 @@ if __name__ == "__main__":
"icon": "chat",
"func": dialogue_page,
},
}
if not is_lite:
from webui_pages.knowledge_base.knowledge_base import knowledge_base_page
pages.update({
"知识库管理": {
"icon": "hdd-stack",
"func": knowledge_base_page,
},
})
}
with st.sidebar:
st.image(

View File

@ -54,11 +54,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
text = f"{text} 当前知识库: `{cur_kb}`。"
st.toast(text)
if is_lite:
dialogue_modes = ["LLM 对话",
"搜索引擎问答",
]
else:
dialogue_modes = ["LLM 对话",
"知识库问答",
"搜索引擎问答",
@ -116,12 +111,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
st.success(msg)
st.session_state["prev_llm_model"] = llm_model
if is_lite:
index_prompt = {
"LLM 对话": "llm_chat",
"搜索引擎问答": "search_engine_chat",
}
else:
index_prompt = {
"LLM 对话": "llm_chat",
"自定义Agent问答": "agent_chat",
@ -167,7 +156,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
## Bge 模型会超过1
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01)
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
elif dialogue_mode == "搜索引擎问答":
search_engine_list = api.list_search_engines()

View File

@ -103,6 +103,9 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
key="vs_type",
)
if is_lite:
embed_models = list_online_embed_models()
else:
embed_models = list_embed_models() + list_online_embed_models()
embed_model = cols[1].selectbox(