支持在线 Embeddings, Lite 模式支持所有知识库相关功能 (#1924)

新功能:
- 支持在线 Embeddings:zhipu-api, qwen-api, minimax-api, qianfan-api
- API 增加 /other/embed_texts 接口
- init_database.py 增加 --embed-model 参数,可以指定使用的嵌入模型(本地或在线均可)
- 对于 FAISS 知识库,支持多向量库,默认位置:{KB_PATH}/vector_store/{embed_model}
- Lite 模式支持所有知识库相关功能。此模式下最主要的限制是:
  - 不能使用本地 LLM 和 Embeddings 模型
  - 知识库不支持 PDF 文件
- init_database.py 重建知识库时不再默认情况数据库表,增加 clear-tables 参数手动控制。
- API 和 WEBUI 中 score_threshold 参数范围改为 [0, 2],以更好的适应在线嵌入模型

问题修复:
- API 中 list_config_models 会删除 ONLINE_LLM_MODEL 中的敏感信息,导致第二轮API请求错误

开发者:
- 统一向量库的识别:以(kb_name,embed_model)为判断向量库唯一性的依据,避免 FAISS 知识库缓存加载逻辑错误
- KBServiceFactory.get_service_by_name 中添加 default_embed_model 参数,用于在构建新知识库时设置 embed_model
- 优化 kb_service 中 Embeddings 操作:
  - 统一加载接口: server.utils.load_embeddings,利用全局缓存避免各处 Embeddings 传参
  - 统一文本嵌入接口:server.knowledge_base.kb_service.base.[embed_texts, embed_documents]
- 重写 normalize 函数,去除对 scikit-learn/scipy 的依赖
This commit is contained in:
liunux4odoo 2023-10-31 14:26:50 +08:00 committed by GitHub
parent deed92169f
commit 65592a45c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 130 additions and 96 deletions

View File

@ -23,6 +23,11 @@ if __name__ == "__main__":
''' '''
) )
) )
parser.add_argument(
"--clear-tables",
action="store_true",
help=("drop the database tables before recreate vector stores")
)
parser.add_argument( parser.add_argument(
"-u", "-u",
"--update-in-db", "--update-in-db",
@ -74,7 +79,7 @@ if __name__ == "__main__":
"--embed-model", "--embed-model",
type=str, type=str,
default=EMBEDDING_MODEL, default=EMBEDDING_MODEL,
help=("specify knowledge base names to operate on. default is all folders exist in KB_ROOT_PATH.") help=("specify embeddings model.")
) )
if len(sys.argv) <= 1: if len(sys.argv) <= 1:
@ -84,9 +89,12 @@ if __name__ == "__main__":
start_time = datetime.now() start_time = datetime.now()
create_tables() # confirm tables exist create_tables() # confirm tables exist
if args.recreate_vs:
if args.clear_tables:
reset_tables() reset_tables()
print("database talbes reseted") print("database talbes reseted")
if args.recreate_vs:
print("recreating all vector stores") print("recreating all vector stores")
folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model) folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model)
elif args.update_in_db: elif args.update_in_db:

View File

@ -7,18 +7,19 @@ openai
# torchvision # torchvision
# torchaudio # torchaudio
fastapi>=0.103.1 fastapi>=0.103.1
python-multipart
nltk~=3.8.1 nltk~=3.8.1
uvicorn~=0.23.1 uvicorn~=0.23.1
starlette~=0.27.0 starlette~=0.27.0
pydantic~=1.10.11 pydantic~=1.10.11
# unstructured[all-docs]>=0.10.4 unstructured[docx,csv]>=0.10.4 # add pdf if need
# python-magic-bin; sys_platform == 'win32' python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19 SQLAlchemy==2.0.19
# faiss-cpu faiss-cpu
# accelerate # accelerate
# spacy # spacy
# PyMuPDF==1.22.5 # PyMuPDF==1.22.5 # install if need pdf
# rapidocr_onnxruntime>=1.3.2 # rapidocr_onnxruntime>=1.3.2 # install if need pdf
requests requests
pathlib pathlib

View File

