Langchain-Chatchat/server/knowledge_base/kb_summary_api.py

221 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import Body
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
OVERLAP_SIZE,
logger, log_verbose, )
from server.knowledge_base.utils import (list_files_from_folder)
from sse_starlette import EventSourceResponse
import json
from server.knowledge_base.kb_service.base import KBServiceFactory
from typing import List, Optional
from server.knowledge_base.kb_summary.base import KBSummaryService
from server.knowledge_base.kb_summary.summary_chunk import SummaryAdapter
from server.utils import wrap_done, get_ChatOpenAI, BaseResponse
from configs import LLM_MODELS, TEMPERATURE
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
def recreate_summary_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]),
allow_empty_kb: bool = Body(True),
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
file_description: str = Body(''),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
):
"""
重建单个知识库文件摘要
:param max_tokens:
:param model_name:
:param temperature:
:param file_description:
:param knowledge_base_name:
:param allow_empty_kb:
:param vs_type:
:param embed_model:
:return:
"""
def output():
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else:
# 重新创建知识库
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
kb_summary.drop_kb_summary()
kb_summary.create_kb_summary()
llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
reduce_llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
# 文本摘要适配器
summary = SummaryAdapter.form_summary(llm=llm,
reduce_llm=reduce_llm,
overlap_size=OVERLAP_SIZE)
files = list_files_from_folder(knowledge_base_name)
i = 0
for i, file_name in enumerate(files):
doc_infos = kb.list_docs(file_name=file_name)
docs = summary.summarize(file_description=file_description,
docs=doc_infos)
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
if status_kb_summary:
logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成")
yield json.dumps({
"code": 200,
"msg": f"({i + 1} / {len(files)}): {file_name}",
"total": len(files),
"finished": i + 1,
"doc": file_name,
}, ensure_ascii=False)
else:
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
logger.error(msg)
yield json.dumps({
"code": 500,
"msg": msg,
})
i += 1
return EventSourceResponse(output())
def summary_file_to_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]),
file_name: str = Body(..., examples=["test.pdf"]),
allow_empty_kb: bool = Body(True),
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
file_description: str = Body(''),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
):
"""
单个知识库根据文件名称摘要
:param model_name:
:param max_tokens:
:param temperature:
:param file_description:
:param file_name:
:param knowledge_base_name:
:param allow_empty_kb:
:param vs_type:
:param embed_model:
:return:
"""
def output():
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else:
# 重新创建知识库
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
kb_summary.create_kb_summary()
llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
reduce_llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
# 文本摘要适配器
summary = SummaryAdapter.form_summary(llm=llm,
reduce_llm=reduce_llm,
overlap_size=OVERLAP_SIZE)
doc_infos = kb.list_docs(file_name=file_name)
docs = summary.summarize(file_description=file_description,
docs=doc_infos)
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
if status_kb_summary:
logger.info(f" {file_name} 总结完成")
yield json.dumps({
"code": 200,
"msg": f"{file_name} 总结完成",
"doc": file_name,
}, ensure_ascii=False)
else:
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
logger.error(msg)
yield json.dumps({
"code": 500,
"msg": msg,
})
return EventSourceResponse(output())
def summary_doc_ids_to_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]),
doc_ids: List = Body([], examples=[["uuid"]]),
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
file_description: str = Body(''),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
) -> BaseResponse:
"""
单个知识库根据doc_ids摘要
:param knowledge_base_name:
:param doc_ids:
:param model_name:
:param max_tokens:
:param temperature:
:param file_description:
:param vs_type:
:param embed_model:
:return:
"""
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists():
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data={})
else:
llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
reduce_llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
# 文本摘要适配器
summary = SummaryAdapter.form_summary(llm=llm,
reduce_llm=reduce_llm,
overlap_size=OVERLAP_SIZE)
doc_infos = kb.get_doc_by_ids(ids=doc_ids)
# doc_infos转换成DocumentWithVSId包装的对象
doc_info_with_ids = [DocumentWithVSId(**doc.dict(), id=with_id) for with_id, doc in zip(doc_ids, doc_infos)]
docs = summary.summarize(file_description=file_description,
docs=doc_info_with_ids)
# 将docs转换成dict
resp_summarize = [{**doc.dict()} for doc in docs]
return BaseResponse(code=200, msg="总结完成", data={"summarize": resp_summarize})