diff --git a/server/chat/chat.py b/server/chat/chat.py index 254a6dd..bac82f8 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -40,13 +40,12 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 callbacks = [callback] memory = None - if conversation_id: - message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id) - # 负责保存llm response到message db - conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id, - chat_type="llm_chat", - query=query) - callbacks.append(conversation_callback) + # 负责保存llm response到message db + message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id) + conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id, + chat_type="llm_chat", + query=query) + callbacks.append(conversation_callback) if isinstance(max_tokens, int) and max_tokens <= 0: max_tokens = None diff --git a/server/chat/file_chat.py b/server/chat/file_chat.py index ef4a522..a58cb29 100644 --- a/server/chat/file_chat.py +++ b/server/chat/file_chat.py @@ -125,7 +125,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples= callbacks=[callback], ) embed_func = EmbeddingsFunAdapter() - embeddings = embed_func.embed_query(query) + embeddings = await embed_func.aembed_query(query) with memo_faiss_pool.acquire(knowledge_id) as vs: docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold) docs = [x[0] for x in docs] diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index cd0aca3..a3ab68b 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -1,5 +1,6 @@ from fastapi import Body, Request from fastapi.responses import StreamingResponse +from fastapi.concurrency import run_in_threadpool from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE) from server.utils import wrap_done, get_ChatOpenAI from server.utils import BaseResponse, get_prompt_template @@ -10,9 +11,7 @@ import asyncio from langchain.prompts.chat import ChatPromptTemplate from server.chat.utils import History from server.knowledge_base.kb_service.base import KBServiceFactory -from server.knowledge_base.utils import get_doc_path import json -from pathlib import Path from urllib.parse import urlencode from server.knowledge_base.kb_doc_api import search_docs @@ -72,7 +71,11 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", max_tokens=max_tokens, callbacks=[callback], ) - docs = search_docs(query, knowledge_base_name, top_k, score_threshold) + docs = await run_in_threadpool(search_docs, + query=query, + knowledge_base_name=knowledge_base_name, + top_k=top_k, + score_threshold=score_threshold) context = "\n".join([doc.page_content for doc in docs]) if len(docs) == 0: # 如果没有找到相关文档,使用empty模板 prompt_template = get_prompt_template("knowledge_base_chat", "empty") diff --git a/server/embeddings_api.py b/server/embeddings_api.py index 80cd289..93555a3 100644 --- a/server/embeddings_api.py +++ b/server/embeddings_api.py @@ -3,6 +3,7 @@ from configs import EMBEDDING_MODEL, logger from server.model_workers.base import ApiEmbeddingsParams from server.utils import BaseResponse, get_model_worker_config, list_embed_models, list_online_embed_models from fastapi import Body +from fastapi.concurrency import run_in_threadpool from typing import Dict, List @@ -39,6 +40,32 @@ def embed_texts( logger.error(e) return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}") + +async def aembed_texts( + texts: List[str], + embed_model: str = EMBEDDING_MODEL, + to_query: bool = False, +) -> BaseResponse: + ''' + 对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]]) + ''' + try: + if embed_model in list_embed_models(): # 使用本地Embeddings模型 + from server.utils import load_local_embeddings + + embeddings = load_local_embeddings(model=embed_model) + return BaseResponse(data=await embeddings.aembed_documents(texts)) + + if embed_model in list_online_embed_models(): # 使用在线API + return await run_in_threadpool(embed_texts, + texts=texts, + embed_model=embed_model, + to_query=to_query) + except Exception as e: + logger.error(e) + return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}") + + def embed_texts_endpoint( texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]), embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型,除了本地部署的Embedding模型,也支持在线API({online_embed_models})提供的嵌入服务。"), diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index a357bd7..1d86d30 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -26,8 +26,7 @@ from server.knowledge_base.utils import ( from typing import List, Union, Dict, Optional -from server.embeddings_api import embed_texts -from server.embeddings_api import embed_documents +from server.embeddings_api import embed_texts, aembed_texts, embed_documents from server.knowledge_base.model.kb_document_model import DocumentWithVSId @@ -405,12 +404,16 @@ class EmbeddingsFunAdapter(Embeddings): normalized_query_embed = normalize(query_embed_2d) return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回 - # TODO: 暂不支持异步 - # async def aembed_documents(self, texts: List[str]) -> List[List[float]]: - # return normalize(await self.embeddings.aembed_documents(texts)) + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + embeddings = (await aembed_texts(texts=texts, embed_model=self.embed_model, to_query=False)).data + return normalize(embeddings).tolist() - # async def aembed_query(self, text: str) -> List[float]: - # return normalize(await self.embeddings.aembed_query(text)) + async def aembed_query(self, text: str) -> List[float]: + embeddings = (await aembed_texts(texts=[text], embed_model=self.embed_model, to_query=True)).data + query_embed = embeddings[0] + query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组 + normalized_query_embed = normalize(query_embed_2d) + return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回 def score_threshold_process(score_threshold, k, docs): diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 4fef821..38565fa 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -71,7 +71,7 @@ def list_files_from_folder(kb_name: str): for target_entry in target_it: process_entry(target_entry) elif entry.is_file(): - result.append(entry.path) + result.append(os.path.relpath(entry.path, doc_path)) elif entry.is_dir(): with os.scandir(entry.path) as it: for sub_entry in it: