支持在线 Embeddings:zhipu-api, qwen-api, minimax-api, qianfan-api (#1907)

* 新功能:
- 支持在线 Embeddings:zhipu-api, qwen-api, minimax-api, qianfan-api
- API 增加 /other/embed_texts 接口
- init_database.py 增加 --embed-model 参数,可以指定使用的嵌入模型(本地或在线均可)

问题修复:
- API 中 list_config_models 会删除 ONLINE_LLM_MODEL 中的敏感信息,导致第二轮API请求错误

开发者:
- 优化 kb_service 中 Embeddings 操作:
  - 统一加载接口: server.utils.load_embeddings,利用全局缓存避免各处 Embeddings 传参
  - 统一文本嵌入接口:server.embedding_api.[embed_texts, embed_documents]
This commit is contained in:
liunux4odoo 2023-10-28 23:37:30 +08:00 committed by GitHub
parent aa7c580974
commit deed92169f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 228 additions and 121 deletions

View File

@ -1,7 +1,7 @@
import sys
sys.path.append(".")
from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files
from configs.model_config import NLTK_DATA_PATH
from configs.model_config import NLTK_DATA_PATH, EMBEDDING_MODEL
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
from datetime import datetime
@ -62,12 +62,20 @@ if __name__ == "__main__":
)
)
parser.add_argument(
"-n",
"--kb-name",
type=str,
nargs="+",
default=[],
help=("specify knowledge base names to operate on. default is all folders exist in KB_ROOT_PATH.")
)
parser.add_argument(
"-e",
"--embed-model",
type=str,
default=EMBEDDING_MODEL,
help=("specify knowledge base names to operate on. default is all folders exist in KB_ROOT_PATH.")
)
if len(sys.argv) <= 1:
parser.print_help()
@ -80,11 +88,11 @@ if __name__ == "__main__":
reset_tables()
print("database talbes reseted")
print("recreating all vector stores")
folder2db(kb_names=args.kb_name, mode="recreate_vs")
folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model)
elif args.update_in_db:
folder2db(kb_names=args.kb_name, mode="update_in_db")
folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model)
elif args.increament:
folder2db(kb_names=args.kb_name, mode="increament")
folder2db(kb_names=args.kb_name, mode="increament", embed_model=args.embed_model)
elif args.prune_db:
prune_db_docs(args.kb_name)
elif args.prune_folder:

View File

@ -30,7 +30,7 @@ pytest
# online api libs
zhipuai
dashscope>=1.10.0 # qwen
qianfan
# qianfan
# volcengine>=1.0.106 # fangzhou
# uncomment libs if you want to use corresponding vector store

View File

@ -16,6 +16,7 @@ from server.chat.chat import chat
from server.chat.openai_chat import openai_chat
from server.chat.search_engine_chat import search_engine_chat
from server.chat.completion import completion
from server.embeddings_api import embed_texts_endpoint
from server.llm_api import (list_running_models, list_config_models,
change_llm_model, stop_llm_model,
get_model_config, list_search_engines)
@ -47,33 +48,34 @@ def create_app(run_mode: str = None):
allow_methods=["*"],
allow_headers=["*"],
)
mount_basic_routes(app)
if run_mode != "lite":
mount_knowledge_routes(app)
mount_app_routes(app, run_mode=run_mode)
return app
def mount_basic_routes(app: FastAPI):
def mount_app_routes(app: FastAPI, run_mode: str = None):
app.get("/",
response_model=BaseResponse,
summary="swagger 文档")(document)
app.post("/completion",
tags=["Completion"],
summary="要求llm模型补全(通过LLMChain)")(completion)
# Tag: Chat
app.post("/chat/fastchat",
tags=["Chat"],
summary="与llm模型对话(直接与fastchat api对话)")(openai_chat)
summary="与llm模型对话(直接与fastchat api对话)",
)(openai_chat)
app.post("/chat/chat",
tags=["Chat"],
summary="与llm模型对话(通过LLMChain)")(chat)
summary="与llm模型对话(通过LLMChain)",
)(chat)
app.post("/chat/search_engine_chat",
tags=["Chat"],
summary="与搜索引擎对话")(search_engine_chat)
summary="与搜索引擎对话",
)(search_engine_chat)
# 知识库相关接口
if run_mode != "lite":
mount_knowledge_routes(app)
# LLM模型相关接口
app.post("/llm_model/list_running_models",
@ -121,6 +123,17 @@ def mount_basic_routes(app: FastAPI):
) -> str:
return get_prompt_template(type=type, name=name)
# 其它接口
app.post("/other/completion",
tags=["Other"],
summary="要求llm模型补全(通过LLMChain)",
)(completion)
app.post("/other/embed_texts",
tags=["Other"],
summary="将文本向量化,支持本地模型和在线模型",
)(embed_texts_endpoint)
def mount_knowledge_routes(app: FastAPI):
from server.chat.knowledge_base_chat import knowledge_base_chat

68
server/embeddings_api.py Normal file
View File

@ -0,0 +1,68 @@
from langchain.docstore.document import Document
from configs import EMBEDDING_MODEL, logger
from server.model_workers.base import ApiEmbeddingsParams
from server.utils import BaseResponse, get_model_worker_config, list_embed_models, list_online_embed_models
from fastapi import Body
from typing import Dict, List
online_embed_models = list_online_embed_models()
def embed_texts(
texts: List[str],
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
) -> BaseResponse:
'''
对文本进行向量化返回数据格式BaseResponse(data=List[List[float]])
'''
try:
if embed_model in list_embed_models(): # 使用本地Embeddings模型
from server.utils import load_local_embeddings
embeddings = load_local_embeddings(model=embed_model)
return BaseResponse(data=embeddings.embed_documents(texts))
if embed_model in list_online_embed_models(): # 使用在线API
config = get_model_worker_config(embed_model)
worker_class = config.get("worker_class")
worker = worker_class()
if worker_class.can_embedding():
params = ApiEmbeddingsParams(texts=texts, to_query=to_query)
resp = worker.do_embeddings(params)
return BaseResponse(**resp)
return BaseResponse(code=500, msg=f"指定的模型 {embed_model} 不支持 Embeddings 功能。")
except Exception as e:
logger.error(e)
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
def embed_texts_endpoint(
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型除了本地部署的Embedding模型也支持在线API({online_embed_models})提供的嵌入服务。"),
to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"),
) -> BaseResponse:
'''
对文本进行向量化返回 BaseResponse(data=List[List[float]])
'''
return embed_texts(texts=texts, embed_model=embed_model, to_query=to_query)
def embed_documents(
docs: List[Document],
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
) -> Dict:
"""
List[Document] 向量化转化为 VectorStore.add_embeddings 可以接受的参数
"""
texts = [x.page_content for x in docs]
metadatas = [x.metadata for x in docs]
embeddings = embed_texts(texts=texts, embed_model=embed_model, to_query=to_query).data
if embeddings is not None:
return {
"texts": texts,
"embeddings": embeddings,
"metadatas": metadatas,
}

View File

@ -110,7 +110,7 @@ class CachePool:
class EmbeddingsPool(CachePool):
def load_embeddings(self, model: str, device: str) -> Embeddings:
def load_embeddings(self, model: str = None, device: str = None) -> Embeddings:
self.atomic.acquire()
model = model or EMBEDDING_MODEL
device = device or embedding_device()
@ -121,7 +121,7 @@ class EmbeddingsPool(CachePool):
with item.acquire(msg="初始化"):
self.atomic.release()
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
embeddings = OpenAIEmbeddings(model_name=model, # TODO: 支持Azure
embeddings = OpenAIEmbeddings(model_name=model,
openai_api_key=get_model_path(model),
chunk_size=CHUNK_SIZE)
elif 'bge-' in model:

View File

@ -1,7 +1,10 @@
from configs import CACHED_VS_NUM
from server.knowledge_base.kb_cache.base import *
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
from server.utils import load_local_embeddings
from server.knowledge_base.utils import get_vs_path
from langchain.vectorstores import FAISS
from langchain.vectorstores.faiss import FAISS
from langchain.schema import Document
import os
from langchain.schema import Document
@ -38,9 +41,9 @@ class _FaissPool(CachePool):
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> FAISS:
embeddings = embeddings_pool.load_embeddings(embed_model, embed_device)
# TODO: 整个Embeddings加载逻辑有些混乱待清理
# create an empty vector store
embeddings = EmbeddingsFunAdapter(embed_model)
doc = Document(page_content="init", metadata={})
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True)
ids = list(vector_store.docstore._dict.keys())
@ -133,7 +136,7 @@ if __name__ == "__main__":
def worker(vs_name: str, name: str):
vs_name = "samples"
time.sleep(random.randint(1, 5))
embeddings = embeddings_pool.load_embeddings()
embeddings = load_local_embeddings()
r = random.randint(1, 3)
with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs:

