支持在线 Embeddings:zhipu-api, qwen-api, minimax-api, qianfan-api (#1907)
* 新功能: - 支持在线 Embeddings:zhipu-api, qwen-api, minimax-api, qianfan-api - API 增加 /other/embed_texts 接口 - init_database.py 增加 --embed-model 参数,可以指定使用的嵌入模型(本地或在线均可) 问题修复: - API 中 list_config_models 会删除 ONLINE_LLM_MODEL 中的敏感信息,导致第二轮API请求错误 开发者: - 优化 kb_service 中 Embeddings 操作: - 统一加载接口: server.utils.load_embeddings,利用全局缓存避免各处 Embeddings 传参 - 统一文本嵌入接口:server.embedding_api.[embed_texts, embed_documents]
This commit is contained in:
parent
aa7c580974
commit
deed92169f
|
|
@ -1,7 +1,7 @@
|
|||
import sys
|
||||
sys.path.append(".")
|
||||
from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files
|
||||
from configs.model_config import NLTK_DATA_PATH
|
||||
from configs.model_config import NLTK_DATA_PATH, EMBEDDING_MODEL
|
||||
import nltk
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
from datetime import datetime
|
||||
|
|
@ -62,12 +62,20 @@ if __name__ == "__main__":
|
|||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
"--kb-name",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=[],
|
||||
help=("specify knowledge base names to operate on. default is all folders exist in KB_ROOT_PATH.")
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--embed-model",
|
||||
type=str,
|
||||
default=EMBEDDING_MODEL,
|
||||
help=("specify knowledge base names to operate on. default is all folders exist in KB_ROOT_PATH.")
|
||||
)
|
||||
|
||||
if len(sys.argv) <= 1:
|
||||
parser.print_help()
|
||||
|
|
@ -80,11 +88,11 @@ if __name__ == "__main__":
|
|||
reset_tables()
|
||||
print("database talbes reseted")
|
||||
print("recreating all vector stores")
|
||||
folder2db(kb_names=args.kb_name, mode="recreate_vs")
|
||||
folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model)
|
||||
elif args.update_in_db:
|
||||
folder2db(kb_names=args.kb_name, mode="update_in_db")
|
||||
folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model)
|
||||
elif args.increament:
|
||||
folder2db(kb_names=args.kb_name, mode="increament")
|
||||
folder2db(kb_names=args.kb_name, mode="increament", embed_model=args.embed_model)
|
||||
elif args.prune_db:
|
||||
prune_db_docs(args.kb_name)
|
||||
elif args.prune_folder:
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ pytest
|
|||
# online api libs
|
||||
zhipuai
|
||||
dashscope>=1.10.0 # qwen
|
||||
qianfan
|
||||
# qianfan
|
||||
# volcengine>=1.0.106 # fangzhou
|
||||
|
||||
# uncomment libs if you want to use corresponding vector store
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from server.chat.chat import chat
|
|||
from server.chat.openai_chat import openai_chat
|
||||
from server.chat.search_engine_chat import search_engine_chat
|
||||
from server.chat.completion import completion
|
||||
from server.embeddings_api import embed_texts_endpoint
|
||||
from server.llm_api import (list_running_models, list_config_models,
|
||||
change_llm_model, stop_llm_model,
|
||||
get_model_config, list_search_engines)
|
||||
|
|
@ -47,33 +48,34 @@ def create_app(run_mode: str = None):
|
|||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
mount_basic_routes(app)
|
||||
if run_mode != "lite":
|
||||
mount_knowledge_routes(app)
|
||||
mount_app_routes(app, run_mode=run_mode)
|
||||
return app
|
||||
|
||||
|
||||
def mount_basic_routes(app: FastAPI):
|
||||
def mount_app_routes(app: FastAPI, run_mode: str = None):
|
||||
app.get("/",
|
||||
response_model=BaseResponse,
|
||||
summary="swagger 文档")(document)
|
||||
|
||||
app.post("/completion",
|
||||
tags=["Completion"],
|
||||
summary="要求llm模型补全(通过LLMChain)")(completion)
|
||||
|
||||
# Tag: Chat
|
||||
app.post("/chat/fastchat",
|
||||
tags=["Chat"],
|
||||
summary="与llm模型对话(直接与fastchat api对话)")(openai_chat)
|
||||
summary="与llm模型对话(直接与fastchat api对话)",
|
||||
)(openai_chat)
|
||||
|
||||
app.post("/chat/chat",
|
||||
tags=["Chat"],
|
||||
summary="与llm模型对话(通过LLMChain)")(chat)
|
||||
summary="与llm模型对话(通过LLMChain)",
|
||||
)(chat)
|
||||
|
||||
app.post("/chat/search_engine_chat",
|
||||
tags=["Chat"],
|
||||
summary="与搜索引擎对话")(search_engine_chat)
|
||||
summary="与搜索引擎对话",
|
||||
)(search_engine_chat)
|
||||
|
||||
# 知识库相关接口
|
||||
if run_mode != "lite":
|
||||
mount_knowledge_routes(app)
|
||||
|
||||
# LLM模型相关接口
|
||||
app.post("/llm_model/list_running_models",
|
||||
|
|
@ -121,6 +123,17 @@ def mount_basic_routes(app: FastAPI):
|
|||
) -> str:
|
||||
return get_prompt_template(type=type, name=name)
|
||||
|
||||
# 其它接口
|
||||
app.post("/other/completion",
|
||||
tags=["Other"],
|
||||
summary="要求llm模型补全(通过LLMChain)",
|
||||
)(completion)
|
||||
|
||||
app.post("/other/embed_texts",
|
||||
tags=["Other"],
|
||||
summary="将文本向量化,支持本地模型和在线模型",
|
||||
)(embed_texts_endpoint)
|
||||
|
||||
|
||||
def mount_knowledge_routes(app: FastAPI):
|
||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||
|
|
|
|||
|
|
@ -0,0 +1,68 @@
|
|||
from langchain.docstore.document import Document
|
||||
from configs import EMBEDDING_MODEL, logger
|
||||
from server.model_workers.base import ApiEmbeddingsParams
|
||||
from server.utils import BaseResponse, get_model_worker_config, list_embed_models, list_online_embed_models
|
||||
from fastapi import Body
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
online_embed_models = list_online_embed_models()
|
||||
|
||||
|
||||
def embed_texts(
|
||||
texts: List[str],
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
to_query: bool = False,
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
|
||||
'''
|
||||
try:
|
||||
if embed_model in list_embed_models(): # 使用本地Embeddings模型
|
||||
from server.utils import load_local_embeddings
|
||||
|
||||
embeddings = load_local_embeddings(model=embed_model)
|
||||
return BaseResponse(data=embeddings.embed_documents(texts))
|
||||
|
||||
if embed_model in list_online_embed_models(): # 使用在线API
|
||||
config = get_model_worker_config(embed_model)
|
||||
worker_class = config.get("worker_class")
|
||||
worker = worker_class()
|
||||
if worker_class.can_embedding():
|
||||
params = ApiEmbeddingsParams(texts=texts, to_query=to_query)
|
||||
resp = worker.do_embeddings(params)
|
||||
return BaseResponse(**resp)
|
||||
|
||||
return BaseResponse(code=500, msg=f"指定的模型 {embed_model} 不支持 Embeddings 功能。")
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
|
||||
|
||||
def embed_texts_endpoint(
|
||||
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
|
||||
embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型,除了本地部署的Embedding模型,也支持在线API({online_embed_models})提供的嵌入服务。"),
|
||||
to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"),
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
对文本进行向量化,返回 BaseResponse(data=List[List[float]])
|
||||
'''
|
||||
return embed_texts(texts=texts, embed_model=embed_model, to_query=to_query)
|
||||
|
||||
|
||||
def embed_documents(
|
||||
docs: List[Document],
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
to_query: bool = False,
|
||||
) -> Dict:
|
||||
"""
|
||||
将 List[Document] 向量化,转化为 VectorStore.add_embeddings 可以接受的参数
|
||||
"""
|
||||
texts = [x.page_content for x in docs]
|
||||
metadatas = [x.metadata for x in docs]
|
||||
embeddings = embed_texts(texts=texts, embed_model=embed_model, to_query=to_query).data
|
||||
if embeddings is not None:
|
||||
return {
|
||||
"texts": texts,
|
||||
"embeddings": embeddings,
|
||||
"metadatas": metadatas,
|
||||
}
|
||||
|
|
@ -110,7 +110,7 @@ class CachePool:
|
|||
|
||||
|
||||
class EmbeddingsPool(CachePool):
|
||||
def load_embeddings(self, model: str, device: str) -> Embeddings:
|
||||
def load_embeddings(self, model: str = None, device: str = None) -> Embeddings:
|
||||
self.atomic.acquire()
|
||||
model = model or EMBEDDING_MODEL
|
||||
device = device or embedding_device()
|
||||
|
|
@ -121,7 +121,7 @@ class EmbeddingsPool(CachePool):
|
|||
with item.acquire(msg="初始化"):
|
||||
self.atomic.release()
|
||||
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
|
||||
embeddings = OpenAIEmbeddings(model_name=model, # TODO: 支持Azure
|
||||
embeddings = OpenAIEmbeddings(model_name=model,
|
||||
openai_api_key=get_model_path(model),
|
||||
chunk_size=CHUNK_SIZE)
|
||||
elif 'bge-' in model:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
from configs import CACHED_VS_NUM
|
||||
from server.knowledge_base.kb_cache.base import *
|
||||
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
|
||||
from server.utils import load_local_embeddings
|
||||
from server.knowledge_base.utils import get_vs_path
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain.vectorstores.faiss import FAISS
|
||||
from langchain.schema import Document
|
||||
import os
|
||||
from langchain.schema import Document
|
||||
|
||||
|
|
@ -38,9 +41,9 @@ class _FaissPool(CachePool):
|
|||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_device: str = embedding_device(),
|
||||
) -> FAISS:
|
||||
embeddings = embeddings_pool.load_embeddings(embed_model, embed_device)
|
||||
|
||||
# TODO: 整个Embeddings加载逻辑有些混乱,待清理
|
||||
# create an empty vector store
|
||||
embeddings = EmbeddingsFunAdapter(embed_model)
|
||||
doc = Document(page_content="init", metadata={})
|
||||
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True)
|
||||
ids = list(vector_store.docstore._dict.keys())
|
||||
|
|
@ -133,7 +136,7 @@ if __name__ == "__main__":
|
|||
def worker(vs_name: str, name: str):
|
||||
vs_name = "samples"
|
||||
time.sleep(random.randint(1, 5))
|
||||
embeddings = embeddings_pool.load_embeddings()
|
||||
embeddings = load_local_embeddings()
|
||||
r = random.randint(1, 3)
|
||||
|
||||
with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs:
|
||||
|
|
|
|||
|
|
@ -21,12 +21,15 @@ from server.db.repository.knowledge_file_repository import (
|
|||
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
EMBEDDING_MODEL, KB_INFO)
|
||||
from server.knowledge_base.utils import (
|
||||
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
|
||||
get_kb_path, get_doc_path, KnowledgeFile,
|
||||
list_kbs_from_folder, list_files_from_folder,
|
||||
)
|
||||
from server.utils import embedding_device
|
||||
|
||||
from typing import List, Union, Dict, Optional
|
||||
|
||||
from server.embeddings_api import embed_texts
|
||||
from server.embeddings_api import embed_documents
|
||||
|
||||
|
||||
class SupportedVSType:
|
||||
FAISS = 'faiss'
|
||||
|
|
@ -48,8 +51,6 @@ class KBService(ABC):
|
|||
self.kb_path = get_kb_path(self.kb_name)
|
||||
self.doc_path = get_doc_path(self.kb_name)
|
||||
self.do_init()
|
||||
def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
|
||||
return load_embeddings(self.embed_model, embed_device)
|
||||
|
||||
def save_vector_store(self):
|
||||
'''
|
||||
|
|
@ -83,6 +84,12 @@ class KBService(ABC):
|
|||
status = delete_kb_from_db(self.kb_name)
|
||||
return status
|
||||
|
||||
def _docs_to_embeddings(self, docs: List[Document]) -> Dict:
|
||||
'''
|
||||
将 List[Document] 转化为 VectorStore.add_embeddings 可以接受的参数
|
||||
'''
|
||||
return embed_documents(docs=docs, embed_model=self.embed_model, to_query=False)
|
||||
|
||||
def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
|
||||
"""
|
||||
向知识库添加文件
|
||||
|
|
@ -149,8 +156,7 @@ class KBService(ABC):
|
|||
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||
score_threshold: float = SCORE_THRESHOLD,
|
||||
):
|
||||
embeddings = self._load_embeddings()
|
||||
docs = self.do_search(query, top_k, score_threshold, embeddings)
|
||||
docs = self.do_search(query, top_k, score_threshold)
|
||||
return docs
|
||||
|
||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||
|
|
@ -346,24 +352,26 @@ def get_kb_file_details(kb_name: str) -> List[Dict]:
|
|||
|
||||
|
||||
class EmbeddingsFunAdapter(Embeddings):
|
||||
|
||||
def __init__(self, embeddings: Embeddings):
|
||||
self.embeddings = embeddings
|
||||
def __init__(self, embed_model: str = EMBEDDING_MODEL):
|
||||
self.embed_model = embed_model
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return normalize(self.embeddings.embed_documents(texts))
|
||||
embeddings = embed_texts(texts=texts, embed_model=self.embed_model, to_query=False).data
|
||||
return normalize(embeddings)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
query_embed = self.embeddings.embed_query(text)
|
||||
embeddings = embed_texts(texts=[text], embed_model=self.embed_model, to_query=True).data
|
||||
query_embed = embeddings[0]
|
||||
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
|
||||
normalized_query_embed = normalize(query_embed_2d)
|
||||
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return await normalize(self.embeddings.aembed_documents(texts))
|
||||
# TODO: 暂不支持异步
|
||||
# async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
# return normalize(await self.embeddings.aembed_documents(texts))
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
return await normalize(self.embeddings.aembed_query(text))
|
||||
# async def aembed_query(self, text: str) -> List[float]:
|
||||
# return normalize(await self.embeddings.aembed_query(text))
|
||||
|
||||
|
||||
def score_threshold_process(score_threshold, k, docs):
|
||||
|
|
|
|||
|
|
@ -9,10 +9,9 @@ from configs import (
|
|||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from typing import List, Dict, Optional
|
||||
from langchain.docstore.document import Document
|
||||
from server.utils import torch_gc
|
||||
from langchain.docstore.document import Document
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
|
||||
class FaissKBService(KBService):
|
||||
|
|
@ -58,7 +57,6 @@ class FaissKBService(KBService):
|
|||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float = SCORE_THRESHOLD,
|
||||
embeddings: Embeddings = None,
|
||||
) -> List[Document]:
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
docs = vs.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
|
||||
|
|
@ -68,8 +66,11 @@ class FaissKBService(KBService):
|
|||
docs: List[Document],
|
||||
**kwargs,
|
||||
) -> List[Dict]:
|
||||
data = self._docs_to_embeddings(docs) # 将向量化单独出来可以减少向量库的锁定时间
|
||||
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
ids = vs.add_documents(docs)
|
||||
ids = vs.add_embeddings(text_embeddings=zip(data["texts"], data["embeddings"]),
|
||||
metadatas=data["metadatas"])
|
||||
if not kwargs.get("not_refresh_vs_cache"):
|
||||
vs.save_local(self.vs_path)
|
||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
from typing import List, Dict, Optional
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import Milvus
|
||||
from langchain.vectorstores.milvus import Milvus
|
||||
|
||||
from configs import kbs_config
|
||||
|
||||
|
|
@ -46,10 +45,8 @@ class MilvusKBService(KBService):
|
|||
def vs_type(self) -> str:
|
||||
return SupportedVSType.MILVUS
|
||||
|
||||
def _load_milvus(self, embeddings: Embeddings = None):
|
||||
if embeddings is None:
|
||||
embeddings = self._load_embeddings()
|
||||
self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(embeddings),
|
||||
def _load_milvus(self):
|
||||
self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(self.embed_model),
|
||||
collection_name=self.kb_name, connection_args=kbs_config.get("milvus"))
|
||||
|
||||
def do_init(self):
|
||||
|
|
@ -60,8 +57,8 @@ class MilvusKBService(KBService):
|
|||
self.milvus.col.release()
|
||||
self.milvus.col.drop()
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
|
||||
self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings))
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||
self._load_milvus()
|
||||
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
|
|
|
|||
|
|
@ -1,28 +1,22 @@
|
|||
import json
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import PGVector
|
||||
from langchain.vectorstores.pgvector import DistanceStrategy
|
||||
from langchain.vectorstores.pgvector import PGVector, DistanceStrategy
|
||||
from sqlalchemy import text
|
||||
|
||||
from configs import kbs_config
|
||||
|
||||
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
|
||||
score_threshold_process
|
||||
from server.knowledge_base.utils import load_embeddings, KnowledgeFile
|
||||
from server.utils import embedding_device as get_embedding_device
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
|
||||
|
||||
class PGKBService(KBService):
|
||||
pg_vector: PGVector
|
||||
|
||||
def _load_pg_vector(self, embedding_device: str = get_embedding_device(), embeddings: Embeddings = None):
|
||||
_embeddings = embeddings
|
||||
if _embeddings is None:
|
||||
_embeddings = load_embeddings(self.embed_model, embedding_device)
|
||||
self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(_embeddings),
|
||||
def _load_pg_vector(self):
|
||||
self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(self.embed_model),
|
||||
collection_name=self.kb_name,
|
||||
distance_strategy=DistanceStrategy.EUCLIDEAN,
|
||||
connection_string=kbs_config.get("pg").get("connection_uri"))
|
||||
|
|
@ -57,8 +51,8 @@ class PGKBService(KBService):
|
|||
'''))
|
||||
connect.commit()
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
|
||||
self._load_pg_vector(embeddings=embeddings)
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||
self._load_pg_vector()
|
||||
return score_threshold_process(score_threshold, top_k,
|
||||
self.pg_vector.similarity_search_with_score(query, top_k))
|
||||
|
||||
|
|
|
|||
|
|
@ -43,11 +43,9 @@ class ZillizKBService(KBService):
|
|||
def vs_type(self) -> str:
|
||||
return SupportedVSType.ZILLIZ
|
||||
|
||||
def _load_zilliz(self, embeddings: Embeddings = None):
|
||||
if embeddings is None:
|
||||
embeddings = self._load_embeddings()
|
||||
def _load_zilliz(self):
|
||||
zilliz_args = kbs_config.get("zilliz")
|
||||
self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(embeddings),
|
||||
self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(self.embed_model),
|
||||
collection_name=self.kb_name, connection_args=zilliz_args)
|
||||
|
||||
|
||||
|
|
@ -59,8 +57,8 @@ class ZillizKBService(KBService):
|
|||
self.zilliz.col.release()
|
||||
self.zilliz.col.drop()
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
|
||||
self._load_zilliz(embeddings=EmbeddingsFunAdapter(embeddings))
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||
self._load_zilliz()
|
||||
return score_threshold_process(score_threshold, top_k, self.zilliz.similarity_search_with_score(query, top_k))
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,4 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append("/home/congyin/Code/Project_Langchain_0814/Langchain-Chatchat")
|
||||
from transformers import AutoTokenizer
|
||||
from configs import (
|
||||
EMBEDDING_MODEL,
|
||||
KB_ROOT_PATH,
|
||||
|
|
@ -22,7 +18,6 @@ from langchain.docstore.document import Document
|
|||
from langchain.text_splitter import TextSplitter
|
||||
from pathlib import Path
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from server.utils import run_in_thread_pool, embedding_device, get_model_worker_config
|
||||
import io
|
||||
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
|
||||
|
|
@ -62,14 +57,6 @@ def list_files_from_folder(kb_name: str):
|
|||
if os.path.isfile(os.path.join(doc_path, file))]
|
||||
|
||||
|
||||
def load_embeddings(model: str = EMBEDDING_MODEL, device: str = embedding_device()):
|
||||
'''
|
||||
从缓存中加载embeddings,可以避免多线程时竞争加载。
|
||||
'''
|
||||
from server.knowledge_base.kb_cache.base import embeddings_pool
|
||||
return embeddings_pool.load_embeddings(model=model, device=device)
|
||||
|
||||
|
||||
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
|
||||
"UnstructuredMarkdownLoader": ['.md'],
|
||||
"CustomJSONLoader": [".json"],
|
||||
|
|
@ -239,6 +226,7 @@ def make_text_splitter(
|
|||
from langchain.text_splitter import CharacterTextSplitter
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
else: ## 字符长度加载
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
||||
trust_remote_code=True)
|
||||
|
|
@ -358,7 +346,6 @@ def files2docs_in_thread(
|
|||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
|
||||
pool: ThreadPoolExecutor = None,
|
||||
) -> Generator:
|
||||
'''
|
||||
利用多线程批量将磁盘文件转化成langchain Document.
|
||||
|
|
@ -396,7 +383,7 @@ def files2docs_in_thread(
|
|||
except Exception as e:
|
||||
yield False, (kb_name, filename, str(e))
|
||||
|
||||
for result in run_in_thread_pool(func=file2docs, params=kwargs_list, pool=pool):
|
||||
for result in run_in_thread_pool(func=file2docs, params=kwargs_list):
|
||||
yield result
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from fastapi import Body
|
|||
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
|
||||
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
|
||||
get_httpx_client, get_model_worker_config)
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
def list_running_models(
|
||||
|
|
@ -31,16 +32,16 @@ def list_config_models() -> BaseResponse:
|
|||
'''
|
||||
从本地获取configs中配置的模型列表
|
||||
'''
|
||||
configs = list_config_llm_models()
|
||||
configs = {}
|
||||
# 删除ONLINE_MODEL配置中的敏感信息
|
||||
for config in configs["online"].values():
|
||||
del_keys = set(["worker_class"])
|
||||
for k in config:
|
||||
if "key" in k.lower() or "secret" in k.lower():
|
||||
del_keys.add(k)
|
||||
for k in del_keys:
|
||||
config.pop(k, None)
|
||||
|
||||
for name, config in list_config_llm_models()["online"].items():
|
||||
configs[name] = {}
|
||||
for k, v in config.items():
|
||||
if not (k == "worker_class"
|
||||
or "key" in k.lower()
|
||||
or "secret" in k.lower()
|
||||
or k.lower().endswith("id")):
|
||||
configs[name][k] = v
|
||||
return BaseResponse(data=configs)
|
||||
|
||||
|
||||
|
|
@ -51,14 +52,14 @@ def get_model_config(
|
|||
'''
|
||||
获取LLM模型配置项(合并后的)
|
||||
'''
|
||||
config = get_model_worker_config(model_name=model_name)
|
||||
config = {}
|
||||
# 删除ONLINE_MODEL配置中的敏感信息
|
||||
del_keys = set(["worker_class"])
|
||||
for k in config:
|
||||
if "key" in k.lower() or "secret" in k.lower():
|
||||
del_keys.add(k)
|
||||
for k in del_keys:
|
||||
config.pop(k, None)
|
||||
for k, v in get_model_worker_config(model_name=model_name).items():
|
||||
if not (k == "worker_class"
|
||||
or "key" in k.lower()
|
||||
or "secret" in k.lower()
|
||||
or k.lower().endswith("id")):
|
||||
config[k] = v
|
||||
|
||||
return BaseResponse(data=config)
|
||||
|
||||
|
|
|
|||
|
|
@ -180,7 +180,7 @@ class ApiModelWorker(BaseModelWorker):
|
|||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||
'''
|
||||
执行Embeddings的方法,默认使用模块里面的embed_documents函数。
|
||||
要求返回形式:{"code": int, "embeddings": List[List[float]], "msg": str}
|
||||
要求返回形式:{"code": int, "data": List[List[float]], "msg": str}
|
||||
'''
|
||||
return {"code": 500, "msg": f"{self.model_names[0]}未实现embeddings功能"}
|
||||
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ class MiniMaxWorker(ApiModelWorker):
|
|||
with get_httpx_client() as client:
|
||||
r = client.post(url, headers=headers, json=data).json()
|
||||
if embeddings := r.get("vectors"):
|
||||
return {"code": 200, "embeddings": embeddings}
|
||||
return {"code": 200, "data": embeddings}
|
||||
elif error := r.get("base_resp"):
|
||||
return {"code": error["status_code"], "msg": error["status_msg"]}
|
||||
|
||||
|
|
|
|||
|
|
@ -172,7 +172,7 @@ class QianFanWorker(ApiModelWorker):
|
|||
resp = client.post(url, json={"input": params.texts}).json()
|
||||
if "error_cdoe" not in resp:
|
||||
embeddings = [x["embedding"] for x in resp.get("data", [])]
|
||||
return {"code": 200, "embeddings": embeddings}
|
||||
return {"code": 200, "data": embeddings}
|
||||
else:
|
||||
return {"code": resp["error_code"], "msg": resp["error_msg"]}
|
||||
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ class QwenWorker(ApiModelWorker):
|
|||
return {"code": resp["status_code"], "msg": resp.message}
|
||||
else:
|
||||
embeddings = [x["embedding"] for x in resp["output"]["embeddings"]]
|
||||
return {"code": 200, "embeddings": embeddings}
|
||||
return {"code": 200, "data": embeddings}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class ChatGLMWorker(ApiModelWorker):
|
|||
except Exception as e:
|
||||
return {"code": 500, "msg": f"对文本向量化时出错:{e}"}
|
||||
|
||||
return {"code": 200, "embeddings": embeddings}
|
||||
return {"code": 200, "data": embeddings}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from langchain.llms import OpenAI, AzureOpenAI, Anthropic
|
|||
import httpx
|
||||
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
|
||||
|
||||
thread_pool = ThreadPoolExecutor(os.cpu_count())
|
||||
|
||||
|
||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||
|
|
@ -368,8 +367,8 @@ def MakeFastAPIOffline(
|
|||
redoc_favicon_url=favicon,
|
||||
)
|
||||
|
||||
# 从model_config中获取模型信息
|
||||
|
||||
# 从model_config中获取模型信息
|
||||
|
||||
def list_embed_models() -> List[str]:
|
||||
'''
|
||||
|
|
@ -432,8 +431,8 @@ def get_model_worker_config(model_name: str = None) -> dict:
|
|||
from server import model_workers
|
||||
|
||||
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
||||
config.update(ONLINE_LLM_MODEL.get(model_name, {}))
|
||||
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
|
||||
config.update(ONLINE_LLM_MODEL.get(model_name, {}).copy())
|
||||
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}).copy())
|
||||
|
||||
|
||||
if model_name in ONLINE_LLM_MODEL:
|
||||
|
|
@ -611,21 +610,19 @@ def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
|||
def run_in_thread_pool(
|
||||
func: Callable,
|
||||
params: List[Dict] = [],
|
||||
pool: ThreadPoolExecutor = None,
|
||||
) -> Generator:
|
||||
'''
|
||||
在线程池中批量运行任务,并将运行结果以生成器的形式返回。
|
||||
请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。
|
||||
'''
|
||||
tasks = []
|
||||
pool = pool or thread_pool
|
||||
with ThreadPoolExecutor() as pool:
|
||||
for kwargs in params:
|
||||
thread = pool.submit(func, **kwargs)
|
||||
tasks.append(thread)
|
||||
|
||||
for kwargs in params:
|
||||
thread = pool.submit(func, **kwargs)
|
||||
tasks.append(thread)
|
||||
|
||||
for obj in as_completed(tasks):
|
||||
yield obj.result()
|
||||
for obj in as_completed(tasks): # TODO: Ctrl+c无法停止
|
||||
yield obj.result()
|
||||
|
||||
|
||||
def get_httpx_client(
|
||||
|
|
@ -703,7 +700,6 @@ def get_server_configs() -> Dict:
|
|||
)
|
||||
from configs.model_config import (
|
||||
LLM_MODEL,
|
||||
EMBEDDING_MODEL,
|
||||
HISTORY_LEN,
|
||||
TEMPERATURE,
|
||||
)
|
||||
|
|
@ -728,3 +724,14 @@ def list_online_embed_models() -> List[str]:
|
|||
if worker_class is not None and worker_class.can_embedding():
|
||||
ret.append(k)
|
||||
return ret
|
||||
|
||||
|
||||
def load_local_embeddings(model: str = None, device: str = embedding_device()):
|
||||
'''
|
||||
从缓存中加载embeddings,可以避免多线程时竞争加载。
|
||||
'''
|
||||
from server.knowledge_base.kb_cache.base import embeddings_pool
|
||||
from configs import EMBEDDING_MODEL
|
||||
|
||||
model = model or EMBEDDING_MODEL
|
||||
return embeddings_pool.load_embeddings(model=model, device=device)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ for x in list_config_llm_models()["online"]:
|
|||
workers.append(x)
|
||||
print(f"all workers to test: {workers}")
|
||||
|
||||
# workers = ["qianfan-api"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("worker", workers)
|
||||
def test_chat(worker):
|
||||
|
|
@ -49,8 +51,8 @@ def test_embeddings(worker):
|
|||
|
||||
pprint(resp, depth=2)
|
||||
assert resp["code"] == 200
|
||||
assert "embeddings" in resp
|
||||
embeddings = resp["embeddings"]
|
||||
assert "data" in resp
|
||||
embeddings = resp["data"]
|
||||
assert isinstance(embeddings, list) and len(embeddings) > 0
|
||||
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
|
||||
assert isinstance(embeddings[0][0], float)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing import Literal, Dict, Tuple
|
|||
from configs import (kbs_config,
|
||||
EMBEDDING_MODEL, DEFAULT_VS_TYPE,
|
||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
||||
from server.utils import list_embed_models
|
||||
from server.utils import list_embed_models, list_online_embed_models
|
||||
import os
|
||||
import time
|
||||
|
||||
|
|
@ -103,7 +103,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
|||
key="vs_type",
|
||||
)
|
||||
|
||||
embed_models = list_embed_models()
|
||||
embed_models = list_embed_models() + list_online_embed_models()
|
||||
|
||||
embed_model = cols[1].selectbox(
|
||||
"Embedding 模型",
|
||||
|
|
|
|||
|
|
@ -853,6 +853,26 @@ class ApiRequest:
|
|||
else:
|
||||
return ret_sync()
|
||||
|
||||
def embed_texts(
|
||||
self,
|
||||
texts: List[str],
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
to_query: bool = False,
|
||||
) -> List[List[float]]:
|
||||
'''
|
||||
对文本进行向量化,可选模型包括本地 embed_models 和支持 embeddings 的在线模型
|
||||
'''
|
||||
data = {
|
||||
"texts": texts,
|
||||
"embed_model": embed_model,
|
||||
"to_query": to_query,
|
||||
}
|
||||
resp = self.post(
|
||||
"/other/embed_texts",
|
||||
json=data,
|
||||
)
|
||||
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data"))
|
||||
|
||||
|
||||
class AsyncApiRequest(ApiRequest):
|
||||
def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT):
|
||||
|
|
|
|||
Loading…
Reference in New Issue