支持在线 Embeddings, Lite 模式支持所有知识库相关功能 (#1924)
新功能:
- 支持在线 Embeddings:zhipu-api, qwen-api, minimax-api, qianfan-api
- API 增加 /other/embed_texts 接口
- init_database.py 增加 --embed-model 参数,可以指定使用的嵌入模型(本地或在线均可)
- 对于 FAISS 知识库,支持多向量库,默认位置:{KB_PATH}/vector_store/{embed_model}
- Lite 模式支持所有知识库相关功能。此模式下最主要的限制是:
- 不能使用本地 LLM 和 Embeddings 模型
- 知识库不支持 PDF 文件
- init_database.py 重建知识库时不再默认情况数据库表,增加 clear-tables 参数手动控制。
- API 和 WEBUI 中 score_threshold 参数范围改为 [0, 2],以更好的适应在线嵌入模型
问题修复:
- API 中 list_config_models 会删除 ONLINE_LLM_MODEL 中的敏感信息,导致第二轮API请求错误
开发者:
- 统一向量库的识别:以(kb_name,embed_model)为判断向量库唯一性的依据,避免 FAISS 知识库缓存加载逻辑错误
- KBServiceFactory.get_service_by_name 中添加 default_embed_model 参数,用于在构建新知识库时设置 embed_model
- 优化 kb_service 中 Embeddings 操作:
- 统一加载接口: server.utils.load_embeddings,利用全局缓存避免各处 Embeddings 传参
- 统一文本嵌入接口:server.knowledge_base.kb_service.base.[embed_texts, embed_documents]
- 重写 normalize 函数,去除对 scikit-learn/scipy 的依赖
This commit is contained in:
parent
deed92169f
commit
65592a45c3
|
|
@ -23,6 +23,11 @@ if __name__ == "__main__":
|
||||||
'''
|
'''
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--clear-tables",
|
||||||
|
action="store_true",
|
||||||
|
help=("drop the database tables before recreate vector stores")
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-u",
|
"-u",
|
||||||
"--update-in-db",
|
"--update-in-db",
|
||||||
|
|
@ -74,7 +79,7 @@ if __name__ == "__main__":
|
||||||
"--embed-model",
|
"--embed-model",
|
||||||
type=str,
|
type=str,
|
||||||
default=EMBEDDING_MODEL,
|
default=EMBEDDING_MODEL,
|
||||||
help=("specify knowledge base names to operate on. default is all folders exist in KB_ROOT_PATH.")
|
help=("specify embeddings model.")
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(sys.argv) <= 1:
|
if len(sys.argv) <= 1:
|
||||||
|
|
@ -84,9 +89,12 @@ if __name__ == "__main__":
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
|
|
||||||
create_tables() # confirm tables exist
|
create_tables() # confirm tables exist
|
||||||
if args.recreate_vs:
|
|
||||||
|
if args.clear_tables:
|
||||||
reset_tables()
|
reset_tables()
|
||||||
print("database talbes reseted")
|
print("database talbes reseted")
|
||||||
|
|
||||||
|
if args.recreate_vs:
|
||||||
print("recreating all vector stores")
|
print("recreating all vector stores")
|
||||||
folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model)
|
folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model)
|
||||||
elif args.update_in_db:
|
elif args.update_in_db:
|
||||||
|
|
|
||||||
|
|
@ -7,18 +7,19 @@ openai
|
||||||
# torchvision
|
# torchvision
|
||||||
# torchaudio
|
# torchaudio
|
||||||
fastapi>=0.103.1
|
fastapi>=0.103.1
|
||||||
|
python-multipart
|
||||||
nltk~=3.8.1
|
nltk~=3.8.1
|
||||||
uvicorn~=0.23.1
|
uvicorn~=0.23.1
|
||||||
starlette~=0.27.0
|
starlette~=0.27.0
|
||||||
pydantic~=1.10.11
|
pydantic~=1.10.11
|
||||||
# unstructured[all-docs]>=0.10.4
|
unstructured[docx,csv]>=0.10.4 # add pdf if need
|
||||||
# python-magic-bin; sys_platform == 'win32'
|
python-magic-bin; sys_platform == 'win32'
|
||||||
SQLAlchemy==2.0.19
|
SQLAlchemy==2.0.19
|
||||||
# faiss-cpu
|
faiss-cpu
|
||||||
# accelerate
|
# accelerate
|
||||||
# spacy
|
# spacy
|
||||||
# PyMuPDF==1.22.5
|
# PyMuPDF==1.22.5 # install if need pdf
|
||||||
# rapidocr_onnxruntime>=1.3.2
|
# rapidocr_onnxruntime>=1.3.2 # install if need pdf
|
||||||
|
|
||||||
requests
|
requests
|
||||||
pathlib
|
pathlib
|
||||||
|
|
|
||||||
|
|
@ -74,8 +74,7 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
|
||||||
)(search_engine_chat)
|
)(search_engine_chat)
|
||||||
|
|
||||||
# 知识库相关接口
|
# 知识库相关接口
|
||||||
if run_mode != "lite":
|
mount_knowledge_routes(app)
|
||||||
mount_knowledge_routes(app)
|
|
||||||
|
|
||||||
# LLM模型相关接口
|
# LLM模型相关接口
|
||||||
app.post("/llm_model/list_running_models",
|
app.post("/llm_model/list_running_models",
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ from server.knowledge_base.kb_doc_api import search_docs
|
||||||
async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||||
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=2),
|
||||||
history: List[History] = Body([],
|
history: List[History] = Body([],
|
||||||
description="历史对话",
|
description="历史对话",
|
||||||
examples=[[
|
examples=[[
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ def embed_texts(
|
||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
'''
|
'''
|
||||||
对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
|
对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
|
||||||
|
TODO: 也许需要加入缓存机制,减少 token 消耗
|
||||||
'''
|
'''
|
||||||
try:
|
try:
|
||||||
if embed_model in list_embed_models(): # 使用本地Embeddings模型
|
if embed_model in list_embed_models(): # 使用本地Embeddings模型
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,8 @@
|
||||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
|
||||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
|
||||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
import threading
|
import threading
|
||||||
from configs import (EMBEDDING_MODEL, CHUNK_SIZE,
|
from configs import (EMBEDDING_MODEL, CHUNK_SIZE,
|
||||||
logger, log_verbose)
|
logger, log_verbose)
|
||||||
from server.utils import embedding_device, get_model_path
|
from server.utils import embedding_device, get_model_path, list_online_embed_models
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import List, Any, Union, Tuple
|
from typing import List, Any, Union, Tuple
|
||||||
|
|
@ -100,13 +97,22 @@ class CachePool:
|
||||||
else:
|
else:
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
def load_kb_embeddings(self, kb_name: str=None, embed_device: str = embedding_device()) -> Embeddings:
|
def load_kb_embeddings(
|
||||||
|
self,
|
||||||
|
kb_name: str,
|
||||||
|
embed_device: str = embedding_device(),
|
||||||
|
default_embed_model: str = EMBEDDING_MODEL,
|
||||||
|
) -> Embeddings:
|
||||||
from server.db.repository.knowledge_base_repository import get_kb_detail
|
from server.db.repository.knowledge_base_repository import get_kb_detail
|
||||||
|
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
|
||||||
|
|
||||||
kb_detail = get_kb_detail(kb_name=kb_name)
|
kb_detail = get_kb_detail(kb_name)
|
||||||
print(kb_detail)
|
embed_model = kb_detail.get("embed_model", default_embed_model)
|
||||||
embed_model = kb_detail.get("embed_model", EMBEDDING_MODEL)
|
|
||||||
return embeddings_pool.load_embeddings(model=embed_model, device=embed_device)
|
if embed_model in list_online_embed_models():
|
||||||
|
return EmbeddingsFunAdapter(embed_model)
|
||||||
|
else:
|
||||||
|
return embeddings_pool.load_embeddings(model=embed_model, device=embed_device)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsPool(CachePool):
|
class EmbeddingsPool(CachePool):
|
||||||
|
|
@ -121,10 +127,12 @@ class EmbeddingsPool(CachePool):
|
||||||
with item.acquire(msg="初始化"):
|
with item.acquire(msg="初始化"):
|
||||||
self.atomic.release()
|
self.atomic.release()
|
||||||
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
|
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
|
||||||
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
embeddings = OpenAIEmbeddings(model_name=model,
|
embeddings = OpenAIEmbeddings(model_name=model,
|
||||||
openai_api_key=get_model_path(model),
|
openai_api_key=get_model_path(model),
|
||||||
chunk_size=CHUNK_SIZE)
|
chunk_size=CHUNK_SIZE)
|
||||||
elif 'bge-' in model:
|
elif 'bge-' in model:
|
||||||
|
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||||
if 'zh' in model:
|
if 'zh' in model:
|
||||||
# for chinese model
|
# for chinese model
|
||||||
query_instruction = "为这个句子生成表示以用于检索相关文章:"
|
query_instruction = "为这个句子生成表示以用于检索相关文章:"
|
||||||
|
|
@ -140,6 +148,7 @@ class EmbeddingsPool(CachePool):
|
||||||
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
|
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
|
||||||
embeddings.query_instruction = ""
|
embeddings.query_instruction = ""
|
||||||
else:
|
else:
|
||||||
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||||
embeddings = HuggingFaceEmbeddings(model_name=get_model_path(model), model_kwargs={'device': device})
|
embeddings = HuggingFaceEmbeddings(model_name=get_model_path(model), model_kwargs={'device': device})
|
||||||
item.obj = embeddings
|
item.obj = embeddings
|
||||||
item.finish_loading()
|
item.finish_loading()
|
||||||
|
|
|
||||||
|
|
@ -64,23 +64,24 @@ class KBFaissPool(_FaissPool):
|
||||||
def load_vector_store(
|
def load_vector_store(
|
||||||
self,
|
self,
|
||||||
kb_name: str,
|
kb_name: str,
|
||||||
vector_name: str = "vector_store",
|
vector_name: str = None,
|
||||||
create: bool = True,
|
create: bool = True,
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
embed_device: str = embedding_device(),
|
embed_device: str = embedding_device(),
|
||||||
) -> ThreadSafeFaiss:
|
) -> ThreadSafeFaiss:
|
||||||
self.atomic.acquire()
|
self.atomic.acquire()
|
||||||
|
vector_name = vector_name or embed_model
|
||||||
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
|
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
|
||||||
if cache is None:
|
if cache is None:
|
||||||
item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
|
item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
|
||||||
self.set((kb_name, vector_name), item)
|
self.set((kb_name, vector_name), item)
|
||||||
with item.acquire(msg="初始化"):
|
with item.acquire(msg="初始化"):
|
||||||
self.atomic.release()
|
self.atomic.release()
|
||||||
logger.info(f"loading vector store in '{kb_name}/{vector_name}' from disk.")
|
logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.")
|
||||||
vs_path = get_vs_path(kb_name, vector_name)
|
vs_path = get_vs_path(kb_name, vector_name)
|
||||||
|
|
||||||
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
|
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
|
||||||
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device)
|
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device, default_embed_model=embed_model)
|
||||||
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
|
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
|
||||||
elif create:
|
elif create:
|
||||||
# create an empty vector store
|
# create an empty vector store
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from sklearn.preprocessing import normalize
|
|
||||||
|
|
||||||
from server.db.repository.knowledge_base_repository import (
|
from server.db.repository.knowledge_base_repository import (
|
||||||
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
|
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
|
||||||
|
|
@ -31,6 +30,16 @@ from server.embeddings_api import embed_texts
|
||||||
from server.embeddings_api import embed_documents
|
from server.embeddings_api import embed_documents
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(embeddings: List[List[float]]) -> np.ndarray:
|
||||||
|
'''
|
||||||
|
sklearn.preprocessing.normalize 的替代(使用 L2),避免安装 scipy, scikit-learn
|
||||||
|
'''
|
||||||
|
norm = np.linalg.norm(embeddings, axis=1)
|
||||||
|
norm = np.reshape(norm, (norm.shape[0], 1))
|
||||||
|
norm = np.tile(norm, (1, len(embeddings[0])))
|
||||||
|
return np.divide(embeddings, norm)
|
||||||
|
|
||||||
|
|
||||||
class SupportedVSType:
|
class SupportedVSType:
|
||||||
FAISS = 'faiss'
|
FAISS = 'faiss'
|
||||||
MILVUS = 'milvus'
|
MILVUS = 'milvus'
|
||||||
|
|
@ -52,6 +61,9 @@ class KBService(ABC):
|
||||||
self.doc_path = get_doc_path(self.kb_name)
|
self.doc_path = get_doc_path(self.kb_name)
|
||||||
self.do_init()
|
self.do_init()
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.kb_name} @ {self.embed_model}"
|
||||||
|
|
||||||
def save_vector_store(self):
|
def save_vector_store(self):
|
||||||
'''
|
'''
|
||||||
保存向量库:FAISS保存到磁盘,milvus保存到数据库。PGVector暂未支持
|
保存向量库:FAISS保存到磁盘,milvus保存到数据库。PGVector暂未支持
|
||||||
|
|
@ -209,7 +221,6 @@ class KBService(ABC):
|
||||||
query: str,
|
query: str,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
score_threshold: float,
|
score_threshold: float,
|
||||||
embeddings: Embeddings,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""
|
"""
|
||||||
搜索知识库子类实自己逻辑
|
搜索知识库子类实自己逻辑
|
||||||
|
|
@ -267,11 +278,15 @@ class KBServiceFactory:
|
||||||
return DefaultKBService(kb_name)
|
return DefaultKBService(kb_name)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_service_by_name(kb_name: str
|
def get_service_by_name(kb_name: str,
|
||||||
|
default_vs_type: SupportedVSType = SupportedVSType.FAISS,
|
||||||
|
default_embed_model: str = EMBEDDING_MODEL,
|
||||||
) -> KBService:
|
) -> KBService:
|
||||||
_, vs_type, embed_model = load_kb_from_db(kb_name)
|
_, vs_type, embed_model = load_kb_from_db(kb_name)
|
||||||
if vs_type is None and os.path.isdir(get_kb_path(kb_name)): # faiss knowledge base not in db
|
if vs_type is None: # faiss knowledge base not in db
|
||||||
vs_type = "faiss"
|
vs_type = default_vs_type
|
||||||
|
if embed_model is None:
|
||||||
|
embed_model = default_embed_model
|
||||||
return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
|
return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -357,7 +372,7 @@ class EmbeddingsFunAdapter(Embeddings):
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
embeddings = embed_texts(texts=texts, embed_model=self.embed_model, to_query=False).data
|
embeddings = embed_texts(texts=texts, embed_model=self.embed_model, to_query=False).data
|
||||||
return normalize(embeddings)
|
return normalize(embeddings).tolist()
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
embeddings = embed_texts(texts=[text], embed_model=self.embed_model, to_query=True).data
|
embeddings = embed_texts(texts=[text], embed_model=self.embed_model, to_query=True).data
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,10 @@
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from configs import (
|
from configs import SCORE_THRESHOLD
|
||||||
KB_ROOT_PATH,
|
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter
|
||||||
SCORE_THRESHOLD,
|
|
||||||
logger, log_verbose,
|
|
||||||
)
|
|
||||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
|
||||||
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
|
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
|
||||||
from server.knowledge_base.utils import KnowledgeFile
|
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
|
||||||
from server.utils import torch_gc
|
from server.utils import torch_gc
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
|
|
@ -17,16 +13,16 @@ from typing import List, Dict, Optional
|
||||||
class FaissKBService(KBService):
|
class FaissKBService(KBService):
|
||||||
vs_path: str
|
vs_path: str
|
||||||
kb_path: str
|
kb_path: str
|
||||||
vector_name: str = "vector_store"
|
vector_name: str = None
|
||||||
|
|
||||||
def vs_type(self) -> str:
|
def vs_type(self) -> str:
|
||||||
return SupportedVSType.FAISS
|
return SupportedVSType.FAISS
|
||||||
|
|
||||||
def get_vs_path(self):
|
def get_vs_path(self):
|
||||||
return os.path.join(self.get_kb_path(), self.vector_name)
|
return get_vs_path(self.kb_name, self.vector_name)
|
||||||
|
|
||||||
def get_kb_path(self):
|
def get_kb_path(self):
|
||||||
return os.path.join(KB_ROOT_PATH, self.kb_name)
|
return get_kb_path(self.kb_name)
|
||||||
|
|
||||||
def load_vector_store(self) -> ThreadSafeFaiss:
|
def load_vector_store(self) -> ThreadSafeFaiss:
|
||||||
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
|
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
|
||||||
|
|
@ -41,6 +37,7 @@ class FaissKBService(KBService):
|
||||||
return vs.docstore._dict.get(id)
|
return vs.docstore._dict.get(id)
|
||||||
|
|
||||||
def do_init(self):
|
def do_init(self):
|
||||||
|
self.vector_name = self.vector_name or self.embed_model
|
||||||
self.kb_path = self.get_kb_path()
|
self.kb_path = self.get_kb_path()
|
||||||
self.vs_path = self.get_vs_path()
|
self.vs_path = self.get_vs_path()
|
||||||
|
|
||||||
|
|
@ -58,8 +55,10 @@ class FaissKBService(KBService):
|
||||||
top_k: int,
|
top_k: int,
|
||||||
score_threshold: float = SCORE_THRESHOLD,
|
score_threshold: float = SCORE_THRESHOLD,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
|
embed_func = EmbeddingsFunAdapter(self.embed_model)
|
||||||
|
embeddings = embed_func.embed_query(query)
|
||||||
with self.load_vector_store().acquire() as vs:
|
with self.load_vector_store().acquire() as vs:
|
||||||
docs = vs.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
|
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def do_add_doc(self,
|
def do_add_doc(self,
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,10 @@ class MilvusKBService(KBService):
|
||||||
|
|
||||||
def do_search(self, query: str, top_k: int, score_threshold: float):
|
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||||
self._load_milvus()
|
self._load_milvus()
|
||||||
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
|
embed_func = EmbeddingsFunAdapter(self.embed_model)
|
||||||
|
embeddings = embed_func.embed_query(query)
|
||||||
|
docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k)
|
||||||
|
return score_threshold_process(score_threshold, top_k, docs)
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
# TODO: workaround for bug #10492 in langchain
|
# TODO: workaround for bug #10492 in langchain
|
||||||
|
|
|
||||||
|
|
@ -53,8 +53,10 @@ class PGKBService(KBService):
|
||||||
|
|
||||||
def do_search(self, query: str, top_k: int, score_threshold: float):
|
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||||
self._load_pg_vector()
|
self._load_pg_vector()
|
||||||
return score_threshold_process(score_threshold, top_k,
|
embed_func = EmbeddingsFunAdapter(self.embed_model)
|
||||||
self.pg_vector.similarity_search_with_score(query, top_k))
|
embeddings = embed_func.embed_query(query)
|
||||||
|
docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k)
|
||||||
|
return score_threshold_process(score_threshold, top_k, docs)
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
ids = self.pg_vector.add_documents(docs)
|
ids = self.pg_vector.add_documents(docs)
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,10 @@ class ZillizKBService(KBService):
|
||||||
|
|
||||||
def do_search(self, query: str, top_k: int, score_threshold: float):
|
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||||
self._load_zilliz()
|
self._load_zilliz()
|
||||||
return score_threshold_process(score_threshold, top_k, self.zilliz.similarity_search_with_score(query, top_k))
|
embed_func = EmbeddingsFunAdapter(self.embed_model)
|
||||||
|
embeddings = embed_func.embed_query(query)
|
||||||
|
docs = self.zilliz.similarity_search_with_score_by_vector(embeddings, top_k)
|
||||||
|
return score_threshold_process(score_threshold, top_k, docs)
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
|
|
|
||||||
|
|
@ -67,7 +67,8 @@ def folder2db(
|
||||||
kb_names = kb_names or list_kbs_from_folder()
|
kb_names = kb_names or list_kbs_from_folder()
|
||||||
for kb_name in kb_names:
|
for kb_name in kb_names:
|
||||||
kb = KBServiceFactory.get_service(kb_name, vs_type, embed_model)
|
kb = KBServiceFactory.get_service(kb_name, vs_type, embed_model)
|
||||||
kb.create_kb()
|
if not kb.exists():
|
||||||
|
kb.create_kb()
|
||||||
|
|
||||||
# 清除向量库,从本地文件重建
|
# 清除向量库,从本地文件重建
|
||||||
if mode == "recreate_vs":
|
if mode == "recreate_vs":
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ def get_doc_path(knowledge_base_name: str):
|
||||||
|
|
||||||
|
|
||||||
def get_vs_path(knowledge_base_name: str, vector_name: str):
|
def get_vs_path(knowledge_base_name: str, vector_name: str):
|
||||||
return os.path.join(get_kb_path(knowledge_base_name), vector_name)
|
return os.path.join(get_kb_path(knowledge_base_name), "vector_store", vector_name)
|
||||||
|
|
||||||
|
|
||||||
def get_file_path(knowledge_base_name: str, doc_name: str):
|
def get_file_path(knowledge_base_name: str, doc_name: str):
|
||||||
|
|
|
||||||
|
|
@ -237,20 +237,23 @@ class ChatMessage(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
def torch_gc():
|
def torch_gc():
|
||||||
import torch
|
try:
|
||||||
if torch.cuda.is_available():
|
import torch
|
||||||
# with torch.cuda.device(DEVICE):
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
# with torch.cuda.device(DEVICE):
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.empty_cache()
|
||||||
elif torch.backends.mps.is_available():
|
torch.cuda.ipc_collect()
|
||||||
try:
|
elif torch.backends.mps.is_available():
|
||||||
from torch.mps import empty_cache
|
try:
|
||||||
empty_cache()
|
from torch.mps import empty_cache
|
||||||
except Exception as e:
|
empty_cache()
|
||||||
msg = ("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,"
|
except Exception as e:
|
||||||
"以支持及时清理 torch 产生的内存占用。")
|
msg = ("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,"
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
"以支持及时清理 torch 产生的内存占用。")
|
||||||
exc_info=e if log_verbose else None)
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
|
exc_info=e if log_verbose else None)
|
||||||
|
except Exception:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
def run_async(cor):
|
def run_async(cor):
|
||||||
|
|
@ -719,10 +722,10 @@ def list_online_embed_models() -> List[str]:
|
||||||
|
|
||||||
ret = []
|
ret = []
|
||||||
for k, v in list_config_llm_models()["online"].items():
|
for k, v in list_config_llm_models()["online"].items():
|
||||||
provider = v.get("provider")
|
if provider := v.get("provider"):
|
||||||
worker_class = getattr(model_workers, provider, None)
|
worker_class = getattr(model_workers, provider, None)
|
||||||
if worker_class is not None and worker_class.can_embedding():
|
if worker_class is not None and worker_class.can_embedding():
|
||||||
ret.append(k)
|
ret.append(k)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
13
webui.py
13
webui.py
|
|
@ -2,6 +2,7 @@ import streamlit as st
|
||||||
from webui_pages.utils import *
|
from webui_pages.utils import *
|
||||||
from streamlit_option_menu import option_menu
|
from streamlit_option_menu import option_menu
|
||||||
from webui_pages.dialogue.dialogue import dialogue_page, chat_box
|
from webui_pages.dialogue.dialogue import dialogue_page, chat_box
|
||||||
|
from webui_pages.knowledge_base.knowledge_base import knowledge_base_page
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from configs import VERSION
|
from configs import VERSION
|
||||||
|
|
@ -29,15 +30,11 @@ if __name__ == "__main__":
|
||||||
"icon": "chat",
|
"icon": "chat",
|
||||||
"func": dialogue_page,
|
"func": dialogue_page,
|
||||||
},
|
},
|
||||||
|
"知识库管理": {
|
||||||
|
"icon": "hdd-stack",
|
||||||
|
"func": knowledge_base_page,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if not is_lite:
|
|
||||||
from webui_pages.knowledge_base.knowledge_base import knowledge_base_page
|
|
||||||
pages.update({
|
|
||||||
"知识库管理": {
|
|
||||||
"icon": "hdd-stack",
|
|
||||||
"func": knowledge_base_page,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
st.image(
|
st.image(
|
||||||
|
|
|
||||||
|
|
@ -54,16 +54,11 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||||
text = f"{text} 当前知识库: `{cur_kb}`。"
|
text = f"{text} 当前知识库: `{cur_kb}`。"
|
||||||
st.toast(text)
|
st.toast(text)
|
||||||
|
|
||||||
if is_lite:
|
dialogue_modes = ["LLM 对话",
|
||||||
dialogue_modes = ["LLM 对话",
|
"知识库问答",
|
||||||
"搜索引擎问答",
|
"搜索引擎问答",
|
||||||
]
|
"自定义Agent问答",
|
||||||
else:
|
]
|
||||||
dialogue_modes = ["LLM 对话",
|
|
||||||
"知识库问答",
|
|
||||||
"搜索引擎问答",
|
|
||||||
"自定义Agent问答",
|
|
||||||
]
|
|
||||||
dialogue_mode = st.selectbox("请选择对话模式:",
|
dialogue_mode = st.selectbox("请选择对话模式:",
|
||||||
dialogue_modes,
|
dialogue_modes,
|
||||||
index=0,
|
index=0,
|
||||||
|
|
@ -116,18 +111,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||||
st.success(msg)
|
st.success(msg)
|
||||||
st.session_state["prev_llm_model"] = llm_model
|
st.session_state["prev_llm_model"] = llm_model
|
||||||
|
|
||||||
if is_lite:
|
index_prompt = {
|
||||||
index_prompt = {
|
"LLM 对话": "llm_chat",
|
||||||
"LLM 对话": "llm_chat",
|
"自定义Agent问答": "agent_chat",
|
||||||
"搜索引擎问答": "search_engine_chat",
|
"搜索引擎问答": "search_engine_chat",
|
||||||
}
|
"知识库问答": "knowledge_base_chat",
|
||||||
else:
|
}
|
||||||
index_prompt = {
|
|
||||||
"LLM 对话": "llm_chat",
|
|
||||||
"自定义Agent问答": "agent_chat",
|
|
||||||
"搜索引擎问答": "search_engine_chat",
|
|
||||||
"知识库问答": "knowledge_base_chat",
|
|
||||||
}
|
|
||||||
prompt_templates_kb_list = list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys())
|
prompt_templates_kb_list = list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys())
|
||||||
prompt_template_name = prompt_templates_kb_list[0]
|
prompt_template_name = prompt_templates_kb_list[0]
|
||||||
if "prompt_template_select" not in st.session_state:
|
if "prompt_template_select" not in st.session_state:
|
||||||
|
|
@ -167,7 +156,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||||
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
|
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
|
||||||
|
|
||||||
## Bge 模型会超过1
|
## Bge 模型会超过1
|
||||||
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01)
|
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
|
||||||
|
|
||||||
elif dialogue_mode == "搜索引擎问答":
|
elif dialogue_mode == "搜索引擎问答":
|
||||||
search_engine_list = api.list_search_engines()
|
search_engine_list = api.list_search_engines()
|
||||||
|
|
|
||||||
|
|
@ -103,7 +103,10 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
||||||
key="vs_type",
|
key="vs_type",
|
||||||
)
|
)
|
||||||
|
|
||||||
embed_models = list_embed_models() + list_online_embed_models()
|
if is_lite:
|
||||||
|
embed_models = list_online_embed_models()
|
||||||
|
else:
|
||||||
|
embed_models = list_embed_models() + list_online_embed_models()
|
||||||
|
|
||||||
embed_model = cols[1].selectbox(
|
embed_model = cols[1].selectbox(
|
||||||
"Embedding 模型",
|
"Embedding 模型",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue