Langchain-Chatchat/server/knowledge_base/kb_summary_api.py

169 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import Body
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
OVERLAP_SIZE,
logger, log_verbose, )
from server.knowledge_base.utils import (list_files_from_folder)
from fastapi.responses import StreamingResponse
import json
from server.knowledge_base.kb_service.base import KBServiceFactory
from typing import List, Optional
from server.knowledge_base.kb_summary.base import KBSummaryService
from server.knowledge_base.kb_summary.summary_chunk import SummaryAdapter
from server.utils import wrap_done, get_ChatOpenAI
from configs import LLM_MODELS, TEMPERATURE
def recreate_summary_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]),
allow_empty_kb: bool = Body(True),
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
file_description: str = Body(''),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
):
"""
重建文件摘要
:param max_tokens:
:param model_name:
:param temperature:
:param file_description:
:param knowledge_base_name:
:param allow_empty_kb:
:param vs_type:
:param embed_model:
:return:
"""
def output():
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else:
# 重新创建知识库
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
kb_summary.drop_kb_summary()
kb_summary.create_kb_summary()
llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
reduce_llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
# 文本摘要适配器
summary = SummaryAdapter.form_summary(llm=llm,
reduce_llm=reduce_llm,
overlap_size=OVERLAP_SIZE)
files = list_files_from_folder(knowledge_base_name)
i = 0
for i, file_name in enumerate(files):
doc_infos = kb.list_docs(file_name=file_name)
docs = summary.summarize(kb_name=knowledge_base_name,
file_description=file_description,
docs=doc_infos)
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
if status_kb_summary:
logger.info(f"({i + 1} / {len(files)}): {file_name} 向量化总结完成")
yield json.dumps({
"code": 200,
"msg": f"({i + 1} / {len(files)}): {file_name}",
"total": len(files),
"finished": i + 1,
"doc": file_name,
}, ensure_ascii=False)
else:
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
logger.error(msg)
yield json.dumps({
"code": 500,
"msg": msg,
})
i += 1
return StreamingResponse(output(), media_type="text/event-stream")
def summary_file_to_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]),
file_name: str = Body(..., examples=["test.pdf"]),
allow_empty_kb: bool = Body(True),
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
file_description: str = Body(''),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
):
"""
文件摘要
:param model_name:
:param max_tokens:
:param temperature:
:param file_description:
:param file_name:
:param knowledge_base_name:
:param allow_empty_kb:
:param vs_type:
:param embed_model:
:return:
"""
def output():
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else:
# 重新创建知识库
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
kb_summary.create_kb_summary()
llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
reduce_llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
# 文本摘要适配器
summary = SummaryAdapter.form_summary(llm=llm,
reduce_llm=reduce_llm,
overlap_size=OVERLAP_SIZE)
doc_infos = kb.list_docs(file_name=file_name)
docs = summary.summarize(kb_name=knowledge_base_name,
file_description=file_description,
docs=doc_infos)
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
if status_kb_summary:
logger.info(f" {file_name} 向量化总结完成")
yield json.dumps({
"code": 200,
"msg": f"{file_name} 向量化总结完成",
"doc": file_name,
}, ensure_ascii=False)
else:
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
logger.error(msg)
yield json.dumps({
"code": 500,
"msg": msg,
})
return StreamingResponse(output(), media_type="text/event-stream")