diff --git a/server/api.py b/server/api.py index b0c2ba2..7444d4b 100644 --- a/server/api.py +++ b/server/api.py @@ -233,15 +233,21 @@ def mount_knowledge_routes(app: FastAPI): def mount_filename_summary_routes(app: FastAPI): - from server.knowledge_base.kb_summary_api import (summary_file_to_vector_store, recreate_summary_vector_store) + from server.knowledge_base.kb_summary_api import (summary_file_to_vector_store, recreate_summary_vector_store, + summary_doc_ids_to_vector_store) app.post("/knowledge_base/kb_summary_api/summary_file_to_vector_store", tags=["Knowledge kb_summary_api Management"], - summary="文件摘要" + summary="单个知识库根据文件名称摘要" )(summary_file_to_vector_store) + app.post("/knowledge_base/kb_summary_api/summary_doc_ids_to_vector_store", + tags=["Knowledge kb_summary_api Management"], + summary="单个知识库根据doc_ids摘要", + response_model=BaseResponse, + )(summary_doc_ids_to_vector_store) app.post("/knowledge_base/kb_summary_api/recreate_summary_vector_store", tags=["Knowledge kb_summary_api Management"], - summary="重建文件摘要" + summary="重建单个知识库文件摘要" )(recreate_summary_vector_store) diff --git a/server/knowledge_base/kb_summary/summary_chunk.py b/server/knowledge_base/kb_summary/summary_chunk.py index b35adc1..0b88f23 100644 --- a/server/knowledge_base/kb_summary/summary_chunk.py +++ b/server/knowledge_base/kb_summary/summary_chunk.py @@ -90,7 +90,6 @@ class SummaryAdapter: token_max=token_max) def summarize(self, - kb_name: str, file_description: str, docs: List[DocumentWithVSId] = [] ) -> List[Document]: @@ -105,12 +104,10 @@ class SummaryAdapter: asyncio.set_event_loop(loop) # 同步调用协程代码 - return loop.run_until_complete(self.asummarize(kb_name=kb_name, - file_description=file_description, + return loop.run_until_complete(self.asummarize(file_description=file_description, docs=docs)) async def asummarize(self, - kb_name: str, file_description: str, docs: List[DocumentWithVSId] = []) -> List[Document]: diff --git a/server/knowledge_base/kb_summary_api.py b/server/knowledge_base/kb_summary_api.py index 9e06380..aac4de7 100644 --- a/server/knowledge_base/kb_summary_api.py +++ b/server/knowledge_base/kb_summary_api.py @@ -9,9 +9,9 @@ from server.knowledge_base.kb_service.base import KBServiceFactory from typing import List, Optional from server.knowledge_base.kb_summary.base import KBSummaryService from server.knowledge_base.kb_summary.summary_chunk import SummaryAdapter -from server.utils import wrap_done, get_ChatOpenAI +from server.utils import wrap_done, get_ChatOpenAI, BaseResponse from configs import LLM_MODELS, TEMPERATURE - +from server.knowledge_base.model.kb_document_model import DocumentWithVSId def recreate_summary_vector_store( knowledge_base_name: str = Body(..., examples=["samples"]), @@ -24,7 +24,7 @@ def recreate_summary_vector_store( max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), ): """ - 重建文件摘要 + 重建单个知识库文件摘要 :param max_tokens: :param model_name: :param temperature: @@ -67,13 +67,12 @@ def recreate_summary_vector_store( for i, file_name in enumerate(files): doc_infos = kb.list_docs(file_name=file_name) - docs = summary.summarize(kb_name=knowledge_base_name, - file_description=file_description, + docs = summary.summarize(file_description=file_description, docs=doc_infos) status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs) if status_kb_summary: - logger.info(f"({i + 1} / {len(files)}): {file_name} 向量化总结完成") + logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成") yield json.dumps({ "code": 200, "msg": f"({i + 1} / {len(files)}): {file_name}", @@ -106,7 +105,7 @@ def summary_file_to_vector_store( max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), ): """ - 文件摘要 + 单个知识库根据文件名称摘要 :param model_name: :param max_tokens: :param temperature: @@ -144,16 +143,15 @@ def summary_file_to_vector_store( overlap_size=OVERLAP_SIZE) doc_infos = kb.list_docs(file_name=file_name) - docs = summary.summarize(kb_name=knowledge_base_name, - file_description=file_description, + docs = summary.summarize(file_description=file_description, docs=doc_infos) status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs) if status_kb_summary: - logger.info(f" {file_name} 向量化总结完成") + logger.info(f" {file_name} 总结完成") yield json.dumps({ "code": 200, - "msg": f"{file_name} 向量化总结完成", + "msg": f"{file_name} 总结完成", "doc": file_name, }, ensure_ascii=False) else: @@ -166,3 +164,57 @@ def summary_file_to_vector_store( }) return StreamingResponse(output(), media_type="text/event-stream") + + +def summary_doc_ids_to_vector_store( + knowledge_base_name: str = Body(..., examples=["samples"]), + doc_ids: List = Body([], examples=[["uuid"]]), + vs_type: str = Body(DEFAULT_VS_TYPE), + embed_model: str = Body(EMBEDDING_MODEL), + file_description: str = Body(''), + model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), + temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), + max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), +) -> BaseResponse: + """ + 单个知识库根据doc_ids摘要 + :param knowledge_base_name: + :param doc_ids: + :param model_name: + :param max_tokens: + :param temperature: + :param file_description: + :param vs_type: + :param embed_model: + :return: + """ + kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model) + if not kb.exists(): + return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data={}) + else: + llm = get_ChatOpenAI( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + ) + reduce_llm = get_ChatOpenAI( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + ) + # 文本摘要适配器 + summary = SummaryAdapter.form_summary(llm=llm, + reduce_llm=reduce_llm, + overlap_size=OVERLAP_SIZE) + + doc_infos = kb.get_doc_by_ids(ids=doc_ids) + # doc_infos转换成DocumentWithVSId包装的对象 + doc_info_with_ids = [DocumentWithVSId(**doc.dict(), id=with_id) for with_id, doc in zip(doc_ids, doc_infos)] + + docs = summary.summarize(file_description=file_description, + docs=doc_info_with_ids) + + # 将docs转换成dict + resp_summarize = [{**doc.dict()} for doc in docs] + + return BaseResponse(code=200, msg="总结完成", data={"summarize": resp_summarize}) diff --git a/tests/api/test_kb_summary_api.py b/tests/api/test_kb_summary_api.py index 4a84498..d59c203 100644 --- a/tests/api/test_kb_summary_api.py +++ b/tests/api/test_kb_summary_api.py @@ -11,6 +11,11 @@ api_base_url = api_address() kb = "samples" file_name = "/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/knowledge_base/samples/content/llm/大模型技术栈-实战与应用.md" +doc_ids = [ + "357d580f-fdf7-495c-b58b-595a398284e8", + "c7338773-2e83-4671-b237-1ad20335b0f0", + "6da613d1-327d-466f-8c1a-b32e6f461f47" +] def test_summary_file_to_vector_store(api="/knowledge_base/kb_summary_api/summary_file_to_vector_store"): @@ -24,3 +29,16 @@ def test_summary_file_to_vector_store(api="/knowledge_base/kb_summary_api/summar assert isinstance(data, dict) assert data["code"] == 200 print(data["msg"]) + + +def test_summary_doc_ids_to_vector_store(api="/knowledge_base/kb_summary_api/summary_doc_ids_to_vector_store"): + url = api_base_url + api + print("\n文件摘要:") + r = requests.post(url, json={"knowledge_base_name": kb, + "doc_ids": doc_ids + }, stream=True) + for chunk in r.iter_content(None): + data = json.loads(chunk) + assert isinstance(data, dict) + assert data["code"] == 200 + print(data)