更新API与ApiReuest:

1. 重新整理webui_pages/utils与server/knowledge_base间的工具依赖
2.
将delete_knowledge_base与delete_doc接口从delete改为post.delete不支持body参数
3. 修复update_doc
4. 修复部分bug
This commit is contained in:
liunux4odoo 2023-08-11 08:37:07 +08:00
parent a261fda20b
commit a08fe994c2
10 changed files with 208 additions and 185 deletions

View File

@ -71,7 +71,7 @@ def create_app():
summary="创建知识库"
)(create_kb)
app.delete("/knowledge_base/delete_knowledge_base",
app.post("/knowledge_base/delete_knowledge_base",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库"
@ -89,7 +89,7 @@ def create_app():
summary="上传文件到知识库"
)(upload_doc)
app.delete("/knowledge_base/delete_doc",
app.post("/knowledge_base/delete_doc",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库内指定文件"
@ -106,10 +106,6 @@ def create_app():
summary="根据content中文档重建向量库流式输出处理进度。"
)(recreate_vector_store)
# init local vector store info to database
from webui_pages.utils import init_vs_database
init_vs_database()
return app

View File

@ -9,6 +9,9 @@ def add_kb_to_db(session, kb_name, vs_type, embed_model):
if not kb:
kb = KnowledgeBaseModel(kb_name=kb_name, vs_type=vs_type, embed_model=embed_model)
session.add(kb)
else: # update kb with new vs_type and embed_model
kb.vs_type = vs_type
kb.embed_model = embed_model
return True

View File

@ -1,16 +1,25 @@
from abc import ABC, abstractmethod
import os
import pandas as pd
from langchain.embeddings.base import Embeddings
from langchain.docstore.document import Document
from server.db.repository.knowledge_base_repository import (
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
load_kb_from_db, get_kb_detail,
)
from server.db.repository.knowledge_file_repository import (
add_doc_to_db, delete_file_from_db, doc_exists,
list_docs_from_db, get_file_detail
)
from server.db.repository.knowledge_base_repository import add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, load_kb_from_db
from server.db.repository.knowledge_file_repository import add_doc_to_db, delete_file_from_db, doc_exists, \
list_docs_from_db
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K,
EMBEDDING_DEVICE, EMBEDDING_MODEL)
from server.knowledge_base.utils import (get_kb_path, get_doc_path, load_embeddings, KnowledgeFile)
from server.knowledge_base.utils import (
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
list_kbs_from_folder, list_docs_from_folder,
)
from typing import List, Union
@ -77,10 +86,10 @@ class KBService(ABC):
"""
从知识库删除文件
"""
if delete_content and os.path.exists(kb_file.filepath):
os.remove(kb_file.filepath)
self.do_delete_doc(kb_file)
status = delete_file_from_db(kb_file)
if delete_content and os.path.exists(kb_file.filepath):
os.remove(kb_file.filepath)
return status
def update_doc(self, kb_file: KnowledgeFile):
@ -121,10 +130,9 @@ class KBService(ABC):
def list_kbs(cls):
return list_kbs_from_db()
@classmethod
def exists(cls,
knowledge_base_name: str):
return kb_exists(knowledge_base_name)
def exists(self, kb_name: str = None):
kb_name = kb_name or self.kb_name
return kb_exists(kb_name)
@abstractmethod
def vs_type(self) -> str:
@ -212,3 +220,86 @@ class KBServiceFactory:
def get_default():
return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT)
def get_kb_details() -> pd.DataFrame:
kbs_in_folder = list_kbs_from_folder()
kbs_in_db = KBService.list_kbs()
result = {}
for kb in kbs_in_folder:
result[kb] = {
"kb_name": kb,
"vs_type": "",
"embed_model": "",
"file_count": 0,
"create_time": None,
"in_folder": True,
"in_db": False,
}
for kb in kbs_in_db:
kb_detail = get_kb_detail(kb)
if kb_detail:
kb_detail["in_db"] = True
if kb in result:
result[kb].update(kb_detail)
else:
kb_detail["in_folder"] = False
result[kb] = kb_detail
df = pd.DataFrame(result.values(), columns=[
"kb_name",
"vs_type",
"embed_model",
"file_count",
"create_time",
"in_folder",
"in_db",
])
df.insert(0, "No", range(1, len(df) + 1))
return df
def get_kb_doc_details(kb_name: str) -> pd.DataFrame:
kb = KBServiceFactory.get_service_by_name(kb_name)
docs_in_folder = list_docs_from_folder(kb_name)
docs_in_db = kb.list_docs()
result = {}
for doc in docs_in_folder:
result[doc] = {
"kb_name": kb_name,
"file_name": doc,
"file_ext": os.path.splitext(doc)[-1],
"file_version": 0,
"document_loader": "",
"text_splitter": "",
"create_time": None,
"in_folder": True,
"in_db": False,
}
for doc in docs_in_db:
doc_detail = get_file_detail(kb_name, doc)
if doc_detail:
doc_detail["in_db"] = True
if doc in result:
result[doc].update(doc_detail)
else:
doc_detail["in_folder"] = False
result[doc] = doc_detail
df = pd.DataFrame(result.values(), columns=[
"kb_name",
"file_name",
"file_ext",
"file_version",
"document_loader",
"text_splitter",
"create_time",
"in_folder",
"in_db",
])
df.insert(0, "No", range(1, len(df) + 1))
return df

