Merge branch 'dev_summary' into dev_command_summary
实现summary_chunk 文档分段总结业务实现 使用 MapReduceDocumentsChain 生成摘要 # Conflicts: # server/api.py # server/knowledge_base/kb_doc_api.py # server/knowledge_base/kb_service/base.py # server/knowledge_base/migrate.py
This commit is contained in:
parent
f57837c07a
commit
248db46187
|
|
@ -81,6 +81,8 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
|
||||||
|
|
||||||
# 知识库相关接口
|
# 知识库相关接口
|
||||||
mount_knowledge_routes(app)
|
mount_knowledge_routes(app)
|
||||||
|
# 摘要相关接口
|
||||||
|
mount_filename_summary_routes(app)
|
||||||
|
|
||||||
# LLM模型相关接口
|
# LLM模型相关接口
|
||||||
app.post("/llm_model/list_running_models",
|
app.post("/llm_model/list_running_models",
|
||||||
|
|
@ -230,6 +232,20 @@ def mount_knowledge_routes(app: FastAPI):
|
||||||
)(upload_temp_docs)
|
)(upload_temp_docs)
|
||||||
|
|
||||||
|
|
||||||
|
def mount_filename_summary_routes(app: FastAPI):
|
||||||
|
from server.knowledge_base.kb_summary_api import (summary_file_to_vector_store, recreate_summary_vector_store)
|
||||||
|
|
||||||
|
app.post("/knowledge_base/kb_summary_api/summary_file_to_vector_store",
|
||||||
|
tags=["Knowledge kb_summary_api Management"],
|
||||||
|
summary="文件摘要"
|
||||||
|
)(summary_file_to_vector_store)
|
||||||
|
app.post("/knowledge_base/kb_summary_api/recreate_summary_vector_store",
|
||||||
|
tags=["Knowledge kb_summary_api Management"],
|
||||||
|
summary="重建文件摘要"
|
||||||
|
)(recreate_summary_vector_store)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run_api(host, port, **kwargs):
|
def run_api(host, port, **kwargs):
|
||||||
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
||||||
uvicorn.run(app,
|
uvicorn.run(app,
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ from typing import List, Union, Dict, Optional
|
||||||
|
|
||||||
from server.embeddings_api import embed_texts
|
from server.embeddings_api import embed_texts
|
||||||
from server.embeddings_api import embed_documents
|
from server.embeddings_api import embed_documents
|
||||||
|
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||||
|
|
||||||
|
|
||||||
def normalize(embeddings: List[List[float]]) -> np.ndarray:
|
def normalize(embeddings: List[List[float]]) -> np.ndarray:
|
||||||
|
|
@ -183,12 +184,22 @@ class KBService(ABC):
|
||||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[Document]:
|
def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[DocumentWithVSId]:
|
||||||
'''
|
'''
|
||||||
通过file_name或metadata检索Document
|
通过file_name或metadata检索Document
|
||||||
'''
|
'''
|
||||||
doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
|
doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
|
||||||
docs = self.get_doc_by_ids([x["id"] for x in doc_infos])
|
docs = []
|
||||||
|
for x in doc_infos:
|
||||||
|
doc_info_s = self.get_doc_by_ids([x["id"]])
|
||||||
|
if doc_info_s is not None and doc_info_s != []:
|
||||||
|
# 处理非空的情况
|
||||||
|
doc_with_id = DocumentWithVSId(**doc_info_s[0].dict(), id=x["id"])
|
||||||
|
docs.append(doc_with_id)
|
||||||
|
else:
|
||||||
|
# 处理空的情况
|
||||||
|
# 可以选择跳过当前循环迭代或执行其他操作
|
||||||
|
pass
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,168 @@
|
||||||
|
from fastapi import Body
|
||||||
|
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
|
||||||
|
OVERLAP_SIZE,
|
||||||
|
logger, log_verbose, )
|
||||||
|
from server.knowledge_base.utils import (list_files_from_folder)
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
import json
|
||||||
|
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||||
|
from typing import List, Optional
|
||||||
|
from server.knowledge_base.kb_summary.base import KBSummaryService
|
||||||
|
from server.knowledge_base.kb_summary.summary_chunk import SummaryAdapter
|
||||||
|
from server.utils import wrap_done, get_ChatOpenAI
|
||||||
|
from configs import LLM_MODELS, TEMPERATURE
|
||||||
|
|
||||||
|
|
||||||
|
def recreate_summary_vector_store(
|
||||||
|
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||||
|
allow_empty_kb: bool = Body(True),
|
||||||
|
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||||
|
embed_model: str = Body(EMBEDDING_MODEL),
|
||||||
|
file_description: str = Body(''),
|
||||||
|
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||||
|
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
|
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
重建文件摘要
|
||||||
|
:param max_tokens:
|
||||||
|
:param model_name:
|
||||||
|
:param temperature:
|
||||||
|
:param file_description:
|
||||||
|
:param knowledge_base_name:
|
||||||
|
:param allow_empty_kb:
|
||||||
|
:param vs_type:
|
||||||
|
:param embed_model:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
def output():
|
||||||
|
|
||||||
|
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
|
||||||
|
if not kb.exists() and not allow_empty_kb:
|
||||||
|
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
||||||
|
else:
|
||||||
|
# 重新创建知识库
|
||||||
|
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
|
||||||
|
kb_summary.drop_kb_summary()
|
||||||
|
kb_summary.create_kb_summary()
|
||||||
|
|
||||||
|
llm = get_ChatOpenAI(
|
||||||
|
model_name=model_name,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
reduce_llm = get_ChatOpenAI(
|
||||||
|
model_name=model_name,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
# 文本摘要适配器
|
||||||
|
summary = SummaryAdapter.form_summary(llm=llm,
|
||||||
|
reduce_llm=reduce_llm,
|
||||||
|
overlap_size=OVERLAP_SIZE)
|
||||||
|
files = list_files_from_folder(knowledge_base_name)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
for i, file_name in enumerate(files):
|
||||||
|
|
||||||
|
doc_infos = kb.list_docs(file_name=file_name)
|
||||||
|
docs = summary.summarize(kb_name=knowledge_base_name,
|
||||||
|
file_description=file_description,
|
||||||
|
docs=doc_infos)
|
||||||
|
|
||||||
|
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
|
||||||
|
if status_kb_summary:
|
||||||
|
logger.info(f"({i + 1} / {len(files)}): {file_name} 向量化总结完成")
|
||||||
|
yield json.dumps({
|
||||||
|
"code": 200,
|
||||||
|
"msg": f"({i + 1} / {len(files)}): {file_name}",
|
||||||
|
"total": len(files),
|
||||||
|
"finished": i + 1,
|
||||||
|
"doc": file_name,
|
||||||
|
}, ensure_ascii=False)
|
||||||
|
else:
|
||||||
|
|
||||||
|
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
|
||||||
|
logger.error(msg)
|
||||||
|
yield json.dumps({
|
||||||
|
"code": 500,
|
||||||
|
"msg": msg,
|
||||||
|
})
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return StreamingResponse(output(), media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
|
def summary_file_to_vector_store(
|
||||||
|
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||||
|
file_name: str = Body(..., examples=["test.pdf"]),
|
||||||
|
allow_empty_kb: bool = Body(True),
|
||||||
|
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||||
|
embed_model: str = Body(EMBEDDING_MODEL),
|
||||||
|
file_description: str = Body(''),
|
||||||
|
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||||
|
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
|
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
文件摘要
|
||||||
|
:param model_name:
|
||||||
|
:param max_tokens:
|
||||||
|
:param temperature:
|
||||||
|
:param file_description:
|
||||||
|
:param file_name:
|
||||||
|
:param knowledge_base_name:
|
||||||
|
:param allow_empty_kb:
|
||||||
|
:param vs_type:
|
||||||
|
:param embed_model:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
def output():
|
||||||
|
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
|
||||||
|
if not kb.exists() and not allow_empty_kb:
|
||||||
|
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
||||||
|
else:
|
||||||
|
# 重新创建知识库
|
||||||
|
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
|
||||||
|
kb_summary.create_kb_summary()
|
||||||
|
|
||||||
|
llm = get_ChatOpenAI(
|
||||||
|
model_name=model_name,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
reduce_llm = get_ChatOpenAI(
|
||||||
|
model_name=model_name,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
# 文本摘要适配器
|
||||||
|
summary = SummaryAdapter.form_summary(llm=llm,
|
||||||
|
reduce_llm=reduce_llm,
|
||||||
|
overlap_size=OVERLAP_SIZE)
|
||||||
|
|
||||||
|
doc_infos = kb.list_docs(file_name=file_name)
|
||||||
|
docs = summary.summarize(kb_name=knowledge_base_name,
|
||||||
|
file_description=file_description,
|
||||||
|
docs=doc_infos)
|
||||||
|
|
||||||
|
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
|
||||||
|
if status_kb_summary:
|
||||||
|
logger.info(f" {file_name} 向量化总结完成")
|
||||||
|
yield json.dumps({
|
||||||
|
"code": 200,
|
||||||
|
"msg": f"{file_name} 向量化总结完成",
|
||||||
|
"doc": file_name,
|
||||||
|
}, ensure_ascii=False)
|
||||||
|
else:
|
||||||
|
|
||||||
|
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
|
||||||
|
logger.error(msg)
|
||||||
|
yield json.dumps({
|
||||||
|
"code": 500,
|
||||||
|
"msg": msg,
|
||||||
|
})
|
||||||
|
|
||||||
|
return StreamingResponse(output(), media_type="text/event-stream")
|
||||||
|
|
@ -12,6 +12,8 @@ from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||||
from server.db.models.conversation_model import ConversationModel
|
from server.db.models.conversation_model import ConversationModel
|
||||||
from server.db.models.message_model import MessageModel
|
from server.db.models.message_model import MessageModel
|
||||||
from server.db.repository.knowledge_file_repository import add_file_to_db # ensure Models are imported
|
from server.db.repository.knowledge_file_repository import add_file_to_db # ensure Models are imported
|
||||||
|
from server.db.repository.knowledge_metadata_repository import add_summary_to_db
|
||||||
|
|
||||||
from server.db.base import Base, engine
|
from server.db.base import Base, engine
|
||||||
from server.db.session import session_scope
|
from server.db.session import session_scope
|
||||||
import os
|
import os
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,26 @@
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
root_path = Path(__file__).parent.parent.parent
|
||||||
|
sys.path.append(str(root_path))
|
||||||
|
from server.utils import api_address
|
||||||
|
|
||||||
|
api_base_url = api_address()
|
||||||
|
|
||||||
|
kb = "samples"
|
||||||
|
file_name = "/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/knowledge_base/samples/content/llm/大模型技术栈-实战与应用.md"
|
||||||
|
|
||||||
|
|
||||||
|
def test_summary_file_to_vector_store(api="/knowledge_base/kb_summary_api/summary_file_to_vector_store"):
|
||||||
|
url = api_base_url + api
|
||||||
|
print("\n文件摘要:")
|
||||||
|
r = requests.post(url, json={"knowledge_base_name": kb,
|
||||||
|
"file_name": file_name
|
||||||
|
}, stream=True)
|
||||||
|
for chunk in r.iter_content(None):
|
||||||
|
data = json.loads(chunk)
|
||||||
|
assert isinstance(data, dict)
|
||||||
|
assert data["code"] == 200
|
||||||
|
print(data["msg"])
|
||||||
Loading…
Reference in New Issue