Langchain-Chatchat/libs/chatchat-server/chatchat/server/knowledge_base/kb_summary_api.py

267 lines
10 KiB
Python
Raw Normal View History

2024-12-20 16:04:03 +08:00
import asyncio
import json
from typing import List, Optional
from fastapi import Body
from sse_starlette import EventSourceResponse
from chatchat.settings import Settings
from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory
from chatchat.server.knowledge_base.kb_summary.base import KBSummaryService
from chatchat.server.knowledge_base.kb_summary.summary_chunk import SummaryAdapter
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
from chatchat.server.knowledge_base.utils import list_files_from_folder
from chatchat.server.utils import BaseResponse, get_ChatOpenAI, wrap_done, get_default_embedding
from chatchat.utils import build_logger
logger = build_logger()
def recreate_summary_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]),
allow_empty_kb: bool = Body(True),
vs_type: str = Body(Settings.kb_settings.DEFAULT_VS_TYPE),
embed_model: str = Body(get_default_embedding()),
file_description: str = Body(""),
model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(
None, description="限制LLM生成Token数量默认None代表模型最大值"
),
):
"""
重建单个知识库文件摘要
:param max_tokens:
:param model_name:
:param temperature:
:param file_description:
:param knowledge_base_name:
:param allow_empty_kb:
:param vs_type:
:param embed_model:
:return:
"""
if max_tokens in [None, 0]:
max_tokens = Settings.model_settings.MAX_TOKENS
def output():
try:
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else:
ok, msg = kb.check_embed_model()
if not ok:
yield {"code": 404, "msg": msg}
else:
# 重新创建知识库
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
kb_summary.drop_kb_summary()
kb_summary.create_kb_summary()
llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
local_wrap=True,
)
reduce_llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
local_wrap=True,
)
# 文本摘要适配器
summary = SummaryAdapter.form_summary(
llm=llm, reduce_llm=reduce_llm, overlap_size=Settings.kb_settings.OVERLAP_SIZE
)
files = list_files_from_folder(knowledge_base_name)
i = 0
for i, file_name in enumerate(files):
doc_infos = kb.list_docs(file_name=file_name)
docs = summary.summarize(
file_description=file_description, docs=doc_infos
)
status_kb_summary = kb_summary.add_kb_summary(
summary_combine_docs=docs
)
if status_kb_summary:
logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成")
yield json.dumps(
{
"code": 200,
"msg": f"({i + 1} / {len(files)}): {file_name}",
"total": len(files),
"finished": i + 1,
"doc": file_name,
},
ensure_ascii=False,
)
else:
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
logger.error(msg)
yield json.dumps(
{
"code": 500,
"msg": msg,
}
)
i += 1
except asyncio.exceptions.CancelledError:
logger.warning("streaming progress has been interrupted by user.")
return
return EventSourceResponse(output())
def summary_file_to_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]),
file_name: str = Body(..., examples=["test.pdf"]),
allow_empty_kb: bool = Body(True),
vs_type: str = Body(Settings.kb_settings.DEFAULT_VS_TYPE),
embed_model: str = Body(get_default_embedding()),
file_description: str = Body(""),
model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(
None, description="限制LLM生成Token数量默认None代表模型最大值"
),
):
"""
单个知识库根据文件名称摘要
:param model_name:
:param max_tokens:
:param temperature:
:param file_description:
:param file_name:
:param knowledge_base_name:
:param allow_empty_kb:
:param vs_type:
:param embed_model:
:return:
"""
def output():
try:
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else:
# 重新创建知识库
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
kb_summary.create_kb_summary()
llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
local_wrap=True,
)
reduce_llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
local_wrap=True,
)
# 文本摘要适配器
summary = SummaryAdapter.form_summary(
llm=llm, reduce_llm=reduce_llm, overlap_size=Settings.kb_settings.OVERLAP_SIZE
)
doc_infos = kb.list_docs(file_name=file_name)
docs = summary.summarize(file_description=file_description, docs=doc_infos)
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
if status_kb_summary:
logger.info(f" {file_name} 总结完成")
yield json.dumps(
{
"code": 200,
"msg": f"{file_name} 总结完成",
"doc": file_name,
},
ensure_ascii=False,
)
else:
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
logger.error(msg)
yield json.dumps(
{
"code": 500,
"msg": msg,
}
)
except asyncio.exceptions.CancelledError:
logger.warning("streaming progress has been interrupted by user.")
return
return EventSourceResponse(output())
def summary_doc_ids_to_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]),
doc_ids: List = Body([], examples=[["uuid"]]),
vs_type: str = Body(Settings.kb_settings.DEFAULT_VS_TYPE),
embed_model: str = Body(get_default_embedding()),
file_description: str = Body(""),
model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(
None, description="限制LLM生成Token数量默认None代表模型最大值"
),
) -> BaseResponse:
"""
单个知识库根据doc_ids摘要
:param knowledge_base_name:
:param doc_ids:
:param model_name:
:param max_tokens:
:param temperature:
:param file_description:
:param vs_type:
:param embed_model:
:return:
"""
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists():
return BaseResponse(
code=404, msg=f"未找到知识库 {knowledge_base_name}", data={}
)
else:
llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
local_wrap=True,
)
reduce_llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
local_wrap=True,
)
# 文本摘要适配器
summary = SummaryAdapter.form_summary(
llm=llm, reduce_llm=reduce_llm, overlap_size=Settings.kb_settings.OVERLAP_SIZE
)
doc_infos = kb.get_doc_by_ids(ids=doc_ids)
# doc_infos转换成DocumentWithVSId包装的对象
doc_info_with_ids = [
DocumentWithVSId(**{**doc.dict(), "id":with_id})
for with_id, doc in zip(doc_ids, doc_infos)
]
docs = summary.summarize(
file_description=file_description, docs=doc_info_with_ids
)
# 将docs转换成dict
resp_summarize = [{**doc.dict()} for doc in docs]
return BaseResponse(
code=200, msg="总结完成", data={"summarize": resp_summarize}
)