View File

@ -21,12 +21,15 @@ from server.db.repository.knowledge_file_repository import (
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
EMBEDDING_MODEL, KB_INFO)
from server.knowledge_base.utils import (
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
get_kb_path, get_doc_path, KnowledgeFile,
list_kbs_from_folder, list_files_from_folder,
)
from server.utils import embedding_device
from typing import List, Union, Dict, Optional
from server.embeddings_api import embed_texts
from server.embeddings_api import embed_documents
class SupportedVSType:
FAISS = 'faiss'
@ -48,8 +51,6 @@ class KBService(ABC):
self.kb_path = get_kb_path(self.kb_name)
self.doc_path = get_doc_path(self.kb_name)
self.do_init()
def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
return load_embeddings(self.embed_model, embed_device)
def save_vector_store(self):
'''
@ -83,6 +84,12 @@ class KBService(ABC):
status = delete_kb_from_db(self.kb_name)
return status
def _docs_to_embeddings(self, docs: List[Document]) -> Dict:
'''
List[Document] 转化为 VectorStore.add_embeddings 可以接受的参数
'''
return embed_documents(docs=docs, embed_model=self.embed_model, to_query=False)
def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
"""
向知识库添加文件
@ -149,8 +156,7 @@ class KBService(ABC):
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
):
embeddings = self._load_embeddings()
docs = self.do_search(query, top_k, score_threshold, embeddings)
docs = self.do_search(query, top_k, score_threshold)
return docs
def get_doc_by_id(self, id: str) -> Optional[Document]:
@ -346,24 +352,26 @@ def get_kb_file_details(kb_name: str) -> List[Dict]:
class EmbeddingsFunAdapter(Embeddings):
def __init__(self, embeddings: Embeddings):
self.embeddings = embeddings
def __init__(self, embed_model: str = EMBEDDING_MODEL):
self.embed_model = embed_model
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return normalize(self.embeddings.embed_documents(texts))
embeddings = embed_texts(texts=texts, embed_model=self.embed_model, to_query=False).data
return normalize(embeddings)
def embed_query(self, text: str) -> List[float]:
query_embed = self.embeddings.embed_query(text)
embeddings = embed_texts(texts=[text], embed_model=self.embed_model, to_query=True).data
query_embed = embeddings[0]
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
normalized_query_embed = normalize(query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
return await normalize(self.embeddings.aembed_documents(texts))
# TODO: 暂不支持异步
# async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
# return normalize(await self.embeddings.aembed_documents(texts))
async def aembed_query(self, text: str) -> List[float]:
return await normalize(self.embeddings.aembed_query(text))
# async def aembed_query(self, text: str) -> List[float]:
# return normalize(await self.embeddings.aembed_query(text))
def score_threshold_process(score_threshold, k, docs):

View File

@ -9,10 +9,9 @@ from configs import (
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
from server.knowledge_base.utils import KnowledgeFile
from langchain.embeddings.base import Embeddings
from typing import List, Dict, Optional
from langchain.docstore.document import Document
from server.utils import torch_gc
from langchain.docstore.document import Document
from typing import List, Dict, Optional
class FaissKBService(KBService):
@ -58,7 +57,6 @@ class FaissKBService(KBService):
query: str,
top_k: int,
score_threshold: float = SCORE_THRESHOLD,
embeddings: Embeddings = None,
) -> List[Document]:
with self.load_vector_store().acquire() as vs:
docs = vs.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
@ -68,8 +66,11 @@ class FaissKBService(KBService):
docs: List[Document],
**kwargs,
) -> List[Dict]:
data = self._docs_to_embeddings(docs) # 将向量化单独出来可以减少向量库的锁定时间
with self.load_vector_store().acquire() as vs:
ids = vs.add_documents(docs)
ids = vs.add_embeddings(text_embeddings=zip(data["texts"], data["embeddings"]),
metadatas=data["metadatas"])
if not kwargs.get("not_refresh_vs_cache"):
vs.save_local(self.vs_path)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]

View File

@ -1,8 +1,7 @@
from typing import List, Dict, Optional
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import Milvus
from langchain.vectorstores.milvus import Milvus
from configs import kbs_config
@ -46,10 +45,8 @@ class MilvusKBService(KBService):
def vs_type(self) -> str:
return SupportedVSType.MILVUS
def _load_milvus(self, embeddings: Embeddings = None):
if embeddings is None:
embeddings = self._load_embeddings()
self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(embeddings),
def _load_milvus(self):
self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(self.embed_model),
collection_name=self.kb_name, connection_args=kbs_config.get("milvus"))
def do_init(self):
@ -60,8 +57,8 @@ class MilvusKBService(KBService):
self.milvus.col.release()
self.milvus.col.drop()
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings))
def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_milvus()
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:

View File

@ -1,28 +1,22 @@
import json
from typing import List, Dict, Optional
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import PGVector
from langchain.vectorstores.pgvector import DistanceStrategy
from langchain.vectorstores.pgvector import PGVector, DistanceStrategy
from sqlalchemy import text
from configs import kbs_config
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
score_threshold_process
from server.knowledge_base.utils import load_embeddings, KnowledgeFile
from server.utils import embedding_device as get_embedding_device
from server.knowledge_base.utils import KnowledgeFile
class PGKBService(KBService):
pg_vector: PGVector
def _load_pg_vector(self, embedding_device: str = get_embedding_device(), embeddings: Embeddings = None):
_embeddings = embeddings
if _embeddings is None:
_embeddings = load_embeddings(self.embed_model, embedding_device)
self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(_embeddings),
def _load_pg_vector(self):
self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(self.embed_model),
collection_name=self.kb_name,
distance_strategy=DistanceStrategy.EUCLIDEAN,
connection_string=kbs_config.get("pg").get("connection_uri"))
@ -57,8 +51,8 @@ class PGKBService(KBService):
'''))
connect.commit()
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
self._load_pg_vector(embeddings=embeddings)
def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_pg_vector()
return score_threshold_process(score_threshold, top_k,
self.pg_vector.similarity_search_with_score(query, top_k))

