From 65592a45c34bb4faeb02ce0fd8147e83d79c9549 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Tue, 31 Oct 2023 14:26:50 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=9C=A8=E7=BA=BF=20Embeddin?= =?UTF-8?q?gs,=20Lite=20=E6=A8=A1=E5=BC=8F=E6=94=AF=E6=8C=81=E6=89=80?= =?UTF-8?q?=E6=9C=89=E7=9F=A5=E8=AF=86=E5=BA=93=E7=9B=B8=E5=85=B3=E5=8A=9F?= =?UTF-8?q?=E8=83=BD=20(#1924)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新功能: - 支持在线 Embeddings:zhipu-api, qwen-api, minimax-api, qianfan-api - API 增加 /other/embed_texts 接口 - init_database.py 增加 --embed-model 参数,可以指定使用的嵌入模型(本地或在线均可) - 对于 FAISS 知识库,支持多向量库,默认位置:{KB_PATH}/vector_store/{embed_model} - Lite 模式支持所有知识库相关功能。此模式下最主要的限制是: - 不能使用本地 LLM 和 Embeddings 模型 - 知识库不支持 PDF 文件 - init_database.py 重建知识库时不再默认情况数据库表,增加 clear-tables 参数手动控制。 - API 和 WEBUI 中 score_threshold 参数范围改为 [0, 2],以更好的适应在线嵌入模型 问题修复: - API 中 list_config_models 会删除 ONLINE_LLM_MODEL 中的敏感信息,导致第二轮API请求错误 开发者: - 统一向量库的识别:以(kb_name,embed_model)为判断向量库唯一性的依据,避免 FAISS 知识库缓存加载逻辑错误 - KBServiceFactory.get_service_by_name 中添加 default_embed_model 参数,用于在构建新知识库时设置 embed_model - 优化 kb_service 中 Embeddings 操作: - 统一加载接口: server.utils.load_embeddings,利用全局缓存避免各处 Embeddings 传参 - 统一文本嵌入接口:server.knowledge_base.kb_service.base.[embed_texts, embed_documents] - 重写 normalize 函数,去除对 scikit-learn/scipy 的依赖 --- init_database.py | 12 +++++- requirements_lite.txt | 11 +++--- server/api.py | 3 +- server/chat/knowledge_base_chat.py | 2 +- server/embeddings_api.py | 1 + server/knowledge_base/kb_cache/base.py | 27 ++++++++----- server/knowledge_base/kb_cache/faiss_cache.py | 7 ++-- server/knowledge_base/kb_service/base.py | 27 ++++++++++--- .../kb_service/faiss_kb_service.py | 23 ++++++----- .../kb_service/milvus_kb_service.py | 5 ++- .../kb_service/pg_kb_service.py | 6 ++- .../kb_service/zilliz_kb_service.py | 5 ++- server/knowledge_base/migrate.py | 3 +- server/knowledge_base/utils.py | 2 +- server/utils.py | 39 ++++++++++--------- webui.py | 13 +++---- webui_pages/dialogue/dialogue.py | 35 ++++++----------- webui_pages/knowledge_base/knowledge_base.py | 5 ++- 18 files changed, 130 insertions(+), 96 deletions(-) diff --git a/init_database.py b/init_database.py index f7f0805..0b15b1e 100644 --- a/init_database.py +++ b/init_database.py @@ -23,6 +23,11 @@ if __name__ == "__main__": ''' ) ) + parser.add_argument( + "--clear-tables", + action="store_true", + help=("drop the database tables before recreate vector stores") + ) parser.add_argument( "-u", "--update-in-db", @@ -74,7 +79,7 @@ if __name__ == "__main__": "--embed-model", type=str, default=EMBEDDING_MODEL, - help=("specify knowledge base names to operate on. default is all folders exist in KB_ROOT_PATH.") + help=("specify embeddings model.") ) if len(sys.argv) <= 1: @@ -84,9 +89,12 @@ if __name__ == "__main__": start_time = datetime.now() create_tables() # confirm tables exist - if args.recreate_vs: + + if args.clear_tables: reset_tables() print("database talbes reseted") + + if args.recreate_vs: print("recreating all vector stores") folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model) elif args.update_in_db: diff --git a/requirements_lite.txt b/requirements_lite.txt index a22d012..a20cb8b 100644 --- a/requirements_lite.txt +++ b/requirements_lite.txt @@ -7,18 +7,19 @@ openai # torchvision # torchaudio fastapi>=0.103.1 +python-multipart nltk~=3.8.1 uvicorn~=0.23.1 starlette~=0.27.0 pydantic~=1.10.11 -# unstructured[all-docs]>=0.10.4 -# python-magic-bin; sys_platform == 'win32' +unstructured[docx,csv]>=0.10.4 # add pdf if need +python-magic-bin; sys_platform == 'win32' SQLAlchemy==2.0.19 -# faiss-cpu +faiss-cpu # accelerate # spacy -# PyMuPDF==1.22.5 -# rapidocr_onnxruntime>=1.3.2 +# PyMuPDF==1.22.5 # install if need pdf +# rapidocr_onnxruntime>=1.3.2 # install if need pdf requests pathlib diff --git a/server/api.py b/server/api.py index be795ba..093f81b 100644 --- a/server/api.py +++ b/server/api.py @@ -74,8 +74,7 @@ def mount_app_routes(app: FastAPI, run_mode: str = None): )(search_engine_chat) # 知识库相关接口 - if run_mode != "lite": - mount_knowledge_routes(app) + mount_knowledge_routes(app) # LLM模型相关接口 app.post("/llm_model/list_running_models", diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 01744c1..c977db1 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -19,7 +19,7 @@ from server.knowledge_base.kb_doc_api import search_docs async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]), knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), - score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1), + score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=2), history: List[History] = Body([], description="历史对话", examples=[[ diff --git a/server/embeddings_api.py b/server/embeddings_api.py index 9b08e21..80cd289 100644 --- a/server/embeddings_api.py +++ b/server/embeddings_api.py @@ -16,6 +16,7 @@ def embed_texts( ) -> BaseResponse: ''' 对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]]) + TODO: 也许需要加入缓存机制,减少 token 消耗 ''' try: if embed_model in list_embed_models(): # 使用本地Embeddings模型 diff --git a/server/knowledge_base/kb_cache/base.py b/server/knowledge_base/kb_cache/base.py index 23f6e84..b65e4ff 100644 --- a/server/knowledge_base/kb_cache/base.py +++ b/server/knowledge_base/kb_cache/base.py @@ -1,11 +1,8 @@ -from langchain.embeddings.huggingface import HuggingFaceEmbeddings -from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.embeddings import HuggingFaceBgeEmbeddings from langchain.embeddings.base import Embeddings import threading from configs import (EMBEDDING_MODEL, CHUNK_SIZE, logger, log_verbose) -from server.utils import embedding_device, get_model_path +from server.utils import embedding_device, get_model_path, list_online_embed_models from contextlib import contextmanager from collections import OrderedDict from typing import List, Any, Union, Tuple @@ -100,13 +97,22 @@ class CachePool: else: return cache - def load_kb_embeddings(self, kb_name: str=None, embed_device: str = embedding_device()) -> Embeddings: + def load_kb_embeddings( + self, + kb_name: str, + embed_device: str = embedding_device(), + default_embed_model: str = EMBEDDING_MODEL, + ) -> Embeddings: from server.db.repository.knowledge_base_repository import get_kb_detail + from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter - kb_detail = get_kb_detail(kb_name=kb_name) - print(kb_detail) - embed_model = kb_detail.get("embed_model", EMBEDDING_MODEL) - return embeddings_pool.load_embeddings(model=embed_model, device=embed_device) + kb_detail = get_kb_detail(kb_name) + embed_model = kb_detail.get("embed_model", default_embed_model) + + if embed_model in list_online_embed_models(): + return EmbeddingsFunAdapter(embed_model) + else: + return embeddings_pool.load_embeddings(model=embed_model, device=embed_device) class EmbeddingsPool(CachePool): @@ -121,10 +127,12 @@ class EmbeddingsPool(CachePool): with item.acquire(msg="初始化"): self.atomic.release() if model == "text-embedding-ada-002": # openai text-embedding-ada-002 + from langchain.embeddings.openai import OpenAIEmbeddings embeddings = OpenAIEmbeddings(model_name=model, openai_api_key=get_model_path(model), chunk_size=CHUNK_SIZE) elif 'bge-' in model: + from langchain.embeddings import HuggingFaceBgeEmbeddings if 'zh' in model: # for chinese model query_instruction = "为这个句子生成表示以用于检索相关文章:" @@ -140,6 +148,7 @@ class EmbeddingsPool(CachePool): if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding embeddings.query_instruction = "" else: + from langchain.embeddings.huggingface import HuggingFaceEmbeddings embeddings = HuggingFaceEmbeddings(model_name=get_model_path(model), model_kwargs={'device': device}) item.obj = embeddings item.finish_loading() diff --git a/server/knowledge_base/kb_cache/faiss_cache.py b/server/knowledge_base/kb_cache/faiss_cache.py index b7dd6d7..bcd89a8 100644 --- a/server/knowledge_base/kb_cache/faiss_cache.py +++ b/server/knowledge_base/kb_cache/faiss_cache.py @@ -64,23 +64,24 @@ class KBFaissPool(_FaissPool): def load_vector_store( self, kb_name: str, - vector_name: str = "vector_store", + vector_name: str = None, create: bool = True, embed_model: str = EMBEDDING_MODEL, embed_device: str = embedding_device(), ) -> ThreadSafeFaiss: self.atomic.acquire() + vector_name = vector_name or embed_model cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些 if cache is None: item = ThreadSafeFaiss((kb_name, vector_name), pool=self) self.set((kb_name, vector_name), item) with item.acquire(msg="初始化"): self.atomic.release() - logger.info(f"loading vector store in '{kb_name}/{vector_name}' from disk.") + logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.") vs_path = get_vs_path(kb_name, vector_name) if os.path.isfile(os.path.join(vs_path, "index.faiss")): - embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device) + embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device, default_embed_model=embed_model) vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True) elif create: # create an empty vector store diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 1a3d551..a81474f 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -6,7 +6,6 @@ import os import numpy as np from langchain.embeddings.base import Embeddings from langchain.docstore.document import Document -from sklearn.preprocessing import normalize from server.db.repository.knowledge_base_repository import ( add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, @@ -31,6 +30,16 @@ from server.embeddings_api import embed_texts from server.embeddings_api import embed_documents +def normalize(embeddings: List[List[float]]) -> np.ndarray: + ''' + sklearn.preprocessing.normalize 的替代(使用 L2),避免安装 scipy, scikit-learn + ''' + norm = np.linalg.norm(embeddings, axis=1) + norm = np.reshape(norm, (norm.shape[0], 1)) + norm = np.tile(norm, (1, len(embeddings[0]))) + return np.divide(embeddings, norm) + + class SupportedVSType: FAISS = 'faiss' MILVUS = 'milvus' @@ -52,6 +61,9 @@ class KBService(ABC): self.doc_path = get_doc_path(self.kb_name) self.do_init() + def __repr__(self) -> str: + return f"{self.kb_name} @ {self.embed_model}" + def save_vector_store(self): ''' 保存向量库:FAISS保存到磁盘,milvus保存到数据库。PGVector暂未支持 @@ -209,7 +221,6 @@ class KBService(ABC): query: str, top_k: int, score_threshold: float, - embeddings: Embeddings, ) -> List[Document]: """ 搜索知识库子类实自己逻辑 @@ -267,11 +278,15 @@ class KBServiceFactory: return DefaultKBService(kb_name) @staticmethod - def get_service_by_name(kb_name: str + def get_service_by_name(kb_name: str, + default_vs_type: SupportedVSType = SupportedVSType.FAISS, + default_embed_model: str = EMBEDDING_MODEL, ) -> KBService: _, vs_type, embed_model = load_kb_from_db(kb_name) - if vs_type is None and os.path.isdir(get_kb_path(kb_name)): # faiss knowledge base not in db - vs_type = "faiss" + if vs_type is None: # faiss knowledge base not in db + vs_type = default_vs_type + if embed_model is None: + embed_model = default_embed_model return KBServiceFactory.get_service(kb_name, vs_type, embed_model) @staticmethod @@ -357,7 +372,7 @@ class EmbeddingsFunAdapter(Embeddings): def embed_documents(self, texts: List[str]) -> List[List[float]]: embeddings = embed_texts(texts=texts, embed_model=self.embed_model, to_query=False).data - return normalize(embeddings) + return normalize(embeddings).tolist() def embed_query(self, text: str) -> List[float]: embeddings = embed_texts(texts=[text], embed_model=self.embed_model, to_query=True).data diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 10e26c1..d118620 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -1,14 +1,10 @@ import os import shutil -from configs import ( - KB_ROOT_PATH, - SCORE_THRESHOLD, - logger, log_verbose, -) -from server.knowledge_base.kb_service.base import KBService, SupportedVSType +from configs import SCORE_THRESHOLD +from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss -from server.knowledge_base.utils import KnowledgeFile +from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path from server.utils import torch_gc from langchain.docstore.document import Document from typing import List, Dict, Optional @@ -17,16 +13,16 @@ from typing import List, Dict, Optional class FaissKBService(KBService): vs_path: str kb_path: str - vector_name: str = "vector_store" - + vector_name: str = None + def vs_type(self) -> str: return SupportedVSType.FAISS def get_vs_path(self): - return os.path.join(self.get_kb_path(), self.vector_name) + return get_vs_path(self.kb_name, self.vector_name) def get_kb_path(self): - return os.path.join(KB_ROOT_PATH, self.kb_name) + return get_kb_path(self.kb_name) def load_vector_store(self) -> ThreadSafeFaiss: return kb_faiss_pool.load_vector_store(kb_name=self.kb_name, @@ -41,6 +37,7 @@ class FaissKBService(KBService): return vs.docstore._dict.get(id) def do_init(self): + self.vector_name = self.vector_name or self.embed_model self.kb_path = self.get_kb_path() self.vs_path = self.get_vs_path() @@ -58,8 +55,10 @@ class FaissKBService(KBService): top_k: int, score_threshold: float = SCORE_THRESHOLD, ) -> List[Document]: + embed_func = EmbeddingsFunAdapter(self.embed_model) + embeddings = embed_func.embed_query(query) with self.load_vector_store().acquire() as vs: - docs = vs.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold) + docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold) return docs def do_add_doc(self, diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index 7257494..5e27040 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -59,7 +59,10 @@ class MilvusKBService(KBService): def do_search(self, query: str, top_k: int, score_threshold: float): self._load_milvus() - return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k)) + embed_func = EmbeddingsFunAdapter(self.embed_model) + embeddings = embed_func.embed_query(query) + docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k) + return score_threshold_process(score_threshold, top_k, docs) def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: # TODO: workaround for bug #10492 in langchain diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index 66ad9c1..6337e2b 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -53,8 +53,10 @@ class PGKBService(KBService): def do_search(self, query: str, top_k: int, score_threshold: float): self._load_pg_vector() - return score_threshold_process(score_threshold, top_k, - self.pg_vector.similarity_search_with_score(query, top_k)) + embed_func = EmbeddingsFunAdapter(self.embed_model) + embeddings = embed_func.embed_query(query) + docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k) + return score_threshold_process(score_threshold, top_k, docs) def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: ids = self.pg_vector.add_documents(docs) diff --git a/server/knowledge_base/kb_service/zilliz_kb_service.py b/server/knowledge_base/kb_service/zilliz_kb_service.py index 8bf10ef..bd8b3e9 100644 --- a/server/knowledge_base/kb_service/zilliz_kb_service.py +++ b/server/knowledge_base/kb_service/zilliz_kb_service.py @@ -59,7 +59,10 @@ class ZillizKBService(KBService): def do_search(self, query: str, top_k: int, score_threshold: float): self._load_zilliz() - return score_threshold_process(score_threshold, top_k, self.zilliz.similarity_search_with_score(query, top_k)) + embed_func = EmbeddingsFunAdapter(self.embed_model) + embeddings = embed_func.embed_query(query) + docs = self.zilliz.similarity_search_with_score_by_vector(embeddings, top_k) + return score_threshold_process(score_threshold, top_k, docs) def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: for doc in docs: diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index 584e543..d0fa620 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -67,7 +67,8 @@ def folder2db( kb_names = kb_names or list_kbs_from_folder() for kb_name in kb_names: kb = KBServiceFactory.get_service(kb_name, vs_type, embed_model) - kb.create_kb() + if not kb.exists(): + kb.create_kb() # 清除向量库,从本地文件重建 if mode == "recreate_vs": diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index ef3f001..2b4cde9 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -39,7 +39,7 @@ def get_doc_path(knowledge_base_name: str): def get_vs_path(knowledge_base_name: str, vector_name: str): - return os.path.join(get_kb_path(knowledge_base_name), vector_name) + return os.path.join(get_kb_path(knowledge_base_name), "vector_store", vector_name) def get_file_path(knowledge_base_name: str, doc_name: str): diff --git a/server/utils.py b/server/utils.py index 23713c4..c31bdc1 100644 --- a/server/utils.py +++ b/server/utils.py @@ -237,20 +237,23 @@ class ChatMessage(BaseModel): def torch_gc(): - import torch - if torch.cuda.is_available(): - # with torch.cuda.device(DEVICE): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - elif torch.backends.mps.is_available(): - try: - from torch.mps import empty_cache - empty_cache() - except Exception as e: - msg = ("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本," - "以支持及时清理 torch 产生的内存占用。") - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + try: + import torch + if torch.cuda.is_available(): + # with torch.cuda.device(DEVICE): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + elif torch.backends.mps.is_available(): + try: + from torch.mps import empty_cache + empty_cache() + except Exception as e: + msg = ("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本," + "以支持及时清理 torch 产生的内存占用。") + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) + except Exception: + ... def run_async(cor): @@ -719,10 +722,10 @@ def list_online_embed_models() -> List[str]: ret = [] for k, v in list_config_llm_models()["online"].items(): - provider = v.get("provider") - worker_class = getattr(model_workers, provider, None) - if worker_class is not None and worker_class.can_embedding(): - ret.append(k) + if provider := v.get("provider"): + worker_class = getattr(model_workers, provider, None) + if worker_class is not None and worker_class.can_embedding(): + ret.append(k) return ret diff --git a/webui.py b/webui.py index fe33939..b0c3097 100644 --- a/webui.py +++ b/webui.py @@ -2,6 +2,7 @@ import streamlit as st from webui_pages.utils import * from streamlit_option_menu import option_menu from webui_pages.dialogue.dialogue import dialogue_page, chat_box +from webui_pages.knowledge_base.knowledge_base import knowledge_base_page import os import sys from configs import VERSION @@ -29,15 +30,11 @@ if __name__ == "__main__": "icon": "chat", "func": dialogue_page, }, + "知识库管理": { + "icon": "hdd-stack", + "func": knowledge_base_page, + }, } - if not is_lite: - from webui_pages.knowledge_base.knowledge_base import knowledge_base_page - pages.update({ - "知识库管理": { - "icon": "hdd-stack", - "func": knowledge_base_page, - }, - }) with st.sidebar: st.image( diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 5ff5a78..594b14e 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -54,16 +54,11 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): text = f"{text} 当前知识库: `{cur_kb}`。" st.toast(text) - if is_lite: - dialogue_modes = ["LLM 对话", - "搜索引擎问答", - ] - else: - dialogue_modes = ["LLM 对话", - "知识库问答", - "搜索引擎问答", - "自定义Agent问答", - ] + dialogue_modes = ["LLM 对话", + "知识库问答", + "搜索引擎问答", + "自定义Agent问答", + ] dialogue_mode = st.selectbox("请选择对话模式:", dialogue_modes, index=0, @@ -116,18 +111,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): st.success(msg) st.session_state["prev_llm_model"] = llm_model - if is_lite: - index_prompt = { - "LLM 对话": "llm_chat", - "搜索引擎问答": "search_engine_chat", - } - else: - index_prompt = { - "LLM 对话": "llm_chat", - "自定义Agent问答": "agent_chat", - "搜索引擎问答": "search_engine_chat", - "知识库问答": "knowledge_base_chat", - } + index_prompt = { + "LLM 对话": "llm_chat", + "自定义Agent问答": "agent_chat", + "搜索引擎问答": "search_engine_chat", + "知识库问答": "knowledge_base_chat", + } prompt_templates_kb_list = list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys()) prompt_template_name = prompt_templates_kb_list[0] if "prompt_template_select" not in st.session_state: @@ -167,7 +156,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K) ## Bge 模型会超过1 - score_threshold = st.slider("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01) + score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01) elif dialogue_mode == "搜索引擎问答": search_engine_list = api.list_search_engines() diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 2a097f1..ca8425c 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -103,7 +103,10 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): key="vs_type", ) - embed_models = list_embed_models() + list_online_embed_models() + if is_lite: + embed_models = list_online_embed_models() + else: + embed_models = list_embed_models() + list_online_embed_models() embed_model = cols[1].selectbox( "Embedding 模型",