diff --git a/server/api.py b/server/api.py index 5fc80c4..b0c2ba2 100644 --- a/server/api.py +++ b/server/api.py @@ -81,6 +81,8 @@ def mount_app_routes(app: FastAPI, run_mode: str = None): # 知识库相关接口 mount_knowledge_routes(app) + # 摘要相关接口 + mount_filename_summary_routes(app) # LLM模型相关接口 app.post("/llm_model/list_running_models", @@ -230,6 +232,20 @@ def mount_knowledge_routes(app: FastAPI): )(upload_temp_docs) +def mount_filename_summary_routes(app: FastAPI): + from server.knowledge_base.kb_summary_api import (summary_file_to_vector_store, recreate_summary_vector_store) + + app.post("/knowledge_base/kb_summary_api/summary_file_to_vector_store", + tags=["Knowledge kb_summary_api Management"], + summary="文件摘要" + )(summary_file_to_vector_store) + app.post("/knowledge_base/kb_summary_api/recreate_summary_vector_store", + tags=["Knowledge kb_summary_api Management"], + summary="重建文件摘要" + )(recreate_summary_vector_store) + + + def run_api(host, port, **kwargs): if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"): uvicorn.run(app, diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 0af8de5..a357bd7 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -28,6 +28,7 @@ from typing import List, Union, Dict, Optional from server.embeddings_api import embed_texts from server.embeddings_api import embed_documents +from server.knowledge_base.model.kb_document_model import DocumentWithVSId def normalize(embeddings: List[List[float]]) -> np.ndarray: @@ -183,12 +184,22 @@ class KBService(ABC): def get_doc_by_ids(self, ids: List[str]) -> List[Document]: return [] - def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[Document]: + def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[DocumentWithVSId]: ''' 通过file_name或metadata检索Document ''' doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata) - docs = self.get_doc_by_ids([x["id"] for x in doc_infos]) + docs = [] + for x in doc_infos: + doc_info_s = self.get_doc_by_ids([x["id"]]) + if doc_info_s is not None and doc_info_s != []: + # 处理非空的情况 + doc_with_id = DocumentWithVSId(**doc_info_s[0].dict(), id=x["id"]) + docs.append(doc_with_id) + else: + # 处理空的情况 + # 可以选择跳过当前循环迭代或执行其他操作 + pass return docs @abstractmethod diff --git a/server/knowledge_base/kb_summary_api.py b/server/knowledge_base/kb_summary_api.py new file mode 100644 index 0000000..9e06380 --- /dev/null +++ b/server/knowledge_base/kb_summary_api.py @@ -0,0 +1,168 @@ +from fastapi import Body +from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, + OVERLAP_SIZE, + logger, log_verbose, ) +from server.knowledge_base.utils import (list_files_from_folder) +from fastapi.responses import StreamingResponse +import json +from server.knowledge_base.kb_service.base import KBServiceFactory +from typing import List, Optional +from server.knowledge_base.kb_summary.base import KBSummaryService +from server.knowledge_base.kb_summary.summary_chunk import SummaryAdapter +from server.utils import wrap_done, get_ChatOpenAI +from configs import LLM_MODELS, TEMPERATURE + + +def recreate_summary_vector_store( + knowledge_base_name: str = Body(..., examples=["samples"]), + allow_empty_kb: bool = Body(True), + vs_type: str = Body(DEFAULT_VS_TYPE), + embed_model: str = Body(EMBEDDING_MODEL), + file_description: str = Body(''), + model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), + temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), + max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), +): + """ + 重建文件摘要 + :param max_tokens: + :param model_name: + :param temperature: + :param file_description: + :param knowledge_base_name: + :param allow_empty_kb: + :param vs_type: + :param embed_model: + :return: + """ + + def output(): + + kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model) + if not kb.exists() and not allow_empty_kb: + yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"} + else: + # 重新创建知识库 + kb_summary = KBSummaryService(knowledge_base_name, embed_model) + kb_summary.drop_kb_summary() + kb_summary.create_kb_summary() + + llm = get_ChatOpenAI( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + ) + reduce_llm = get_ChatOpenAI( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + ) + # 文本摘要适配器 + summary = SummaryAdapter.form_summary(llm=llm, + reduce_llm=reduce_llm, + overlap_size=OVERLAP_SIZE) + files = list_files_from_folder(knowledge_base_name) + + i = 0 + for i, file_name in enumerate(files): + + doc_infos = kb.list_docs(file_name=file_name) + docs = summary.summarize(kb_name=knowledge_base_name, + file_description=file_description, + docs=doc_infos) + + status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs) + if status_kb_summary: + logger.info(f"({i + 1} / {len(files)}): {file_name} 向量化总结完成") + yield json.dumps({ + "code": 200, + "msg": f"({i + 1} / {len(files)}): {file_name}", + "total": len(files), + "finished": i + 1, + "doc": file_name, + }, ensure_ascii=False) + else: + + msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。" + logger.error(msg) + yield json.dumps({ + "code": 500, + "msg": msg, + }) + i += 1 + + return StreamingResponse(output(), media_type="text/event-stream") + + +def summary_file_to_vector_store( + knowledge_base_name: str = Body(..., examples=["samples"]), + file_name: str = Body(..., examples=["test.pdf"]), + allow_empty_kb: bool = Body(True), + vs_type: str = Body(DEFAULT_VS_TYPE), + embed_model: str = Body(EMBEDDING_MODEL), + file_description: str = Body(''), + model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), + temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), + max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), +): + """ + 文件摘要 + :param model_name: + :param max_tokens: + :param temperature: + :param file_description: + :param file_name: + :param knowledge_base_name: + :param allow_empty_kb: + :param vs_type: + :param embed_model: + :return: + """ + + def output(): + kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model) + if not kb.exists() and not allow_empty_kb: + yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"} + else: + # 重新创建知识库 + kb_summary = KBSummaryService(knowledge_base_name, embed_model) + kb_summary.create_kb_summary() + + llm = get_ChatOpenAI( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + ) + reduce_llm = get_ChatOpenAI( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + ) + # 文本摘要适配器 + summary = SummaryAdapter.form_summary(llm=llm, + reduce_llm=reduce_llm, + overlap_size=OVERLAP_SIZE) + + doc_infos = kb.list_docs(file_name=file_name) + docs = summary.summarize(kb_name=knowledge_base_name, + file_description=file_description, + docs=doc_infos) + + status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs) + if status_kb_summary: + logger.info(f" {file_name} 向量化总结完成") + yield json.dumps({ + "code": 200, + "msg": f"{file_name} 向量化总结完成", + "doc": file_name, + }, ensure_ascii=False) + else: + + msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。" + logger.error(msg) + yield json.dumps({ + "code": 500, + "msg": msg, + }) + + return StreamingResponse(output(), media_type="text/event-stream") diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index b360147..bde6e1f 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -12,6 +12,8 @@ from server.knowledge_base.kb_service.base import KBServiceFactory from server.db.models.conversation_model import ConversationModel from server.db.models.message_model import MessageModel from server.db.repository.knowledge_file_repository import add_file_to_db # ensure Models are imported +from server.db.repository.knowledge_metadata_repository import add_summary_to_db + from server.db.base import Base, engine from server.db.session import session_scope import os diff --git a/tests/api/test_kb_summary_api.py b/tests/api/test_kb_summary_api.py new file mode 100644 index 0000000..4a84498 --- /dev/null +++ b/tests/api/test_kb_summary_api.py @@ -0,0 +1,26 @@ +import requests +import json +import sys +from pathlib import Path + +root_path = Path(__file__).parent.parent.parent +sys.path.append(str(root_path)) +from server.utils import api_address + +api_base_url = api_address() + +kb = "samples" +file_name = "/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/knowledge_base/samples/content/llm/大模型技术栈-实战与应用.md" + + +def test_summary_file_to_vector_store(api="/knowledge_base/kb_summary_api/summary_file_to_vector_store"): + url = api_base_url + api + print("\n文件摘要:") + r = requests.post(url, json={"knowledge_base_name": kb, + "file_name": file_name + }, stream=True) + for chunk in r.iter_content(None): + data = json.loads(chunk) + assert isinstance(data, dict) + assert data["code"] == 200 + print(data["msg"])