View File

@ -43,11 +43,9 @@ class ZillizKBService(KBService):
def vs_type(self) -> str:
return SupportedVSType.ZILLIZ
def _load_zilliz(self, embeddings: Embeddings = None):
if embeddings is None:
embeddings = self._load_embeddings()
def _load_zilliz(self):
zilliz_args = kbs_config.get("zilliz")
self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(embeddings),
self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(self.embed_model),
collection_name=self.kb_name, connection_args=zilliz_args)
@ -59,8 +57,8 @@ class ZillizKBService(KBService):
self.zilliz.col.release()
self.zilliz.col.drop()
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
self._load_zilliz(embeddings=EmbeddingsFunAdapter(embeddings))
def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_zilliz()
return score_threshold_process(score_threshold, top_k, self.zilliz.similarity_search_with_score(query, top_k))
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:

View File

@ -1,8 +1,4 @@
import os
import sys
sys.path.append("/home/congyin/Code/Project_Langchain_0814/Langchain-Chatchat")
from transformers import AutoTokenizer
from configs import (
EMBEDDING_MODEL,
KB_ROOT_PATH,
@ -22,7 +18,6 @@ from langchain.docstore.document import Document
from langchain.text_splitter import TextSplitter
from pathlib import Path
import json
from concurrent.futures import ThreadPoolExecutor
from server.utils import run_in_thread_pool, embedding_device, get_model_worker_config
import io
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
@ -62,14 +57,6 @@ def list_files_from_folder(kb_name: str):
if os.path.isfile(os.path.join(doc_path, file))]
def load_embeddings(model: str = EMBEDDING_MODEL, device: str = embedding_device()):
'''
从缓存中加载embeddings可以避免多线程时竞争加载
'''
from server.knowledge_base.kb_cache.base import embeddings_pool
return embeddings_pool.load_embeddings(model=model, device=device)
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
"UnstructuredMarkdownLoader": ['.md'],
"CustomJSONLoader": [".json"],
@ -239,6 +226,7 @@ def make_text_splitter(
from langchain.text_splitter import CharacterTextSplitter
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
else: ## 字符长度加载
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
trust_remote_code=True)
@ -358,7 +346,6 @@ def files2docs_in_thread(
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
pool: ThreadPoolExecutor = None,
) -> Generator:
'''
利用多线程批量将磁盘文件转化成langchain Document.
@ -396,7 +383,7 @@ def files2docs_in_thread(
except Exception as e:
yield False, (kb_name, filename, str(e))
for result in run_in_thread_pool(func=file2docs, params=kwargs_list, pool=pool):
for result in run_in_thread_pool(func=file2docs, params=kwargs_list):
yield result

View File

@ -2,6 +2,7 @@ from fastapi import Body
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
get_httpx_client, get_model_worker_config)
from copy import deepcopy
def list_running_models(
@ -31,16 +32,16 @@ def list_config_models() -> BaseResponse:
'''
从本地获取configs中配置的模型列表
'''
configs = list_config_llm_models()
configs = {}
# 删除ONLINE_MODEL配置中的敏感信息
for config in configs["online"].values():
del_keys = set(["worker_class"])
for k in config:
if "key" in k.lower() or "secret" in k.lower():
del_keys.add(k)
for k in del_keys:
config.pop(k, None)
for name, config in list_config_llm_models()["online"].items():
configs[name] = {}
for k, v in config.items():
if not (k == "worker_class"
or "key" in k.lower()
or "secret" in k.lower()
or k.lower().endswith("id")):
configs[name][k] = v
return BaseResponse(data=configs)
@ -51,14 +52,14 @@ def get_model_config(
'''
获取LLM模型配置项合并后的
'''
config = get_model_worker_config(model_name=model_name)
config = {}
# 删除ONLINE_MODEL配置中的敏感信息
del_keys = set(["worker_class"])
for k in config:
if "key" in k.lower() or "secret" in k.lower():
del_keys.add(k)
for k in del_keys:
config.pop(k, None)
for k, v in get_model_worker_config(model_name=model_name).items():
if not (k == "worker_class"
or "key" in k.lower()
or "secret" in k.lower()
or k.lower().endswith("id")):
config[k] = v
return BaseResponse(data=config)

View File

@ -180,7 +180,7 @@ class ApiModelWorker(BaseModelWorker):
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
'''
执行Embeddings的方法默认使用模块里面的embed_documents函数
要求返回形式{"code": int, "embeddings": List[List[float]], "msg": str}
要求返回形式{"code": int, "data": List[List[float]], "msg": str}
'''
return {"code": 500, "msg": f"{self.model_names[0]}未实现embeddings功能"}

View File

@ -99,7 +99,7 @@ class MiniMaxWorker(ApiModelWorker):
with get_httpx_client() as client:
r = client.post(url, headers=headers, json=data).json()
if embeddings := r.get("vectors"):
return {"code": 200, "embeddings": embeddings}
return {"code": 200, "data": embeddings}
elif error := r.get("base_resp"):
return {"code": error["status_code"], "msg": error["status_msg"]}

View File

@ -172,7 +172,7 @@ class QianFanWorker(ApiModelWorker):
resp = client.post(url, json={"input": params.texts}).json()
if "error_cdoe" not in resp:
embeddings = [x["embedding"] for x in resp.get("data", [])]
return {"code": 200, "embeddings": embeddings}
return {"code": 200, "data": embeddings}
else:
return {"code": resp["error_code"], "msg": resp["error_msg"]}

View File

@ -68,7 +68,7 @@ class QwenWorker(ApiModelWorker):
return {"code": resp["status_code"], "msg": resp.message}
else:
embeddings = [x["embedding"] for x in resp["output"]["embeddings"]]
return {"code": 200, "embeddings": embeddings}
return {"code": 200, "data": embeddings}
def get_embeddings(self, params):
# TODO: 支持embeddings

View File

@ -59,7 +59,7 @@ class ChatGLMWorker(ApiModelWorker):
except Exception as e:
return {"code": 500, "msg": f"对文本向量化时出错:{e}"}
return {"code": 200, "embeddings": embeddings}
return {"code": 200, "data": embeddings}
def get_embeddings(self, params):
# TODO: 支持embeddings

View File

@ -14,7 +14,6 @@ from langchain.llms import OpenAI, AzureOpenAI, Anthropic
import httpx
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
thread_pool = ThreadPoolExecutor(os.cpu_count())
async def wrap_done(fn: Awaitable, event: asyncio.Event):
@ -368,8 +367,8 @@ def MakeFastAPIOffline(
redoc_favicon_url=favicon,
)
# 从model_config中获取模型信息
# 从model_config中获取模型信息
def list_embed_models() -> List[str]:
'''
@ -432,8 +431,8 @@ def get_model_worker_config(model_name: str = None) -> dict:
from server import model_workers
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
config.update(ONLINE_LLM_MODEL.get(model_name, {}))
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
config.update(ONLINE_LLM_MODEL.get(model_name, {}).copy())
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}).copy())
if model_name in ONLINE_LLM_MODEL:
@ -611,20 +610,18 @@ def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
def run_in_thread_pool(
func: Callable,
params: List[Dict] = [],
pool: ThreadPoolExecutor = None,
) -> Generator:
'''
在线程池中批量运行任务并将运行结果以生成器的形式返回
请确保任务中的所有操作是线程安全的任务函数请全部使用关键字参数
'''
tasks = []
pool = pool or thread_pool
with ThreadPoolExecutor() as pool:
for kwargs in params:
thread = pool.submit(func, **kwargs)
tasks.append(thread)
for obj in as_completed(tasks):
for obj in as_completed(tasks): # TODO: Ctrl+c无法停止
yield obj.result()
@ -703,7 +700,6 @@ def get_server_configs() -> Dict:
)
from configs.model_config import (
LLM_MODEL,
EMBEDDING_MODEL,
HISTORY_LEN,
TEMPERATURE,
)
@ -728,3 +724,14 @@ def list_online_embed_models() -> List[str]:
if worker_class is not None and worker_class.can_embedding():
ret.append(k)
return ret
def load_local_embeddings(model: str = None, device: str = embedding_device()):
'''
从缓存中加载embeddings可以避免多线程时竞争加载
'''
from server.knowledge_base.kb_cache.base import embeddings_pool
from configs import EMBEDDING_MODEL
model = model or EMBEDDING_MODEL
return embeddings_pool.load_embeddings(model=model, device=device)