View File

@ -70,7 +70,6 @@ def folder2db(
def recreate_all_vs(
mode: Literal["recreate_vs", "fill_info_only", "update_in_db", "increament"],
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
embed_mode: str = EMBEDDING_MODEL,
):
@ -78,7 +77,7 @@ def recreate_all_vs(
used to recreate a vector store or change current vector store to another type or embed_model
'''
for kb_name in list_kbs_from_folder():
folder2db(kb_name, mode, vs_type, embed_mode)
folder2db(kb_name, "recreate_vs", vs_type, embed_mode)
def prune_db_docs(kb_name: str):

View File

@ -52,7 +52,6 @@ LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg
'.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv'
"CSVLoader": [".csv"],
"PyPDFLoader": [".pdf"],
}
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]

View File

@ -5,6 +5,7 @@ import torch
from fastapi_offline import FastAPIOffline
import fastapi_offline
from pathlib import Path
import asyncio
# patch fastapi_offline to use local static assests
@ -82,3 +83,32 @@ def torch_gc():
except Exception as e:
print(e)
print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。")
def run_async(cor):
'''
在同步环境中运行异步代码.
'''
try:
loop = asyncio.get_event_loop()
except:
loop = asyncio.new_event_loop()
return loop.run_until_complete(cor)
def iter_over_async(ait, loop):
'''
将异步生成器封装成同步生成器.
'''
ait = ait.__aiter__()
async def get_next():
try:
obj = await ait.__anext__()
return False, obj
except StopAsyncIteration:
return True, None
while True:
done, obj = loop.run_until_complete(get_next())
if done:
break
yield obj

View File

@ -12,9 +12,6 @@ from webui_pages import *
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False)
if __name__ == "__main__":
# init local vector store info to database
init_vs_database()
st.set_page_config("langchain-chatglm WebUI")
if not chat_box.chat_inited:

View File

@ -90,8 +90,7 @@ def dialogue_page(api: ApiRequest):
on_change=on_mode_change,
key="dialogue_mode",
)
history_len = st.slider("历史对话轮数:", 1, 10, 3)
history_len = st.slider("历史对话轮数:", 0, 10, 3)
# todo: support history len
def on_kb_change():

View File

@ -1,11 +1,10 @@
import streamlit as st
from webui_pages.utils import *
# import streamlit_antd_components as sac
from st_aggrid import AgGrid
from st_aggrid.grid_options_builder import GridOptionsBuilder
import pandas as pd
from server.knowledge_base.utils import get_file_path
# from streamlit_chatbox import *
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_doc_details
from typing import Literal, Dict, Tuple
SENTENCE_SIZE = 100
@ -33,7 +32,7 @@ def config_aggrid(
def knowledge_base_page(api: ApiRequest):
# api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=True)
kb_details = get_kb_details(api)
kb_details = get_kb_details()
kb_list = list(kb_details.kb_name)
cols = st.columns([3, 1, 1])
@ -127,7 +126,7 @@ def knowledge_base_page(api: ApiRequest):
# 知识库详情
st.write(f"知识库 `{kb}` 详情:")
st.info("请选择文件")
doc_details = get_kb_doc_details(api, kb)
doc_details = get_kb_doc_details(kb)
doc_details.drop(columns=["kb_name"], inplace=True)
gb = config_aggrid(

View File

@ -1,7 +1,6 @@
# 该文件包含webui通用工具可以被不同的webui使用
from typing import *
from pathlib import Path
import os
from configs.model_config import (
EMBEDDING_MODEL,
KB_ROOT_PATH,
@ -17,10 +16,9 @@ from fastapi.responses import StreamingResponse
import contextlib
import json
from io import BytesIO
import pandas as pd
from server.knowledge_base.utils import list_kbs_from_folder, list_docs_from_folder
from server.db.repository.knowledge_base_repository import get_kb_detail
from server.db.repository.knowledge_file_repository import get_file_detail
from server.utils import run_async, iter_over_async
def set_httpx_timeout(timeout=60.0):
@ -37,35 +35,6 @@ KB_ROOT_PATH = Path(KB_ROOT_PATH)
set_httpx_timeout()
def run_async(cor):
'''
在同步环境中运行异步代码.
'''
try:
loop = asyncio.get_event_loop()
except:
loop = asyncio.new_event_loop()
return loop.run_until_complete(cor)
def iter_over_async(ait, loop):
'''
将异步生成器封装成同步生成器.
'''
ait = ait.__aiter__()
async def get_next():
try:
obj = await ait.__anext__()
return False, obj
except StopAsyncIteration:
return True, None
while True:
done, obj = loop.run_until_complete(get_next())
if done:
break
yield obj
class ApiRequest:
'''
api.py调用的封装,主要实现:
@ -97,12 +66,16 @@ class ApiRequest:
url: str,
params: Union[Dict, List[Tuple], bytes] = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any,
) -> Union[httpx.Response, None]:
url = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout)
while retry > 0:
try:
if stream:
return httpx.stream("GET", url, params=params, **kwargs)
else:
return httpx.get(url, params=params, **kwargs)
except:
retry -= 1
@ -112,6 +85,7 @@ class ApiRequest:
url: str,
params: Union[Dict, List[Tuple], bytes] = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any,
) -> Union[httpx.Response, None]:
rl = self._parse_url(url)
@ -119,6 +93,9 @@ class ApiRequest:
async with httpx.AsyncClient() as client:
while retry > 0:
try:
if stream:
return await client.stream("GET", url, params=params, **kwargs)
else:
return await client.get(url, params=params, **kwargs)
except:
retry -= 1
@ -150,6 +127,7 @@ class ApiRequest:
data: Dict = None,
json: Dict = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any
) -> Union[httpx.Response, None]:
rl = self._parse_url(url)
@ -157,10 +135,54 @@ class ApiRequest:
async with httpx.AsyncClient() as client:
while retry > 0:
try:
if stream:
return await client.stream("POST", url, data=data, json=json, **kwargs)
else:
return await client.post(url, data=data, json=json, **kwargs)
except:
retry -= 1
def delete(
self,
url: str,
data: Dict = None,
json: Dict = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any
) -> Union[httpx.Response, None]:
url = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout)
while retry > 0:
try:
if stream:
return httpx.stream("DELETE", url, data=data, json=json, **kwargs)
else:
return httpx.delete(url, data=data, json=json, **kwargs)
except:
retry -= 1
async def adelete(
self,
url: str,
data: Dict = None,
json: Dict = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any
) -> Union[httpx.Response, None]:
rl = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout)
async with httpx.AsyncClient() as client:
while retry > 0:
try:
if stream:
return await client.stream("DELETE", url, data=data, json=json, **kwargs)
else:
return await client.delete(url, data=data, json=json, **kwargs)
except:
retry -= 1
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
'''
将api.py中视图函数返回的StreamingResponse转化为同步生成器
@ -384,7 +406,7 @@ class ApiRequest:
response = run_async(delete_kb(knowledge_base_name))
return response.dict()
else:
response = self.delete(
response = self.post(
"/knowledge_base/delete_knowledge_base",
json={"knowledge_base_name": knowledge_base_name},
)
@ -480,7 +502,7 @@ class ApiRequest:
response = run_async(delete_doc(**data))
return response.dict()
else:
response = self.delete(
response = self.post(
"/knowledge_base/delete_doc",
json=data,
)
@ -489,7 +511,7 @@ class ApiRequest:
def update_kb_doc(
self,
knowledge_base_name: str,
doc_name: str,
file_name: str,
no_remote_api: bool = None,
):
'''
@ -500,12 +522,12 @@ class ApiRequest:
if no_remote_api:
from server.knowledge_base.kb_doc_api import update_doc
response = run_async(update_doc(knowledge_base_name, doc_name))
response = run_async(update_doc(knowledge_base_name, file_name))
return response.dict()
else:
response = self.delete(
response = self.post(
"/knowledge_base/update_doc",
json={"knowledge_base_name": knowledge_base_name, "doc_name": doc_name},
json={"knowledge_base_name": knowledge_base_name, "file_name": file_name},
)
return response.json()
@ -532,120 +554,8 @@ class ApiRequest:
return self._httpx_stream2generator(response, as_json=True)
def get_kb_details(api: ApiRequest) -> pd.DataFrame:
kbs_in_folder = list_kbs_from_folder()
kbs_in_db = api.list_knowledge_bases()
result = {}
for kb in kbs_in_folder:
result[kb] = {
"kb_name": kb,
"vs_type": "",
"embed_model": "",
"file_count": 0,
"create_time": None,
"in_folder": True,
"in_db": False,
}
for kb in kbs_in_db:
kb_detail = get_kb_detail(kb)
if kb_detail:
kb_detail["in_db"] = True
if kb in result:
result[kb].update(kb_detail)
else:
kb_detail["in_folder"] = False
result[kb] = kb_detail
df = pd.DataFrame(result.values(), columns=[
"kb_name",
"vs_type",
"embed_model",
"file_count",
"create_time",
"in_folder",
"in_db",
])
df.insert(0, "No", range(1, len(df) + 1))
return df
def get_kb_doc_details(api: ApiRequest, kb: str) -> pd.DataFrame:
docs_in_folder = list_docs_from_folder(kb)
docs_in_db = api.list_kb_docs(kb)
result = {}
for doc in docs_in_folder:
result[doc] = {
"kb_name": kb,
"file_name": doc,
"file_ext": os.path.splitext(doc)[-1],
"file_version": 0,
"document_loader": "",
"text_splitter": "",
"create_time": None,
"in_folder": True,
"in_db": False,
}
for doc in docs_in_db:
doc_detail = get_file_detail(kb, doc)
if doc_detail:
doc_detail["in_db"] = True
if doc in result:
result[doc].update(doc_detail)
else:
doc_detail["in_folder"] = False
result[doc] = doc_detail
df = pd.DataFrame(result.values(), columns=[
"kb_name",
"file_name",
"file_ext",
"file_version",
"document_loader",
"text_splitter",
"create_time",
"in_folder",
"in_db",
])
df.insert(0, "No", range(1, len(df) + 1))
return df
def init_vs_database(recreate_vs: bool = False):
'''
init local vector store info to database
'''
from server.db.base import Base, engine
from server.db.repository.knowledge_base_repository import add_kb_to_db, kb_exists
from server.db.repository.knowledge_file_repository import add_doc_to_db
from server.knowledge_base.utils import KnowledgeFile
Base.metadata.create_all(bind=engine)
if recreate_vs:
api = ApiRequest(no_remote_api=True)
for kb in list_kbs_from_folder():
for t in api.recreate_vector_store(kb):
print(t)
else: # add vs info to db only
for kb in list_kbs_from_folder():
if not kb_exists(kb):
add_kb_to_db(kb, "faiss", EMBEDDING_MODEL)
for doc in list_docs_from_folder(kb):
try:
kb_file = KnowledgeFile(doc, kb)
add_doc_to_db(kb_file)
except Exception as e:
print(e)
if __name__ == "__main__":
api = ApiRequest(no_remote_api=True)
# init vector store database
init_vs_database()
# print(api.chat_fastchat(
# messages=[{"role": "user", "content": "hello"}]