更新API与ApiReuest:
1. 重新整理webui_pages/utils与server/knowledge_base间的工具依赖 2. 将delete_knowledge_base与delete_doc接口从delete改为post.delete不支持body参数 3. 修复update_doc 4. 修复部分bug
This commit is contained in:
parent
a261fda20b
commit
a08fe994c2
|
|
@ -71,7 +71,7 @@ def create_app():
|
||||||
summary="创建知识库"
|
summary="创建知识库"
|
||||||
)(create_kb)
|
)(create_kb)
|
||||||
|
|
||||||
app.delete("/knowledge_base/delete_knowledge_base",
|
app.post("/knowledge_base/delete_knowledge_base",
|
||||||
tags=["Knowledge Base Management"],
|
tags=["Knowledge Base Management"],
|
||||||
response_model=BaseResponse,
|
response_model=BaseResponse,
|
||||||
summary="删除知识库"
|
summary="删除知识库"
|
||||||
|
|
@ -89,7 +89,7 @@ def create_app():
|
||||||
summary="上传文件到知识库"
|
summary="上传文件到知识库"
|
||||||
)(upload_doc)
|
)(upload_doc)
|
||||||
|
|
||||||
app.delete("/knowledge_base/delete_doc",
|
app.post("/knowledge_base/delete_doc",
|
||||||
tags=["Knowledge Base Management"],
|
tags=["Knowledge Base Management"],
|
||||||
response_model=BaseResponse,
|
response_model=BaseResponse,
|
||||||
summary="删除知识库内指定文件"
|
summary="删除知识库内指定文件"
|
||||||
|
|
@ -106,10 +106,6 @@ def create_app():
|
||||||
summary="根据content中文档重建向量库,流式输出处理进度。"
|
summary="根据content中文档重建向量库,流式输出处理进度。"
|
||||||
)(recreate_vector_store)
|
)(recreate_vector_store)
|
||||||
|
|
||||||
# init local vector store info to database
|
|
||||||
from webui_pages.utils import init_vs_database
|
|
||||||
init_vs_database()
|
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,9 @@ def add_kb_to_db(session, kb_name, vs_type, embed_model):
|
||||||
if not kb:
|
if not kb:
|
||||||
kb = KnowledgeBaseModel(kb_name=kb_name, vs_type=vs_type, embed_model=embed_model)
|
kb = KnowledgeBaseModel(kb_name=kb_name, vs_type=vs_type, embed_model=embed_model)
|
||||||
session.add(kb)
|
session.add(kb)
|
||||||
|
else: # update kb with new vs_type and embed_model
|
||||||
|
kb.vs_type = vs_type
|
||||||
|
kb.embed_model = embed_model
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,25 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
from server.db.repository.knowledge_base_repository import (
|
||||||
|
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
|
||||||
|
load_kb_from_db, get_kb_detail,
|
||||||
|
)
|
||||||
|
from server.db.repository.knowledge_file_repository import (
|
||||||
|
add_doc_to_db, delete_file_from_db, doc_exists,
|
||||||
|
list_docs_from_db, get_file_detail
|
||||||
|
)
|
||||||
|
|
||||||
from server.db.repository.knowledge_base_repository import add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, load_kb_from_db
|
|
||||||
from server.db.repository.knowledge_file_repository import add_doc_to_db, delete_file_from_db, doc_exists, \
|
|
||||||
list_docs_from_db
|
|
||||||
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K,
|
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K,
|
||||||
EMBEDDING_DEVICE, EMBEDDING_MODEL)
|
EMBEDDING_DEVICE, EMBEDDING_MODEL)
|
||||||
from server.knowledge_base.utils import (get_kb_path, get_doc_path, load_embeddings, KnowledgeFile)
|
from server.knowledge_base.utils import (
|
||||||
|
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
|
||||||
|
list_kbs_from_folder, list_docs_from_folder,
|
||||||
|
)
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -77,10 +86,10 @@ class KBService(ABC):
|
||||||
"""
|
"""
|
||||||
从知识库删除文件
|
从知识库删除文件
|
||||||
"""
|
"""
|
||||||
if delete_content and os.path.exists(kb_file.filepath):
|
|
||||||
os.remove(kb_file.filepath)
|
|
||||||
self.do_delete_doc(kb_file)
|
self.do_delete_doc(kb_file)
|
||||||
status = delete_file_from_db(kb_file)
|
status = delete_file_from_db(kb_file)
|
||||||
|
if delete_content and os.path.exists(kb_file.filepath):
|
||||||
|
os.remove(kb_file.filepath)
|
||||||
return status
|
return status
|
||||||
|
|
||||||
def update_doc(self, kb_file: KnowledgeFile):
|
def update_doc(self, kb_file: KnowledgeFile):
|
||||||
|
|
@ -121,10 +130,9 @@ class KBService(ABC):
|
||||||
def list_kbs(cls):
|
def list_kbs(cls):
|
||||||
return list_kbs_from_db()
|
return list_kbs_from_db()
|
||||||
|
|
||||||
@classmethod
|
def exists(self, kb_name: str = None):
|
||||||
def exists(cls,
|
kb_name = kb_name or self.kb_name
|
||||||
knowledge_base_name: str):
|
return kb_exists(kb_name)
|
||||||
return kb_exists(knowledge_base_name)
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def vs_type(self) -> str:
|
def vs_type(self) -> str:
|
||||||
|
|
@ -212,3 +220,86 @@ class KBServiceFactory:
|
||||||
def get_default():
|
def get_default():
|
||||||
return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT)
|
return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT)
|
||||||
|
|
||||||
|
|
||||||
|
def get_kb_details() -> pd.DataFrame:
|
||||||
|
kbs_in_folder = list_kbs_from_folder()
|
||||||
|
kbs_in_db = KBService.list_kbs()
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
for kb in kbs_in_folder:
|
||||||
|
result[kb] = {
|
||||||
|
"kb_name": kb,
|
||||||
|
"vs_type": "",
|
||||||
|
"embed_model": "",
|
||||||
|
"file_count": 0,
|
||||||
|
"create_time": None,
|
||||||
|
"in_folder": True,
|
||||||
|
"in_db": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
for kb in kbs_in_db:
|
||||||
|
kb_detail = get_kb_detail(kb)
|
||||||
|
if kb_detail:
|
||||||
|
kb_detail["in_db"] = True
|
||||||
|
if kb in result:
|
||||||
|
result[kb].update(kb_detail)
|
||||||
|
else:
|
||||||
|
kb_detail["in_folder"] = False
|
||||||
|
result[kb] = kb_detail
|
||||||
|
|
||||||
|
df = pd.DataFrame(result.values(), columns=[
|
||||||
|
"kb_name",
|
||||||
|
"vs_type",
|
||||||
|
"embed_model",
|
||||||
|
"file_count",
|
||||||
|
"create_time",
|
||||||
|
"in_folder",
|
||||||
|
"in_db",
|
||||||
|
])
|
||||||
|
df.insert(0, "No", range(1, len(df) + 1))
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def get_kb_doc_details(kb_name: str) -> pd.DataFrame:
|
||||||
|
kb = KBServiceFactory.get_service_by_name(kb_name)
|
||||||
|
docs_in_folder = list_docs_from_folder(kb_name)
|
||||||
|
docs_in_db = kb.list_docs()
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
for doc in docs_in_folder:
|
||||||
|
result[doc] = {
|
||||||
|
"kb_name": kb_name,
|
||||||
|
"file_name": doc,
|
||||||
|
"file_ext": os.path.splitext(doc)[-1],
|
||||||
|
"file_version": 0,
|
||||||
|
"document_loader": "",
|
||||||
|
"text_splitter": "",
|
||||||
|
"create_time": None,
|
||||||
|
"in_folder": True,
|
||||||
|
"in_db": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
for doc in docs_in_db:
|
||||||
|
doc_detail = get_file_detail(kb_name, doc)
|
||||||
|
if doc_detail:
|
||||||
|
doc_detail["in_db"] = True
|
||||||
|
if doc in result:
|
||||||
|
result[doc].update(doc_detail)
|
||||||
|
else:
|
||||||
|
doc_detail["in_folder"] = False
|
||||||
|
result[doc] = doc_detail
|
||||||
|
|
||||||
|
df = pd.DataFrame(result.values(), columns=[
|
||||||
|
"kb_name",
|
||||||
|
"file_name",
|
||||||
|
"file_ext",
|
||||||
|
"file_version",
|
||||||
|
"document_loader",
|
||||||
|
"text_splitter",
|
||||||
|
"create_time",
|
||||||
|
"in_folder",
|
||||||
|
"in_db",
|
||||||
|
])
|
||||||
|
df.insert(0, "No", range(1, len(df) + 1))
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,6 @@ def folder2db(
|
||||||
|
|
||||||
|
|
||||||
def recreate_all_vs(
|
def recreate_all_vs(
|
||||||
mode: Literal["recreate_vs", "fill_info_only", "update_in_db", "increament"],
|
|
||||||
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
|
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
|
||||||
embed_mode: str = EMBEDDING_MODEL,
|
embed_mode: str = EMBEDDING_MODEL,
|
||||||
):
|
):
|
||||||
|
|
@ -78,7 +77,7 @@ def recreate_all_vs(
|
||||||
used to recreate a vector store or change current vector store to another type or embed_model
|
used to recreate a vector store or change current vector store to another type or embed_model
|
||||||
'''
|
'''
|
||||||
for kb_name in list_kbs_from_folder():
|
for kb_name in list_kbs_from_folder():
|
||||||
folder2db(kb_name, mode, vs_type, embed_mode)
|
folder2db(kb_name, "recreate_vs", vs_type, embed_mode)
|
||||||
|
|
||||||
|
|
||||||
def prune_db_docs(kb_name: str):
|
def prune_db_docs(kb_name: str):
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,6 @@ LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg
|
||||||
'.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv'
|
'.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv'
|
||||||
"CSVLoader": [".csv"],
|
"CSVLoader": [".csv"],
|
||||||
"PyPDFLoader": [".pdf"],
|
"PyPDFLoader": [".pdf"],
|
||||||
|
|
||||||
}
|
}
|
||||||
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import torch
|
||||||
from fastapi_offline import FastAPIOffline
|
from fastapi_offline import FastAPIOffline
|
||||||
import fastapi_offline
|
import fastapi_offline
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
# patch fastapi_offline to use local static assests
|
# patch fastapi_offline to use local static assests
|
||||||
|
|
@ -82,3 +83,32 @@ def torch_gc():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。")
|
print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。")
|
||||||
|
|
||||||
|
|
||||||
|
def run_async(cor):
|
||||||
|
'''
|
||||||
|
在同步环境中运行异步代码.
|
||||||
|
'''
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
except:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
return loop.run_until_complete(cor)
|
||||||
|
|
||||||
|
|
||||||
|
def iter_over_async(ait, loop):
|
||||||
|
'''
|
||||||
|
将异步生成器封装成同步生成器.
|
||||||
|
'''
|
||||||
|
ait = ait.__aiter__()
|
||||||
|
async def get_next():
|
||||||
|
try:
|
||||||
|
obj = await ait.__anext__()
|
||||||
|
return False, obj
|
||||||
|
except StopAsyncIteration:
|
||||||
|
return True, None
|
||||||
|
while True:
|
||||||
|
done, obj = loop.run_until_complete(get_next())
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
yield obj
|
||||||
|
|
|
||||||
3
webui.py
3
webui.py
|
|
@ -12,9 +12,6 @@ from webui_pages import *
|
||||||
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False)
|
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# init local vector store info to database
|
|
||||||
init_vs_database()
|
|
||||||
|
|
||||||
st.set_page_config("langchain-chatglm WebUI")
|
st.set_page_config("langchain-chatglm WebUI")
|
||||||
|
|
||||||
if not chat_box.chat_inited:
|
if not chat_box.chat_inited:
|
||||||
|
|
|
||||||
|
|
@ -90,8 +90,7 @@ def dialogue_page(api: ApiRequest):
|
||||||
on_change=on_mode_change,
|
on_change=on_mode_change,
|
||||||
key="dialogue_mode",
|
key="dialogue_mode",
|
||||||
)
|
)
|
||||||
history_len = st.slider("历史对话轮数:", 1, 10, 3)
|
history_len = st.slider("历史对话轮数:", 0, 10, 3)
|
||||||
|
|
||||||
# todo: support history len
|
# todo: support history len
|
||||||
|
|
||||||
def on_kb_change():
|
def on_kb_change():
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,10 @@
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from webui_pages.utils import *
|
from webui_pages.utils import *
|
||||||
# import streamlit_antd_components as sac
|
|
||||||
from st_aggrid import AgGrid
|
from st_aggrid import AgGrid
|
||||||
from st_aggrid.grid_options_builder import GridOptionsBuilder
|
from st_aggrid.grid_options_builder import GridOptionsBuilder
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from server.knowledge_base.utils import get_file_path
|
from server.knowledge_base.utils import get_file_path
|
||||||
# from streamlit_chatbox import *
|
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_doc_details
|
||||||
from typing import Literal, Dict, Tuple
|
from typing import Literal, Dict, Tuple
|
||||||
|
|
||||||
SENTENCE_SIZE = 100
|
SENTENCE_SIZE = 100
|
||||||
|
|
@ -33,7 +32,7 @@ def config_aggrid(
|
||||||
|
|
||||||
def knowledge_base_page(api: ApiRequest):
|
def knowledge_base_page(api: ApiRequest):
|
||||||
# api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=True)
|
# api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=True)
|
||||||
kb_details = get_kb_details(api)
|
kb_details = get_kb_details()
|
||||||
kb_list = list(kb_details.kb_name)
|
kb_list = list(kb_details.kb_name)
|
||||||
|
|
||||||
cols = st.columns([3, 1, 1])
|
cols = st.columns([3, 1, 1])
|
||||||
|
|
@ -127,7 +126,7 @@ def knowledge_base_page(api: ApiRequest):
|
||||||
# 知识库详情
|
# 知识库详情
|
||||||
st.write(f"知识库 `{kb}` 详情:")
|
st.write(f"知识库 `{kb}` 详情:")
|
||||||
st.info("请选择文件")
|
st.info("请选择文件")
|
||||||
doc_details = get_kb_doc_details(api, kb)
|
doc_details = get_kb_doc_details(kb)
|
||||||
doc_details.drop(columns=["kb_name"], inplace=True)
|
doc_details.drop(columns=["kb_name"], inplace=True)
|
||||||
|
|
||||||
gb = config_aggrid(
|
gb = config_aggrid(
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
# 该文件包含webui通用工具,可以被不同的webui使用
|
# 该文件包含webui通用工具,可以被不同的webui使用
|
||||||
from typing import *
|
from typing import *
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import os
|
|
||||||
from configs.model_config import (
|
from configs.model_config import (
|
||||||
EMBEDDING_MODEL,
|
EMBEDDING_MODEL,
|
||||||
KB_ROOT_PATH,
|
KB_ROOT_PATH,
|
||||||
|
|
@ -17,10 +16,9 @@ from fastapi.responses import StreamingResponse
|
||||||
import contextlib
|
import contextlib
|
||||||
import json
|
import json
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import pandas as pd
|
|
||||||
from server.knowledge_base.utils import list_kbs_from_folder, list_docs_from_folder
|
|
||||||
from server.db.repository.knowledge_base_repository import get_kb_detail
|
from server.db.repository.knowledge_base_repository import get_kb_detail
|
||||||
from server.db.repository.knowledge_file_repository import get_file_detail
|
from server.db.repository.knowledge_file_repository import get_file_detail
|
||||||
|
from server.utils import run_async, iter_over_async
|
||||||
|
|
||||||
|
|
||||||
def set_httpx_timeout(timeout=60.0):
|
def set_httpx_timeout(timeout=60.0):
|
||||||
|
|
@ -37,35 +35,6 @@ KB_ROOT_PATH = Path(KB_ROOT_PATH)
|
||||||
set_httpx_timeout()
|
set_httpx_timeout()
|
||||||
|
|
||||||
|
|
||||||
def run_async(cor):
|
|
||||||
'''
|
|
||||||
在同步环境中运行异步代码.
|
|
||||||
'''
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
except:
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
return loop.run_until_complete(cor)
|
|
||||||
|
|
||||||
|
|
||||||
def iter_over_async(ait, loop):
|
|
||||||
'''
|
|
||||||
将异步生成器封装成同步生成器.
|
|
||||||
'''
|
|
||||||
ait = ait.__aiter__()
|
|
||||||
async def get_next():
|
|
||||||
try:
|
|
||||||
obj = await ait.__anext__()
|
|
||||||
return False, obj
|
|
||||||
except StopAsyncIteration:
|
|
||||||
return True, None
|
|
||||||
while True:
|
|
||||||
done, obj = loop.run_until_complete(get_next())
|
|
||||||
if done:
|
|
||||||
break
|
|
||||||
yield obj
|
|
||||||
|
|
||||||
|
|
||||||
class ApiRequest:
|
class ApiRequest:
|
||||||
'''
|
'''
|
||||||
api.py调用的封装,主要实现:
|
api.py调用的封装,主要实现:
|
||||||
|
|
@ -97,12 +66,16 @@ class ApiRequest:
|
||||||
url: str,
|
url: str,
|
||||||
params: Union[Dict, List[Tuple], bytes] = None,
|
params: Union[Dict, List[Tuple], bytes] = None,
|
||||||
retry: int = 3,
|
retry: int = 3,
|
||||||
|
stream: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[httpx.Response, None]:
|
) -> Union[httpx.Response, None]:
|
||||||
url = self._parse_url(url)
|
url = self._parse_url(url)
|
||||||
kwargs.setdefault("timeout", self.timeout)
|
kwargs.setdefault("timeout", self.timeout)
|
||||||
while retry > 0:
|
while retry > 0:
|
||||||
try:
|
try:
|
||||||
|
if stream:
|
||||||
|
return httpx.stream("GET", url, params=params, **kwargs)
|
||||||
|
else:
|
||||||
return httpx.get(url, params=params, **kwargs)
|
return httpx.get(url, params=params, **kwargs)
|
||||||
except:
|
except:
|
||||||
retry -= 1
|
retry -= 1
|
||||||
|
|
@ -112,6 +85,7 @@ class ApiRequest:
|
||||||
url: str,
|
url: str,
|
||||||
params: Union[Dict, List[Tuple], bytes] = None,
|
params: Union[Dict, List[Tuple], bytes] = None,
|
||||||
retry: int = 3,
|
retry: int = 3,
|
||||||
|
stream: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[httpx.Response, None]:
|
) -> Union[httpx.Response, None]:
|
||||||
rl = self._parse_url(url)
|
rl = self._parse_url(url)
|
||||||
|
|
@ -119,6 +93,9 @@ class ApiRequest:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
while retry > 0:
|
while retry > 0:
|
||||||
try:
|
try:
|
||||||
|
if stream:
|
||||||
|
return await client.stream("GET", url, params=params, **kwargs)
|
||||||
|
else:
|
||||||
return await client.get(url, params=params, **kwargs)
|
return await client.get(url, params=params, **kwargs)
|
||||||
except:
|
except:
|
||||||
retry -= 1
|
retry -= 1
|
||||||
|
|
@ -150,6 +127,7 @@ class ApiRequest:
|
||||||
data: Dict = None,
|
data: Dict = None,
|
||||||
json: Dict = None,
|
json: Dict = None,
|
||||||
retry: int = 3,
|
retry: int = 3,
|
||||||
|
stream: bool = False,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> Union[httpx.Response, None]:
|
) -> Union[httpx.Response, None]:
|
||||||
rl = self._parse_url(url)
|
rl = self._parse_url(url)
|
||||||
|
|
@ -157,10 +135,54 @@ class ApiRequest:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
while retry > 0:
|
while retry > 0:
|
||||||
try:
|
try:
|
||||||
|
if stream:
|
||||||
|
return await client.stream("POST", url, data=data, json=json, **kwargs)
|
||||||
|
else:
|
||||||
return await client.post(url, data=data, json=json, **kwargs)
|
return await client.post(url, data=data, json=json, **kwargs)
|
||||||
except:
|
except:
|
||||||
retry -= 1
|
retry -= 1
|
||||||
|
|
||||||
|
def delete(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
data: Dict = None,
|
||||||
|
json: Dict = None,
|
||||||
|
retry: int = 3,
|
||||||
|
stream: bool = False,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Union[httpx.Response, None]:
|
||||||
|
url = self._parse_url(url)
|
||||||
|
kwargs.setdefault("timeout", self.timeout)
|
||||||
|
while retry > 0:
|
||||||
|
try:
|
||||||
|
if stream:
|
||||||
|
return httpx.stream("DELETE", url, data=data, json=json, **kwargs)
|
||||||
|
else:
|
||||||
|
return httpx.delete(url, data=data, json=json, **kwargs)
|
||||||
|
except:
|
||||||
|
retry -= 1
|
||||||
|
|
||||||
|
async def adelete(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
data: Dict = None,
|
||||||
|
json: Dict = None,
|
||||||
|
retry: int = 3,
|
||||||
|
stream: bool = False,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Union[httpx.Response, None]:
|
||||||
|
rl = self._parse_url(url)
|
||||||
|
kwargs.setdefault("timeout", self.timeout)
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
while retry > 0:
|
||||||
|
try:
|
||||||
|
if stream:
|
||||||
|
return await client.stream("DELETE", url, data=data, json=json, **kwargs)
|
||||||
|
else:
|
||||||
|
return await client.delete(url, data=data, json=json, **kwargs)
|
||||||
|
except:
|
||||||
|
retry -= 1
|
||||||
|
|
||||||
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
|
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
|
||||||
'''
|
'''
|
||||||
将api.py中视图函数返回的StreamingResponse转化为同步生成器
|
将api.py中视图函数返回的StreamingResponse转化为同步生成器
|
||||||
|
|
@ -384,7 +406,7 @@ class ApiRequest:
|
||||||
response = run_async(delete_kb(knowledge_base_name))
|
response = run_async(delete_kb(knowledge_base_name))
|
||||||
return response.dict()
|
return response.dict()
|
||||||
else:
|
else:
|
||||||
response = self.delete(
|
response = self.post(
|
||||||
"/knowledge_base/delete_knowledge_base",
|
"/knowledge_base/delete_knowledge_base",
|
||||||
json={"knowledge_base_name": knowledge_base_name},
|
json={"knowledge_base_name": knowledge_base_name},
|
||||||
)
|
)
|
||||||
|
|
@ -480,7 +502,7 @@ class ApiRequest:
|
||||||
response = run_async(delete_doc(**data))
|
response = run_async(delete_doc(**data))
|
||||||
return response.dict()
|
return response.dict()
|
||||||
else:
|
else:
|
||||||
response = self.delete(
|
response = self.post(
|
||||||
"/knowledge_base/delete_doc",
|
"/knowledge_base/delete_doc",
|
||||||
json=data,
|
json=data,
|
||||||
)
|
)
|
||||||
|
|
@ -489,7 +511,7 @@ class ApiRequest:
|
||||||
def update_kb_doc(
|
def update_kb_doc(
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
doc_name: str,
|
file_name: str,
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
|
|
@ -500,12 +522,12 @@ class ApiRequest:
|
||||||
|
|
||||||
if no_remote_api:
|
if no_remote_api:
|
||||||
from server.knowledge_base.kb_doc_api import update_doc
|
from server.knowledge_base.kb_doc_api import update_doc
|
||||||
response = run_async(update_doc(knowledge_base_name, doc_name))
|
response = run_async(update_doc(knowledge_base_name, file_name))
|
||||||
return response.dict()
|
return response.dict()
|
||||||
else:
|
else:
|
||||||
response = self.delete(
|
response = self.post(
|
||||||
"/knowledge_base/update_doc",
|
"/knowledge_base/update_doc",
|
||||||
json={"knowledge_base_name": knowledge_base_name, "doc_name": doc_name},
|
json={"knowledge_base_name": knowledge_base_name, "file_name": file_name},
|
||||||
)
|
)
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
|
@ -532,120 +554,8 @@ class ApiRequest:
|
||||||
return self._httpx_stream2generator(response, as_json=True)
|
return self._httpx_stream2generator(response, as_json=True)
|
||||||
|
|
||||||
|
|
||||||
def get_kb_details(api: ApiRequest) -> pd.DataFrame:
|
|
||||||
kbs_in_folder = list_kbs_from_folder()
|
|
||||||
kbs_in_db = api.list_knowledge_bases()
|
|
||||||
result = {}
|
|
||||||
|
|
||||||
for kb in kbs_in_folder:
|
|
||||||
result[kb] = {
|
|
||||||
"kb_name": kb,
|
|
||||||
"vs_type": "",
|
|
||||||
"embed_model": "",
|
|
||||||
"file_count": 0,
|
|
||||||
"create_time": None,
|
|
||||||
"in_folder": True,
|
|
||||||
"in_db": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
for kb in kbs_in_db:
|
|
||||||
kb_detail = get_kb_detail(kb)
|
|
||||||
if kb_detail:
|
|
||||||
kb_detail["in_db"] = True
|
|
||||||
if kb in result:
|
|
||||||
result[kb].update(kb_detail)
|
|
||||||
else:
|
|
||||||
kb_detail["in_folder"] = False
|
|
||||||
result[kb] = kb_detail
|
|
||||||
|
|
||||||
df = pd.DataFrame(result.values(), columns=[
|
|
||||||
"kb_name",
|
|
||||||
"vs_type",
|
|
||||||
"embed_model",
|
|
||||||
"file_count",
|
|
||||||
"create_time",
|
|
||||||
"in_folder",
|
|
||||||
"in_db",
|
|
||||||
])
|
|
||||||
df.insert(0, "No", range(1, len(df) + 1))
|
|
||||||
return df
|
|
||||||
|
|
||||||
|
|
||||||
def get_kb_doc_details(api: ApiRequest, kb: str) -> pd.DataFrame:
|
|
||||||
docs_in_folder = list_docs_from_folder(kb)
|
|
||||||
docs_in_db = api.list_kb_docs(kb)
|
|
||||||
result = {}
|
|
||||||
|
|
||||||
for doc in docs_in_folder:
|
|
||||||
result[doc] = {
|
|
||||||
"kb_name": kb,
|
|
||||||
"file_name": doc,
|
|
||||||
"file_ext": os.path.splitext(doc)[-1],
|
|
||||||
"file_version": 0,
|
|
||||||
"document_loader": "",
|
|
||||||
"text_splitter": "",
|
|
||||||
"create_time": None,
|
|
||||||
"in_folder": True,
|
|
||||||
"in_db": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
for doc in docs_in_db:
|
|
||||||
doc_detail = get_file_detail(kb, doc)
|
|
||||||
if doc_detail:
|
|
||||||
doc_detail["in_db"] = True
|
|
||||||
if doc in result:
|
|
||||||
result[doc].update(doc_detail)
|
|
||||||
else:
|
|
||||||
doc_detail["in_folder"] = False
|
|
||||||
result[doc] = doc_detail
|
|
||||||
|
|
||||||
df = pd.DataFrame(result.values(), columns=[
|
|
||||||
"kb_name",
|
|
||||||
"file_name",
|
|
||||||
"file_ext",
|
|
||||||
"file_version",
|
|
||||||
"document_loader",
|
|
||||||
"text_splitter",
|
|
||||||
"create_time",
|
|
||||||
"in_folder",
|
|
||||||
"in_db",
|
|
||||||
])
|
|
||||||
df.insert(0, "No", range(1, len(df) + 1))
|
|
||||||
return df
|
|
||||||
|
|
||||||
|
|
||||||
def init_vs_database(recreate_vs: bool = False):
|
|
||||||
'''
|
|
||||||
init local vector store info to database
|
|
||||||
'''
|
|
||||||
from server.db.base import Base, engine
|
|
||||||
from server.db.repository.knowledge_base_repository import add_kb_to_db, kb_exists
|
|
||||||
from server.db.repository.knowledge_file_repository import add_doc_to_db
|
|
||||||
from server.knowledge_base.utils import KnowledgeFile
|
|
||||||
|
|
||||||
Base.metadata.create_all(bind=engine)
|
|
||||||
|
|
||||||
if recreate_vs:
|
|
||||||
api = ApiRequest(no_remote_api=True)
|
|
||||||
for kb in list_kbs_from_folder():
|
|
||||||
for t in api.recreate_vector_store(kb):
|
|
||||||
print(t)
|
|
||||||
else: # add vs info to db only
|
|
||||||
for kb in list_kbs_from_folder():
|
|
||||||
if not kb_exists(kb):
|
|
||||||
add_kb_to_db(kb, "faiss", EMBEDDING_MODEL)
|
|
||||||
for doc in list_docs_from_folder(kb):
|
|
||||||
try:
|
|
||||||
kb_file = KnowledgeFile(doc, kb)
|
|
||||||
add_doc_to_db(kb_file)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
api = ApiRequest(no_remote_api=True)
|
api = ApiRequest(no_remote_api=True)
|
||||||
# init vector store database
|
|
||||||
init_vs_database()
|
|
||||||
|
|
||||||
# print(api.chat_fastchat(
|
# print(api.chat_fastchat(
|
||||||
# messages=[{"role": "user", "content": "hello"}]
|
# messages=[{"role": "user", "content": "hello"}]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue