更新API与ApiReuest:

1. 重新整理webui_pages/utils与server/knowledge_base间的工具依赖
2.
将delete_knowledge_base与delete_doc接口从delete改为post.delete不支持body参数
3. 修复update_doc
4. 修复部分bug
This commit is contained in:
liunux4odoo 2023-08-11 08:37:07 +08:00
parent a261fda20b
commit a08fe994c2
10 changed files with 208 additions and 185 deletions

View File

@ -71,7 +71,7 @@ def create_app():
summary="创建知识库" summary="创建知识库"
)(create_kb) )(create_kb)
app.delete("/knowledge_base/delete_knowledge_base", app.post("/knowledge_base/delete_knowledge_base",
tags=["Knowledge Base Management"], tags=["Knowledge Base Management"],
response_model=BaseResponse, response_model=BaseResponse,
summary="删除知识库" summary="删除知识库"
@ -89,7 +89,7 @@ def create_app():
summary="上传文件到知识库" summary="上传文件到知识库"
)(upload_doc) )(upload_doc)
app.delete("/knowledge_base/delete_doc", app.post("/knowledge_base/delete_doc",
tags=["Knowledge Base Management"], tags=["Knowledge Base Management"],
response_model=BaseResponse, response_model=BaseResponse,
summary="删除知识库内指定文件" summary="删除知识库内指定文件"
@ -106,10 +106,6 @@ def create_app():
summary="根据content中文档重建向量库流式输出处理进度。" summary="根据content中文档重建向量库流式输出处理进度。"
)(recreate_vector_store) )(recreate_vector_store)
# init local vector store info to database
from webui_pages.utils import init_vs_database
init_vs_database()
return app return app

View File

@ -9,6 +9,9 @@ def add_kb_to_db(session, kb_name, vs_type, embed_model):
if not kb: if not kb:
kb = KnowledgeBaseModel(kb_name=kb_name, vs_type=vs_type, embed_model=embed_model) kb = KnowledgeBaseModel(kb_name=kb_name, vs_type=vs_type, embed_model=embed_model)
session.add(kb) session.add(kb)
else: # update kb with new vs_type and embed_model
kb.vs_type = vs_type
kb.embed_model = embed_model
return True return True

View File

@ -1,16 +1,25 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import os import os
import pandas as pd
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.docstore.document import Document from langchain.docstore.document import Document
from server.db.repository.knowledge_base_repository import (
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
load_kb_from_db, get_kb_detail,
)
from server.db.repository.knowledge_file_repository import (
add_doc_to_db, delete_file_from_db, doc_exists,
list_docs_from_db, get_file_detail
)
from server.db.repository.knowledge_base_repository import add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, load_kb_from_db
from server.db.repository.knowledge_file_repository import add_doc_to_db, delete_file_from_db, doc_exists, \
list_docs_from_db
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K,
EMBEDDING_DEVICE, EMBEDDING_MODEL) EMBEDDING_DEVICE, EMBEDDING_MODEL)
from server.knowledge_base.utils import (get_kb_path, get_doc_path, load_embeddings, KnowledgeFile) from server.knowledge_base.utils import (
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
list_kbs_from_folder, list_docs_from_folder,
)
from typing import List, Union from typing import List, Union
@ -77,10 +86,10 @@ class KBService(ABC):
""" """
从知识库删除文件 从知识库删除文件
""" """
if delete_content and os.path.exists(kb_file.filepath):
os.remove(kb_file.filepath)
self.do_delete_doc(kb_file) self.do_delete_doc(kb_file)
status = delete_file_from_db(kb_file) status = delete_file_from_db(kb_file)
if delete_content and os.path.exists(kb_file.filepath):
os.remove(kb_file.filepath)
return status return status
def update_doc(self, kb_file: KnowledgeFile): def update_doc(self, kb_file: KnowledgeFile):
@ -121,10 +130,9 @@ class KBService(ABC):
def list_kbs(cls): def list_kbs(cls):
return list_kbs_from_db() return list_kbs_from_db()
@classmethod def exists(self, kb_name: str = None):
def exists(cls, kb_name = kb_name or self.kb_name
knowledge_base_name: str): return kb_exists(kb_name)
return kb_exists(knowledge_base_name)
@abstractmethod @abstractmethod
def vs_type(self) -> str: def vs_type(self) -> str:
@ -212,3 +220,86 @@ class KBServiceFactory:
def get_default(): def get_default():
return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT) return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT)
def get_kb_details() -> pd.DataFrame:
kbs_in_folder = list_kbs_from_folder()
kbs_in_db = KBService.list_kbs()
result = {}
for kb in kbs_in_folder:
result[kb] = {
"kb_name": kb,
"vs_type": "",
"embed_model": "",
"file_count": 0,
"create_time": None,
"in_folder": True,
"in_db": False,
}
for kb in kbs_in_db:
kb_detail = get_kb_detail(kb)
if kb_detail:
kb_detail["in_db"] = True
if kb in result:
result[kb].update(kb_detail)
else:
kb_detail["in_folder"] = False
result[kb] = kb_detail
df = pd.DataFrame(result.values(), columns=[
"kb_name",
"vs_type",
"embed_model",
"file_count",
"create_time",
"in_folder",
"in_db",
])
df.insert(0, "No", range(1, len(df) + 1))
return df
def get_kb_doc_details(kb_name: str) -> pd.DataFrame:
kb = KBServiceFactory.get_service_by_name(kb_name)
docs_in_folder = list_docs_from_folder(kb_name)
docs_in_db = kb.list_docs()
result = {}
for doc in docs_in_folder:
result[doc] = {
"kb_name": kb_name,
"file_name": doc,
"file_ext": os.path.splitext(doc)[-1],
"file_version": 0,
"document_loader": "",
"text_splitter": "",
"create_time": None,
"in_folder": True,
"in_db": False,
}
for doc in docs_in_db:
doc_detail = get_file_detail(kb_name, doc)
if doc_detail:
doc_detail["in_db"] = True
if doc in result:
result[doc].update(doc_detail)
else:
doc_detail["in_folder"] = False
result[doc] = doc_detail
df = pd.DataFrame(result.values(), columns=[
"kb_name",
"file_name",
"file_ext",
"file_version",
"document_loader",
"text_splitter",
"create_time",
"in_folder",
"in_db",
])
df.insert(0, "No", range(1, len(df) + 1))
return df

View File

@ -70,7 +70,6 @@ def folder2db(
def recreate_all_vs( def recreate_all_vs(
mode: Literal["recreate_vs", "fill_info_only", "update_in_db", "increament"],
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE, vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
embed_mode: str = EMBEDDING_MODEL, embed_mode: str = EMBEDDING_MODEL,
): ):
@ -78,7 +77,7 @@ def recreate_all_vs(
used to recreate a vector store or change current vector store to another type or embed_model used to recreate a vector store or change current vector store to another type or embed_model
''' '''
for kb_name in list_kbs_from_folder(): for kb_name in list_kbs_from_folder():
folder2db(kb_name, mode, vs_type, embed_mode) folder2db(kb_name, "recreate_vs", vs_type, embed_mode)
def prune_db_docs(kb_name: str): def prune_db_docs(kb_name: str):

View File

@ -52,7 +52,6 @@ LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg
'.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv' '.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv'
"CSVLoader": [".csv"], "CSVLoader": [".csv"],
"PyPDFLoader": [".pdf"], "PyPDFLoader": [".pdf"],
} }
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist] SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]

View File

@ -5,6 +5,7 @@ import torch
from fastapi_offline import FastAPIOffline from fastapi_offline import FastAPIOffline
import fastapi_offline import fastapi_offline
from pathlib import Path from pathlib import Path
import asyncio
# patch fastapi_offline to use local static assests # patch fastapi_offline to use local static assests
@ -82,3 +83,32 @@ def torch_gc():
except Exception as e: except Exception as e:
print(e) print(e)
print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。") print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。")
def run_async(cor):
'''
在同步环境中运行异步代码.
'''
try:
loop = asyncio.get_event_loop()
except:
loop = asyncio.new_event_loop()
return loop.run_until_complete(cor)
def iter_over_async(ait, loop):
'''
将异步生成器封装成同步生成器.
'''
ait = ait.__aiter__()
async def get_next():
try:
obj = await ait.__anext__()
return False, obj
except StopAsyncIteration:
return True, None
while True:
done, obj = loop.run_until_complete(get_next())
if done:
break
yield obj

View File

@ -12,9 +12,6 @@ from webui_pages import *
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False) api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False)
if __name__ == "__main__": if __name__ == "__main__":
# init local vector store info to database
init_vs_database()
st.set_page_config("langchain-chatglm WebUI") st.set_page_config("langchain-chatglm WebUI")
if not chat_box.chat_inited: if not chat_box.chat_inited:

View File

@ -90,8 +90,7 @@ def dialogue_page(api: ApiRequest):
on_change=on_mode_change, on_change=on_mode_change,
key="dialogue_mode", key="dialogue_mode",
) )
history_len = st.slider("历史对话轮数:", 1, 10, 3) history_len = st.slider("历史对话轮数:", 0, 10, 3)
# todo: support history len # todo: support history len
def on_kb_change(): def on_kb_change():

View File

@ -1,11 +1,10 @@
import streamlit as st import streamlit as st
from webui_pages.utils import * from webui_pages.utils import *
# import streamlit_antd_components as sac
from st_aggrid import AgGrid from st_aggrid import AgGrid
from st_aggrid.grid_options_builder import GridOptionsBuilder from st_aggrid.grid_options_builder import GridOptionsBuilder
import pandas as pd import pandas as pd
from server.knowledge_base.utils import get_file_path from server.knowledge_base.utils import get_file_path
# from streamlit_chatbox import * from server.knowledge_base.kb_service.base import get_kb_details, get_kb_doc_details
from typing import Literal, Dict, Tuple from typing import Literal, Dict, Tuple
SENTENCE_SIZE = 100 SENTENCE_SIZE = 100
@ -33,7 +32,7 @@ def config_aggrid(
def knowledge_base_page(api: ApiRequest): def knowledge_base_page(api: ApiRequest):
# api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=True) # api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=True)
kb_details = get_kb_details(api) kb_details = get_kb_details()
kb_list = list(kb_details.kb_name) kb_list = list(kb_details.kb_name)
cols = st.columns([3, 1, 1]) cols = st.columns([3, 1, 1])
@ -127,7 +126,7 @@ def knowledge_base_page(api: ApiRequest):
# 知识库详情 # 知识库详情
st.write(f"知识库 `{kb}` 详情:") st.write(f"知识库 `{kb}` 详情:")
st.info("请选择文件") st.info("请选择文件")
doc_details = get_kb_doc_details(api, kb) doc_details = get_kb_doc_details(kb)
doc_details.drop(columns=["kb_name"], inplace=True) doc_details.drop(columns=["kb_name"], inplace=True)
gb = config_aggrid( gb = config_aggrid(

View File

@ -1,7 +1,6 @@
# 该文件包含webui通用工具可以被不同的webui使用 # 该文件包含webui通用工具可以被不同的webui使用
from typing import * from typing import *
from pathlib import Path from pathlib import Path
import os
from configs.model_config import ( from configs.model_config import (
EMBEDDING_MODEL, EMBEDDING_MODEL,
KB_ROOT_PATH, KB_ROOT_PATH,
@ -17,10 +16,9 @@ from fastapi.responses import StreamingResponse
import contextlib import contextlib
import json import json
from io import BytesIO from io import BytesIO
import pandas as pd
from server.knowledge_base.utils import list_kbs_from_folder, list_docs_from_folder
from server.db.repository.knowledge_base_repository import get_kb_detail from server.db.repository.knowledge_base_repository import get_kb_detail
from server.db.repository.knowledge_file_repository import get_file_detail from server.db.repository.knowledge_file_repository import get_file_detail
from server.utils import run_async, iter_over_async
def set_httpx_timeout(timeout=60.0): def set_httpx_timeout(timeout=60.0):
@ -37,35 +35,6 @@ KB_ROOT_PATH = Path(KB_ROOT_PATH)
set_httpx_timeout() set_httpx_timeout()
def run_async(cor):
'''
在同步环境中运行异步代码.
'''
try:
loop = asyncio.get_event_loop()
except:
loop = asyncio.new_event_loop()
return loop.run_until_complete(cor)
def iter_over_async(ait, loop):
'''
将异步生成器封装成同步生成器.
'''
ait = ait.__aiter__()
async def get_next():
try:
obj = await ait.__anext__()
return False, obj
except StopAsyncIteration:
return True, None
while True:
done, obj = loop.run_until_complete(get_next())
if done:
break
yield obj
class ApiRequest: class ApiRequest:
''' '''
api.py调用的封装,主要实现: api.py调用的封装,主要实现:
@ -97,12 +66,16 @@ class ApiRequest:
url: str, url: str,
params: Union[Dict, List[Tuple], bytes] = None, params: Union[Dict, List[Tuple], bytes] = None,
retry: int = 3, retry: int = 3,
stream: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Union[httpx.Response, None]: ) -> Union[httpx.Response, None]:
url = self._parse_url(url) url = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout) kwargs.setdefault("timeout", self.timeout)
while retry > 0: while retry > 0:
try: try:
if stream:
return httpx.stream("GET", url, params=params, **kwargs)
else:
return httpx.get(url, params=params, **kwargs) return httpx.get(url, params=params, **kwargs)
except: except:
retry -= 1 retry -= 1
@ -112,6 +85,7 @@ class ApiRequest:
url: str, url: str,
params: Union[Dict, List[Tuple], bytes] = None, params: Union[Dict, List[Tuple], bytes] = None,
retry: int = 3, retry: int = 3,
stream: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Union[httpx.Response, None]: ) -> Union[httpx.Response, None]:
rl = self._parse_url(url) rl = self._parse_url(url)
@ -119,6 +93,9 @@ class ApiRequest:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
while retry > 0: while retry > 0:
try: try:
if stream:
return await client.stream("GET", url, params=params, **kwargs)
else:
return await client.get(url, params=params, **kwargs) return await client.get(url, params=params, **kwargs)
except: except:
retry -= 1 retry -= 1
@ -150,6 +127,7 @@ class ApiRequest:
data: Dict = None, data: Dict = None,
json: Dict = None, json: Dict = None,
retry: int = 3, retry: int = 3,
stream: bool = False,
**kwargs: Any **kwargs: Any
) -> Union[httpx.Response, None]: ) -> Union[httpx.Response, None]:
rl = self._parse_url(url) rl = self._parse_url(url)
@ -157,10 +135,54 @@ class ApiRequest:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
while retry > 0: while retry > 0:
try: try:
if stream:
return await client.stream("POST", url, data=data, json=json, **kwargs)
else:
return await client.post(url, data=data, json=json, **kwargs) return await client.post(url, data=data, json=json, **kwargs)
except: except:
retry -= 1 retry -= 1
def delete(
self,
url: str,
data: Dict = None,
json: Dict = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any
) -> Union[httpx.Response, None]:
url = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout)
while retry > 0:
try:
if stream:
return httpx.stream("DELETE", url, data=data, json=json, **kwargs)
else:
return httpx.delete(url, data=data, json=json, **kwargs)
except:
retry -= 1
async def adelete(
self,
url: str,
data: Dict = None,
json: Dict = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any
) -> Union[httpx.Response, None]:
rl = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout)
async with httpx.AsyncClient() as client:
while retry > 0:
try:
if stream:
return await client.stream("DELETE", url, data=data, json=json, **kwargs)
else:
return await client.delete(url, data=data, json=json, **kwargs)
except:
retry -= 1
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False): def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
''' '''
将api.py中视图函数返回的StreamingResponse转化为同步生成器 将api.py中视图函数返回的StreamingResponse转化为同步生成器
@ -384,7 +406,7 @@ class ApiRequest:
response = run_async(delete_kb(knowledge_base_name)) response = run_async(delete_kb(knowledge_base_name))
return response.dict() return response.dict()
else: else:
response = self.delete( response = self.post(
"/knowledge_base/delete_knowledge_base", "/knowledge_base/delete_knowledge_base",
json={"knowledge_base_name": knowledge_base_name}, json={"knowledge_base_name": knowledge_base_name},
) )
@ -480,7 +502,7 @@ class ApiRequest:
response = run_async(delete_doc(**data)) response = run_async(delete_doc(**data))
return response.dict() return response.dict()
else: else:
response = self.delete( response = self.post(
"/knowledge_base/delete_doc", "/knowledge_base/delete_doc",
json=data, json=data,
) )
@ -489,7 +511,7 @@ class ApiRequest:
def update_kb_doc( def update_kb_doc(
self, self,
knowledge_base_name: str, knowledge_base_name: str,
doc_name: str, file_name: str,
no_remote_api: bool = None, no_remote_api: bool = None,
): ):
''' '''
@ -500,12 +522,12 @@ class ApiRequest:
if no_remote_api: if no_remote_api:
from server.knowledge_base.kb_doc_api import update_doc from server.knowledge_base.kb_doc_api import update_doc
response = run_async(update_doc(knowledge_base_name, doc_name)) response = run_async(update_doc(knowledge_base_name, file_name))
return response.dict() return response.dict()
else: else:
response = self.delete( response = self.post(
"/knowledge_base/update_doc", "/knowledge_base/update_doc",
json={"knowledge_base_name": knowledge_base_name, "doc_name": doc_name}, json={"knowledge_base_name": knowledge_base_name, "file_name": file_name},
) )
return response.json() return response.json()
@ -532,120 +554,8 @@ class ApiRequest:
return self._httpx_stream2generator(response, as_json=True) return self._httpx_stream2generator(response, as_json=True)
def get_kb_details(api: ApiRequest) -> pd.DataFrame:
kbs_in_folder = list_kbs_from_folder()
kbs_in_db = api.list_knowledge_bases()
result = {}
for kb in kbs_in_folder:
result[kb] = {
"kb_name": kb,
"vs_type": "",
"embed_model": "",
"file_count": 0,
"create_time": None,
"in_folder": True,
"in_db": False,
}
for kb in kbs_in_db:
kb_detail = get_kb_detail(kb)
if kb_detail:
kb_detail["in_db"] = True
if kb in result:
result[kb].update(kb_detail)
else:
kb_detail["in_folder"] = False
result[kb] = kb_detail
df = pd.DataFrame(result.values(), columns=[
"kb_name",
"vs_type",
"embed_model",
"file_count",
"create_time",
"in_folder",
"in_db",
])
df.insert(0, "No", range(1, len(df) + 1))
return df
def get_kb_doc_details(api: ApiRequest, kb: str) -> pd.DataFrame:
docs_in_folder = list_docs_from_folder(kb)
docs_in_db = api.list_kb_docs(kb)
result = {}
for doc in docs_in_folder:
result[doc] = {
"kb_name": kb,
"file_name": doc,
"file_ext": os.path.splitext(doc)[-1],
"file_version": 0,
"document_loader": "",
"text_splitter": "",
"create_time": None,
"in_folder": True,
"in_db": False,
}
for doc in docs_in_db:
doc_detail = get_file_detail(kb, doc)
if doc_detail:
doc_detail["in_db"] = True
if doc in result:
result[doc].update(doc_detail)
else:
doc_detail["in_folder"] = False
result[doc] = doc_detail
df = pd.DataFrame(result.values(), columns=[
"kb_name",
"file_name",
"file_ext",
"file_version",
"document_loader",
"text_splitter",
"create_time",
"in_folder",
"in_db",
])
df.insert(0, "No", range(1, len(df) + 1))
return df
def init_vs_database(recreate_vs: bool = False):
'''
init local vector store info to database
'''
from server.db.base import Base, engine
from server.db.repository.knowledge_base_repository import add_kb_to_db, kb_exists
from server.db.repository.knowledge_file_repository import add_doc_to_db
from server.knowledge_base.utils import KnowledgeFile
Base.metadata.create_all(bind=engine)
if recreate_vs:
api = ApiRequest(no_remote_api=True)
for kb in list_kbs_from_folder():
for t in api.recreate_vector_store(kb):
print(t)
else: # add vs info to db only
for kb in list_kbs_from_folder():
if not kb_exists(kb):
add_kb_to_db(kb, "faiss", EMBEDDING_MODEL)
for doc in list_docs_from_folder(kb):
try:
kb_file = KnowledgeFile(doc, kb)
add_doc_to_db(kb_file)
except Exception as e:
print(e)
if __name__ == "__main__": if __name__ == "__main__":
api = ApiRequest(no_remote_api=True) api = ApiRequest(no_remote_api=True)
# init vector store database
init_vs_database()
# print(api.chat_fastchat( # print(api.chat_fastchat(
# messages=[{"role": "user", "content": "hello"}] # messages=[{"role": "user", "content": "hello"}]