@ -74,8 +74,7 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
)(search_engine_chat) )(search_engine_chat)
# 知识库相关接口 # 知识库相关接口
if run_mode != "lite": mount_knowledge_routes(app)
mount_knowledge_routes(app)
# LLM模型相关接口 # LLM模型相关接口
app.post("/llm_model/list_running_models", app.post("/llm_model/list_running_models",

View File

@ -19,7 +19,7 @@ from server.knowledge_base.kb_doc_api import search_docs
async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]), async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=1), score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=2),
history: List[History] = Body([], history: List[History] = Body([],
description="历史对话", description="历史对话",
examples=[[ examples=[[

View File

@ -16,6 +16,7 @@ def embed_texts(
) -> BaseResponse: ) -> BaseResponse:
''' '''
对文本进行向量化返回数据格式BaseResponse(data=List[List[float]]) 对文本进行向量化返回数据格式BaseResponse(data=List[List[float]])
TODO: 也许需要加入缓存机制减少 token 消耗
''' '''
try: try:
if embed_model in list_embed_models(): # 使用本地Embeddings模型 if embed_model in list_embed_models(): # 使用本地Embeddings模型

View File

@ -1,11 +1,8 @@
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
import threading import threading
from configs import (EMBEDDING_MODEL, CHUNK_SIZE, from configs import (EMBEDDING_MODEL, CHUNK_SIZE,
logger, log_verbose) logger, log_verbose)
from server.utils import embedding_device, get_model_path from server.utils import embedding_device, get_model_path, list_online_embed_models
from contextlib import contextmanager from contextlib import contextmanager
from collections import OrderedDict from collections import OrderedDict
from typing import List, Any, Union, Tuple from typing import List, Any, Union, Tuple
@ -100,13 +97,22 @@ class CachePool:
else: else:
return cache return cache
def load_kb_embeddings(self, kb_name: str=None, embed_device: str = embedding_device()) -> Embeddings: def load_kb_embeddings(
self,
kb_name: str,
embed_device: str = embedding_device(),
default_embed_model: str = EMBEDDING_MODEL,
) -> Embeddings:
from server.db.repository.knowledge_base_repository import get_kb_detail from server.db.repository.knowledge_base_repository import get_kb_detail
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
kb_detail = get_kb_detail(kb_name=kb_name) kb_detail = get_kb_detail(kb_name)
print(kb_detail) embed_model = kb_detail.get("embed_model", default_embed_model)
embed_model = kb_detail.get("embed_model", EMBEDDING_MODEL)
return embeddings_pool.load_embeddings(model=embed_model, device=embed_device) if embed_model in list_online_embed_models():
return EmbeddingsFunAdapter(embed_model)
else:
return embeddings_pool.load_embeddings(model=embed_model, device=embed_device)
class EmbeddingsPool(CachePool): class EmbeddingsPool(CachePool):
@ -121,10 +127,12 @@ class EmbeddingsPool(CachePool):
with item.acquire(msg="初始化"): with item.acquire(msg="初始化"):
self.atomic.release() self.atomic.release()
if model == "text-embedding-ada-002": # openai text-embedding-ada-002 if model == "text-embedding-ada-002": # openai text-embedding-ada-002
from langchain.embeddings.openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings(model_name=model, embeddings = OpenAIEmbeddings(model_name=model,
openai_api_key=get_model_path(model), openai_api_key=get_model_path(model),
chunk_size=CHUNK_SIZE) chunk_size=CHUNK_SIZE)
elif 'bge-' in model: elif 'bge-' in model:
from langchain.embeddings import HuggingFaceBgeEmbeddings
if 'zh' in model: if 'zh' in model:
# for chinese model # for chinese model
query_instruction = "为这个句子生成表示以用于检索相关文章:" query_instruction = "为这个句子生成表示以用于检索相关文章:"
@ -140,6 +148,7 @@ class EmbeddingsPool(CachePool):
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
embeddings.query_instruction = "" embeddings.query_instruction = ""
else: else:
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name=get_model_path(model), model_kwargs={'device': device}) embeddings = HuggingFaceEmbeddings(model_name=get_model_path(model), model_kwargs={'device': device})
item.obj = embeddings item.obj = embeddings
item.finish_loading() item.finish_loading()

View File

@ -64,23 +64,24 @@ class KBFaissPool(_FaissPool):
def load_vector_store( def load_vector_store(
self, self,
kb_name: str, kb_name: str,
vector_name: str = "vector_store", vector_name: str = None,
create: bool = True, create: bool = True,
embed_model: str = EMBEDDING_MODEL, embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(), embed_device: str = embedding_device(),
) -> ThreadSafeFaiss: ) -> ThreadSafeFaiss:
self.atomic.acquire() self.atomic.acquire()
vector_name = vector_name or embed_model
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些 cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
if cache is None: if cache is None:
item = ThreadSafeFaiss((kb_name, vector_name), pool=self) item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
self.set((kb_name, vector_name), item) self.set((kb_name, vector_name), item)
with item.acquire(msg="初始化"): with item.acquire(msg="初始化"):
self.atomic.release() self.atomic.release()
logger.info(f"loading vector store in '{kb_name}/{vector_name}' from disk.") logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.")
vs_path = get_vs_path(kb_name, vector_name) vs_path = get_vs_path(kb_name, vector_name)
if os.path.isfile(os.path.join(vs_path, "index.faiss")): if os.path.isfile(os.path.join(vs_path, "index.faiss")):
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device) embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device, default_embed_model=embed_model)
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True) vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
elif create: elif create:
# create an empty vector store # create an empty vector store

View File

@ -6,7 +6,6 @@ import os
import numpy as np import numpy as np
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.docstore.document import Document from langchain.docstore.document import Document
from sklearn.preprocessing import normalize
from server.db.repository.knowledge_base_repository import ( from server.db.repository.knowledge_base_repository import (
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
@ -31,6 +30,16 @@ from server.embeddings_api import embed_texts
from server.embeddings_api import embed_documents from server.embeddings_api import embed_documents
def normalize(embeddings: List[List[float]]) -> np.ndarray:
'''
sklearn.preprocessing.normalize 的替代使用 L2避免安装 scipy, scikit-learn
'''
norm = np.linalg.norm(embeddings, axis=1)
norm = np.reshape(norm, (norm.shape[0], 1))
norm = np.tile(norm, (1, len(embeddings[0])))
return np.divide(embeddings, norm)
class SupportedVSType: class SupportedVSType:
FAISS = 'faiss' FAISS = 'faiss'
MILVUS = 'milvus' MILVUS = 'milvus'
@ -52,6 +61,9 @@ class KBService(ABC):
self.doc_path = get_doc_path(self.kb_name) self.doc_path = get_doc_path(self.kb_name)
self.do_init() self.do_init()
def __repr__(self) -> str:
return f"{self.kb_name} @ {self.embed_model}"
def save_vector_store(self): def save_vector_store(self):
''' '''
保存向量库:FAISS保存到磁盘milvus保存到数据库PGVector暂未支持 保存向量库:FAISS保存到磁盘milvus保存到数据库PGVector暂未支持
@ -209,7 +221,6 @@ class KBService(ABC):
query: str, query: str,
top_k: int, top_k: int,
score_threshold: float, score_threshold: float,
embeddings: Embeddings,
) -> List[Document]: ) -> List[Document]:
""" """
搜索知识库子类实自己逻辑 搜索知识库子类实自己逻辑
@ -267,11 +278,15 @@ class KBServiceFactory:
return DefaultKBService(kb_name) return DefaultKBService(kb_name)
@staticmethod @staticmethod
def get_service_by_name(kb_name: str def get_service_by_name(kb_name: str,
default_vs_type: SupportedVSType = SupportedVSType.FAISS,
default_embed_model: str = EMBEDDING_MODEL,
) -> KBService: ) -> KBService:
_, vs_type, embed_model = load_kb_from_db(kb_name) _, vs_type, embed_model = load_kb_from_db(kb_name)
if vs_type is None and os.path.isdir(get_kb_path(kb_name)): # faiss knowledge base not in db if vs_type is None: # faiss knowledge base not in db
vs_type = "faiss" vs_type = default_vs_type
if embed_model is None:
embed_model = default_embed_model
return KBServiceFactory.get_service(kb_name, vs_type, embed_model) return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
@staticmethod @staticmethod
@ -357,7 +372,7 @@ class EmbeddingsFunAdapter(Embeddings):
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = embed_texts(texts=texts, embed_model=self.embed_model, to_query=False).data embeddings = embed_texts(texts=texts, embed_model=self.embed_model, to_query=False).data
return normalize(embeddings) return normalize(embeddings).tolist()
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
embeddings = embed_texts(texts=[text], embed_model=self.embed_model, to_query=True).data embeddings = embed_texts(texts=[text], embed_model=self.embed_model, to_query=True).data

View File

@ -1,14 +1,10 @@
import os import os
import shutil import shutil
from configs import ( from configs import SCORE_THRESHOLD
KB_ROOT_PATH, from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter
SCORE_THRESHOLD,
logger, log_verbose,
)
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
from server.knowledge_base.utils import KnowledgeFile from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
from server.utils import torch_gc from server.utils import torch_gc
from langchain.docstore.document import Document from langchain.docstore.document import Document
from typing import List, Dict, Optional from typing import List, Dict, Optional
@ -17,16 +13,16 @@ from typing import List, Dict, Optional
class FaissKBService(KBService): class FaissKBService(KBService):
vs_path: str vs_path: str
kb_path: str kb_path: str
vector_name: str = "vector_store" vector_name: str = None
def vs_type(self) -> str: def vs_type(self) -> str:
return SupportedVSType.FAISS return SupportedVSType.FAISS
def get_vs_path(self): def get_vs_path(self):
return os.path.join(self.get_kb_path(), self.vector_name) return get_vs_path(self.kb_name, self.vector_name)
def get_kb_path(self): def get_kb_path(self):
return os.path.join(KB_ROOT_PATH, self.kb_name) return get_kb_path(self.kb_name)
def load_vector_store(self) -> ThreadSafeFaiss: def load_vector_store(self) -> ThreadSafeFaiss:
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name, return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
@ -41,6 +37,7 @@ class FaissKBService(KBService):
return vs.docstore._dict.get(id) return vs.docstore._dict.get(id)
def do_init(self): def do_init(self):
self.vector_name = self.vector_name or self.embed_model
self.kb_path = self.get_kb_path() self.kb_path = self.get_kb_path()
self.vs_path = self.get_vs_path() self.vs_path = self.get_vs_path()
@ -58,8 +55,10 @@ class FaissKBService(KBService):
top_k: int, top_k: int,
score_threshold: float = SCORE_THRESHOLD, score_threshold: float = SCORE_THRESHOLD,
) -> List[Document]: ) -> List[Document]:
embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query)
with self.load_vector_store().acquire() as vs: with self.load_vector_store().acquire() as vs:
docs = vs.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold) docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
return docs return docs
def do_add_doc(self, def do_add_doc(self,

View File

@ -59,7 +59,10 @@ class MilvusKBService(KBService):
def do_search(self, query: str, top_k: int, score_threshold: float): def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_milvus() self._load_milvus()
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k)) embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query)
docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k)
return score_threshold_process(score_threshold, top_k, docs)
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
# TODO: workaround for bug #10492 in langchain # TODO: workaround for bug #10492 in langchain

