diff --git a/server/api.py b/server/api.py index 7444d4b..aaf333e 100644 --- a/server/api.py +++ b/server/api.py @@ -149,7 +149,8 @@ def mount_knowledge_routes(app: FastAPI): from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, update_docs, download_doc, recreate_vector_store, - search_docs, DocumentWithScore, update_info) + search_docs, DocumentWithVSId, update_info, + update_docs_by_id,) app.post("/chat/knowledge_base_chat", tags=["Chat"], @@ -190,10 +191,17 @@ def mount_knowledge_routes(app: FastAPI): app.post("/knowledge_base/search_docs", tags=["Knowledge Base Management"], - response_model=List[DocumentWithScore], + response_model=List[DocumentWithVSId], summary="搜索知识库" )(search_docs) + app.post("/knowledge_base/update_docs_by_id", + tags=["Knowledge Base Management"], + response_model=BaseResponse, + summary="直接更新知识库文档" + )(update_docs_by_id) + + app.post("/knowledge_base/upload_docs", tags=["Knowledge Base Management"], response_model=BaseResponse, diff --git a/server/chat/chat.py b/server/chat/chat.py index 0c566dd..5783829 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -29,7 +29,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 ), stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), - temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), + temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), diff --git a/server/db/repository/knowledge_base_repository.py b/server/db/repository/knowledge_base_repository.py index d20a973..b39c8c5 100644 --- a/server/db/repository/knowledge_base_repository.py +++ b/server/db/repository/knowledge_base_repository.py @@ -5,7 +5,7 @@ from server.db.session import with_session @with_session def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model): # 创建知识库实例 - kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first() + kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first() if not kb: kb = KnowledgeBaseModel(kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model) session.add(kb) @@ -25,14 +25,14 @@ def list_kbs_from_db(session, min_file_count: int = -1): @with_session def kb_exists(session, kb_name): - kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first() + kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first() status = True if kb else False return status @with_session def load_kb_from_db(session, kb_name): - kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first() + kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first() if kb: kb_name, vs_type, embed_model = kb.kb_name, kb.vs_type, kb.embed_model else: @@ -42,7 +42,7 @@ def load_kb_from_db(session, kb_name): @with_session def delete_kb_from_db(session, kb_name): - kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first() + kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first() if kb: session.delete(kb) return True @@ -50,7 +50,7 @@ def delete_kb_from_db(session, kb_name): @with_session def get_kb_detail(session, kb_name: str) -> dict: - kb: KnowledgeBaseModel = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first() + kb: KnowledgeBaseModel = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first() if kb: return { "kb_name": kb.kb_name, diff --git a/server/db/repository/knowledge_file_repository.py b/server/db/repository/knowledge_file_repository.py index 4a37441..0a6b782 100644 --- a/server/db/repository/knowledge_file_repository.py +++ b/server/db/repository/knowledge_file_repository.py @@ -15,7 +15,7 @@ def list_docs_from_db(session, 列出某知识库某文件对应的所有Document。 返回形式:[{"id": str, "metadata": dict}, ...] ''' - docs = session.query(FileDocModel).filter_by(kb_name=kb_name) + docs = session.query(FileDocModel).filter(FileDocModel.kb_name.ilike(kb_name)) if file_name: docs = docs.filter(FileDocModel.file_name.ilike(file_name)) for k, v in metadata.items(): @@ -34,10 +34,10 @@ def delete_docs_from_db(session, 返回形式:[{"id": str, "metadata": dict}, ...] ''' docs = list_docs_from_db(kb_name=kb_name, file_name=file_name) - query = session.query(FileDocModel).filter_by(kb_name=kb_name) + query = session.query(FileDocModel).filter(FileDocModel.kb_name.ilike(kb_name)) if file_name: - query = query.filter_by(file_name=file_name) - query.delete() + query = query.filter(FileDocModel.file_name.ilike(file_name)) + query.delete(synchronize_session=False) session.commit() return docs @@ -68,12 +68,12 @@ def add_docs_to_db(session, @with_session def count_files_from_db(session, kb_name: str) -> int: - return session.query(KnowledgeFileModel).filter_by(kb_name=kb_name).count() + return session.query(KnowledgeFileModel).filter(KnowledgeFileModel.kb_name.ilike(kb_name)).count() @with_session def list_files_from_db(session, kb_name): - files = session.query(KnowledgeFileModel).filter_by(kb_name=kb_name).all() + files = session.query(KnowledgeFileModel).filter(KnowledgeFileModel.kb_name.ilike(kb_name)).all() docs = [f.file_name for f in files] return docs @@ -89,8 +89,8 @@ def add_file_to_db(session, if kb: # 如果已经存在该文件,则更新文件信息与版本号 existing_file: KnowledgeFileModel = (session.query(KnowledgeFileModel) - .filter_by(file_name=kb_file.filename, - kb_name=kb_file.kb_name) + .filter(KnowledgeFileModel.kb_name.ilike(kb_file.kb_name), + KnowledgeFileModel.file_name.ilike(kb_file.filename)) .first()) mtime = kb_file.get_mtime() size = kb_file.get_size() @@ -122,14 +122,16 @@ def add_file_to_db(session, @with_session def delete_file_from_db(session, kb_file: KnowledgeFile): - existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename, - kb_name=kb_file.kb_name).first() + existing_file = (session.query(KnowledgeFileModel) + .filter(KnowledgeFileModel.file_name.ilike(kb_file.filename), + KnowledgeFileModel.kb_name.ilike(kb_file.kb_name)) + .first()) if existing_file: session.delete(existing_file) delete_docs_from_db(kb_name=kb_file.kb_name, file_name=kb_file.filename) session.commit() - kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first() + kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_file.kb_name)).first() if kb: kb.file_count -= 1 session.commit() @@ -138,9 +140,9 @@ def delete_file_from_db(session, kb_file: KnowledgeFile): @with_session def delete_files_from_db(session, knowledge_base_name: str): - session.query(KnowledgeFileModel).filter_by(kb_name=knowledge_base_name).delete() - session.query(FileDocModel).filter_by(kb_name=knowledge_base_name).delete() - kb = session.query(KnowledgeBaseModel).filter_by(kb_name=knowledge_base_name).first() + session.query(KnowledgeFileModel).filter(KnowledgeFileModel.kb_name.ilike(knowledge_base_name)).delete(synchronize_session=False) + session.query(FileDocModel).filter(FileDocModel.kb_name.ilike(knowledge_base_name)).delete(synchronize_session=False) + kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(knowledge_base_name)).first() if kb: kb.file_count = 0 @@ -150,16 +152,19 @@ def delete_files_from_db(session, knowledge_base_name: str): @with_session def file_exists_in_db(session, kb_file: KnowledgeFile): - existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename, - kb_name=kb_file.kb_name).first() + existing_file = (session.query(KnowledgeFileModel) + .filter(KnowledgeFileModel.file_name.ilike(kb_file.filename), + KnowledgeFileModel.kb_name.ilike(kb_file.kb_name)) + .first()) return True if existing_file else False @with_session def get_file_detail(session, kb_name: str, filename: str) -> dict: file: KnowledgeFileModel = (session.query(KnowledgeFileModel) - .filter_by(file_name=filename, - kb_name=kb_name).first()) + .filter(KnowledgeFileModel.file_name.ilike(filename), + KnowledgeFileModel.kb_name.ilike(kb_name)) + .first()) if file: return { "kb_name": file.kb_name, diff --git a/server/db/repository/knowledge_metadata_repository.py b/server/db/repository/knowledge_metadata_repository.py index 4158e70..20725e3 100644 --- a/server/db/repository/knowledge_metadata_repository.py +++ b/server/db/repository/knowledge_metadata_repository.py @@ -12,7 +12,7 @@ def list_summary_from_db(session, 列出某知识库chunk summary。 返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...] ''' - docs = session.query(SummaryChunkModel).filter_by(kb_name=kb_name) + docs = session.query(SummaryChunkModel).filter(SummaryChunkModel.kb_name.ilike(kb_name)) for k, v in metadata.items(): docs = docs.filter(SummaryChunkModel.meta_data[k].as_string() == str(v)) @@ -33,8 +33,8 @@ def delete_summary_from_db(session, 返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...] ''' docs = list_summary_from_db(kb_name=kb_name) - query = session.query(SummaryChunkModel).filter_by(kb_name=kb_name) - query.delete() + query = session.query(SummaryChunkModel).filter(SummaryChunkModel.kb_name.ilike(kb_name)) + query.delete(synchronize_session=False) session.commit() return docs @@ -63,4 +63,4 @@ def add_summary_to_db(session, @with_session def count_summary_from_db(session, kb_name: str) -> int: - return session.query(SummaryChunkModel).filter_by(kb_name=kb_name).count() + return session.query(SummaryChunkModel).filter(SummaryChunkModel.kb_name.ilike(kb_name)).count() diff --git a/server/knowledge_base/kb_cache/faiss_cache.py b/server/knowledge_base/kb_cache/faiss_cache.py index 68b5494..ed48b5d 100644 --- a/server/knowledge_base/kb_cache/faiss_cache.py +++ b/server/knowledge_base/kb_cache/faiss_cache.py @@ -4,10 +4,24 @@ from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter from server.utils import load_local_embeddings from server.knowledge_base.utils import get_vs_path from langchain.vectorstores.faiss import FAISS +from langchain.docstore.in_memory import InMemoryDocstore from langchain.schema import Document import os from langchain.schema import Document + +# patch FAISS to include doc id in Document.metadata +def _new_ds_search(self, search: str) -> Union[str, Document]: + if search not in self._dict: + return f"ID {search} not found." + else: + doc = self._dict[search] + if isinstance(doc, Document): + doc.metadata["id"] = search + return doc +InMemoryDocstore.search = _new_ds_search + + class ThreadSafeFaiss(ThreadSafeObject): def __repr__(self) -> str: cls = type(self).__name__ diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index ff68a6b..09a264f 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -15,15 +15,12 @@ import json from server.knowledge_base.kb_service.base import KBServiceFactory from server.db.repository.knowledge_file_repository import get_file_detail from langchain.docstore.document import Document -from typing import List - - -class DocumentWithScore(Document): - score: float = None +from server.knowledge_base.model.kb_document_model import DocumentWithVSId +from typing import List, Dict def search_docs( - query: str = Body(..., description="用户输入", examples=["你好"]), + query: str = Body("", description="用户输入", examples=["你好"]), knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), score_threshold: float = Body(SCORE_THRESHOLD, @@ -31,13 +28,34 @@ def search_docs( "SCORE越小,相关度越高," "取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1), -) -> List[DocumentWithScore]: + file_name: str = Body("", description="文件名称,支持 sql 通配符"), + metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"), +) -> List[DocumentWithVSId]: + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + data = [] + if kb is not None: + if query: + docs = kb.search_docs(query, top_k, score_threshold) + data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs] + elif file_name or metadata: + data = kb.list_docs(file_name=file_name, metadata=metadata) + return data + + +def update_docs_by_id( + knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), + docs: Dict[str, Document] = Body(..., description="要更新的文档内容,形如:{id: Document, ...}") +) -> BaseResponse: + ''' + 按照文档 ID 更新文档内容 + ''' kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: - return [] - docs = kb.search_docs(query, top_k, score_threshold) - data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs] - return data + return BaseResponse(code=500, msg=f"指定的知识库 {knowledge_base_name} 不存在") + if kb.update_doc_by_ids(docs=docs): + return BaseResponse(msg=f"文档更新成功") + else: + return BaseResponse(msg=f"文档更新失败") def list_files( diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 1d86d30..86b9905 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -121,8 +121,9 @@ class KBService(ABC): for doc in docs: try: source = doc.metadata.get("source", "") - rel_path = Path(source).relative_to(self.doc_path) - doc.metadata["source"] = str(rel_path.as_posix().strip("/")) + if os.path.isabs(source): + rel_path = Path(source).relative_to(self.doc_path) + doc.metadata["source"] = str(rel_path.as_posix().strip("/")) except Exception as e: print(f"cannot convert absolute path ({source}) to relative path. error is : {e}") self.delete_doc(kb_file) @@ -176,13 +177,33 @@ class KBService(ABC): query: str, top_k: int = VECTOR_SEARCH_TOP_K, score_threshold: float = SCORE_THRESHOLD, - ): + ) ->List[Document]: docs = self.do_search(query, top_k, score_threshold) return docs def get_doc_by_ids(self, ids: List[str]) -> List[Document]: return [] + def del_doc_by_ids(self, ids: List[str]) -> bool: + raise NotImplementedError + + def update_doc_by_ids(self, docs: Dict[str, Document]) -> bool: + ''' + 传入参数为: {doc_id: Document, ...} + 如果对应 doc_id 的值为 None,或其 page_content 为空,则删除该文档 + TODO:是否要支持新增 docs ? + ''' + self.del_doc_by_ids(list(docs.keys())) + docs = [] + ids = [] + for k, v in docs.items(): + if not v or not v.page_content.strip(): + continue + ids.append(k) + docs.append(v) + self.do_add_doc(docs=docs, ids=ids) + return True + def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[DocumentWithVSId]: ''' 通过file_name或metadata检索Document @@ -190,10 +211,10 @@ class KBService(ABC): doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata) docs = [] for x in doc_infos: - doc_info_s = self.get_doc_by_ids([x["id"]]) - if doc_info_s is not None and doc_info_s != []: + doc_info = self.get_doc_by_ids([x["id"]])[0] + if doc_info is not None: # 处理非空的情况 - doc_with_id = DocumentWithVSId(**doc_info_s[0].dict(), id=x["id"]) + doc_with_id = DocumentWithVSId(**doc_info.dict(), id=x["id"]) docs.append(doc_with_id) else: # 处理空的情况 @@ -249,6 +270,7 @@ class KBService(ABC): @abstractmethod def do_add_doc(self, docs: List[Document], + **kwargs, ) -> List[Dict]: """ 向知识库添加文档子类实自己逻辑 @@ -371,12 +393,13 @@ def get_kb_file_details(kb_name: str) -> List[Dict]: "in_folder": True, "in_db": False, } + lower_names = {x.lower(): x for x in result} for doc in files_in_db: doc_detail = get_file_detail(kb_name, doc) if doc_detail: doc_detail["in_db"] = True - if doc in result: - result[doc].update(doc_detail) + if doc.lower() in lower_names: + result[lower_names[doc.lower()]].update(doc_detail) else: doc_detail["in_folder"] = False result[doc] = doc_detail diff --git a/server/knowledge_base/kb_service/es_kb_service.py b/server/knowledge_base/kb_service/es_kb_service.py index 1c20707..4a408cb 100644 --- a/server/knowledge_base/kb_service/es_kb_service.py +++ b/server/knowledge_base/kb_service/es_kb_service.py @@ -145,6 +145,14 @@ class ESKBService(KBService): k=top_k) return docs + def del_doc_by_ids(self, ids: List[str]) -> bool: + for doc_id in ids: + try: + self.es_client_python.delete(index=self.index_name, + id=doc_id, + refresh=True) + except Exception as e: + logger.error(f"ES Docs Delete Error! {e}") def do_delete_doc(self, kb_file, **kwargs): if self.es_client_python.indices.exists(index=self.index_name): @@ -168,7 +176,7 @@ class ESKBService(KBService): id=doc_id, refresh=True) except Exception as e: - logger.error("ES Docs Delete Error!") + logger.error(f"ES Docs Delete Error! {e}") # self.db_init.delete(ids=delete_list) #self.es_client_python.indices.refresh(index=self.index_name) diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 231c0e3..f073b4e 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -36,6 +36,10 @@ class FaissKBService(KBService): with self.load_vector_store().acquire() as vs: return [vs.docstore._dict.get(id) for id in ids] + def del_doc_by_ids(self, ids: List[str]) -> bool: + with self.load_vector_store().acquire() as vs: + vs.delete(ids) + def do_init(self): self.vector_name = self.vector_name or self.embed_model self.kb_path = self.get_kb_path() @@ -72,7 +76,8 @@ class FaissKBService(KBService): with self.load_vector_store().acquire() as vs: ids = vs.add_embeddings(text_embeddings=zip(data["texts"], data["embeddings"]), - metadatas=data["metadatas"]) + metadatas=data["metadatas"], + ids=kwargs.get("ids")) if not kwargs.get("not_refresh_vs_cache"): vs.save_local(self.vs_path) doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] @@ -83,7 +88,7 @@ class FaissKBService(KBService): kb_file: KnowledgeFile, **kwargs): with self.load_vector_store().acquire() as vs: - ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source") == kb_file.filename] + ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source").lower() == kb_file.filename.lower()] if len(ids) > 0: vs.delete(ids) if not kwargs.get("not_refresh_vs_cache"): diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index 0e0633b..dc392ee 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -31,6 +31,9 @@ class MilvusKBService(KBService): result.append(Document(page_content=text, metadata=data)) return result + def del_doc_by_ids(self, ids: List[str]) -> bool: + self.milvus.col.delete(expr=f'pk in {ids}') + @staticmethod def search(milvus_name, content, limit=3): search_params = { diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index cf58ce3..8a4afcb 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -29,6 +29,10 @@ class PGKBService(KBService): connect.execute(stmt, parameters={'ids': ids}).fetchall()] return results + # TODO: + def del_doc_by_ids(self, ids: List[str]) -> bool: + return super().del_doc_by_ids(ids) + def do_init(self): self._load_pg_vector() diff --git a/server/knowledge_base/kb_service/zilliz_kb_service.py b/server/knowledge_base/kb_service/zilliz_kb_service.py index d82f873..5d00a49 100644 --- a/server/knowledge_base/kb_service/zilliz_kb_service.py +++ b/server/knowledge_base/kb_service/zilliz_kb_service.py @@ -29,6 +29,9 @@ class ZillizKBService(KBService): result.append(Document(page_content=text, metadata=data)) return result + def del_doc_by_ids(self, ids: List[str]) -> bool: + self.zilliz.col.delete(expr=f'pk in {ids}') + @staticmethod def search(zilliz_name, content, limit=3): search_params = { diff --git a/server/knowledge_base/model/kb_document_model.py b/server/knowledge_base/model/kb_document_model.py index a5d2c6a..662929d 100644 --- a/server/knowledge_base/model/kb_document_model.py +++ b/server/knowledge_base/model/kb_document_model.py @@ -7,4 +7,4 @@ class DocumentWithVSId(Document): 矢量化后的文档 """ id: str = None - + score: float = 3.0 diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index fdfd2e3..7fced47 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -71,7 +71,8 @@ def list_files_from_folder(kb_name: str): for target_entry in target_it: process_entry(target_entry) elif entry.is_file(): - result.append(os.path.relpath(entry.path, doc_path)) + file_path = (Path(os.path.relpath(entry.path, doc_path)).as_posix()) # 路径统一为 posix 格式 + result.append(file_path) elif entry.is_dir(): with os.scandir(entry.path) as it: for sub_entry in it: @@ -272,7 +273,7 @@ class KnowledgeFile: 对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。 ''' self.kb_name = knowledge_base_name - self.filename = filename + self.filename = str(Path(filename).as_posix()) self.ext = os.path.splitext(filename)[-1].lower() if self.ext not in SUPPORTED_EXTS: raise ValueError(f"暂未支持的文件格式 {self.filename}") diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 064c8a4..325cd5d 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -224,7 +224,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): key="prompt_template_select", ) prompt_template_name = st.session_state.prompt_template_select - temperature = st.slider("Temperature:", 0.0, 1.0, TEMPERATURE, 0.05) + temperature = st.slider("Temperature:", 0.0, 2.0, TEMPERATURE, 0.05) history_len = st.number_input("历史对话轮数:", 0, 20, HISTORY_LEN) def on_kb_change(): diff --git a/webui_pages/knowledge_base/__init__.py b/webui_pages/knowledge_base/__init__.py index b7b37a0..8eca058 100644 --- a/webui_pages/knowledge_base/__init__.py +++ b/webui_pages/knowledge_base/__init__.py @@ -1 +1 @@ -from .knowledge_base import knowledge_base_page \ No newline at end of file +from .knowledge_base import knowledge_base_page diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 5348491..31fc151 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -190,6 +190,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): # 知识库详情 # st.info("请选择文件,点击按钮进行操作。") doc_details = pd.DataFrame(get_kb_file_details(kb)) + selected_rows = [] if not len(doc_details): st.info(f"知识库 `{kb}` 中暂无文件") else: @@ -284,32 +285,80 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): st.divider() - cols = st.columns(3) + # cols = st.columns(3) - if cols[0].button( - "依据源文件重建向量库", - # help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。", - use_container_width=True, - type="primary", - ): - with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"): - empty = st.empty() - empty.progress(0.0, "") - for d in api.recreate_vector_store(kb, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - zh_title_enhance=zh_title_enhance): - if msg := check_error_msg(d): - st.toast(msg) + # if cols[0].button( + # "依据源文件重建向量库", + # # help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。", + # use_container_width=True, + # type="primary", + # ): + # with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"): + # empty = st.empty() + # empty.progress(0.0, "") + # for d in api.recreate_vector_store(kb, + # chunk_size=chunk_size, + # chunk_overlap=chunk_overlap, + # zh_title_enhance=zh_title_enhance): + # if msg := check_error_msg(d): + # st.toast(msg) + # else: + # empty.progress(d["finished"] / d["total"], d["msg"]) + # st.rerun() + + # if cols[2].button( + # "删除知识库", + # use_container_width=True, + # ): + # ret = api.delete_knowledge_base(kb) + # st.toast(ret.get("msg", " ")) + # time.sleep(1) + # st.rerun() + + # with st.sidebar: + # keyword = st.text_input("查询关键字") + # top_k = st.slider("匹配条数", 1, 100, 3) + + st.write("文件内文档列表。双击进行修改,在删除列填入 Y 可删除对应行。") + docs = [] + df = pd.DataFrame([], columns=["seq", "id", "content", "source"]) + if selected_rows: + file_name = selected_rows[0]["file_name"] + docs = api.search_kb_docs(knowledge_base_name=selected_kb, file_name=file_name) + data = [{"seq": i+1, "id": x["id"], "page_content": x["page_content"], "source": x["metadata"].get("source"), + "type": x["type"], + "metadata": json.dumps(x["metadata"], ensure_ascii=False), + "to_del": "", + } for i, x in enumerate(docs)] + df = pd.DataFrame(data) + + gb = GridOptionsBuilder.from_dataframe(df) + gb.configure_columns(["id", "source", "type", "metadata"], hide=True) + gb.configure_column("seq", "No.", width=50) + gb.configure_column("page_content", "内容", editable=True, autoHeight=True, wrapText=True, flex=1, + cellEditor="agLargeTextCellEditor", cellEditorPopup=True) + gb.configure_column("to_del", "删除", editable=True, width=50, wrapHeaderText=True, + cellEditor="agCheckboxCellEditor", cellRender="agCheckboxCellRenderer") + gb.configure_selection() + edit_docs = AgGrid(df, gb.build()) + + if st.button("保存更改"): + # origin_docs = {x["id"]: {"page_content": x["page_content"], "type": x["type"], "metadata": x["metadata"]} for x in docs} + changed_docs = [] + for index, row in edit_docs.data.iterrows(): + # origin_doc = origin_docs[row["id"]] + # if row["page_content"] != origin_doc["page_content"]: + if row["to_del"] not in ["Y", "y", 1]: + changed_docs.append({ + "page_content": row["page_content"], + "type": row["type"], + "metadata": json.loads(row["metadata"]), + }) + + if changed_docs: + if api.update_kb_docs(knowledge_base_name=selected_kb, + file_names=[file_name], + docs={file_name: changed_docs}): + st.toast("更新文档成功") else: - empty.progress(d["finished"] / d["total"], d["msg"]) - st.rerun() - - if cols[2].button( - "删除知识库", - use_container_width=True, - ): - ret = api.delete_knowledge_base(kb) - st.toast(ret.get("msg", " ")) - time.sleep(1) - st.rerun() + st.toast("更新文档失败") diff --git a/webui_pages/utils.py b/webui_pages/utils.py index f688949..c7eb6d1 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -571,10 +571,12 @@ class ApiRequest: def search_kb_docs( self, - query: str, knowledge_base_name: str, + query: str = "", top_k: int = VECTOR_SEARCH_TOP_K, score_threshold: int = SCORE_THRESHOLD, + file_name: str = "", + metadata: dict = {}, ) -> List: ''' 对应api.py/knowledge_base/search_docs接口 @@ -584,6 +586,8 @@ class ApiRequest: "knowledge_base_name": knowledge_base_name, "top_k": top_k, "score_threshold": score_threshold, + "file_name": file_name, + "metadata": metadata, } response = self.post( @@ -592,6 +596,24 @@ class ApiRequest: ) return self._get_response_value(response, as_json=True) + def update_docs_by_id( + self, + knowledge_base_name: str, + docs: Dict[str, Dict], + ) -> bool: + ''' + 对应api.py/knowledge_base/update_docs_by_id接口 + ''' + data = { + "knowledge_base_name": knowledge_base_name, + "docs": docs, + } + response = self.post( + "/knowledge_base/update_docs_by_id", + json=data + ) + return self._get_response_value(response) + def upload_kb_docs( self, files: List[Union[str, Path, bytes]],