diff --git a/server/api.py b/server/api.py index e5d9830..873b887 100644 --- a/server/api.py +++ b/server/api.py @@ -71,7 +71,7 @@ def create_app(): summary="创建知识库" )(create_kb) - app.delete("/knowledge_base/delete_knowledge_base", + app.post("/knowledge_base/delete_knowledge_base", tags=["Knowledge Base Management"], response_model=BaseResponse, summary="删除知识库" @@ -89,7 +89,7 @@ def create_app(): summary="上传文件到知识库" )(upload_doc) - app.delete("/knowledge_base/delete_doc", + app.post("/knowledge_base/delete_doc", tags=["Knowledge Base Management"], response_model=BaseResponse, summary="删除知识库内指定文件" @@ -106,10 +106,6 @@ def create_app(): summary="根据content中文档重建向量库,流式输出处理进度。" )(recreate_vector_store) - # init local vector store info to database - from webui_pages.utils import init_vs_database - init_vs_database() - return app diff --git a/server/db/repository/knowledge_base_repository.py b/server/db/repository/knowledge_base_repository.py index 9f5ad40..585fd9b 100644 --- a/server/db/repository/knowledge_base_repository.py +++ b/server/db/repository/knowledge_base_repository.py @@ -9,6 +9,9 @@ def add_kb_to_db(session, kb_name, vs_type, embed_model): if not kb: kb = KnowledgeBaseModel(kb_name=kb_name, vs_type=vs_type, embed_model=embed_model) session.add(kb) + else: # update kb with new vs_type and embed_model + kb.vs_type = vs_type + kb.embed_model = embed_model return True diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index a9b102f..c22bac7 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -1,16 +1,25 @@ from abc import ABC, abstractmethod import os +import pandas as pd from langchain.embeddings.base import Embeddings from langchain.docstore.document import Document +from server.db.repository.knowledge_base_repository import ( + add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, + load_kb_from_db, get_kb_detail, +) +from server.db.repository.knowledge_file_repository import ( + add_doc_to_db, delete_file_from_db, doc_exists, + list_docs_from_db, get_file_detail +) -from server.db.repository.knowledge_base_repository import add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, load_kb_from_db -from server.db.repository.knowledge_file_repository import add_doc_to_db, delete_file_from_db, doc_exists, \ - list_docs_from_db from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, EMBEDDING_DEVICE, EMBEDDING_MODEL) -from server.knowledge_base.utils import (get_kb_path, get_doc_path, load_embeddings, KnowledgeFile) +from server.knowledge_base.utils import ( + get_kb_path, get_doc_path, load_embeddings, KnowledgeFile, + list_kbs_from_folder, list_docs_from_folder, +) from typing import List, Union @@ -77,10 +86,10 @@ class KBService(ABC): """ 从知识库删除文件 """ - if delete_content and os.path.exists(kb_file.filepath): - os.remove(kb_file.filepath) self.do_delete_doc(kb_file) status = delete_file_from_db(kb_file) + if delete_content and os.path.exists(kb_file.filepath): + os.remove(kb_file.filepath) return status def update_doc(self, kb_file: KnowledgeFile): @@ -121,10 +130,9 @@ class KBService(ABC): def list_kbs(cls): return list_kbs_from_db() - @classmethod - def exists(cls, - knowledge_base_name: str): - return kb_exists(knowledge_base_name) + def exists(self, kb_name: str = None): + kb_name = kb_name or self.kb_name + return kb_exists(kb_name) @abstractmethod def vs_type(self) -> str: @@ -212,3 +220,86 @@ class KBServiceFactory: def get_default(): return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT) + +def get_kb_details() -> pd.DataFrame: + kbs_in_folder = list_kbs_from_folder() + kbs_in_db = KBService.list_kbs() + result = {} + + for kb in kbs_in_folder: + result[kb] = { + "kb_name": kb, + "vs_type": "", + "embed_model": "", + "file_count": 0, + "create_time": None, + "in_folder": True, + "in_db": False, + } + + for kb in kbs_in_db: + kb_detail = get_kb_detail(kb) + if kb_detail: + kb_detail["in_db"] = True + if kb in result: + result[kb].update(kb_detail) + else: + kb_detail["in_folder"] = False + result[kb] = kb_detail + + df = pd.DataFrame(result.values(), columns=[ + "kb_name", + "vs_type", + "embed_model", + "file_count", + "create_time", + "in_folder", + "in_db", + ]) + df.insert(0, "No", range(1, len(df) + 1)) + return df + + +def get_kb_doc_details(kb_name: str) -> pd.DataFrame: + kb = KBServiceFactory.get_service_by_name(kb_name) + docs_in_folder = list_docs_from_folder(kb_name) + docs_in_db = kb.list_docs() + result = {} + + for doc in docs_in_folder: + result[doc] = { + "kb_name": kb_name, + "file_name": doc, + "file_ext": os.path.splitext(doc)[-1], + "file_version": 0, + "document_loader": "", + "text_splitter": "", + "create_time": None, + "in_folder": True, + "in_db": False, + } + + for doc in docs_in_db: + doc_detail = get_file_detail(kb_name, doc) + if doc_detail: + doc_detail["in_db"] = True + if doc in result: + result[doc].update(doc_detail) + else: + doc_detail["in_folder"] = False + result[doc] = doc_detail + + df = pd.DataFrame(result.values(), columns=[ + "kb_name", + "file_name", + "file_ext", + "file_version", + "document_loader", + "text_splitter", + "create_time", + "in_folder", + "in_db", + ]) + df.insert(0, "No", range(1, len(df) + 1)) + return df + diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index 721862a..29f274e 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -70,7 +70,6 @@ def folder2db( def recreate_all_vs( - mode: Literal["recreate_vs", "fill_info_only", "update_in_db", "increament"], vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE, embed_mode: str = EMBEDDING_MODEL, ): @@ -78,7 +77,7 @@ def recreate_all_vs( used to recreate a vector store or change current vector store to another type or embed_model ''' for kb_name in list_kbs_from_folder(): - folder2db(kb_name, mode, vs_type, embed_mode) + folder2db(kb_name, "recreate_vs", vs_type, embed_mode) def prune_db_docs(kb_name: str): diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index a6ff4b8..3e8be26 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -52,7 +52,6 @@ LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg '.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv' "CSVLoader": [".csv"], "PyPDFLoader": [".pdf"], - } SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist] diff --git a/server/utils.py b/server/utils.py index 7228930..e1a23d1 100644 --- a/server/utils.py +++ b/server/utils.py @@ -5,6 +5,7 @@ import torch from fastapi_offline import FastAPIOffline import fastapi_offline from pathlib import Path +import asyncio # patch fastapi_offline to use local static assests @@ -82,3 +83,32 @@ def torch_gc(): except Exception as e: print(e) print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。") + + +def run_async(cor): + ''' + 在同步环境中运行异步代码. + ''' + try: + loop = asyncio.get_event_loop() + except: + loop = asyncio.new_event_loop() + return loop.run_until_complete(cor) + + +def iter_over_async(ait, loop): + ''' + 将异步生成器封装成同步生成器. + ''' + ait = ait.__aiter__() + async def get_next(): + try: + obj = await ait.__anext__() + return False, obj + except StopAsyncIteration: + return True, None + while True: + done, obj = loop.run_until_complete(get_next()) + if done: + break + yield obj diff --git a/webui.py b/webui.py index 559b74a..c742437 100644 --- a/webui.py +++ b/webui.py @@ -12,9 +12,6 @@ from webui_pages import * api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False) if __name__ == "__main__": - # init local vector store info to database - init_vs_database() - st.set_page_config("langchain-chatglm WebUI") if not chat_box.chat_inited: diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 3b7a445..86bf499 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -86,12 +86,11 @@ def dialogue_page(api: ApiRequest): ["LLM 对话", "知识库问答", "搜索引擎问答", - ], - on_change=on_mode_change, - key="dialogue_mode", - ) - history_len = st.slider("历史对话轮数:", 1, 10, 3) - + ], + on_change=on_mode_change, + key="dialogue_mode", + ) + history_len = st.slider("历史对话轮数:", 0, 10, 3) # todo: support history len def on_kb_change(): diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index cb61178..a790bcf 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -1,11 +1,10 @@ import streamlit as st from webui_pages.utils import * -# import streamlit_antd_components as sac from st_aggrid import AgGrid from st_aggrid.grid_options_builder import GridOptionsBuilder import pandas as pd from server.knowledge_base.utils import get_file_path -# from streamlit_chatbox import * +from server.knowledge_base.kb_service.base import get_kb_details, get_kb_doc_details from typing import Literal, Dict, Tuple SENTENCE_SIZE = 100 @@ -33,7 +32,7 @@ def config_aggrid( def knowledge_base_page(api: ApiRequest): # api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=True) - kb_details = get_kb_details(api) + kb_details = get_kb_details() kb_list = list(kb_details.kb_name) cols = st.columns([3, 1, 1]) @@ -127,7 +126,7 @@ def knowledge_base_page(api: ApiRequest): # 知识库详情 st.write(f"知识库 `{kb}` 详情:") st.info("请选择文件") - doc_details = get_kb_doc_details(api, kb) + doc_details = get_kb_doc_details(kb) doc_details.drop(columns=["kb_name"], inplace=True) gb = config_aggrid( diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 6dac5a9..1432326 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -1,7 +1,6 @@ # 该文件包含webui通用工具,可以被不同的webui使用 from typing import * from pathlib import Path -import os from configs.model_config import ( EMBEDDING_MODEL, KB_ROOT_PATH, @@ -17,10 +16,9 @@ from fastapi.responses import StreamingResponse import contextlib import json from io import BytesIO -import pandas as pd -from server.knowledge_base.utils import list_kbs_from_folder, list_docs_from_folder from server.db.repository.knowledge_base_repository import get_kb_detail from server.db.repository.knowledge_file_repository import get_file_detail +from server.utils import run_async, iter_over_async def set_httpx_timeout(timeout=60.0): @@ -37,35 +35,6 @@ KB_ROOT_PATH = Path(KB_ROOT_PATH) set_httpx_timeout() -def run_async(cor): - ''' - 在同步环境中运行异步代码. - ''' - try: - loop = asyncio.get_event_loop() - except: - loop = asyncio.new_event_loop() - return loop.run_until_complete(cor) - - -def iter_over_async(ait, loop): - ''' - 将异步生成器封装成同步生成器. - ''' - ait = ait.__aiter__() - async def get_next(): - try: - obj = await ait.__anext__() - return False, obj - except StopAsyncIteration: - return True, None - while True: - done, obj = loop.run_until_complete(get_next()) - if done: - break - yield obj - - class ApiRequest: ''' api.py调用的封装,主要实现: @@ -97,13 +66,17 @@ class ApiRequest: url: str, params: Union[Dict, List[Tuple], bytes] = None, retry: int = 3, + stream: bool = False, **kwargs: Any, ) -> Union[httpx.Response, None]: url = self._parse_url(url) kwargs.setdefault("timeout", self.timeout) while retry > 0: try: - return httpx.get(url, params=params, **kwargs) + if stream: + return httpx.stream("GET", url, params=params, **kwargs) + else: + return httpx.get(url, params=params, **kwargs) except: retry -= 1 @@ -112,6 +85,7 @@ class ApiRequest: url: str, params: Union[Dict, List[Tuple], bytes] = None, retry: int = 3, + stream: bool = False, **kwargs: Any, ) -> Union[httpx.Response, None]: rl = self._parse_url(url) @@ -119,7 +93,10 @@ class ApiRequest: async with httpx.AsyncClient() as client: while retry > 0: try: - return await client.get(url, params=params, **kwargs) + if stream: + return await client.stream("GET", url, params=params, **kwargs) + else: + return await client.get(url, params=params, **kwargs) except: retry -= 1 @@ -150,6 +127,7 @@ class ApiRequest: data: Dict = None, json: Dict = None, retry: int = 3, + stream: bool = False, **kwargs: Any ) -> Union[httpx.Response, None]: rl = self._parse_url(url) @@ -157,7 +135,51 @@ class ApiRequest: async with httpx.AsyncClient() as client: while retry > 0: try: - return await client.post(url, data=data, json=json, **kwargs) + if stream: + return await client.stream("POST", url, data=data, json=json, **kwargs) + else: + return await client.post(url, data=data, json=json, **kwargs) + except: + retry -= 1 + + def delete( + self, + url: str, + data: Dict = None, + json: Dict = None, + retry: int = 3, + stream: bool = False, + **kwargs: Any + ) -> Union[httpx.Response, None]: + url = self._parse_url(url) + kwargs.setdefault("timeout", self.timeout) + while retry > 0: + try: + if stream: + return httpx.stream("DELETE", url, data=data, json=json, **kwargs) + else: + return httpx.delete(url, data=data, json=json, **kwargs) + except: + retry -= 1 + + async def adelete( + self, + url: str, + data: Dict = None, + json: Dict = None, + retry: int = 3, + stream: bool = False, + **kwargs: Any + ) -> Union[httpx.Response, None]: + rl = self._parse_url(url) + kwargs.setdefault("timeout", self.timeout) + async with httpx.AsyncClient() as client: + while retry > 0: + try: + if stream: + return await client.stream("DELETE", url, data=data, json=json, **kwargs) + else: + return await client.delete(url, data=data, json=json, **kwargs) except: retry -= 1 @@ -384,7 +406,7 @@ class ApiRequest: response = run_async(delete_kb(knowledge_base_name)) return response.dict() else: - response = self.delete( + response = self.post( "/knowledge_base/delete_knowledge_base", json={"knowledge_base_name": knowledge_base_name}, ) @@ -480,7 +502,7 @@ class ApiRequest: response = run_async(delete_doc(**data)) return response.dict() else: - response = self.delete( + response = self.post( "/knowledge_base/delete_doc", json=data, ) @@ -489,7 +511,7 @@ class ApiRequest: def update_kb_doc( self, knowledge_base_name: str, - doc_name: str, + file_name: str, no_remote_api: bool = None, ): ''' @@ -500,12 +522,12 @@ class ApiRequest: if no_remote_api: from server.knowledge_base.kb_doc_api import update_doc - response = run_async(update_doc(knowledge_base_name, doc_name)) + response = run_async(update_doc(knowledge_base_name, file_name)) return response.dict() else: - response = self.delete( + response = self.post( "/knowledge_base/update_doc", - json={"knowledge_base_name": knowledge_base_name, "doc_name": doc_name}, + json={"knowledge_base_name": knowledge_base_name, "file_name": file_name}, ) return response.json() @@ -532,120 +554,8 @@ class ApiRequest: return self._httpx_stream2generator(response, as_json=True) -def get_kb_details(api: ApiRequest) -> pd.DataFrame: - kbs_in_folder = list_kbs_from_folder() - kbs_in_db = api.list_knowledge_bases() - result = {} - - for kb in kbs_in_folder: - result[kb] = { - "kb_name": kb, - "vs_type": "", - "embed_model": "", - "file_count": 0, - "create_time": None, - "in_folder": True, - "in_db": False, - } - - for kb in kbs_in_db: - kb_detail = get_kb_detail(kb) - if kb_detail: - kb_detail["in_db"] = True - if kb in result: - result[kb].update(kb_detail) - else: - kb_detail["in_folder"] = False - result[kb] = kb_detail - - df = pd.DataFrame(result.values(), columns=[ - "kb_name", - "vs_type", - "embed_model", - "file_count", - "create_time", - "in_folder", - "in_db", - ]) - df.insert(0, "No", range(1, len(df) + 1)) - return df - - -def get_kb_doc_details(api: ApiRequest, kb: str) -> pd.DataFrame: - docs_in_folder = list_docs_from_folder(kb) - docs_in_db = api.list_kb_docs(kb) - result = {} - - for doc in docs_in_folder: - result[doc] = { - "kb_name": kb, - "file_name": doc, - "file_ext": os.path.splitext(doc)[-1], - "file_version": 0, - "document_loader": "", - "text_splitter": "", - "create_time": None, - "in_folder": True, - "in_db": False, - } - - for doc in docs_in_db: - doc_detail = get_file_detail(kb, doc) - if doc_detail: - doc_detail["in_db"] = True - if doc in result: - result[doc].update(doc_detail) - else: - doc_detail["in_folder"] = False - result[doc] = doc_detail - - df = pd.DataFrame(result.values(), columns=[ - "kb_name", - "file_name", - "file_ext", - "file_version", - "document_loader", - "text_splitter", - "create_time", - "in_folder", - "in_db", - ]) - df.insert(0, "No", range(1, len(df) + 1)) - return df - - -def init_vs_database(recreate_vs: bool = False): - ''' - init local vector store info to database - ''' - from server.db.base import Base, engine - from server.db.repository.knowledge_base_repository import add_kb_to_db, kb_exists - from server.db.repository.knowledge_file_repository import add_doc_to_db - from server.knowledge_base.utils import KnowledgeFile - - Base.metadata.create_all(bind=engine) - - if recreate_vs: - api = ApiRequest(no_remote_api=True) - for kb in list_kbs_from_folder(): - for t in api.recreate_vector_store(kb): - print(t) - else: # add vs info to db only - for kb in list_kbs_from_folder(): - if not kb_exists(kb): - add_kb_to_db(kb, "faiss", EMBEDDING_MODEL) - for doc in list_docs_from_folder(kb): - try: - kb_file = KnowledgeFile(doc, kb) - add_doc_to_db(kb_file) - except Exception as e: - print(e) - - if __name__ == "__main__": api = ApiRequest(no_remote_api=True) - # init vector store database - init_vs_database() # print(api.chat_fastchat( # messages=[{"role": "user", "content": "hello"}]