View File

@ -53,8 +53,10 @@ class PGKBService(KBService):
def do_search(self, query: str, top_k: int, score_threshold: float): def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_pg_vector() self._load_pg_vector()
return score_threshold_process(score_threshold, top_k, embed_func = EmbeddingsFunAdapter(self.embed_model)
self.pg_vector.similarity_search_with_score(query, top_k)) embeddings = embed_func.embed_query(query)
docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k)
return score_threshold_process(score_threshold, top_k, docs)
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
ids = self.pg_vector.add_documents(docs) ids = self.pg_vector.add_documents(docs)

View File

@ -59,7 +59,10 @@ class ZillizKBService(KBService):
def do_search(self, query: str, top_k: int, score_threshold: float): def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_zilliz() self._load_zilliz()
return score_threshold_process(score_threshold, top_k, self.zilliz.similarity_search_with_score(query, top_k)) embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query)
docs = self.zilliz.similarity_search_with_score_by_vector(embeddings, top_k)
return score_threshold_process(score_threshold, top_k, docs)
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
for doc in docs: for doc in docs:

View File

@ -67,7 +67,8 @@ def folder2db(
kb_names = kb_names or list_kbs_from_folder() kb_names = kb_names or list_kbs_from_folder()
for kb_name in kb_names: for kb_name in kb_names:
kb = KBServiceFactory.get_service(kb_name, vs_type, embed_model) kb = KBServiceFactory.get_service(kb_name, vs_type, embed_model)
kb.create_kb() if not kb.exists():
kb.create_kb()
# 清除向量库,从本地文件重建 # 清除向量库,从本地文件重建
if mode == "recreate_vs": if mode == "recreate_vs":

View File

@ -39,7 +39,7 @@ def get_doc_path(knowledge_base_name: str):
def get_vs_path(knowledge_base_name: str, vector_name: str): def get_vs_path(knowledge_base_name: str, vector_name: str):
return os.path.join(get_kb_path(knowledge_base_name), vector_name) return os.path.join(get_kb_path(knowledge_base_name), "vector_store", vector_name)
def get_file_path(knowledge_base_name: str, doc_name: str): def get_file_path(knowledge_base_name: str, doc_name: str):

View File

@ -237,20 +237,23 @@ class ChatMessage(BaseModel):
def torch_gc(): def torch_gc():
import torch try:
if torch.cuda.is_available(): import torch
# with torch.cuda.device(DEVICE): if torch.cuda.is_available():
torch.cuda.empty_cache() # with torch.cuda.device(DEVICE):
torch.cuda.ipc_collect() torch.cuda.empty_cache()
elif torch.backends.mps.is_available(): torch.cuda.ipc_collect()
try: elif torch.backends.mps.is_available():
from torch.mps import empty_cache try:
empty_cache() from torch.mps import empty_cache
except Exception as e: empty_cache()
msg = ("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本," except Exception as e:
"以支持及时清理 torch 产生的内存占用。") msg = ("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,"
logger.error(f'{e.__class__.__name__}: {msg}', "以支持及时清理 torch 产生的内存占用。")
exc_info=e if log_verbose else None) logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
except Exception:
...
def run_async(cor): def run_async(cor):
@ -719,10 +722,10 @@ def list_online_embed_models() -> List[str]:
ret = [] ret = []
for k, v in list_config_llm_models()["online"].items(): for k, v in list_config_llm_models()["online"].items():
provider = v.get("provider") if provider := v.get("provider"):
worker_class = getattr(model_workers, provider, None) worker_class = getattr(model_workers, provider, None)
if worker_class is not None and worker_class.can_embedding(): if worker_class is not None and worker_class.can_embedding():
ret.append(k) ret.append(k)
return ret return ret

View File

@ -2,6 +2,7 @@ import streamlit as st
from webui_pages.utils import * from webui_pages.utils import *
from streamlit_option_menu import option_menu from streamlit_option_menu import option_menu
from webui_pages.dialogue.dialogue import dialogue_page, chat_box from webui_pages.dialogue.dialogue import dialogue_page, chat_box
from webui_pages.knowledge_base.knowledge_base import knowledge_base_page
import os import os
import sys import sys
from configs import VERSION from configs import VERSION
@ -29,15 +30,11 @@ if __name__ == "__main__":
"icon": "chat", "icon": "chat",
"func": dialogue_page, "func": dialogue_page,
}, },
"知识库管理": {
"icon": "hdd-stack",
"func": knowledge_base_page,
},
} }
if not is_lite:
from webui_pages.knowledge_base.knowledge_base import knowledge_base_page
pages.update({
"知识库管理": {
"icon": "hdd-stack",
"func": knowledge_base_page,
},
})
with st.sidebar: with st.sidebar:
st.image( st.image(

View File

@ -54,16 +54,11 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
text = f"{text} 当前知识库: `{cur_kb}`。" text = f"{text} 当前知识库: `{cur_kb}`。"
st.toast(text) st.toast(text)
if is_lite: dialogue_modes = ["LLM 对话",
dialogue_modes = ["LLM 对话", "知识库问答",
"搜索引擎问答", "搜索引擎问答",
] "自定义Agent问答",
else: ]
dialogue_modes = ["LLM 对话",
"知识库问答",
"搜索引擎问答",
"自定义Agent问答",
]
dialogue_mode = st.selectbox("请选择对话模式:", dialogue_mode = st.selectbox("请选择对话模式:",
dialogue_modes, dialogue_modes,
index=0, index=0,
@ -116,18 +111,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
st.success(msg) st.success(msg)
st.session_state["prev_llm_model"] = llm_model st.session_state["prev_llm_model"] = llm_model
if is_lite: index_prompt = {
index_prompt = { "LLM 对话": "llm_chat",
"LLM 对话": "llm_chat", "自定义Agent问答": "agent_chat",
"搜索引擎问答": "search_engine_chat", "搜索引擎问答": "search_engine_chat",
} "知识库问答": "knowledge_base_chat",
else: }
index_prompt = {
"LLM 对话": "llm_chat",
"自定义Agent问答": "agent_chat",
"搜索引擎问答": "search_engine_chat",
"知识库问答": "knowledge_base_chat",
}
prompt_templates_kb_list = list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys()) prompt_templates_kb_list = list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys())
prompt_template_name = prompt_templates_kb_list[0] prompt_template_name = prompt_templates_kb_list[0]
if "prompt_template_select" not in st.session_state: if "prompt_template_select" not in st.session_state:
@ -167,7 +156,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K) kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
## Bge 模型会超过1 ## Bge 模型会超过1
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01) score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
elif dialogue_mode == "搜索引擎问答": elif dialogue_mode == "搜索引擎问答":
search_engine_list = api.list_search_engines() search_engine_list = api.list_search_engines()

View File

@ -103,7 +103,10 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
key="vs_type", key="vs_type",
) )
embed_models = list_embed_models() + list_online_embed_models() if is_lite:
embed_models = list_online_embed_models()
else:
embed_models = list_embed_models() + list_online_embed_models()
embed_model = cols[1].selectbox( embed_model = cols[1].selectbox(
"Embedding 模型", "Embedding 模型",