diff --git a/server/db/models/knowledge_metadata_model.py b/server/db/models/knowledge_metadata_model.py new file mode 100644 index 0000000..03f4200 --- /dev/null +++ b/server/db/models/knowledge_metadata_model.py @@ -0,0 +1,28 @@ +from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, func + +from server.db.base import Base + + +class SummaryChunkModel(Base): + """ + chunk summary模型,用于存储file_doc中每个doc_id的chunk 片段, + 数据来源: + 用户输入: 用户上传文件,可填写文件的描述,生成的file_doc中的doc_id,存入summary_chunk中 + 程序自动切分 对file_doc表meta_data字段信息中存储的页码信息,按每页的页码切分,自定义prompt生成总结文本,将对应页码关联的doc_id存入summary_chunk中 + 后续任务: + 矢量库构建: 对数据库表summary_chunk中summary_context创建索引,构建矢量库,meta_data为矢量库的元数据(doc_ids) + 语义关联: 通过用户输入的描述,自动切分的总结文本,计算 + 语义相似度 + + """ + __tablename__ = 'summary_chunk' + id = Column(Integer, primary_key=True, autoincrement=True, comment='ID') + kb_name = Column(String(50), comment='知识库名称') + summary_context = Column(String(255), comment='总结文本') + summary_id = Column(String(255), comment='总结矢量id') + doc_ids = Column(String(1024), comment="向量库id关联列表") + meta_data = Column(JSON, default={}) + + def __repr__(self): + return (f"") diff --git a/server/db/repository/knowledge_metadata_repository.py b/server/db/repository/knowledge_metadata_repository.py new file mode 100644 index 0000000..4158e70 --- /dev/null +++ b/server/db/repository/knowledge_metadata_repository.py @@ -0,0 +1,66 @@ +from server.db.models.knowledge_metadata_model import SummaryChunkModel +from server.db.session import with_session +from typing import List, Dict + + +@with_session +def list_summary_from_db(session, + kb_name: str, + metadata: Dict = {}, + ) -> List[Dict]: + ''' + 列出某知识库chunk summary。 + 返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...] + ''' + docs = session.query(SummaryChunkModel).filter_by(kb_name=kb_name) + + for k, v in metadata.items(): + docs = docs.filter(SummaryChunkModel.meta_data[k].as_string() == str(v)) + + return [{"id": x.id, + "summary_context": x.summary_context, + "summary_id": x.summary_id, + "doc_ids": x.doc_ids, + "metadata": x.metadata} for x in docs.all()] + + +@with_session +def delete_summary_from_db(session, + kb_name: str + ) -> List[Dict]: + ''' + 删除知识库chunk summary,并返回被删除的Dchunk summary。 + 返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...] + ''' + docs = list_summary_from_db(kb_name=kb_name) + query = session.query(SummaryChunkModel).filter_by(kb_name=kb_name) + query.delete() + session.commit() + return docs + + +@with_session +def add_summary_to_db(session, + kb_name: str, + summary_infos: List[Dict]): + ''' + 将总结信息添加到数据库。 + summary_infos形式:[{"summary_context": str, "doc_ids": str}, ...] + ''' + for summary in summary_infos: + obj = SummaryChunkModel( + kb_name=kb_name, + summary_context=summary["summary_context"], + summary_id=summary["summary_id"], + doc_ids=summary["doc_ids"], + meta_data=summary["metadata"], + ) + session.add(obj) + + session.commit() + return True + + +@with_session +def count_summary_from_db(session, kb_name: str) -> int: + return session.query(SummaryChunkModel).filter_by(kb_name=kb_name).count() diff --git a/server/knowledge_base/kb_summary/__init__.py b/server/knowledge_base/kb_summary/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/knowledge_base/kb_summary/base.py b/server/knowledge_base/kb_summary/base.py new file mode 100644 index 0000000..00dcea6 --- /dev/null +++ b/server/knowledge_base/kb_summary/base.py @@ -0,0 +1,79 @@ +from typing import List + +from configs import ( + EMBEDDING_MODEL, + KB_ROOT_PATH) + +from abc import ABC, abstractmethod +from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss +import os +import shutil +from server.db.repository.knowledge_metadata_repository import add_summary_to_db, delete_summary_from_db + +from langchain.docstore.document import Document + + +# TODO 暂不考虑文件更新,需要重新删除相关文档,再重新添加 +class KBSummaryService(ABC): + kb_name: str + embed_model: str + vs_path: str + kb_path: str + + def __init__(self, + knowledge_base_name: str, + embed_model: str = EMBEDDING_MODEL + ): + self.kb_name = knowledge_base_name + self.embed_model = embed_model + + self.kb_path = self.get_kb_path() + self.vs_path = self.get_vs_path() + + if not os.path.exists(self.vs_path): + os.makedirs(self.vs_path) + + + def get_vs_path(self): + return os.path.join(self.get_kb_path(), "summary_vector_store") + + def get_kb_path(self): + return os.path.join(KB_ROOT_PATH, self.kb_name) + + def load_vector_store(self) -> ThreadSafeFaiss: + return kb_faiss_pool.load_vector_store(kb_name=self.kb_name, + vector_name="summary_vector_store", + embed_model=self.embed_model, + create=True) + + def add_kb_summary(self, summary_combine_docs: List[Document]): + with self.load_vector_store().acquire() as vs: + ids = vs.add_documents(documents=summary_combine_docs) + vs.save_local(self.vs_path) + + summary_infos = [{"summary_context": doc.page_content, + "summary_id": id, + "doc_ids": doc.metadata.get('doc_ids'), + "metadata": doc.metadata} for id, doc in zip(ids, summary_combine_docs)] + status = add_summary_to_db(kb_name=self.kb_name, summary_infos=summary_infos) + return status + + def create_kb_summary(self): + """ + 创建知识库chunk summary + :return: + """ + + if not os.path.exists(self.vs_path): + os.makedirs(self.vs_path) + + def drop_kb_summary(self): + """ + 删除知识库chunk summary + :param kb_name: + :return: + """ + with kb_faiss_pool.atomic: + kb_faiss_pool.pop(self.kb_name) + shutil.rmtree(self.vs_path) + delete_summary_from_db(kb_name=self.kb_name) diff --git a/server/knowledge_base/kb_summary/summary_chunk.py b/server/knowledge_base/kb_summary/summary_chunk.py new file mode 100644 index 0000000..b35adc1 --- /dev/null +++ b/server/knowledge_base/kb_summary/summary_chunk.py @@ -0,0 +1,250 @@ +from typing import List, Optional + +from langchain.schema.language_model import BaseLanguageModel + +from server.knowledge_base.model.kb_document_model import DocumentWithVSId +from configs import (logger) +from langchain.chains import StuffDocumentsChain, LLMChain +from langchain.prompts import PromptTemplate + +from langchain.docstore.document import Document +from langchain.output_parsers.regex import RegexParser +from langchain.chains.combine_documents.map_reduce import ReduceDocumentsChain, MapReduceDocumentsChain + +import sys +import asyncio + + +class SummaryAdapter: + _OVERLAP_SIZE: int + token_max: int + _separator: str = "\n\n" + chain: MapReduceDocumentsChain + + def __init__(self, overlap_size: int, token_max: int, + chain: MapReduceDocumentsChain): + self._OVERLAP_SIZE = overlap_size + self.chain = chain + self.token_max = token_max + + @classmethod + def form_summary(cls, + llm: BaseLanguageModel, + reduce_llm: BaseLanguageModel, + overlap_size: int, + token_max: int = 1300): + """ + 获取实例 + :param reduce_llm: 用于合并摘要的llm + :param llm: 用于生成摘要的llm + :param overlap_size: 重叠部分大小 + :param token_max: 最大的chunk数量,每个chunk长度小于token_max长度,第一次生成摘要时,大于token_max长度的摘要会报错 + :return: + """ + + # This controls how each document will be formatted. Specifically, + document_prompt = PromptTemplate( + input_variables=["page_content"], + template="{page_content}" + ) + + # The prompt here should take as an input variable the + # `document_variable_name` + prompt_template = ( + "根据文本执行任务。以下任务信息" + "{task_briefing}" + "文本内容如下: " + "\r\n" + "{context}" + ) + prompt = PromptTemplate( + template=prompt_template, + input_variables=["task_briefing", "context"] + ) + llm_chain = LLMChain(llm=llm, prompt=prompt) + # We now define how to combine these summaries + reduce_prompt = PromptTemplate.from_template( + "Combine these summaries: {context}" + ) + reduce_llm_chain = LLMChain(llm=reduce_llm, prompt=reduce_prompt) + + document_variable_name = "context" + combine_documents_chain = StuffDocumentsChain( + llm_chain=reduce_llm_chain, + document_prompt=document_prompt, + document_variable_name=document_variable_name + ) + reduce_documents_chain = ReduceDocumentsChain( + token_max=token_max, + combine_documents_chain=combine_documents_chain, + ) + chain = MapReduceDocumentsChain( + llm_chain=llm_chain, + document_variable_name=document_variable_name, + reduce_documents_chain=reduce_documents_chain, + # 返回中间步骤 + return_intermediate_steps=True + ) + return cls(overlap_size=overlap_size, + chain=chain, + token_max=token_max) + + def summarize(self, + kb_name: str, + file_description: str, + docs: List[DocumentWithVSId] = [] + ) -> List[Document]: + + if sys.version_info < (3, 10): + loop = asyncio.get_event_loop() + else: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + + asyncio.set_event_loop(loop) + # 同步调用协程代码 + return loop.run_until_complete(self.asummarize(kb_name=kb_name, + file_description=file_description, + docs=docs)) + + async def asummarize(self, + kb_name: str, + file_description: str, + docs: List[DocumentWithVSId] = []) -> List[Document]: + + logger.info("start summary") + # TODO 暂不处理文档中涉及语义重复、上下文缺失、document was longer than the context length 的问题 + # merge_docs = self._drop_overlap(docs) + # # 将merge_docs中的句子合并成一个文档 + # text = self._join_docs(merge_docs) + # 根据段落于句子的分隔符,将文档分成chunk,每个chunk长度小于token_max长度 + + """ + 这个过程分成两个部分: + 1. 对每个文档进行处理,得到每个文档的摘要 + map_results = self.llm_chain.apply( + # FYI - this is parallelized and so it is fast. + [{self.document_variable_name: d.page_content, **kwargs} for d in docs], + callbacks=callbacks, + ) + 2. 对每个文档的摘要进行合并,得到最终的摘要,return_intermediate_steps=True,返回中间步骤 + result, extra_return_dict = self.reduce_documents_chain.combine_docs( + result_docs, token_max=token_max, callbacks=callbacks, **kwargs + ) + """ + summary_combine, summary_intermediate_steps = self.chain.combine_docs(docs=docs, + task_briefing="描述不同方法之间的接近度和相似性," + "以帮助读者理解它们之间的关系。") + print(summary_combine) + print(summary_intermediate_steps) + + # if len(summary_combine) == 0: + # # 为空重新生成,数量减半 + # result_docs = [ + # Document(page_content=question_result_key, metadata=docs[i].metadata) + # # This uses metadata from the docs, and the textual results from `results` + # for i, question_result_key in enumerate( + # summary_intermediate_steps["intermediate_steps"][ + # :len(summary_intermediate_steps["intermediate_steps"]) // 2 + # ]) + # ] + # summary_combine, summary_intermediate_steps = self.chain.reduce_documents_chain.combine_docs( + # result_docs, token_max=self.token_max + # ) + logger.info("end summary") + doc_ids = ",".join([doc.id for doc in docs]) + _metadata = { + "file_description": file_description, + "summary_intermediate_steps": summary_intermediate_steps, + "doc_ids": doc_ids + } + summary_combine_doc = Document(page_content=summary_combine, metadata=_metadata) + + return [summary_combine_doc] + + def _drop_overlap(self, docs: List[DocumentWithVSId]) -> List[str]: + """ + # 将文档中page_content句子叠加的部分去掉 + :param docs: + :param separator: + :return: + """ + merge_docs = [] + + pre_doc = None + for doc in docs: + # 第一个文档直接添加 + if len(merge_docs) == 0: + pre_doc = doc.page_content + merge_docs.append(doc.page_content) + continue + + # 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分 + # 迭代递减pre_doc的长度,每次迭代删除前面的字符, + # 查询重叠部分,直到pre_doc的长度小于 self._OVERLAP_SIZE // 2 - 2len(separator) + for i in range(len(pre_doc), self._OVERLAP_SIZE // 2 - 2 * len(self._separator), -1): + # 每次迭代删除前面的字符 + pre_doc = pre_doc[1:] + if doc.page_content[:len(pre_doc)] == pre_doc: + # 删除下一个开头重叠的部分 + merge_docs.append(doc.page_content[len(pre_doc):]) + break + + pre_doc = doc.page_content + + return merge_docs + + def _join_docs(self, docs: List[str]) -> Optional[str]: + text = self._separator.join(docs) + text = text.strip() + if text == "": + return None + else: + return text + + +if __name__ == '__main__': + + docs = [ + + '梦者有特别的作用,也就是说梦是在预卜未来。因此,梦内容的', + + '梦内容的多彩多姿以及对梦者本身所遗留的特殊印象,使他们很难想象', + + '使他们很难想象出一套系统划一的观念,而需要以其个别的价值与可靠性作各', + '值与可靠性作各种不同的分化与聚合。因此,古代哲学家们对梦的评价也就完全' + ] + _OVERLAP_SIZE = 1 + separator: str = "\n\n" + merge_docs = [] + # 将文档中page_content句子叠加的部分去掉, + # 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分 + pre_doc = None + for doc in docs: + # 第一个文档直接添加 + if len(merge_docs) == 0: + pre_doc = doc + merge_docs.append(doc) + continue + + # 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分 + # 迭代递减pre_doc的长度,每次迭代删除前面的字符, + # 查询重叠部分,直到pre_doc的长度小于 _OVERLAP_SIZE-2len(separator) + for i in range(len(pre_doc), _OVERLAP_SIZE - 2 * len(separator), -1): + # 每次迭代删除前面的字符 + pre_doc = pre_doc[1:] + if doc[:len(pre_doc)] == pre_doc: + # 删除下一个开头重叠的部分 + page_content = doc[len(pre_doc):] + merge_docs.append(page_content) + + pre_doc = doc + break + + # 将merge_docs中的句子合并成一个文档 + text = separator.join(merge_docs) + text = text.strip() + + print(text) diff --git a/server/knowledge_base/model/kb_document_model.py b/server/knowledge_base/model/kb_document_model.py new file mode 100644 index 0000000..a5d2c6a --- /dev/null +++ b/server/knowledge_base/model/kb_document_model.py @@ -0,0 +1,10 @@ + +from langchain.docstore.document import Document + + +class DocumentWithVSId(Document): + """ + 矢量化后的文档 + """ + id: str = None +