Merge branch 'dev_summary' into dev_tmp
# Conflicts: # server/api.py # server/knowledge_base/kb_doc_api.py # server/knowledge_base/kb_service/base.py # server/knowledge_base/migrate.py
This commit is contained in:
commit
f57837c07a
|
|
@ -0,0 +1,28 @@
|
|||
from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, func
|
||||
|
||||
from server.db.base import Base
|
||||
|
||||
|
||||
class SummaryChunkModel(Base):
|
||||
"""
|
||||
chunk summary模型,用于存储file_doc中每个doc_id的chunk 片段,
|
||||
数据来源:
|
||||
用户输入: 用户上传文件,可填写文件的描述,生成的file_doc中的doc_id,存入summary_chunk中
|
||||
程序自动切分 对file_doc表meta_data字段信息中存储的页码信息,按每页的页码切分,自定义prompt生成总结文本,将对应页码关联的doc_id存入summary_chunk中
|
||||
后续任务:
|
||||
矢量库构建: 对数据库表summary_chunk中summary_context创建索引,构建矢量库,meta_data为矢量库的元数据(doc_ids)
|
||||
语义关联: 通过用户输入的描述,自动切分的总结文本,计算
|
||||
语义相似度
|
||||
|
||||
"""
|
||||
__tablename__ = 'summary_chunk'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, comment='ID')
|
||||
kb_name = Column(String(50), comment='知识库名称')
|
||||
summary_context = Column(String(255), comment='总结文本')
|
||||
summary_id = Column(String(255), comment='总结矢量id')
|
||||
doc_ids = Column(String(1024), comment="向量库id关联列表")
|
||||
meta_data = Column(JSON, default={})
|
||||
|
||||
def __repr__(self):
|
||||
return (f"<SummaryChunk(id='{self.id}', kb_name='{self.kb_name}', summary_context='{self.summary_context}',"
|
||||
f" doc_ids='{self.doc_ids}', metadata='{self.metadata}')>")
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
from server.db.models.knowledge_metadata_model import SummaryChunkModel
|
||||
from server.db.session import with_session
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
@with_session
|
||||
def list_summary_from_db(session,
|
||||
kb_name: str,
|
||||
metadata: Dict = {},
|
||||
) -> List[Dict]:
|
||||
'''
|
||||
列出某知识库chunk summary。
|
||||
返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...]
|
||||
'''
|
||||
docs = session.query(SummaryChunkModel).filter_by(kb_name=kb_name)
|
||||
|
||||
for k, v in metadata.items():
|
||||
docs = docs.filter(SummaryChunkModel.meta_data[k].as_string() == str(v))
|
||||
|
||||
return [{"id": x.id,
|
||||
"summary_context": x.summary_context,
|
||||
"summary_id": x.summary_id,
|
||||
"doc_ids": x.doc_ids,
|
||||
"metadata": x.metadata} for x in docs.all()]
|
||||
|
||||
|
||||
@with_session
|
||||
def delete_summary_from_db(session,
|
||||
kb_name: str
|
||||
) -> List[Dict]:
|
||||
'''
|
||||
删除知识库chunk summary,并返回被删除的Dchunk summary。
|
||||
返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...]
|
||||
'''
|
||||
docs = list_summary_from_db(kb_name=kb_name)
|
||||
query = session.query(SummaryChunkModel).filter_by(kb_name=kb_name)
|
||||
query.delete()
|
||||
session.commit()
|
||||
return docs
|
||||
|
||||
|
||||
@with_session
|
||||
def add_summary_to_db(session,
|
||||
kb_name: str,
|
||||
summary_infos: List[Dict]):
|
||||
'''
|
||||
将总结信息添加到数据库。
|
||||
summary_infos形式:[{"summary_context": str, "doc_ids": str}, ...]
|
||||
'''
|
||||
for summary in summary_infos:
|
||||
obj = SummaryChunkModel(
|
||||
kb_name=kb_name,
|
||||
summary_context=summary["summary_context"],
|
||||
summary_id=summary["summary_id"],
|
||||
doc_ids=summary["doc_ids"],
|
||||
meta_data=summary["metadata"],
|
||||
)
|
||||
session.add(obj)
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
|
||||
@with_session
|
||||
def count_summary_from_db(session, kb_name: str) -> int:
|
||||
return session.query(SummaryChunkModel).filter_by(kb_name=kb_name).count()
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
from typing import List
|
||||
|
||||
from configs import (
|
||||
EMBEDDING_MODEL,
|
||||
KB_ROOT_PATH)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
|
||||
import os
|
||||
import shutil
|
||||
from server.db.repository.knowledge_metadata_repository import add_summary_to_db, delete_summary_from_db
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
||||
# TODO 暂不考虑文件更新,需要重新删除相关文档,再重新添加
|
||||
class KBSummaryService(ABC):
|
||||
kb_name: str
|
||||
embed_model: str
|
||||
vs_path: str
|
||||
kb_path: str
|
||||
|
||||
def __init__(self,
|
||||
knowledge_base_name: str,
|
||||
embed_model: str = EMBEDDING_MODEL
|
||||
):
|
||||
self.kb_name = knowledge_base_name
|
||||
self.embed_model = embed_model
|
||||
|
||||
self.kb_path = self.get_kb_path()
|
||||
self.vs_path = self.get_vs_path()
|
||||
|
||||
if not os.path.exists(self.vs_path):
|
||||
os.makedirs(self.vs_path)
|
||||
|
||||
|
||||
def get_vs_path(self):
|
||||
return os.path.join(self.get_kb_path(), "summary_vector_store")
|
||||
|
||||
def get_kb_path(self):
|
||||
return os.path.join(KB_ROOT_PATH, self.kb_name)
|
||||
|
||||
def load_vector_store(self) -> ThreadSafeFaiss:
|
||||
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
|
||||
vector_name="summary_vector_store",
|
||||
embed_model=self.embed_model,
|
||||
create=True)
|
||||
|
||||
def add_kb_summary(self, summary_combine_docs: List[Document]):
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
ids = vs.add_documents(documents=summary_combine_docs)
|
||||
vs.save_local(self.vs_path)
|
||||
|
||||
summary_infos = [{"summary_context": doc.page_content,
|
||||
"summary_id": id,
|
||||
"doc_ids": doc.metadata.get('doc_ids'),
|
||||
"metadata": doc.metadata} for id, doc in zip(ids, summary_combine_docs)]
|
||||
status = add_summary_to_db(kb_name=self.kb_name, summary_infos=summary_infos)
|
||||
return status
|
||||
|
||||
def create_kb_summary(self):
|
||||
"""
|
||||
创建知识库chunk summary
|
||||
:return:
|
||||
"""
|
||||
|
||||
if not os.path.exists(self.vs_path):
|
||||
os.makedirs(self.vs_path)
|
||||
|
||||
def drop_kb_summary(self):
|
||||
"""
|
||||
删除知识库chunk summary
|
||||
:param kb_name:
|
||||
:return:
|
||||
"""
|
||||
with kb_faiss_pool.atomic:
|
||||
kb_faiss_pool.pop(self.kb_name)
|
||||
shutil.rmtree(self.vs_path)
|
||||
delete_summary_from_db(kb_name=self.kb_name)
|
||||
|
|
@ -0,0 +1,250 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
from configs import (logger)
|
||||
from langchain.chains import StuffDocumentsChain, LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.output_parsers.regex import RegexParser
|
||||
from langchain.chains.combine_documents.map_reduce import ReduceDocumentsChain, MapReduceDocumentsChain
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
|
||||
|
||||
class SummaryAdapter:
|
||||
_OVERLAP_SIZE: int
|
||||
token_max: int
|
||||
_separator: str = "\n\n"
|
||||
chain: MapReduceDocumentsChain
|
||||
|
||||
def __init__(self, overlap_size: int, token_max: int,
|
||||
chain: MapReduceDocumentsChain):
|
||||
self._OVERLAP_SIZE = overlap_size
|
||||
self.chain = chain
|
||||
self.token_max = token_max
|
||||
|
||||
@classmethod
|
||||
def form_summary(cls,
|
||||
llm: BaseLanguageModel,
|
||||
reduce_llm: BaseLanguageModel,
|
||||
overlap_size: int,
|
||||
token_max: int = 1300):
|
||||
"""
|
||||
获取实例
|
||||
:param reduce_llm: 用于合并摘要的llm
|
||||
:param llm: 用于生成摘要的llm
|
||||
:param overlap_size: 重叠部分大小
|
||||
:param token_max: 最大的chunk数量,每个chunk长度小于token_max长度,第一次生成摘要时,大于token_max长度的摘要会报错
|
||||
:return:
|
||||
"""
|
||||
|
||||
# This controls how each document will be formatted. Specifically,
|
||||
document_prompt = PromptTemplate(
|
||||
input_variables=["page_content"],
|
||||
template="{page_content}"
|
||||
)
|
||||
|
||||
# The prompt here should take as an input variable the
|
||||
# `document_variable_name`
|
||||
prompt_template = (
|
||||
"根据文本执行任务。以下任务信息"
|
||||
"{task_briefing}"
|
||||
"文本内容如下: "
|
||||
"\r\n"
|
||||
"{context}"
|
||||
)
|
||||
prompt = PromptTemplate(
|
||||
template=prompt_template,
|
||||
input_variables=["task_briefing", "context"]
|
||||
)
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
# We now define how to combine these summaries
|
||||
reduce_prompt = PromptTemplate.from_template(
|
||||
"Combine these summaries: {context}"
|
||||
)
|
||||
reduce_llm_chain = LLMChain(llm=reduce_llm, prompt=reduce_prompt)
|
||||
|
||||
document_variable_name = "context"
|
||||
combine_documents_chain = StuffDocumentsChain(
|
||||
llm_chain=reduce_llm_chain,
|
||||
document_prompt=document_prompt,
|
||||
document_variable_name=document_variable_name
|
||||
)
|
||||
reduce_documents_chain = ReduceDocumentsChain(
|
||||
token_max=token_max,
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
)
|
||||
chain = MapReduceDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_variable_name=document_variable_name,
|
||||
reduce_documents_chain=reduce_documents_chain,
|
||||
# 返回中间步骤
|
||||
return_intermediate_steps=True
|
||||
)
|
||||
return cls(overlap_size=overlap_size,
|
||||
chain=chain,
|
||||
token_max=token_max)
|
||||
|
||||
def summarize(self,
|
||||
kb_name: str,
|
||||
file_description: str,
|
||||
docs: List[DocumentWithVSId] = []
|
||||
) -> List[Document]:
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
loop = asyncio.get_event_loop()
|
||||
else:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
asyncio.set_event_loop(loop)
|
||||
# 同步调用协程代码
|
||||
return loop.run_until_complete(self.asummarize(kb_name=kb_name,
|
||||
file_description=file_description,
|
||||
docs=docs))
|
||||
|
||||
async def asummarize(self,
|
||||
kb_name: str,
|
||||
file_description: str,
|
||||
docs: List[DocumentWithVSId] = []) -> List[Document]:
|
||||
|
||||
logger.info("start summary")
|
||||
# TODO 暂不处理文档中涉及语义重复、上下文缺失、document was longer than the context length 的问题
|
||||
# merge_docs = self._drop_overlap(docs)
|
||||
# # 将merge_docs中的句子合并成一个文档
|
||||
# text = self._join_docs(merge_docs)
|
||||
# 根据段落于句子的分隔符,将文档分成chunk,每个chunk长度小于token_max长度
|
||||
|
||||
"""
|
||||
这个过程分成两个部分:
|
||||
1. 对每个文档进行处理,得到每个文档的摘要
|
||||
map_results = self.llm_chain.apply(
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
[{self.document_variable_name: d.page_content, **kwargs} for d in docs],
|
||||
callbacks=callbacks,
|
||||
)
|
||||
2. 对每个文档的摘要进行合并,得到最终的摘要,return_intermediate_steps=True,返回中间步骤
|
||||
result, extra_return_dict = self.reduce_documents_chain.combine_docs(
|
||||
result_docs, token_max=token_max, callbacks=callbacks, **kwargs
|
||||
)
|
||||
"""
|
||||
summary_combine, summary_intermediate_steps = self.chain.combine_docs(docs=docs,
|
||||
task_briefing="描述不同方法之间的接近度和相似性,"
|
||||
"以帮助读者理解它们之间的关系。")
|
||||
print(summary_combine)
|
||||
print(summary_intermediate_steps)
|
||||
|
||||
# if len(summary_combine) == 0:
|
||||
# # 为空重新生成,数量减半
|
||||
# result_docs = [
|
||||
# Document(page_content=question_result_key, metadata=docs[i].metadata)
|
||||
# # This uses metadata from the docs, and the textual results from `results`
|
||||
# for i, question_result_key in enumerate(
|
||||
# summary_intermediate_steps["intermediate_steps"][
|
||||
# :len(summary_intermediate_steps["intermediate_steps"]) // 2
|
||||
# ])
|
||||
# ]
|
||||
# summary_combine, summary_intermediate_steps = self.chain.reduce_documents_chain.combine_docs(
|
||||
# result_docs, token_max=self.token_max
|
||||
# )
|
||||
logger.info("end summary")
|
||||
doc_ids = ",".join([doc.id for doc in docs])
|
||||
_metadata = {
|
||||
"file_description": file_description,
|
||||
"summary_intermediate_steps": summary_intermediate_steps,
|
||||
"doc_ids": doc_ids
|
||||
}
|
||||
summary_combine_doc = Document(page_content=summary_combine, metadata=_metadata)
|
||||
|
||||
return [summary_combine_doc]
|
||||
|
||||
def _drop_overlap(self, docs: List[DocumentWithVSId]) -> List[str]:
|
||||
"""
|
||||
# 将文档中page_content句子叠加的部分去掉
|
||||
:param docs:
|
||||
:param separator:
|
||||
:return:
|
||||
"""
|
||||
merge_docs = []
|
||||
|
||||
pre_doc = None
|
||||
for doc in docs:
|
||||
# 第一个文档直接添加
|
||||
if len(merge_docs) == 0:
|
||||
pre_doc = doc.page_content
|
||||
merge_docs.append(doc.page_content)
|
||||
continue
|
||||
|
||||
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
|
||||
# 迭代递减pre_doc的长度,每次迭代删除前面的字符,
|
||||
# 查询重叠部分,直到pre_doc的长度小于 self._OVERLAP_SIZE // 2 - 2len(separator)
|
||||
for i in range(len(pre_doc), self._OVERLAP_SIZE // 2 - 2 * len(self._separator), -1):
|
||||
# 每次迭代删除前面的字符
|
||||
pre_doc = pre_doc[1:]
|
||||
if doc.page_content[:len(pre_doc)] == pre_doc:
|
||||
# 删除下一个开头重叠的部分
|
||||
merge_docs.append(doc.page_content[len(pre_doc):])
|
||||
break
|
||||
|
||||
pre_doc = doc.page_content
|
||||
|
||||
return merge_docs
|
||||
|
||||
def _join_docs(self, docs: List[str]) -> Optional[str]:
|
||||
text = self._separator.join(docs)
|
||||
text = text.strip()
|
||||
if text == "":
|
||||
return None
|
||||
else:
|
||||
return text
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
docs = [
|
||||
|
||||
'梦者有特别的作用,也就是说梦是在预卜未来。因此,梦内容的',
|
||||
|
||||
'梦内容的多彩多姿以及对梦者本身所遗留的特殊印象,使他们很难想象',
|
||||
|
||||
'使他们很难想象出一套系统划一的观念,而需要以其个别的价值与可靠性作各',
|
||||
'值与可靠性作各种不同的分化与聚合。因此,古代哲学家们对梦的评价也就完全'
|
||||
]
|
||||
_OVERLAP_SIZE = 1
|
||||
separator: str = "\n\n"
|
||||
merge_docs = []
|
||||
# 将文档中page_content句子叠加的部分去掉,
|
||||
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
|
||||
pre_doc = None
|
||||
for doc in docs:
|
||||
# 第一个文档直接添加
|
||||
if len(merge_docs) == 0:
|
||||
pre_doc = doc
|
||||
merge_docs.append(doc)
|
||||
continue
|
||||
|
||||
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
|
||||
# 迭代递减pre_doc的长度,每次迭代删除前面的字符,
|
||||
# 查询重叠部分,直到pre_doc的长度小于 _OVERLAP_SIZE-2len(separator)
|
||||
for i in range(len(pre_doc), _OVERLAP_SIZE - 2 * len(separator), -1):
|
||||
# 每次迭代删除前面的字符
|
||||
pre_doc = pre_doc[1:]
|
||||
if doc[:len(pre_doc)] == pre_doc:
|
||||
# 删除下一个开头重叠的部分
|
||||
page_content = doc[len(pre_doc):]
|
||||
merge_docs.append(page_content)
|
||||
|
||||
pre_doc = doc
|
||||
break
|
||||
|
||||
# 将merge_docs中的句子合并成一个文档
|
||||
text = separator.join(merge_docs)
|
||||
text = text.strip()
|
||||
|
||||
print(text)
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
||||
class DocumentWithVSId(Document):
|
||||
"""
|
||||
矢量化后的文档
|
||||
"""
|
||||
id: str = None
|
||||
|
||||
Loading…
Reference in New Issue