parent
bd0164ea62
commit
0586f94c5a
|
|
@ -16,7 +16,8 @@ from server.chat import (chat, knowledge_base_chat, openai_chat,
|
||||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||||
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||||||
update_docs, download_doc, recreate_vector_store,
|
update_docs, download_doc, recreate_vector_store,
|
||||||
search_docs, DocumentWithScore)
|
search_docs, DocumentWithScore,
|
||||||
|
recreate_summary_vector_store)
|
||||||
from server.llm_api import list_running_models, list_config_models, change_llm_model, stop_llm_model
|
from server.llm_api import list_running_models, list_config_models, change_llm_model, stop_llm_model
|
||||||
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
|
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
@ -128,6 +129,16 @@ def create_app():
|
||||||
summary="根据content中文档重建向量库,流式输出处理进度。"
|
summary="根据content中文档重建向量库,流式输出处理进度。"
|
||||||
)(recreate_vector_store)
|
)(recreate_vector_store)
|
||||||
|
|
||||||
|
app.post("/knowledge_base/recreate_summary_vector_store",
|
||||||
|
tags=["Knowledge Base Management"],
|
||||||
|
summary="重建项目库所有文件的摘要信息,并生成向量入库。"
|
||||||
|
)(recreate_summary_vector_store)
|
||||||
|
|
||||||
|
app.post("/knowledge_base/summary_file_to_vector_store",
|
||||||
|
tags=["Knowledge Base Management"],
|
||||||
|
summary="获取文件的摘要信息,并生成向量入库。"
|
||||||
|
)(recreate_summary_vector_store)
|
||||||
|
|
||||||
# LLM模型相关接口
|
# LLM模型相关接口
|
||||||
app.post("/llm_model/list_running_models",
|
app.post("/llm_model/list_running_models",
|
||||||
tags=["LLM Model Management"],
|
tags=["LLM Model Management"],
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,28 @@
|
||||||
|
from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, func
|
||||||
|
|
||||||
|
from server.db.base import Base
|
||||||
|
|
||||||
|
|
||||||
|
class SummaryChunkModel(Base):
|
||||||
|
"""
|
||||||
|
chunk summary模型,用于存储file_doc中每个doc_id的chunk 片段,
|
||||||
|
数据来源:
|
||||||
|
用户输入: 用户上传文件,可填写文件的描述,生成的file_doc中的doc_id,存入summary_chunk中
|
||||||
|
程序自动切分 对file_doc表meta_data字段信息中存储的页码信息,按每页的页码切分,自定义prompt生成总结文本,将对应页码关联的doc_id存入summary_chunk中
|
||||||
|
后续任务:
|
||||||
|
矢量库构建: 对数据库表summary_chunk中summary_context创建索引,构建矢量库,meta_data为矢量库的元数据(doc_ids)
|
||||||
|
语义关联: 通过用户输入的描述,自动切分的总结文本,计算
|
||||||
|
语义相似度
|
||||||
|
|
||||||
|
"""
|
||||||
|
__tablename__ = 'summary_chunk'
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True, comment='ID')
|
||||||
|
kb_name = Column(String(50), comment='知识库名称')
|
||||||
|
summary_context = Column(String(255), comment='总结文本')
|
||||||
|
summary_id = Column(String(255), comment='总结矢量id')
|
||||||
|
doc_ids = Column(String(1024), comment="向量库id关联列表")
|
||||||
|
meta_data = Column(JSON, default={})
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (f"<SummaryChunk(id='{self.id}', kb_name='{self.kb_name}', summary_context='{self.summary_context}',"
|
||||||
|
f" doc_ids='{self.doc_ids}', metadata='{self.metadata}')>")
|
||||||
|
|
@ -0,0 +1,66 @@
|
||||||
|
from server.db.models.knowledge_metadata_model import SummaryChunkModel
|
||||||
|
from server.db.session import with_session
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
|
||||||
|
@with_session
|
||||||
|
def list_summary_from_db(session,
|
||||||
|
kb_name: str,
|
||||||
|
metadata: Dict = {},
|
||||||
|
) -> List[Dict]:
|
||||||
|
'''
|
||||||
|
列出某知识库chunk summary。
|
||||||
|
返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...]
|
||||||
|
'''
|
||||||
|
docs = session.query(SummaryChunkModel).filter_by(kb_name=kb_name)
|
||||||
|
|
||||||
|
for k, v in metadata.items():
|
||||||
|
docs = docs.filter(SummaryChunkModel.meta_data[k].as_string() == str(v))
|
||||||
|
|
||||||
|
return [{"id": x.id,
|
||||||
|
"summary_context": x.summary_context,
|
||||||
|
"summary_id": x.summary_id,
|
||||||
|
"doc_ids": x.doc_ids,
|
||||||
|
"metadata": x.metadata} for x in docs.all()]
|
||||||
|
|
||||||
|
|
||||||
|
@with_session
|
||||||
|
def delete_summary_from_db(session,
|
||||||
|
kb_name: str
|
||||||
|
) -> List[Dict]:
|
||||||
|
'''
|
||||||
|
删除知识库chunk summary,并返回被删除的Dchunk summary。
|
||||||
|
返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...]
|
||||||
|
'''
|
||||||
|
docs = list_summary_from_db(kb_name=kb_name)
|
||||||
|
query = session.query(SummaryChunkModel).filter_by(kb_name=kb_name)
|
||||||
|
query.delete()
|
||||||
|
session.commit()
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
@with_session
|
||||||
|
def add_summary_to_db(session,
|
||||||
|
kb_name: str,
|
||||||
|
summary_infos: List[Dict]):
|
||||||
|
'''
|
||||||
|
将总结信息添加到数据库。
|
||||||
|
summary_infos形式:[{"summary_context": str, "doc_ids": str}, ...]
|
||||||
|
'''
|
||||||
|
for summary in summary_infos:
|
||||||
|
obj = SummaryChunkModel(
|
||||||
|
kb_name=kb_name,
|
||||||
|
summary_context=summary["summary_context"],
|
||||||
|
summary_id=summary["summary_id"],
|
||||||
|
doc_ids=summary["doc_ids"],
|
||||||
|
meta_data=summary["metadata"],
|
||||||
|
)
|
||||||
|
session.add(obj)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@with_session
|
||||||
|
def count_summary_from_db(session, kb_name: str) -> int:
|
||||||
|
return session.query(SummaryChunkModel).filter_by(kb_name=kb_name).count()
|
||||||
|
|
@ -4,8 +4,14 @@ from fastapi import File, Form, Body, Query, UploadFile
|
||||||
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
|
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
|
||||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
|
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
|
||||||
logger, log_verbose,)
|
logger, log_verbose, LLM_MODEL, TEMPERATURE)
|
||||||
from server.utils import BaseResponse, ListResponse, run_in_thread_pool
|
from server.utils import (
|
||||||
|
BaseResponse,
|
||||||
|
ListResponse,
|
||||||
|
run_in_thread_pool,
|
||||||
|
get_model_worker_config,
|
||||||
|
fschat_openai_api_address
|
||||||
|
)
|
||||||
from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder,get_file_path,
|
from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder,get_file_path,
|
||||||
files2docs_in_thread, KnowledgeFile)
|
files2docs_in_thread, KnowledgeFile)
|
||||||
from fastapi.responses import StreamingResponse, FileResponse
|
from fastapi.responses import StreamingResponse, FileResponse
|
||||||
|
|
@ -16,6 +22,11 @@ from server.db.repository.knowledge_file_repository import get_file_detail
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
|
||||||
|
from server.knowledge_base.kb_summary.base import KBSummaryService
|
||||||
|
from server.knowledge_base.kb_summary.summary_chunk import SummaryAdapter
|
||||||
|
|
||||||
|
|
||||||
class DocumentWithScore(Document):
|
class DocumentWithScore(Document):
|
||||||
score: float = None
|
score: float = None
|
||||||
|
|
@ -371,3 +382,164 @@ def recreate_vector_store(
|
||||||
kb.save_vector_store()
|
kb.save_vector_store()
|
||||||
|
|
||||||
return StreamingResponse(output(), media_type="text/event-stream")
|
return StreamingResponse(output(), media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
|
def recreate_summary_vector_store(
|
||||||
|
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||||
|
allow_empty_kb: bool = Body(True),
|
||||||
|
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||||
|
embed_model: str = Body(EMBEDDING_MODEL),
|
||||||
|
file_description: str = Body(''),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
重建文件摘要
|
||||||
|
:param file_description:
|
||||||
|
:param knowledge_base_name:
|
||||||
|
:param allow_empty_kb:
|
||||||
|
:param vs_type:
|
||||||
|
:param embed_model:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
def output():
|
||||||
|
|
||||||
|
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
|
||||||
|
if not kb.exists() and not allow_empty_kb:
|
||||||
|
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
||||||
|
else:
|
||||||
|
# 重新创建知识库
|
||||||
|
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
|
||||||
|
kb_summary.drop_kb_summary()
|
||||||
|
kb_summary.create_kb_summary()
|
||||||
|
config = get_model_worker_config(LLM_MODEL)
|
||||||
|
llm = ChatOpenAI(
|
||||||
|
streaming=False,
|
||||||
|
verbose=True,
|
||||||
|
# callbacks=callbacks,
|
||||||
|
openai_api_key=config.get("api_key", "EMPTY"),
|
||||||
|
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
|
||||||
|
model_name=LLM_MODEL,
|
||||||
|
temperature=TEMPERATURE,
|
||||||
|
openai_proxy=config.get("openai_proxy"),
|
||||||
|
frequency_penalty=2.0,
|
||||||
|
presence_penalty=-1.0,
|
||||||
|
stop=['.', '。']
|
||||||
|
)
|
||||||
|
reduce_llm = ChatOpenAI(
|
||||||
|
streaming=False,
|
||||||
|
verbose=True,
|
||||||
|
# callbacks=callbacks,
|
||||||
|
openai_api_key=config.get("api_key", "EMPTY"),
|
||||||
|
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
|
||||||
|
model_name=LLM_MODEL,
|
||||||
|
temperature=TEMPERATURE,
|
||||||
|
openai_proxy=config.get("openai_proxy")
|
||||||
|
)
|
||||||
|
# 文本摘要适配器
|
||||||
|
summary = SummaryAdapter.form_summary(llm=llm,
|
||||||
|
reduce_llm=reduce_llm,
|
||||||
|
overlap_size=OVERLAP_SIZE)
|
||||||
|
files = list_files_from_folder(knowledge_base_name)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
for i, file_name in enumerate(files):
|
||||||
|
|
||||||
|
doc_infos = kb.list_docs(file_name=file_name)
|
||||||
|
docs = summary.summarize(kb_name=knowledge_base_name,
|
||||||
|
file_description=file_description,
|
||||||
|
docs=doc_infos)
|
||||||
|
|
||||||
|
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
|
||||||
|
if status_kb_summary:
|
||||||
|
logger.info(f"({i + 1} / {len(files)}): {file_name} 向量化总结完成")
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
|
||||||
|
logger.error(msg)
|
||||||
|
yield json.dumps({
|
||||||
|
"code": 500,
|
||||||
|
"msg": msg,
|
||||||
|
})
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return StreamingResponse(output(), media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
|
def summary_file_to_vector_store(
|
||||||
|
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||||
|
file_name: StreamingResponse = Body(..., examples=["test.pdf"]),
|
||||||
|
allow_empty_kb: bool = Body(True),
|
||||||
|
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||||
|
embed_model: str = Body(EMBEDDING_MODEL),
|
||||||
|
file_description: str = Body(''),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
文件摘要
|
||||||
|
:param file_description:
|
||||||
|
:param file_name:
|
||||||
|
:param knowledge_base_name:
|
||||||
|
:param allow_empty_kb:
|
||||||
|
:param vs_type:
|
||||||
|
:param embed_model:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
def output():
|
||||||
|
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
|
||||||
|
if not kb.exists() and not allow_empty_kb:
|
||||||
|
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
||||||
|
else:
|
||||||
|
# 重新创建知识库
|
||||||
|
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
|
||||||
|
kb_summary.drop_kb_summary()
|
||||||
|
kb_summary.create_kb_summary()
|
||||||
|
config = get_model_worker_config(LLM_MODEL)
|
||||||
|
llm = ChatOpenAI(
|
||||||
|
streaming=False,
|
||||||
|
verbose=True,
|
||||||
|
# callbacks=callbacks,
|
||||||
|
openai_api_key=config.get("api_key", "EMPTY"),
|
||||||
|
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
|
||||||
|
model_name=LLM_MODEL,
|
||||||
|
temperature=TEMPERATURE,
|
||||||
|
openai_proxy=config.get("openai_proxy"),
|
||||||
|
frequency_penalty=2.0,
|
||||||
|
presence_penalty=-1.0,
|
||||||
|
stop=['.', '。']
|
||||||
|
)
|
||||||
|
reduce_llm = ChatOpenAI(
|
||||||
|
streaming=False,
|
||||||
|
verbose=True,
|
||||||
|
# callbacks=callbacks,
|
||||||
|
openai_api_key=config.get("api_key", "EMPTY"),
|
||||||
|
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
|
||||||
|
model_name=LLM_MODEL,
|
||||||
|
temperature=TEMPERATURE,
|
||||||
|
openai_proxy=config.get("openai_proxy")
|
||||||
|
)
|
||||||
|
# 文本摘要适配器
|
||||||
|
summary = SummaryAdapter.form_summary(llm=llm,
|
||||||
|
reduce_llm=reduce_llm,
|
||||||
|
overlap_size=OVERLAP_SIZE)
|
||||||
|
files = list_files_from_folder(knowledge_base_name)
|
||||||
|
|
||||||
|
doc_infos = kb.list_docs(file_name=file_name)
|
||||||
|
docs = summary.summarize(kb_name=knowledge_base_name,
|
||||||
|
file_description=file_description,
|
||||||
|
docs=doc_infos)
|
||||||
|
|
||||||
|
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
|
||||||
|
if status_kb_summary:
|
||||||
|
logger.info(f"({i + 1} / {len(files)}): {file_name} 向量化总结完成")
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
|
||||||
|
logger.error(msg)
|
||||||
|
yield json.dumps({
|
||||||
|
"code": 500,
|
||||||
|
"msg": msg,
|
||||||
|
})
|
||||||
|
|
||||||
|
return StreamingResponse(output(), media_type="text/event-stream")
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,8 @@ from server.knowledge_base.utils import (
|
||||||
from server.utils import embedding_device
|
from server.utils import embedding_device
|
||||||
from typing import List, Union, Dict, Optional
|
from typing import List, Union, Dict, Optional
|
||||||
|
|
||||||
|
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||||
|
|
||||||
|
|
||||||
class SupportedVSType:
|
class SupportedVSType:
|
||||||
FAISS = 'faiss'
|
FAISS = 'faiss'
|
||||||
|
|
@ -147,12 +149,22 @@ class KBService(ABC):
|
||||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[Document]:
|
def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[DocumentWithVSId]:
|
||||||
'''
|
'''
|
||||||
通过file_name或metadata检索Document
|
通过file_name或metadata检索Document
|
||||||
'''
|
'''
|
||||||
doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
|
doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
|
||||||
docs = [self.get_doc_by_id(x["id"]) for x in doc_infos]
|
docs = []
|
||||||
|
for x in doc_infos:
|
||||||
|
doc = self.get_doc_by_id(x["id"])
|
||||||
|
if doc is not None:
|
||||||
|
# 处理非空的情况
|
||||||
|
doc_with_id = DocumentWithVSId(**doc.dict(), id=x["id"])
|
||||||
|
docs.append(doc_with_id)
|
||||||
|
else:
|
||||||
|
# 处理空的情况
|
||||||
|
# 可以选择跳过当前循环迭代或执行其他操作
|
||||||
|
pass
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,79 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from configs import (
|
||||||
|
EMBEDDING_MODEL,
|
||||||
|
KB_ROOT_PATH)
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from server.db.repository.knowledge_metadata_repository import add_summary_to_db, delete_summary_from_db
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
|
|
||||||
|
# TODO 暂不考虑文件更新,需要重新删除相关文档,再重新添加
|
||||||
|
class KBSummaryService(ABC):
|
||||||
|
kb_name: str
|
||||||
|
embed_model: str
|
||||||
|
vs_path: str
|
||||||
|
kb_path: str
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
knowledge_base_name: str,
|
||||||
|
embed_model: str = EMBEDDING_MODEL
|
||||||
|
):
|
||||||
|
self.kb_name = knowledge_base_name
|
||||||
|
self.embed_model = embed_model
|
||||||
|
|
||||||
|
self.kb_path = self.get_kb_path()
|
||||||
|
self.vs_path = self.get_vs_path()
|
||||||
|
|
||||||
|
if not os.path.exists(self.vs_path):
|
||||||
|
os.makedirs(self.vs_path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_vs_path(self):
|
||||||
|
return os.path.join(self.get_kb_path(), "summary_vector_store")
|
||||||
|
|
||||||
|
def get_kb_path(self):
|
||||||
|
return os.path.join(KB_ROOT_PATH, self.kb_name)
|
||||||
|
|
||||||
|
def load_vector_store(self) -> ThreadSafeFaiss:
|
||||||
|
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
|
||||||
|
vector_name="summary_vector_store",
|
||||||
|
embed_model=self.embed_model,
|
||||||
|
create=True)
|
||||||
|
|
||||||
|
def add_kb_summary(self, summary_combine_docs: List[Document]):
|
||||||
|
with self.load_vector_store().acquire() as vs:
|
||||||
|
ids = vs.add_documents(documents=summary_combine_docs)
|
||||||
|
vs.save_local(self.vs_path)
|
||||||
|
|
||||||
|
summary_infos = [{"summary_context": doc.page_content,
|
||||||
|
"summary_id": id,
|
||||||
|
"doc_ids": doc.metadata.get('doc_ids'),
|
||||||
|
"metadata": doc.metadata} for id, doc in zip(ids, summary_combine_docs)]
|
||||||
|
status = add_summary_to_db(kb_name=self.kb_name, summary_infos=summary_infos)
|
||||||
|
return status
|
||||||
|
|
||||||
|
def create_kb_summary(self):
|
||||||
|
"""
|
||||||
|
创建知识库chunk summary
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not os.path.exists(self.vs_path):
|
||||||
|
os.makedirs(self.vs_path)
|
||||||
|
|
||||||
|
def drop_kb_summary(self):
|
||||||
|
"""
|
||||||
|
删除知识库chunk summary
|
||||||
|
:param kb_name:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
with kb_faiss_pool.atomic:
|
||||||
|
kb_faiss_pool.pop(self.kb_name)
|
||||||
|
shutil.rmtree(self.vs_path)
|
||||||
|
delete_summary_from_db(kb_name=self.kb_name)
|
||||||
|
|
@ -0,0 +1,250 @@
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
|
|
||||||
|
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||||
|
from configs import (logger)
|
||||||
|
from langchain.chains import StuffDocumentsChain, LLMChain
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.output_parsers.regex import RegexParser
|
||||||
|
from langchain.chains.combine_documents.map_reduce import ReduceDocumentsChain, MapReduceDocumentsChain
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
class SummaryAdapter:
|
||||||
|
_OVERLAP_SIZE: int
|
||||||
|
token_max: int
|
||||||
|
_separator: str = "\n\n"
|
||||||
|
chain: MapReduceDocumentsChain
|
||||||
|
|
||||||
|
def __init__(self, overlap_size: int, token_max: int,
|
||||||
|
chain: MapReduceDocumentsChain):
|
||||||
|
self._OVERLAP_SIZE = overlap_size
|
||||||
|
self.chain = chain
|
||||||
|
self.token_max = token_max
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def form_summary(cls,
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
reduce_llm: BaseLanguageModel,
|
||||||
|
overlap_size: int,
|
||||||
|
token_max: int = 1300):
|
||||||
|
"""
|
||||||
|
获取实例
|
||||||
|
:param reduce_llm: 用于合并摘要的llm
|
||||||
|
:param llm: 用于生成摘要的llm
|
||||||
|
:param overlap_size: 重叠部分大小
|
||||||
|
:param token_max: 最大的chunk数量,每个chunk长度小于token_max长度,第一次生成摘要时,大于token_max长度的摘要会报错
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
# This controls how each document will be formatted. Specifically,
|
||||||
|
document_prompt = PromptTemplate(
|
||||||
|
input_variables=["page_content"],
|
||||||
|
template="{page_content}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# The prompt here should take as an input variable the
|
||||||
|
# `document_variable_name`
|
||||||
|
prompt_template = (
|
||||||
|
"根据文本执行任务。以下任务信息"
|
||||||
|
"{task_briefing}"
|
||||||
|
"文本内容如下: "
|
||||||
|
"\r\n"
|
||||||
|
"{context}"
|
||||||
|
)
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
template=prompt_template,
|
||||||
|
input_variables=["task_briefing", "context"]
|
||||||
|
)
|
||||||
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
# We now define how to combine these summaries
|
||||||
|
reduce_prompt = PromptTemplate.from_template(
|
||||||
|
"Combine these summaries: {context}"
|
||||||
|
)
|
||||||
|
reduce_llm_chain = LLMChain(llm=reduce_llm, prompt=reduce_prompt)
|
||||||
|
|
||||||
|
document_variable_name = "context"
|
||||||
|
combine_documents_chain = StuffDocumentsChain(
|
||||||
|
llm_chain=reduce_llm_chain,
|
||||||
|
document_prompt=document_prompt,
|
||||||
|
document_variable_name=document_variable_name
|
||||||
|
)
|
||||||
|
reduce_documents_chain = ReduceDocumentsChain(
|
||||||
|
token_max=token_max,
|
||||||
|
combine_documents_chain=combine_documents_chain,
|
||||||
|
)
|
||||||
|
chain = MapReduceDocumentsChain(
|
||||||
|
llm_chain=llm_chain,
|
||||||
|
document_variable_name=document_variable_name,
|
||||||
|
reduce_documents_chain=reduce_documents_chain,
|
||||||
|
# 返回中间步骤
|
||||||
|
return_intermediate_steps=True
|
||||||
|
)
|
||||||
|
return cls(overlap_size=overlap_size,
|
||||||
|
chain=chain,
|
||||||
|
token_max=token_max)
|
||||||
|
|
||||||
|
def summarize(self,
|
||||||
|
kb_name: str,
|
||||||
|
file_description: str,
|
||||||
|
docs: List[DocumentWithVSId] = []
|
||||||
|
) -> List[Document]:
|
||||||
|
|
||||||
|
if sys.version_info < (3, 10):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
# 同步调用协程代码
|
||||||
|
return loop.run_until_complete(self.asummarize(kb_name=kb_name,
|
||||||
|
file_description=file_description,
|
||||||
|
docs=docs))
|
||||||
|
|
||||||
|
async def asummarize(self,
|
||||||
|
kb_name: str,
|
||||||
|
file_description: str,
|
||||||
|
docs: List[DocumentWithVSId] = []) -> List[Document]:
|
||||||
|
|
||||||
|
logger.info("start summary")
|
||||||
|
# TODO 暂不处理文档中涉及语义重复、上下文缺失、document was longer than the context length 的问题
|
||||||
|
# merge_docs = self._drop_overlap(docs)
|
||||||
|
# # 将merge_docs中的句子合并成一个文档
|
||||||
|
# text = self._join_docs(merge_docs)
|
||||||
|
# 根据段落于句子的分隔符,将文档分成chunk,每个chunk长度小于token_max长度
|
||||||
|
|
||||||
|
"""
|
||||||
|
这个过程分成两个部分:
|
||||||
|
1. 对每个文档进行处理,得到每个文档的摘要
|
||||||
|
map_results = self.llm_chain.apply(
|
||||||
|
# FYI - this is parallelized and so it is fast.
|
||||||
|
[{self.document_variable_name: d.page_content, **kwargs} for d in docs],
|
||||||
|
callbacks=callbacks,
|
||||||
|
)
|
||||||
|
2. 对每个文档的摘要进行合并,得到最终的摘要,return_intermediate_steps=True,返回中间步骤
|
||||||
|
result, extra_return_dict = self.reduce_documents_chain.combine_docs(
|
||||||
|
result_docs, token_max=token_max, callbacks=callbacks, **kwargs
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
summary_combine, summary_intermediate_steps = self.chain.combine_docs(docs=docs,
|
||||||
|
task_briefing="描述不同方法之间的接近度和相似性,"
|
||||||
|
"以帮助读者理解它们之间的关系。")
|
||||||
|
print(summary_combine)
|
||||||
|
print(summary_intermediate_steps)
|
||||||
|
|
||||||
|
# if len(summary_combine) == 0:
|
||||||
|
# # 为空重新生成,数量减半
|
||||||
|
# result_docs = [
|
||||||
|
# Document(page_content=question_result_key, metadata=docs[i].metadata)
|
||||||
|
# # This uses metadata from the docs, and the textual results from `results`
|
||||||
|
# for i, question_result_key in enumerate(
|
||||||
|
# summary_intermediate_steps["intermediate_steps"][
|
||||||
|
# :len(summary_intermediate_steps["intermediate_steps"]) // 2
|
||||||
|
# ])
|
||||||
|
# ]
|
||||||
|
# summary_combine, summary_intermediate_steps = self.chain.reduce_documents_chain.combine_docs(
|
||||||
|
# result_docs, token_max=self.token_max
|
||||||
|
# )
|
||||||
|
logger.info("end summary")
|
||||||
|
doc_ids = ",".join([doc.id for doc in docs])
|
||||||
|
_metadata = {
|
||||||
|
"file_description": file_description,
|
||||||
|
"summary_intermediate_steps": summary_intermediate_steps,
|
||||||
|
"doc_ids": doc_ids
|
||||||
|
}
|
||||||
|
summary_combine_doc = Document(page_content=summary_combine, metadata=_metadata)
|
||||||
|
|
||||||
|
return [summary_combine_doc]
|
||||||
|
|
||||||
|
def _drop_overlap(self, docs: List[DocumentWithVSId]) -> List[str]:
|
||||||
|
"""
|
||||||
|
# 将文档中page_content句子叠加的部分去掉
|
||||||
|
:param docs:
|
||||||
|
:param separator:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
merge_docs = []
|
||||||
|
|
||||||
|
pre_doc = None
|
||||||
|
for doc in docs:
|
||||||
|
# 第一个文档直接添加
|
||||||
|
if len(merge_docs) == 0:
|
||||||
|
pre_doc = doc.page_content
|
||||||
|
merge_docs.append(doc.page_content)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
|
||||||
|
# 迭代递减pre_doc的长度,每次迭代删除前面的字符,
|
||||||
|
# 查询重叠部分,直到pre_doc的长度小于 self._OVERLAP_SIZE // 2 - 2len(separator)
|
||||||
|
for i in range(len(pre_doc), self._OVERLAP_SIZE // 2 - 2 * len(self._separator), -1):
|
||||||
|
# 每次迭代删除前面的字符
|
||||||
|
pre_doc = pre_doc[1:]
|
||||||
|
if doc.page_content[:len(pre_doc)] == pre_doc:
|
||||||
|
# 删除下一个开头重叠的部分
|
||||||
|
merge_docs.append(doc.page_content[len(pre_doc):])
|
||||||
|
break
|
||||||
|
|
||||||
|
pre_doc = doc.page_content
|
||||||
|
|
||||||
|
return merge_docs
|
||||||
|
|
||||||
|
def _join_docs(self, docs: List[str]) -> Optional[str]:
|
||||||
|
text = self._separator.join(docs)
|
||||||
|
text = text.strip()
|
||||||
|
if text == "":
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
docs = [
|
||||||
|
|
||||||
|
'梦者有特别的作用,也就是说梦是在预卜未来。因此,梦内容的',
|
||||||
|
|
||||||
|
'梦内容的多彩多姿以及对梦者本身所遗留的特殊印象,使他们很难想象',
|
||||||
|
|
||||||
|
'使他们很难想象出一套系统划一的观念,而需要以其个别的价值与可靠性作各',
|
||||||
|
'值与可靠性作各种不同的分化与聚合。因此,古代哲学家们对梦的评价也就完全'
|
||||||
|
]
|
||||||
|
_OVERLAP_SIZE = 1
|
||||||
|
separator: str = "\n\n"
|
||||||
|
merge_docs = []
|
||||||
|
# 将文档中page_content句子叠加的部分去掉,
|
||||||
|
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
|
||||||
|
pre_doc = None
|
||||||
|
for doc in docs:
|
||||||
|
# 第一个文档直接添加
|
||||||
|
if len(merge_docs) == 0:
|
||||||
|
pre_doc = doc
|
||||||
|
merge_docs.append(doc)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
|
||||||
|
# 迭代递减pre_doc的长度,每次迭代删除前面的字符,
|
||||||
|
# 查询重叠部分,直到pre_doc的长度小于 _OVERLAP_SIZE-2len(separator)
|
||||||
|
for i in range(len(pre_doc), _OVERLAP_SIZE - 2 * len(separator), -1):
|
||||||
|
# 每次迭代删除前面的字符
|
||||||
|
pre_doc = pre_doc[1:]
|
||||||
|
if doc[:len(pre_doc)] == pre_doc:
|
||||||
|
# 删除下一个开头重叠的部分
|
||||||
|
page_content = doc[len(pre_doc):]
|
||||||
|
merge_docs.append(page_content)
|
||||||
|
|
||||||
|
pre_doc = doc
|
||||||
|
break
|
||||||
|
|
||||||
|
# 将merge_docs中的句子合并成一个文档
|
||||||
|
text = separator.join(merge_docs)
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
|
print(text)
|
||||||
|
|
@ -5,6 +5,7 @@ from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder,
|
||||||
KnowledgeFile,)
|
KnowledgeFile,)
|
||||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||||
from server.db.repository.knowledge_file_repository import add_file_to_db
|
from server.db.repository.knowledge_file_repository import add_file_to_db
|
||||||
|
from server.db.repository.knowledge_metadata_repository import add_summary_to_db
|
||||||
from server.db.base import Base, engine
|
from server.db.base import Base, engine
|
||||||
import os
|
import os
|
||||||
from typing import Literal, Any, List
|
from typing import Literal, Any, List
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentWithVSId(Document):
|
||||||
|
"""
|
||||||
|
矢量化后的文档
|
||||||
|
"""
|
||||||
|
id: str = None
|
||||||
|
|
||||||
Loading…
Reference in New Issue