新功能:知识库管理界面支持查看、编辑、删除向量库文档 (#2471)

* 新功能:
- 知识库管理界面支持查看、编辑、删除向量库文档。暂不支持增加(aggrid添加新行比较麻烦,需要另外实现)
- 去除知识库管理界面中重建知识库和删除知识库按钮,建议用户到终端命令操作

修复:
- 所有与知识库名称、文件名称有关的数据库操作函数都改成大小写不敏感,所有路径统一为 posix 风格,避免因路径文本不一致导致数据重复和操作失效 (close #2232)

开发者:
- 添加 update_docs_by_id 函数与 API 接口。当前仅支持 FAISS,暂时未用到,未将来对知识库做更细致的修改做准备
- 统一 DocumentWithScore 与 DocumentWithVsId
- FAISS 返回的 Document.metadata 中包含 ID, 方便后续查找比对
- /knowledge_base/search_docs 接口支持 file_name, metadata 参数,可以据此检索文档

* fix bug
This commit is contained in:
liunux4odoo 2023-12-26 13:44:36 +08:00 committed by GitHub
parent 2e1442a5c1
commit 9ff7bef2c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 248 additions and 85 deletions

View File

@ -149,7 +149,8 @@ def mount_knowledge_routes(app: FastAPI):
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
update_docs, download_doc, recreate_vector_store,
search_docs, DocumentWithScore, update_info)
search_docs, DocumentWithVSId, update_info,
update_docs_by_id,)
app.post("/chat/knowledge_base_chat",
tags=["Chat"],
@ -190,10 +191,17 @@ def mount_knowledge_routes(app: FastAPI):
app.post("/knowledge_base/search_docs",
tags=["Knowledge Base Management"],
response_model=List[DocumentWithScore],
response_model=List[DocumentWithVSId],
summary="搜索知识库"
)(search_docs)
app.post("/knowledge_base/update_docs_by_id",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="直接更新知识库文档"
)(update_docs_by_id)
app.post("/knowledge_base/upload_docs",
tags=["Knowledge Base Management"],
response_model=BaseResponse,

View File

@ -29,7 +29,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),

View File

@ -5,7 +5,7 @@ from server.db.session import with_session
@with_session
def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model):
# 创建知识库实例
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first()
kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
if not kb:
kb = KnowledgeBaseModel(kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model)
session.add(kb)
@ -25,14 +25,14 @@ def list_kbs_from_db(session, min_file_count: int = -1):
@with_session
def kb_exists(session, kb_name):
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first()
kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
status = True if kb else False
return status
@with_session
def load_kb_from_db(session, kb_name):
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first()
kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
if kb:
kb_name, vs_type, embed_model = kb.kb_name, kb.vs_type, kb.embed_model
else:
@ -42,7 +42,7 @@ def load_kb_from_db(session, kb_name):
@with_session
def delete_kb_from_db(session, kb_name):
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first()
kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
if kb:
session.delete(kb)
return True
@ -50,7 +50,7 @@ def delete_kb_from_db(session, kb_name):
@with_session
def get_kb_detail(session, kb_name: str) -> dict:
kb: KnowledgeBaseModel = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first()
kb: KnowledgeBaseModel = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
if kb:
return {
"kb_name": kb.kb_name,

View File

@ -15,7 +15,7 @@ def list_docs_from_db(session,
列出某知识库某文件对应的所有Document
返回形式[{"id": str, "metadata": dict}, ...]
'''
docs = session.query(FileDocModel).filter_by(kb_name=kb_name)
docs = session.query(FileDocModel).filter(FileDocModel.kb_name.ilike(kb_name))
if file_name:
docs = docs.filter(FileDocModel.file_name.ilike(file_name))
for k, v in metadata.items():
@ -34,10 +34,10 @@ def delete_docs_from_db(session,
返回形式[{"id": str, "metadata": dict}, ...]
'''
docs = list_docs_from_db(kb_name=kb_name, file_name=file_name)
query = session.query(FileDocModel).filter_by(kb_name=kb_name)
query = session.query(FileDocModel).filter(FileDocModel.kb_name.ilike(kb_name))
if file_name:
query = query.filter_by(file_name=file_name)
query.delete()
query = query.filter(FileDocModel.file_name.ilike(file_name))
query.delete(synchronize_session=False)
session.commit()
return docs
@ -68,12 +68,12 @@ def add_docs_to_db(session,
@with_session
def count_files_from_db(session, kb_name: str) -> int:
return session.query(KnowledgeFileModel).filter_by(kb_name=kb_name).count()
return session.query(KnowledgeFileModel).filter(KnowledgeFileModel.kb_name.ilike(kb_name)).count()
@with_session
def list_files_from_db(session, kb_name):
files = session.query(KnowledgeFileModel).filter_by(kb_name=kb_name).all()
files = session.query(KnowledgeFileModel).filter(KnowledgeFileModel.kb_name.ilike(kb_name)).all()
docs = [f.file_name for f in files]
return docs
@ -89,8 +89,8 @@ def add_file_to_db(session,
if kb:
# 如果已经存在该文件,则更新文件信息与版本号
existing_file: KnowledgeFileModel = (session.query(KnowledgeFileModel)
.filter_by(file_name=kb_file.filename,
kb_name=kb_file.kb_name)
.filter(KnowledgeFileModel.kb_name.ilike(kb_file.kb_name),
KnowledgeFileModel.file_name.ilike(kb_file.filename))
.first())
mtime = kb_file.get_mtime()
size = kb_file.get_size()
@ -122,14 +122,16 @@ def add_file_to_db(session,
@with_session
def delete_file_from_db(session, kb_file: KnowledgeFile):
existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename,
kb_name=kb_file.kb_name).first()
existing_file = (session.query(KnowledgeFileModel)
.filter(KnowledgeFileModel.file_name.ilike(kb_file.filename),
KnowledgeFileModel.kb_name.ilike(kb_file.kb_name))
.first())
if existing_file:
session.delete(existing_file)
delete_docs_from_db(kb_name=kb_file.kb_name, file_name=kb_file.filename)
session.commit()
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_file.kb_name)).first()
if kb:
kb.file_count -= 1
session.commit()
@ -138,9 +140,9 @@ def delete_file_from_db(session, kb_file: KnowledgeFile):
@with_session
def delete_files_from_db(session, knowledge_base_name: str):
session.query(KnowledgeFileModel).filter_by(kb_name=knowledge_base_name).delete()
session.query(FileDocModel).filter_by(kb_name=knowledge_base_name).delete()
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=knowledge_base_name).first()
session.query(KnowledgeFileModel).filter(KnowledgeFileModel.kb_name.ilike(knowledge_base_name)).delete(synchronize_session=False)
session.query(FileDocModel).filter(FileDocModel.kb_name.ilike(knowledge_base_name)).delete(synchronize_session=False)
kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(knowledge_base_name)).first()
if kb:
kb.file_count = 0
@ -150,16 +152,19 @@ def delete_files_from_db(session, knowledge_base_name: str):
@with_session
def file_exists_in_db(session, kb_file: KnowledgeFile):
existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename,
kb_name=kb_file.kb_name).first()
existing_file = (session.query(KnowledgeFileModel)
.filter(KnowledgeFileModel.file_name.ilike(kb_file.filename),
KnowledgeFileModel.kb_name.ilike(kb_file.kb_name))
.first())
return True if existing_file else False
@with_session
def get_file_detail(session, kb_name: str, filename: str) -> dict:
file: KnowledgeFileModel = (session.query(KnowledgeFileModel)
.filter_by(file_name=filename,
kb_name=kb_name).first())
.filter(KnowledgeFileModel.file_name.ilike(filename),
KnowledgeFileModel.kb_name.ilike(kb_name))
.first())
if file:
return {
"kb_name": file.kb_name,

View File

@ -12,7 +12,7 @@ def list_summary_from_db(session,
列出某知识库chunk summary
返回形式[{"id": str, "summary_context": str, "doc_ids": str}, ...]
'''
docs = session.query(SummaryChunkModel).filter_by(kb_name=kb_name)
docs = session.query(SummaryChunkModel).filter(SummaryChunkModel.kb_name.ilike(kb_name))
for k, v in metadata.items():
docs = docs.filter(SummaryChunkModel.meta_data[k].as_string() == str(v))
@ -33,8 +33,8 @@ def delete_summary_from_db(session,
返回形式[{"id": str, "summary_context": str, "doc_ids": str}, ...]
'''
docs = list_summary_from_db(kb_name=kb_name)
query = session.query(SummaryChunkModel).filter_by(kb_name=kb_name)
query.delete()
query = session.query(SummaryChunkModel).filter(SummaryChunkModel.kb_name.ilike(kb_name))
query.delete(synchronize_session=False)
session.commit()
return docs
@ -63,4 +63,4 @@ def add_summary_to_db(session,
@with_session
def count_summary_from_db(session, kb_name: str) -> int:
return session.query(SummaryChunkModel).filter_by(kb_name=kb_name).count()
return session.query(SummaryChunkModel).filter(SummaryChunkModel.kb_name.ilike(kb_name)).count()

View File

@ -4,10 +4,24 @@ from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
from server.utils import load_local_embeddings
from server.knowledge_base.utils import get_vs_path
from langchain.vectorstores.faiss import FAISS
from langchain.docstore.in_memory import InMemoryDocstore
from langchain.schema import Document
import os
from langchain.schema import Document
# patch FAISS to include doc id in Document.metadata
def _new_ds_search(self, search: str) -> Union[str, Document]:
if search not in self._dict:
return f"ID {search} not found."
else:
doc = self._dict[search]
if isinstance(doc, Document):
doc.metadata["id"] = search
return doc
InMemoryDocstore.search = _new_ds_search
class ThreadSafeFaiss(ThreadSafeObject):
def __repr__(self) -> str:
cls = type(self).__name__

View File

@ -15,15 +15,12 @@ import json
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_file_repository import get_file_detail
from langchain.docstore.document import Document
from typing import List
class DocumentWithScore(Document):
score: float = None
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
from typing import List, Dict
def search_docs(
query: str = Body(..., description="用户输入", examples=["你好"]),
query: str = Body("", description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD,
@ -31,13 +28,34 @@ def search_docs(
"SCORE越小相关度越高"
"取到1相当于不筛选建议设置在0.5左右",
ge=0, le=1),
) -> List[DocumentWithScore]:
file_name: str = Body("", description="文件名称,支持 sql 通配符"),
metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"),
) -> List[DocumentWithVSId]:
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
data = []
if kb is not None:
if query:
docs = kb.search_docs(query, top_k, score_threshold)
data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
elif file_name or metadata:
data = kb.list_docs(file_name=file_name, metadata=metadata)
return data
def update_docs_by_id(
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
docs: Dict[str, Document] = Body(..., description="要更新的文档内容,形如:{id: Document, ...}")
) -> BaseResponse:
'''
按照文档 ID 更新文档内容
'''
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return []
docs = kb.search_docs(query, top_k, score_threshold)
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
return data
return BaseResponse(code=500, msg=f"指定的知识库 {knowledge_base_name} 不存在")
if kb.update_doc_by_ids(docs=docs):
return BaseResponse(msg=f"文档更新成功")
else:
return BaseResponse(msg=f"文档更新失败")
def list_files(

View File

@ -121,8 +121,9 @@ class KBService(ABC):
for doc in docs:
try:
source = doc.metadata.get("source", "")
rel_path = Path(source).relative_to(self.doc_path)
doc.metadata["source"] = str(rel_path.as_posix().strip("/"))
if os.path.isabs(source):
rel_path = Path(source).relative_to(self.doc_path)
doc.metadata["source"] = str(rel_path.as_posix().strip("/"))
except Exception as e:
print(f"cannot convert absolute path ({source}) to relative path. error is : {e}")
self.delete_doc(kb_file)
@ -176,13 +177,33 @@ class KBService(ABC):
query: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
):
) ->List[Document]:
docs = self.do_search(query, top_k, score_threshold)
return docs
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
return []
def del_doc_by_ids(self, ids: List[str]) -> bool:
raise NotImplementedError
def update_doc_by_ids(self, docs: Dict[str, Document]) -> bool:
'''
传入参数为 {doc_id: Document, ...}
如果对应 doc_id 的值为 None或其 page_content 为空则删除该文档
TODO是否要支持新增 docs
'''
self.del_doc_by_ids(list(docs.keys()))
docs = []
ids = []
for k, v in docs.items():
if not v or not v.page_content.strip():
continue
ids.append(k)
docs.append(v)
self.do_add_doc(docs=docs, ids=ids)
return True
def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[DocumentWithVSId]:
'''
通过file_name或metadata检索Document
@ -190,10 +211,10 @@ class KBService(ABC):
doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
docs = []
for x in doc_infos:
doc_info_s = self.get_doc_by_ids([x["id"]])
if doc_info_s is not None and doc_info_s != []:
doc_info = self.get_doc_by_ids([x["id"]])[0]
if doc_info is not None:
# 处理非空的情况
doc_with_id = DocumentWithVSId(**doc_info_s[0].dict(), id=x["id"])
doc_with_id = DocumentWithVSId(**doc_info.dict(), id=x["id"])
docs.append(doc_with_id)
else:
# 处理空的情况
@ -249,6 +270,7 @@ class KBService(ABC):
@abstractmethod
def do_add_doc(self,
docs: List[Document],
**kwargs,
) -> List[Dict]:
"""
向知识库添加文档子类实自己逻辑
@ -371,12 +393,13 @@ def get_kb_file_details(kb_name: str) -> List[Dict]:
"in_folder": True,
"in_db": False,
}
lower_names = {x.lower(): x for x in result}
for doc in files_in_db:
doc_detail = get_file_detail(kb_name, doc)
if doc_detail:
doc_detail["in_db"] = True
if doc in result:
result[doc].update(doc_detail)
if doc.lower() in lower_names:
result[lower_names[doc.lower()]].update(doc_detail)
else:
doc_detail["in_folder"] = False
result[doc] = doc_detail

View File

@ -145,6 +145,14 @@ class ESKBService(KBService):
k=top_k)
return docs
def del_doc_by_ids(self, ids: List[str]) -> bool:
for doc_id in ids:
try:
self.es_client_python.delete(index=self.index_name,
id=doc_id,
refresh=True)
except Exception as e:
logger.error(f"ES Docs Delete Error! {e}")
def do_delete_doc(self, kb_file, **kwargs):
if self.es_client_python.indices.exists(index=self.index_name):
@ -168,7 +176,7 @@ class ESKBService(KBService):
id=doc_id,
refresh=True)
except Exception as e:
logger.error("ES Docs Delete Error!")
logger.error(f"ES Docs Delete Error! {e}")
# self.db_init.delete(ids=delete_list)
#self.es_client_python.indices.refresh(index=self.index_name)

View File

@ -36,6 +36,10 @@ class FaissKBService(KBService):
with self.load_vector_store().acquire() as vs:
return [vs.docstore._dict.get(id) for id in ids]
def del_doc_by_ids(self, ids: List[str]) -> bool:
with self.load_vector_store().acquire() as vs:
vs.delete(ids)
def do_init(self):
self.vector_name = self.vector_name or self.embed_model
self.kb_path = self.get_kb_path()
@ -72,7 +76,8 @@ class FaissKBService(KBService):
with self.load_vector_store().acquire() as vs:
ids = vs.add_embeddings(text_embeddings=zip(data["texts"], data["embeddings"]),
metadatas=data["metadatas"])
metadatas=data["metadatas"],
ids=kwargs.get("ids"))
if not kwargs.get("not_refresh_vs_cache"):
vs.save_local(self.vs_path)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
@ -83,7 +88,7 @@ class FaissKBService(KBService):
kb_file: KnowledgeFile,
**kwargs):
with self.load_vector_store().acquire() as vs:
ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source") == kb_file.filename]
ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source").lower() == kb_file.filename.lower()]
if len(ids) > 0:
vs.delete(ids)
if not kwargs.get("not_refresh_vs_cache"):

View File

@ -31,6 +31,9 @@ class MilvusKBService(KBService):
result.append(Document(page_content=text, metadata=data))
return result
def del_doc_by_ids(self, ids: List[str]) -> bool:
self.milvus.col.delete(expr=f'pk in {ids}')
@staticmethod
def search(milvus_name, content, limit=3):
search_params = {

View File

@ -29,6 +29,10 @@ class PGKBService(KBService):
connect.execute(stmt, parameters={'ids': ids}).fetchall()]
return results
# TODO:
def del_doc_by_ids(self, ids: List[str]) -> bool:
return super().del_doc_by_ids(ids)
def do_init(self):
self._load_pg_vector()

View File

@ -29,6 +29,9 @@ class ZillizKBService(KBService):
result.append(Document(page_content=text, metadata=data))
return result
def del_doc_by_ids(self, ids: List[str]) -> bool:
self.zilliz.col.delete(expr=f'pk in {ids}')
@staticmethod
def search(zilliz_name, content, limit=3):
search_params = {

View File

@ -7,4 +7,4 @@ class DocumentWithVSId(Document):
矢量化后的文档
"""
id: str = None
score: float = 3.0

View File

@ -71,7 +71,8 @@ def list_files_from_folder(kb_name: str):
for target_entry in target_it:
process_entry(target_entry)
elif entry.is_file():
result.append(os.path.relpath(entry.path, doc_path))
file_path = (Path(os.path.relpath(entry.path, doc_path)).as_posix()) # 路径统一为 posix 格式
result.append(file_path)
elif entry.is_dir():
with os.scandir(entry.path) as it:
for sub_entry in it:
@ -272,7 +273,7 @@ class KnowledgeFile:
对应知识库目录中的文件必须是磁盘上存在的才能进行向量化等操作
'''
self.kb_name = knowledge_base_name
self.filename = filename
self.filename = str(Path(filename).as_posix())
self.ext = os.path.splitext(filename)[-1].lower()
if self.ext not in SUPPORTED_EXTS:
raise ValueError(f"暂未支持的文件格式 {self.filename}")

View File

@ -224,7 +224,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
key="prompt_template_select",
)
prompt_template_name = st.session_state.prompt_template_select
temperature = st.slider("Temperature", 0.0, 1.0, TEMPERATURE, 0.05)
temperature = st.slider("Temperature", 0.0, 2.0, TEMPERATURE, 0.05)
history_len = st.number_input("历史对话轮数:", 0, 20, HISTORY_LEN)
def on_kb_change():

View File

@ -1 +1 @@
from .knowledge_base import knowledge_base_page
from .knowledge_base import knowledge_base_page

View File

@ -190,6 +190,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
# 知识库详情
# st.info("请选择文件,点击按钮进行操作。")
doc_details = pd.DataFrame(get_kb_file_details(kb))
selected_rows = []
if not len(doc_details):
st.info(f"知识库 `{kb}` 中暂无文件")
else:
@ -284,32 +285,80 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
st.divider()
cols = st.columns(3)
# cols = st.columns(3)
if cols[0].button(
"依据源文件重建向量库",
# help="无需上传文件通过其它方式将文档拷贝到对应知识库content目录下点击本按钮即可重建知识库。",
use_container_width=True,
type="primary",
):
with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"):
empty = st.empty()
empty.progress(0.0, "")
for d in api.recreate_vector_store(kb,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance):
if msg := check_error_msg(d):
st.toast(msg)
# if cols[0].button(
# "依据源文件重建向量库",
# # help="无需上传文件通过其它方式将文档拷贝到对应知识库content目录下点击本按钮即可重建知识库。",
# use_container_width=True,
# type="primary",
# ):
# with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"):
# empty = st.empty()
# empty.progress(0.0, "")
# for d in api.recreate_vector_store(kb,
# chunk_size=chunk_size,
# chunk_overlap=chunk_overlap,
# zh_title_enhance=zh_title_enhance):
# if msg := check_error_msg(d):
# st.toast(msg)
# else:
# empty.progress(d["finished"] / d["total"], d["msg"])
# st.rerun()
# if cols[2].button(
# "删除知识库",
# use_container_width=True,
# ):
# ret = api.delete_knowledge_base(kb)
# st.toast(ret.get("msg", " "))
# time.sleep(1)
# st.rerun()
# with st.sidebar:
# keyword = st.text_input("查询关键字")
# top_k = st.slider("匹配条数", 1, 100, 3)
st.write("文件内文档列表。双击进行修改,在删除列填入 Y 可删除对应行。")
docs = []
df = pd.DataFrame([], columns=["seq", "id", "content", "source"])
if selected_rows:
file_name = selected_rows[0]["file_name"]
docs = api.search_kb_docs(knowledge_base_name=selected_kb, file_name=file_name)
data = [{"seq": i+1, "id": x["id"], "page_content": x["page_content"], "source": x["metadata"].get("source"),
"type": x["type"],
"metadata": json.dumps(x["metadata"], ensure_ascii=False),
"to_del": "",
} for i, x in enumerate(docs)]
df = pd.DataFrame(data)
gb = GridOptionsBuilder.from_dataframe(df)
gb.configure_columns(["id", "source", "type", "metadata"], hide=True)
gb.configure_column("seq", "No.", width=50)
gb.configure_column("page_content", "内容", editable=True, autoHeight=True, wrapText=True, flex=1,
cellEditor="agLargeTextCellEditor", cellEditorPopup=True)
gb.configure_column("to_del", "删除", editable=True, width=50, wrapHeaderText=True,
cellEditor="agCheckboxCellEditor", cellRender="agCheckboxCellRenderer")
gb.configure_selection()
edit_docs = AgGrid(df, gb.build())
if st.button("保存更改"):
# origin_docs = {x["id"]: {"page_content": x["page_content"], "type": x["type"], "metadata": x["metadata"]} for x in docs}
changed_docs = []
for index, row in edit_docs.data.iterrows():
# origin_doc = origin_docs[row["id"]]
# if row["page_content"] != origin_doc["page_content"]:
if row["to_del"] not in ["Y", "y", 1]:
changed_docs.append({
"page_content": row["page_content"],
"type": row["type"],
"metadata": json.loads(row["metadata"]),
})
if changed_docs:
if api.update_kb_docs(knowledge_base_name=selected_kb,
file_names=[file_name],
docs={file_name: changed_docs}):
st.toast("更新文档成功")
else:
empty.progress(d["finished"] / d["total"], d["msg"])
st.rerun()
if cols[2].button(
"删除知识库",
use_container_width=True,
):
ret = api.delete_knowledge_base(kb)
st.toast(ret.get("msg", " "))
time.sleep(1)
st.rerun()
st.toast("更新文档失败")

View File

@ -571,10 +571,12 @@ class ApiRequest:
def search_kb_docs(
self,
query: str,
knowledge_base_name: str,
query: str = "",
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: int = SCORE_THRESHOLD,
file_name: str = "",
metadata: dict = {},
) -> List:
'''
对应api.py/knowledge_base/search_docs接口
@ -584,6 +586,8 @@ class ApiRequest:
"knowledge_base_name": knowledge_base_name,
"top_k": top_k,
"score_threshold": score_threshold,
"file_name": file_name,
"metadata": metadata,
}
response = self.post(
@ -592,6 +596,24 @@ class ApiRequest:
)
return self._get_response_value(response, as_json=True)
def update_docs_by_id(
self,
knowledge_base_name: str,
docs: Dict[str, Dict],
) -> bool:
'''
对应api.py/knowledge_base/update_docs_by_id接口
'''
data = {
"knowledge_base_name": knowledge_base_name,
"docs": docs,
}
response = self.post(
"/knowledge_base/update_docs_by_id",
json=data
)
return self._get_response_value(response)
def upload_kb_docs(
self,
files: List[Union[str, Path, bytes]],