251 lines
9.8 KiB
Python
251 lines
9.8 KiB
Python
from typing import List, Optional
|
||
|
||
from langchain.schema.language_model import BaseLanguageModel
|
||
|
||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||
from configs import (logger)
|
||
from langchain.chains import StuffDocumentsChain, LLMChain
|
||
from langchain.prompts import PromptTemplate
|
||
|
||
from langchain.docstore.document import Document
|
||
from langchain.output_parsers.regex import RegexParser
|
||
from langchain.chains.combine_documents.map_reduce import ReduceDocumentsChain, MapReduceDocumentsChain
|
||
|
||
import sys
|
||
import asyncio
|
||
|
||
|
||
class SummaryAdapter:
|
||
_OVERLAP_SIZE: int
|
||
token_max: int
|
||
_separator: str = "\n\n"
|
||
chain: MapReduceDocumentsChain
|
||
|
||
def __init__(self, overlap_size: int, token_max: int,
|
||
chain: MapReduceDocumentsChain):
|
||
self._OVERLAP_SIZE = overlap_size
|
||
self.chain = chain
|
||
self.token_max = token_max
|
||
|
||
@classmethod
|
||
def form_summary(cls,
|
||
llm: BaseLanguageModel,
|
||
reduce_llm: BaseLanguageModel,
|
||
overlap_size: int,
|
||
token_max: int = 1300):
|
||
"""
|
||
获取实例
|
||
:param reduce_llm: 用于合并摘要的llm
|
||
:param llm: 用于生成摘要的llm
|
||
:param overlap_size: 重叠部分大小
|
||
:param token_max: 最大的chunk数量,每个chunk长度小于token_max长度,第一次生成摘要时,大于token_max长度的摘要会报错
|
||
:return:
|
||
"""
|
||
|
||
# This controls how each document will be formatted. Specifically,
|
||
document_prompt = PromptTemplate(
|
||
input_variables=["page_content"],
|
||
template="{page_content}"
|
||
)
|
||
|
||
# The prompt here should take as an input variable the
|
||
# `document_variable_name`
|
||
prompt_template = (
|
||
"根据文本执行任务。以下任务信息"
|
||
"{task_briefing}"
|
||
"文本内容如下: "
|
||
"\r\n"
|
||
"{context}"
|
||
)
|
||
prompt = PromptTemplate(
|
||
template=prompt_template,
|
||
input_variables=["task_briefing", "context"]
|
||
)
|
||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||
# We now define how to combine these summaries
|
||
reduce_prompt = PromptTemplate.from_template(
|
||
"Combine these summaries: {context}"
|
||
)
|
||
reduce_llm_chain = LLMChain(llm=reduce_llm, prompt=reduce_prompt)
|
||
|
||
document_variable_name = "context"
|
||
combine_documents_chain = StuffDocumentsChain(
|
||
llm_chain=reduce_llm_chain,
|
||
document_prompt=document_prompt,
|
||
document_variable_name=document_variable_name
|
||
)
|
||
reduce_documents_chain = ReduceDocumentsChain(
|
||
token_max=token_max,
|
||
combine_documents_chain=combine_documents_chain,
|
||
)
|
||
chain = MapReduceDocumentsChain(
|
||
llm_chain=llm_chain,
|
||
document_variable_name=document_variable_name,
|
||
reduce_documents_chain=reduce_documents_chain,
|
||
# 返回中间步骤
|
||
return_intermediate_steps=True
|
||
)
|
||
return cls(overlap_size=overlap_size,
|
||
chain=chain,
|
||
token_max=token_max)
|
||
|
||
def summarize(self,
|
||
kb_name: str,
|
||
file_description: str,
|
||
docs: List[DocumentWithVSId] = []
|
||
) -> List[Document]:
|
||
|
||
if sys.version_info < (3, 10):
|
||
loop = asyncio.get_event_loop()
|
||
else:
|
||
try:
|
||
loop = asyncio.get_running_loop()
|
||
except RuntimeError:
|
||
loop = asyncio.new_event_loop()
|
||
|
||
asyncio.set_event_loop(loop)
|
||
# 同步调用协程代码
|
||
return loop.run_until_complete(self.asummarize(kb_name=kb_name,
|
||
file_description=file_description,
|
||
docs=docs))
|
||
|
||
async def asummarize(self,
|
||
kb_name: str,
|
||
file_description: str,
|
||
docs: List[DocumentWithVSId] = []) -> List[Document]:
|
||
|
||
logger.info("start summary")
|
||
# TODO 暂不处理文档中涉及语义重复、上下文缺失、document was longer than the context length 的问题
|
||
# merge_docs = self._drop_overlap(docs)
|
||
# # 将merge_docs中的句子合并成一个文档
|
||
# text = self._join_docs(merge_docs)
|
||
# 根据段落于句子的分隔符,将文档分成chunk,每个chunk长度小于token_max长度
|
||
|
||
"""
|
||
这个过程分成两个部分:
|
||
1. 对每个文档进行处理,得到每个文档的摘要
|
||
map_results = self.llm_chain.apply(
|
||
# FYI - this is parallelized and so it is fast.
|
||
[{self.document_variable_name: d.page_content, **kwargs} for d in docs],
|
||
callbacks=callbacks,
|
||
)
|
||
2. 对每个文档的摘要进行合并,得到最终的摘要,return_intermediate_steps=True,返回中间步骤
|
||
result, extra_return_dict = self.reduce_documents_chain.combine_docs(
|
||
result_docs, token_max=token_max, callbacks=callbacks, **kwargs
|
||
)
|
||
"""
|
||
summary_combine, summary_intermediate_steps = self.chain.combine_docs(docs=docs,
|
||
task_briefing="描述不同方法之间的接近度和相似性,"
|
||
"以帮助读者理解它们之间的关系。")
|
||
print(summary_combine)
|
||
print(summary_intermediate_steps)
|
||
|
||
# if len(summary_combine) == 0:
|
||
# # 为空重新生成,数量减半
|
||
# result_docs = [
|
||
# Document(page_content=question_result_key, metadata=docs[i].metadata)
|
||
# # This uses metadata from the docs, and the textual results from `results`
|
||
# for i, question_result_key in enumerate(
|
||
# summary_intermediate_steps["intermediate_steps"][
|
||
# :len(summary_intermediate_steps["intermediate_steps"]) // 2
|
||
# ])
|
||
# ]
|
||
# summary_combine, summary_intermediate_steps = self.chain.reduce_documents_chain.combine_docs(
|
||
# result_docs, token_max=self.token_max
|
||
# )
|
||
logger.info("end summary")
|
||
doc_ids = ",".join([doc.id for doc in docs])
|
||
_metadata = {
|
||
"file_description": file_description,
|
||
"summary_intermediate_steps": summary_intermediate_steps,
|
||
"doc_ids": doc_ids
|
||
}
|
||
summary_combine_doc = Document(page_content=summary_combine, metadata=_metadata)
|
||
|
||
return [summary_combine_doc]
|
||
|
||
def _drop_overlap(self, docs: List[DocumentWithVSId]) -> List[str]:
|
||
"""
|
||
# 将文档中page_content句子叠加的部分去掉
|
||
:param docs:
|
||
:param separator:
|
||
:return:
|
||
"""
|
||
merge_docs = []
|
||
|
||
pre_doc = None
|
||
for doc in docs:
|
||
# 第一个文档直接添加
|
||
if len(merge_docs) == 0:
|
||
pre_doc = doc.page_content
|
||
merge_docs.append(doc.page_content)
|
||
continue
|
||
|
||
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
|
||
# 迭代递减pre_doc的长度,每次迭代删除前面的字符,
|
||
# 查询重叠部分,直到pre_doc的长度小于 self._OVERLAP_SIZE // 2 - 2len(separator)
|
||
for i in range(len(pre_doc), self._OVERLAP_SIZE // 2 - 2 * len(self._separator), -1):
|
||
# 每次迭代删除前面的字符
|
||
pre_doc = pre_doc[1:]
|
||
if doc.page_content[:len(pre_doc)] == pre_doc:
|
||
# 删除下一个开头重叠的部分
|
||
merge_docs.append(doc.page_content[len(pre_doc):])
|
||
break
|
||
|
||
pre_doc = doc.page_content
|
||
|
||
return merge_docs
|
||
|
||
def _join_docs(self, docs: List[str]) -> Optional[str]:
|
||
text = self._separator.join(docs)
|
||
text = text.strip()
|
||
if text == "":
|
||
return None
|
||
else:
|
||
return text
|
||
|
||
|
||
if __name__ == '__main__':
|
||
|
||
docs = [
|
||
|
||
'梦者有特别的作用,也就是说梦是在预卜未来。因此,梦内容的',
|
||
|
||
'梦内容的多彩多姿以及对梦者本身所遗留的特殊印象,使他们很难想象',
|
||
|
||
'使他们很难想象出一套系统划一的观念,而需要以其个别的价值与可靠性作各',
|
||
'值与可靠性作各种不同的分化与聚合。因此,古代哲学家们对梦的评价也就完全'
|
||
]
|
||
_OVERLAP_SIZE = 1
|
||
separator: str = "\n\n"
|
||
merge_docs = []
|
||
# 将文档中page_content句子叠加的部分去掉,
|
||
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
|
||
pre_doc = None
|
||
for doc in docs:
|
||
# 第一个文档直接添加
|
||
if len(merge_docs) == 0:
|
||
pre_doc = doc
|
||
merge_docs.append(doc)
|
||
continue
|
||
|
||
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
|
||
# 迭代递减pre_doc的长度,每次迭代删除前面的字符,
|
||
# 查询重叠部分,直到pre_doc的长度小于 _OVERLAP_SIZE-2len(separator)
|
||
for i in range(len(pre_doc), _OVERLAP_SIZE - 2 * len(separator), -1):
|
||
# 每次迭代删除前面的字符
|
||
pre_doc = pre_doc[1:]
|
||
if doc[:len(pre_doc)] == pre_doc:
|
||
# 删除下一个开头重叠的部分
|
||
page_content = doc[len(pre_doc):]
|
||
merge_docs.append(page_content)
|
||
|
||
pre_doc = doc
|
||
break
|
||
|
||
# 将merge_docs中的句子合并成一个文档
|
||
text = separator.join(merge_docs)
|
||
text = text.strip()
|
||
|
||
print(text)
|