View File

@ -16,6 +16,8 @@ for x in list_config_llm_models()["online"]:
workers.append(x)
print(f"all workers to test: {workers}")
# workers = ["qianfan-api"]
@pytest.mark.parametrize("worker", workers)
def test_chat(worker):
@ -49,8 +51,8 @@ def test_embeddings(worker):
pprint(resp, depth=2)
assert resp["code"] == 200
assert "embeddings" in resp
embeddings = resp["embeddings"]
assert "data" in resp
embeddings = resp["data"]
assert isinstance(embeddings, list) and len(embeddings) > 0
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
assert isinstance(embeddings[0][0], float)

View File

@ -9,7 +9,7 @@ from typing import Literal, Dict, Tuple
from configs import (kbs_config,
EMBEDDING_MODEL, DEFAULT_VS_TYPE,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
from server.utils import list_embed_models
from server.utils import list_embed_models, list_online_embed_models
import os
import time
@ -103,7 +103,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
key="vs_type",
)
embed_models = list_embed_models()
embed_models = list_embed_models() + list_online_embed_models()
embed_model = cols[1].selectbox(
"Embedding 模型",

View File

@ -853,6 +853,26 @@ class ApiRequest:
else:
return ret_sync()
def embed_texts(
self,
texts: List[str],
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
) -> List[List[float]]:
'''
对文本进行向量化可选模型包括本地 embed_models 和支持 embeddings 的在线模型
'''
data = {
"texts": texts,
"embed_model": embed_model,
"to_query": to_query,
}
resp = self.post(
"/other/embed_texts",
json=data,
)
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data"))
class AsyncApiRequest(ApiRequest):
def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT):