更新API与ApiReuest:
1. 重新整理webui_pages/utils与server/knowledge_base间的工具依赖 2. 将delete_knowledge_base与delete_doc接口从delete改为post.delete不支持body参数 3. 修复update_doc 4. 修复部分bug
This commit is contained in:
parent
a261fda20b
commit
a08fe994c2
|
|
@ -71,7 +71,7 @@ def create_app():
|
|||
summary="创建知识库"
|
||||
)(create_kb)
|
||||
|
||||
app.delete("/knowledge_base/delete_knowledge_base",
|
||||
app.post("/knowledge_base/delete_knowledge_base",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="删除知识库"
|
||||
|
|
@ -89,7 +89,7 @@ def create_app():
|
|||
summary="上传文件到知识库"
|
||||
)(upload_doc)
|
||||
|
||||
app.delete("/knowledge_base/delete_doc",
|
||||
app.post("/knowledge_base/delete_doc",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="删除知识库内指定文件"
|
||||
|
|
@ -106,10 +106,6 @@ def create_app():
|
|||
summary="根据content中文档重建向量库,流式输出处理进度。"
|
||||
)(recreate_vector_store)
|
||||
|
||||
# init local vector store info to database
|
||||
from webui_pages.utils import init_vs_database
|
||||
init_vs_database()
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,9 @@ def add_kb_to_db(session, kb_name, vs_type, embed_model):
|
|||
if not kb:
|
||||
kb = KnowledgeBaseModel(kb_name=kb_name, vs_type=vs_type, embed_model=embed_model)
|
||||
session.add(kb)
|
||||
else: # update kb with new vs_type and embed_model
|
||||
kb.vs_type = vs_type
|
||||
kb.embed_model = embed_model
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,16 +1,25 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.docstore.document import Document
|
||||
from server.db.repository.knowledge_base_repository import (
|
||||
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
|
||||
load_kb_from_db, get_kb_detail,
|
||||
)
|
||||
from server.db.repository.knowledge_file_repository import (
|
||||
add_doc_to_db, delete_file_from_db, doc_exists,
|
||||
list_docs_from_db, get_file_detail
|
||||
)
|
||||
|
||||
from server.db.repository.knowledge_base_repository import add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, load_kb_from_db
|
||||
from server.db.repository.knowledge_file_repository import add_doc_to_db, delete_file_from_db, doc_exists, \
|
||||
list_docs_from_db
|
||||
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K,
|
||||
EMBEDDING_DEVICE, EMBEDDING_MODEL)
|
||||
from server.knowledge_base.utils import (get_kb_path, get_doc_path, load_embeddings, KnowledgeFile)
|
||||
from server.knowledge_base.utils import (
|
||||
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
|
||||
list_kbs_from_folder, list_docs_from_folder,
|
||||
)
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
|
|
@ -77,10 +86,10 @@ class KBService(ABC):
|
|||
"""
|
||||
从知识库删除文件
|
||||
"""
|
||||
if delete_content and os.path.exists(kb_file.filepath):
|
||||
os.remove(kb_file.filepath)
|
||||
self.do_delete_doc(kb_file)
|
||||
status = delete_file_from_db(kb_file)
|
||||
if delete_content and os.path.exists(kb_file.filepath):
|
||||
os.remove(kb_file.filepath)
|
||||
return status
|
||||
|
||||
def update_doc(self, kb_file: KnowledgeFile):
|
||||
|
|
@ -121,10 +130,9 @@ class KBService(ABC):
|
|||
def list_kbs(cls):
|
||||
return list_kbs_from_db()
|
||||
|
||||
@classmethod
|
||||
def exists(cls,
|
||||
knowledge_base_name: str):
|
||||
return kb_exists(knowledge_base_name)
|
||||
def exists(self, kb_name: str = None):
|
||||
kb_name = kb_name or self.kb_name
|
||||
return kb_exists(kb_name)
|
||||
|
||||
@abstractmethod
|
||||
def vs_type(self) -> str:
|
||||
|
|
@ -212,3 +220,86 @@ class KBServiceFactory:
|
|||
def get_default():
|
||||
return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT)
|
||||
|
||||
|
||||
def get_kb_details() -> pd.DataFrame:
|
||||
kbs_in_folder = list_kbs_from_folder()
|
||||
kbs_in_db = KBService.list_kbs()
|
||||
result = {}
|
||||
|
||||
for kb in kbs_in_folder:
|
||||
result[kb] = {
|
||||
"kb_name": kb,
|
||||
"vs_type": "",
|
||||
"embed_model": "",
|
||||
"file_count": 0,
|
||||
"create_time": None,
|
||||
"in_folder": True,
|
||||
"in_db": False,
|
||||
}
|
||||
|
||||
for kb in kbs_in_db:
|
||||
kb_detail = get_kb_detail(kb)
|
||||
if kb_detail:
|
||||
kb_detail["in_db"] = True
|
||||
if kb in result:
|
||||
result[kb].update(kb_detail)
|
||||
else:
|
||||
kb_detail["in_folder"] = False
|
||||
result[kb] = kb_detail
|
||||
|
||||
df = pd.DataFrame(result.values(), columns=[
|
||||
"kb_name",
|
||||
"vs_type",
|
||||
"embed_model",
|
||||
"file_count",
|
||||
"create_time",
|
||||
"in_folder",
|
||||
"in_db",
|
||||
])
|
||||
df.insert(0, "No", range(1, len(df) + 1))
|
||||
return df
|
||||
|
||||
|
||||
def get_kb_doc_details(kb_name: str) -> pd.DataFrame:
|
||||
kb = KBServiceFactory.get_service_by_name(kb_name)
|
||||
docs_in_folder = list_docs_from_folder(kb_name)
|
||||
docs_in_db = kb.list_docs()
|
||||
result = {}
|
||||
|
||||
for doc in docs_in_folder:
|
||||
result[doc] = {
|
||||
"kb_name": kb_name,
|
||||
"file_name": doc,
|
||||
"file_ext": os.path.splitext(doc)[-1],
|
||||
"file_version": 0,
|
||||
"document_loader": "",
|
||||
"text_splitter": "",
|
||||
"create_time": None,
|
||||
"in_folder": True,
|
||||
"in_db": False,
|
||||
}
|
||||
|
||||
for doc in docs_in_db:
|
||||
doc_detail = get_file_detail(kb_name, doc)
|
||||
if doc_detail:
|
||||
doc_detail["in_db"] = True
|
||||
if doc in result:
|
||||
result[doc].update(doc_detail)
|
||||
else:
|
||||
doc_detail["in_folder"] = False
|
||||
result[doc] = doc_detail
|
||||
|
||||
df = pd.DataFrame(result.values(), columns=[
|
||||
"kb_name",
|
||||
"file_name",
|
||||
"file_ext",
|
||||
"file_version",
|
||||
"document_loader",
|
||||
"text_splitter",
|
||||
"create_time",
|
||||
"in_folder",
|
||||
"in_db",
|
||||
])
|
||||
df.insert(0, "No", range(1, len(df) + 1))
|
||||
return df
|
||||
|
||||
|
|
|
|||
|
|
@ -70,7 +70,6 @@ def folder2db(
|
|||
|
||||
|
||||
def recreate_all_vs(
|
||||
mode: Literal["recreate_vs", "fill_info_only", "update_in_db", "increament"],
|
||||
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
|
||||
embed_mode: str = EMBEDDING_MODEL,
|
||||
):
|
||||
|
|
@ -78,7 +77,7 @@ def recreate_all_vs(
|
|||
used to recreate a vector store or change current vector store to another type or embed_model
|
||||
'''
|
||||
for kb_name in list_kbs_from_folder():
|
||||
folder2db(kb_name, mode, vs_type, embed_mode)
|
||||
folder2db(kb_name, "recreate_vs", vs_type, embed_mode)
|
||||
|
||||
|
||||
def prune_db_docs(kb_name: str):
|
||||
|
|
|
|||
|
|
@ -52,7 +52,6 @@ LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg
|
|||
'.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv'
|
||||
"CSVLoader": [".csv"],
|
||||
"PyPDFLoader": [".pdf"],
|
||||
|
||||
}
|
||||
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import torch
|
|||
from fastapi_offline import FastAPIOffline
|
||||
import fastapi_offline
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
|
||||
|
||||
# patch fastapi_offline to use local static assests
|
||||
|
|
@ -82,3 +83,32 @@ def torch_gc():
|
|||
except Exception as e:
|
||||
print(e)
|
||||
print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。")
|
||||
|
||||
|
||||
def run_async(cor):
|
||||
'''
|
||||
在同步环境中运行异步代码.
|
||||
'''
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except:
|
||||
loop = asyncio.new_event_loop()
|
||||
return loop.run_until_complete(cor)
|
||||
|
||||
|
||||
def iter_over_async(ait, loop):
|
||||
'''
|
||||
将异步生成器封装成同步生成器.
|
||||
'''
|
||||
ait = ait.__aiter__()
|
||||
async def get_next():
|
||||
try:
|
||||
obj = await ait.__anext__()
|
||||
return False, obj
|
||||
except StopAsyncIteration:
|
||||
return True, None
|
||||
while True:
|
||||
done, obj = loop.run_until_complete(get_next())
|
||||
if done:
|
||||
break
|
||||
yield obj
|
||||
|
|
|
|||
3
webui.py
3
webui.py
|
|
@ -12,9 +12,6 @@ from webui_pages import *
|
|||
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# init local vector store info to database
|
||||
init_vs_database()
|
||||
|
||||
st.set_page_config("langchain-chatglm WebUI")
|
||||
|
||||
if not chat_box.chat_inited:
|
||||
|
|
|
|||
|
|
@ -90,8 +90,7 @@ def dialogue_page(api: ApiRequest):
|
|||
on_change=on_mode_change,
|
||||
key="dialogue_mode",
|
||||
)
|
||||
history_len = st.slider("历史对话轮数:", 1, 10, 3)
|
||||
|
||||
history_len = st.slider("历史对话轮数:", 0, 10, 3)
|
||||
# todo: support history len
|
||||
|
||||
def on_kb_change():
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
import streamlit as st
|
||||
from webui_pages.utils import *
|
||||
# import streamlit_antd_components as sac
|
||||
from st_aggrid import AgGrid
|
||||
from st_aggrid.grid_options_builder import GridOptionsBuilder
|
||||
import pandas as pd
|
||||
from server.knowledge_base.utils import get_file_path
|
||||
# from streamlit_chatbox import *
|
||||
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_doc_details
|
||||
from typing import Literal, Dict, Tuple
|
||||
|
||||
SENTENCE_SIZE = 100
|
||||
|
|
@ -33,7 +32,7 @@ def config_aggrid(
|
|||
|
||||
def knowledge_base_page(api: ApiRequest):
|
||||
# api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=True)
|
||||
kb_details = get_kb_details(api)
|
||||
kb_details = get_kb_details()
|
||||
kb_list = list(kb_details.kb_name)
|
||||
|
||||
cols = st.columns([3, 1, 1])
|
||||
|
|
@ -127,7 +126,7 @@ def knowledge_base_page(api: ApiRequest):
|
|||
# 知识库详情
|
||||
st.write(f"知识库 `{kb}` 详情:")
|
||||
st.info("请选择文件")
|
||||
doc_details = get_kb_doc_details(api, kb)
|
||||
doc_details = get_kb_doc_details(kb)
|
||||
doc_details.drop(columns=["kb_name"], inplace=True)
|
||||
|
||||
gb = config_aggrid(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# 该文件包含webui通用工具,可以被不同的webui使用
|
||||
from typing import *
|
||||
from pathlib import Path
|
||||
import os
|
||||
from configs.model_config import (
|
||||
EMBEDDING_MODEL,
|
||||
KB_ROOT_PATH,
|
||||
|
|
@ -17,10 +16,9 @@ from fastapi.responses import StreamingResponse
|
|||
import contextlib
|
||||
import json
|
||||
from io import BytesIO
|
||||
import pandas as pd
|
||||
from server.knowledge_base.utils import list_kbs_from_folder, list_docs_from_folder
|
||||
from server.db.repository.knowledge_base_repository import get_kb_detail
|
||||
from server.db.repository.knowledge_file_repository import get_file_detail
|
||||
from server.utils import run_async, iter_over_async
|
||||
|
||||
|
||||
def set_httpx_timeout(timeout=60.0):
|
||||
|
|
@ -37,35 +35,6 @@ KB_ROOT_PATH = Path(KB_ROOT_PATH)
|
|||
set_httpx_timeout()
|
||||
|
||||
|
||||
def run_async(cor):
|
||||
'''
|
||||
在同步环境中运行异步代码.
|
||||
'''
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except:
|
||||
loop = asyncio.new_event_loop()
|
||||
return loop.run_until_complete(cor)
|
||||
|
||||
|
||||
def iter_over_async(ait, loop):
|
||||
'''
|
||||
将异步生成器封装成同步生成器.
|
||||
'''
|
||||
ait = ait.__aiter__()
|
||||
async def get_next():
|
||||
try:
|
||||
obj = await ait.__anext__()
|
||||
return False, obj
|
||||
except StopAsyncIteration:
|
||||
return True, None
|
||||
while True:
|
||||
done, obj = loop.run_until_complete(get_next())
|
||||
if done:
|
||||
break
|
||||
yield obj
|
||||
|
||||
|
||||
class ApiRequest:
|
||||
'''
|
||||
api.py调用的封装,主要实现:
|
||||
|
|
@ -97,12 +66,16 @@ class ApiRequest:
|
|||
url: str,
|
||||
params: Union[Dict, List[Tuple], bytes] = None,
|
||||
retry: int = 3,
|
||||
stream: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Union[httpx.Response, None]:
|
||||
url = self._parse_url(url)
|
||||
kwargs.setdefault("timeout", self.timeout)
|
||||
while retry > 0:
|
||||
try:
|
||||
if stream:
|
||||
return httpx.stream("GET", url, params=params, **kwargs)
|
||||
else:
|
||||
return httpx.get(url, params=params, **kwargs)
|
||||
except:
|
||||
retry -= 1
|
||||
|
|
@ -112,6 +85,7 @@ class ApiRequest:
|
|||
url: str,
|
||||
params: Union[Dict, List[Tuple], bytes] = None,
|
||||
retry: int = 3,
|
||||
stream: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Union[httpx.Response, None]:
|
||||
rl = self._parse_url(url)
|
||||
|
|
@ -119,6 +93,9 @@ class ApiRequest:
|
|||
async with httpx.AsyncClient() as client:
|
||||
while retry > 0:
|
||||
try:
|
||||
if stream:
|
||||
return await client.stream("GET", url, params=params, **kwargs)
|
||||
else:
|
||||
return await client.get(url, params=params, **kwargs)
|
||||
except:
|
||||
retry -= 1
|
||||
|
|
@ -150,6 +127,7 @@ class ApiRequest:
|
|||
data: Dict = None,
|
||||
json: Dict = None,
|
||||
retry: int = 3,
|
||||
stream: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Union[httpx.Response, None]:
|
||||
rl = self._parse_url(url)
|
||||
|
|
@ -157,10 +135,54 @@ class ApiRequest:
|
|||
async with httpx.AsyncClient() as client:
|
||||
while retry > 0:
|
||||
try:
|
||||
if stream:
|
||||
return await client.stream("POST", url, data=data, json=json, **kwargs)
|
||||
else:
|
||||
return await client.post(url, data=data, json=json, **kwargs)
|
||||
except:
|
||||
retry -= 1
|
||||
|
||||
def delete(
|
||||
self,
|
||||
url: str,
|
||||
data: Dict = None,
|
||||
json: Dict = None,
|
||||
retry: int = 3,
|
||||
stream: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Union[httpx.Response, None]:
|
||||
url = self._parse_url(url)
|
||||
kwargs.setdefault("timeout", self.timeout)
|
||||
while retry > 0:
|
||||
try:
|
||||
if stream:
|
||||
return httpx.stream("DELETE", url, data=data, json=json, **kwargs)
|
||||
else:
|
||||
return httpx.delete(url, data=data, json=json, **kwargs)
|
||||
except:
|
||||
retry -= 1
|
||||
|
||||
async def adelete(
|
||||
self,
|
||||
url: str,
|
||||
data: Dict = None,
|
||||
json: Dict = None,
|
||||
retry: int = 3,
|
||||
stream: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Union[httpx.Response, None]:
|
||||
rl = self._parse_url(url)
|
||||
kwargs.setdefault("timeout", self.timeout)
|
||||
async with httpx.AsyncClient() as client:
|
||||
while retry > 0:
|
||||
try:
|
||||
if stream:
|
||||
return await client.stream("DELETE", url, data=data, json=json, **kwargs)
|
||||
else:
|
||||
return await client.delete(url, data=data, json=json, **kwargs)
|
||||
except:
|
||||
retry -= 1
|
||||
|
||||
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
|
||||
'''
|
||||
将api.py中视图函数返回的StreamingResponse转化为同步生成器
|
||||
|
|
@ -384,7 +406,7 @@ class ApiRequest:
|
|||
response = run_async(delete_kb(knowledge_base_name))
|
||||
return response.dict()
|
||||
else:
|
||||
response = self.delete(
|
||||
response = self.post(
|
||||
"/knowledge_base/delete_knowledge_base",
|
||||
json={"knowledge_base_name": knowledge_base_name},
|
||||
)
|
||||
|
|
@ -480,7 +502,7 @@ class ApiRequest:
|
|||
response = run_async(delete_doc(**data))
|
||||
return response.dict()
|
||||
else:
|
||||
response = self.delete(
|
||||
response = self.post(
|
||||
"/knowledge_base/delete_doc",
|
||||
json=data,
|
||||
)
|
||||
|
|
@ -489,7 +511,7 @@ class ApiRequest:
|
|||
def update_kb_doc(
|
||||
self,
|
||||
knowledge_base_name: str,
|
||||
doc_name: str,
|
||||
file_name: str,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
|
|
@ -500,12 +522,12 @@ class ApiRequest:
|
|||
|
||||
if no_remote_api:
|
||||
from server.knowledge_base.kb_doc_api import update_doc
|
||||
response = run_async(update_doc(knowledge_base_name, doc_name))
|
||||
response = run_async(update_doc(knowledge_base_name, file_name))
|
||||
return response.dict()
|
||||
else:
|
||||
response = self.delete(
|
||||
response = self.post(
|
||||
"/knowledge_base/update_doc",
|
||||
json={"knowledge_base_name": knowledge_base_name, "doc_name": doc_name},
|
||||
json={"knowledge_base_name": knowledge_base_name, "file_name": file_name},
|
||||
)
|
||||
return response.json()
|
||||
|
||||
|
|
@ -532,120 +554,8 @@ class ApiRequest:
|
|||
return self._httpx_stream2generator(response, as_json=True)
|
||||
|
||||
|
||||
def get_kb_details(api: ApiRequest) -> pd.DataFrame:
|
||||
kbs_in_folder = list_kbs_from_folder()
|
||||
kbs_in_db = api.list_knowledge_bases()
|
||||
result = {}
|
||||
|
||||
for kb in kbs_in_folder:
|
||||
result[kb] = {
|
||||
"kb_name": kb,
|
||||
"vs_type": "",
|
||||
"embed_model": "",
|
||||
"file_count": 0,
|
||||
"create_time": None,
|
||||
"in_folder": True,
|
||||
"in_db": False,
|
||||
}
|
||||
|
||||
for kb in kbs_in_db:
|
||||
kb_detail = get_kb_detail(kb)
|
||||
if kb_detail:
|
||||
kb_detail["in_db"] = True
|
||||
if kb in result:
|
||||
result[kb].update(kb_detail)
|
||||
else:
|
||||
kb_detail["in_folder"] = False
|
||||
result[kb] = kb_detail
|
||||
|
||||
df = pd.DataFrame(result.values(), columns=[
|
||||
"kb_name",
|
||||
"vs_type",
|
||||
"embed_model",
|
||||
"file_count",
|
||||
"create_time",
|
||||
"in_folder",
|
||||
"in_db",
|
||||
])
|
||||
df.insert(0, "No", range(1, len(df) + 1))
|
||||
return df
|
||||
|
||||
|
||||
def get_kb_doc_details(api: ApiRequest, kb: str) -> pd.DataFrame:
|
||||
docs_in_folder = list_docs_from_folder(kb)
|
||||
docs_in_db = api.list_kb_docs(kb)
|
||||
result = {}
|
||||
|
||||
for doc in docs_in_folder:
|
||||
result[doc] = {
|
||||
"kb_name": kb,
|
||||
"file_name": doc,
|
||||
"file_ext": os.path.splitext(doc)[-1],
|
||||
"file_version": 0,
|
||||
"document_loader": "",
|
||||
"text_splitter": "",
|
||||
"create_time": None,
|
||||
"in_folder": True,
|
||||
"in_db": False,
|
||||
}
|
||||
|
||||
for doc in docs_in_db:
|
||||
doc_detail = get_file_detail(kb, doc)
|
||||
if doc_detail:
|
||||
doc_detail["in_db"] = True
|
||||
if doc in result:
|
||||
result[doc].update(doc_detail)
|
||||
else:
|
||||
doc_detail["in_folder"] = False
|
||||
result[doc] = doc_detail
|
||||
|
||||
df = pd.DataFrame(result.values(), columns=[
|
||||
"kb_name",
|
||||
"file_name",
|
||||
"file_ext",
|
||||
"file_version",
|
||||
"document_loader",
|
||||
"text_splitter",
|
||||
"create_time",
|
||||
"in_folder",
|
||||
"in_db",
|
||||
])
|
||||
df.insert(0, "No", range(1, len(df) + 1))
|
||||
return df
|
||||
|
||||
|
||||
def init_vs_database(recreate_vs: bool = False):
|
||||
'''
|
||||
init local vector store info to database
|
||||
'''
|
||||
from server.db.base import Base, engine
|
||||
from server.db.repository.knowledge_base_repository import add_kb_to_db, kb_exists
|
||||
from server.db.repository.knowledge_file_repository import add_doc_to_db
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
if recreate_vs:
|
||||
api = ApiRequest(no_remote_api=True)
|
||||
for kb in list_kbs_from_folder():
|
||||
for t in api.recreate_vector_store(kb):
|
||||
print(t)
|
||||
else: # add vs info to db only
|
||||
for kb in list_kbs_from_folder():
|
||||
if not kb_exists(kb):
|
||||
add_kb_to_db(kb, "faiss", EMBEDDING_MODEL)
|
||||
for doc in list_docs_from_folder(kb):
|
||||
try:
|
||||
kb_file = KnowledgeFile(doc, kb)
|
||||
add_doc_to_db(kb_file)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
api = ApiRequest(no_remote_api=True)
|
||||
# init vector store database
|
||||
init_vs_database()
|
||||
|
||||
# print(api.chat_fastchat(
|
||||
# messages=[{"role": "user", "content": "hello"}]
|
||||
|
|
|
|||
Loading…
Reference in New Issue