Merge remote-tracking branch 'origin/dev_fastchat' into dev_fastchat
# Conflicts: # server/knowledge_base/kb_service/default_kb_service.py # server/knowledge_base/kb_service/milvus_kb_service.py # server/knowledge_base/knowledge_base_factory.py
This commit is contained in:
commit
41fd1acc9c
|
|
@ -15,9 +15,13 @@ starlette~=0.27.0
|
||||||
numpy~=1.24.4
|
numpy~=1.24.4
|
||||||
pydantic~=1.10.11
|
pydantic~=1.10.11
|
||||||
unstructured[all-docs]
|
unstructured[all-docs]
|
||||||
|
python-magic-bin; sys_platform == 'win32'
|
||||||
|
|
||||||
streamlit>=1.25.0
|
streamlit>=1.25.0
|
||||||
streamlit-option-menu
|
streamlit-option-menu
|
||||||
streamlit-antd-components
|
streamlit-antd-components
|
||||||
streamlit-chatbox>=1.1.6
|
streamlit-chatbox>=1.1.6
|
||||||
httpx
|
httpx
|
||||||
|
|
||||||
|
faiss-cpu
|
||||||
|
pymilvus==2.1.3 # requires milvus==2.1.3
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,11 @@ from typing import Any, Dict, List, Optional
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
messages_from_dict,
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
ChatMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
LLMResult
|
LLMResult
|
||||||
)
|
)
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
|
@ -16,6 +20,18 @@ from langchain.callbacks.manager import (
|
||||||
from server.model.chat_openai_chain import OpenAiChatMsgDto, OpenAiMessageDto, BaseMessageDto
|
from server.model.chat_openai_chain import OpenAiChatMsgDto, OpenAiMessageDto, BaseMessageDto
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_dict_to_message(_dict: dict) -> BaseMessage:
|
||||||
|
role = _dict["role"]
|
||||||
|
if role == "user":
|
||||||
|
return HumanMessage(content=_dict["content"])
|
||||||
|
elif role == "assistant":
|
||||||
|
return AIMessage(content=_dict["content"])
|
||||||
|
elif role == "system":
|
||||||
|
return SystemMessage(content=_dict["content"])
|
||||||
|
else:
|
||||||
|
return ChatMessage(content=_dict["content"], role=role)
|
||||||
|
|
||||||
|
|
||||||
def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[BaseMessage]:
|
def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[BaseMessage]:
|
||||||
"""
|
"""
|
||||||
前端消息传输对象DTO转换为chat消息传输对象DTO
|
前端消息传输对象DTO转换为chat消息传输对象DTO
|
||||||
|
|
@ -25,7 +41,7 @@ def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[Bas
|
||||||
messages = []
|
messages = []
|
||||||
for message_datum in message_data:
|
for message_datum in message_data:
|
||||||
messages.append(message_datum.dict())
|
messages.append(message_datum.dict())
|
||||||
return messages_from_dict(messages)
|
return _convert_dict_to_message(messages)
|
||||||
|
|
||||||
|
|
||||||
class BaseChatOpenAIChain(Chain, ABC):
|
class BaseChatOpenAIChain(Chain, ABC):
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,8 @@ from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
from typing import AsyncIterable
|
from typing import AsyncIterable
|
||||||
import asyncio
|
import asyncio
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from server.knowledge_base.knowledge_base import KnowledgeBase
|
from server.knowledge_base.knowledge_base_factory import KBServiceFactory
|
||||||
|
from server.knowledge_base.kb_service.base import KBService
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -18,12 +19,12 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
||||||
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
|
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
|
||||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||||
):
|
):
|
||||||
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
|
if kb is None:
|
||||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||||
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
|
|
||||||
|
|
||||||
async def knowledge_base_chat_iterator(query: str,
|
async def knowledge_base_chat_iterator(query: str,
|
||||||
kb: KnowledgeBase,
|
kb: KBService,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
) -> AsyncIterable[str]:
|
) -> AsyncIterable[str]:
|
||||||
callback = AsyncIteratorCallbackHandler()
|
callback = AsyncIteratorCallbackHandler()
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from .kb_api import list_kbs, create_kb, delete_kb
|
from .kb_api import list_kbs, create_kb, delete_kb
|
||||||
from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store
|
from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store
|
||||||
from .knowledge_base import KnowledgeBase
|
|
||||||
from .knowledge_file import KnowledgeFile
|
from .knowledge_file import KnowledgeFile
|
||||||
|
from .knowledge_base_factory import KBServiceFactory
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,14 @@
|
||||||
import urllib
|
import urllib
|
||||||
from server.utils import BaseResponse, ListResponse
|
from server.utils import BaseResponse, ListResponse
|
||||||
from server.knowledge_base.utils import validate_kb_name
|
from server.knowledge_base.utils import validate_kb_name
|
||||||
from server.knowledge_base.knowledge_base import KnowledgeBase
|
from server.knowledge_base.knowledge_base_factory import KBServiceFactory
|
||||||
|
from server.knowledge_base.kb_service.base import list_kbs_from_db
|
||||||
from configs.model_config import EMBEDDING_MODEL
|
from configs.model_config import EMBEDDING_MODEL
|
||||||
|
|
||||||
|
|
||||||
async def list_kbs():
|
async def list_kbs():
|
||||||
# Get List of Knowledge Base
|
# Get List of Knowledge Base
|
||||||
return ListResponse(data=KnowledgeBase.list_kbs())
|
return ListResponse(data=list_kbs_from_db())
|
||||||
|
|
||||||
|
|
||||||
async def create_kb(knowledge_base_name: str,
|
async def create_kb(knowledge_base_name: str,
|
||||||
|
|
@ -19,11 +20,10 @@ async def create_kb(knowledge_base_name: str,
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
if knowledge_base_name is None or knowledge_base_name.strip() == "":
|
if knowledge_base_name is None or knowledge_base_name.strip() == "":
|
||||||
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
|
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
|
||||||
if KnowledgeBase.exists(knowledge_base_name):
|
|
||||||
|
kb = KBServiceFactory.get_service(knowledge_base_name, "faiss")
|
||||||
|
if kb is not None:
|
||||||
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
|
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
|
||||||
kb = KnowledgeBase(knowledge_base_name=knowledge_base_name,
|
|
||||||
vector_store_type=vector_store_type,
|
|
||||||
embed_model=embed_model)
|
|
||||||
kb.create()
|
kb.create()
|
||||||
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
|
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
|
|
@ -34,10 +34,12 @@ async def delete_kb(knowledge_base_name: str):
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
||||||
|
|
||||||
if not KnowledgeBase.exists(knowledge_base_name):
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
|
|
||||||
|
if kb is None:
|
||||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
status = KnowledgeBase.delete(knowledge_base_name)
|
status = kb.drop_kb()
|
||||||
if status:
|
if status:
|
||||||
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
|
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,9 @@ from server.knowledge_base.utils import (validate_kb_name)
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import json
|
import json
|
||||||
from server.knowledge_base.knowledge_file import KnowledgeFile
|
from server.knowledge_base.knowledge_file import KnowledgeFile
|
||||||
from server.knowledge_base.knowledge_base import KnowledgeBase
|
from server.knowledge_base.knowledge_base_factory import KBServiceFactory
|
||||||
|
from server.knowledge_base.kb_service.base import SupportedVSType, list_docs_from_folder
|
||||||
|
from server.knowledge_base.kb_service.faiss_kb_service import refresh_vs_cache
|
||||||
|
|
||||||
|
|
||||||
async def list_docs(knowledge_base_name: str):
|
async def list_docs(knowledge_base_name: str):
|
||||||
|
|
@ -14,10 +16,11 @@ async def list_docs(knowledge_base_name: str):
|
||||||
return ListResponse(code=403, msg="Don't attack me", data=[])
|
return ListResponse(code=403, msg="Don't attack me", data=[])
|
||||||
|
|
||||||
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
||||||
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
|
if kb is None:
|
||||||
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
|
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
|
||||||
else:
|
else:
|
||||||
all_doc_names = KnowledgeBase.load(knowledge_base_name=knowledge_base_name).list_docs()
|
all_doc_names = kb.list_docs()
|
||||||
return ListResponse(data=all_doc_names)
|
return ListResponse(data=all_doc_names)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -28,11 +31,10 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
|
|
||||||
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
|
if kb is None:
|
||||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
|
|
||||||
|
|
||||||
file_content = await file.read() # 读取上传文件的内容
|
file_content = await file.read() # 读取上传文件的内容
|
||||||
|
|
||||||
kb_file = KnowledgeFile(filename=file.filename,
|
kb_file = KnowledgeFile(filename=file.filename,
|
||||||
|
|
@ -63,10 +65,10 @@ async def delete_doc(knowledge_base_name: str,
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
|
|
||||||
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
||||||
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
|
if kb is None:
|
||||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
|
|
||||||
if not kb.exist_doc(doc_name):
|
if not kb.exist_doc(doc_name):
|
||||||
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
|
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
|
||||||
kb_file = KnowledgeFile(filename=doc_name,
|
kb_file = KnowledgeFile(filename=doc_name,
|
||||||
|
|
@ -92,21 +94,26 @@ async def recreate_vector_store(knowledge_base_name: str):
|
||||||
recreate vector store from the content.
|
recreate vector store from the content.
|
||||||
this is usefull when user can copy files to content folder directly instead of upload through network.
|
this is usefull when user can copy files to content folder directly instead of upload through network.
|
||||||
'''
|
'''
|
||||||
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
|
if kb is None:
|
||||||
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
async def output(kb: KnowledgeBase):
|
async def output(kb):
|
||||||
kb.recreate_vs()
|
kb.clear_vs()
|
||||||
print(f"start to recreate vector store of {kb.kb_name}")
|
print(f"start to recreate vector store of {kb.kb_name}")
|
||||||
docs = kb.list_docs()
|
docs = list_docs_from_folder(knowledge_base_name)
|
||||||
|
print(docs)
|
||||||
for i, filename in enumerate(docs):
|
for i, filename in enumerate(docs):
|
||||||
|
yield json.dumps({
|
||||||
|
"total": len(docs),
|
||||||
|
"finished": i,
|
||||||
|
"doc": filename,
|
||||||
|
})
|
||||||
kb_file = KnowledgeFile(filename=filename,
|
kb_file = KnowledgeFile(filename=filename,
|
||||||
knowledge_base_name=kb.kb_name)
|
knowledge_base_name=kb.kb_name)
|
||||||
print(f"processing {kb_file.filepath} to vector store.")
|
print(f"processing {kb_file.filepath} to vector store.")
|
||||||
kb.add_doc(kb_file)
|
kb.add_doc(kb_file)
|
||||||
yield json.dumps({
|
if kb.vs_type == SupportedVSType.FAISS:
|
||||||
"total": len(docs),
|
refresh_vs_cache(knowledge_base_name)
|
||||||
"finished": i + 1,
|
|
||||||
"doc": filename,
|
|
||||||
})
|
|
||||||
|
|
||||||
return StreamingResponse(output(kb), media_type="text/event-stream")
|
return StreamingResponse(output(kb), media_type="text/event-stream")
|
||||||
|
|
|
||||||
|
|
@ -125,6 +125,10 @@ def list_docs_from_db(kb_name):
|
||||||
conn.close()
|
conn.close()
|
||||||
return kbs
|
return kbs
|
||||||
|
|
||||||
|
def list_docs_from_folder(kb_name: str):
|
||||||
|
doc_path = get_doc_path(kb_name)
|
||||||
|
return [file for file in os.listdir(doc_path)
|
||||||
|
if os.path.isfile(os.path.join(doc_path, file))]
|
||||||
|
|
||||||
def add_doc_to_db(kb_file: KnowledgeFile):
|
def add_doc_to_db(kb_file: KnowledgeFile):
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||||
|
|
|
||||||
|
|
@ -122,3 +122,4 @@ class FaissKBService(KBService):
|
||||||
|
|
||||||
def do_clear_vs(self):
|
def do_clear_vs(self):
|
||||||
shutil.rmtree(self.vs_path)
|
shutil.rmtree(self.vs_path)
|
||||||
|
os.makedirs(self.vs_path)
|
||||||
|
|
|
||||||
|
|
@ -1,407 +0,0 @@
|
||||||
import os
|
|
||||||
import sqlite3
|
|
||||||
import datetime
|
|
||||||
import shutil
|
|
||||||
from langchain.vectorstores import FAISS
|
|
||||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
|
||||||
from configs.model_config import (embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE,
|
|
||||||
DB_ROOT_PATH, VECTOR_SEARCH_TOP_K, CACHED_VS_NUM)
|
|
||||||
from server.utils import torch_gc
|
|
||||||
from functools import lru_cache
|
|
||||||
from server.knowledge_base.knowledge_file import KnowledgeFile
|
|
||||||
from typing import List
|
|
||||||
import numpy as np
|
|
||||||
from server.knowledge_base.utils import (get_kb_path, get_doc_path, get_vs_path)
|
|
||||||
|
|
||||||
SUPPORTED_VS_TYPES = ["faiss", "milvus"]
|
|
||||||
|
|
||||||
_VECTOR_STORE_TICKS = {}
|
|
||||||
|
|
||||||
@lru_cache(1)
|
|
||||||
def load_embeddings(model: str, device: str):
|
|
||||||
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
|
|
||||||
model_kwargs={'device': device})
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(CACHED_VS_NUM)
|
|
||||||
def load_vector_store(
|
|
||||||
knowledge_base_name: str,
|
|
||||||
embedding_model: str,
|
|
||||||
embedding_device: str,
|
|
||||||
tick: int, # tick will be changed by upload_doc etc. and make cache refreshed.
|
|
||||||
):
|
|
||||||
print(f"loading vector store in '{knowledge_base_name}' with '{embedding_model}' embeddings.")
|
|
||||||
embeddings = load_embeddings(embedding_model, embedding_device)
|
|
||||||
vs_path = get_vs_path(knowledge_base_name)
|
|
||||||
search_index = FAISS.load_local(vs_path, embeddings)
|
|
||||||
return search_index
|
|
||||||
|
|
||||||
|
|
||||||
def refresh_vs_cache(kb_name: str):
|
|
||||||
"""
|
|
||||||
make vector store cache refreshed when next loading
|
|
||||||
"""
|
|
||||||
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
|
|
||||||
|
|
||||||
|
|
||||||
def list_kbs_from_db():
|
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
|
||||||
c = conn.cursor()
|
|
||||||
c.execute('''CREATE TABLE if not exists knowledge_base
|
|
||||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
kb_name TEXT,
|
|
||||||
vs_type TEXT,
|
|
||||||
embed_model TEXT,
|
|
||||||
file_count INTEGER,
|
|
||||||
create_time DATETIME) ''')
|
|
||||||
c.execute(f'''SELECT kb_name
|
|
||||||
FROM knowledge_base
|
|
||||||
WHERE file_count>0 ''')
|
|
||||||
kbs = [i[0] for i in c.fetchall() if i]
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
return kbs
|
|
||||||
|
|
||||||
|
|
||||||
def add_kb_to_db(kb_name, vs_type, embed_model):
|
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
|
||||||
c = conn.cursor()
|
|
||||||
# Create table
|
|
||||||
c.execute('''CREATE TABLE if not exists knowledge_base
|
|
||||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
kb_name TEXT,
|
|
||||||
vs_type TEXT,
|
|
||||||
embed_model TEXT,
|
|
||||||
file_count INTEGER,
|
|
||||||
create_time DATETIME) ''')
|
|
||||||
# Insert a row of data
|
|
||||||
c.execute(f"""INSERT INTO knowledge_base
|
|
||||||
(kb_name, vs_type, embed_model, file_count, create_time)
|
|
||||||
VALUES
|
|
||||||
('{kb_name}','{vs_type}','{embed_model}',
|
|
||||||
0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""")
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
|
|
||||||
def kb_exists(kb_name):
|
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
|
||||||
c = conn.cursor()
|
|
||||||
c.execute('''CREATE TABLE if not exists knowledge_base
|
|
||||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
kb_name TEXT,
|
|
||||||
vs_type TEXT,
|
|
||||||
embed_model TEXT,
|
|
||||||
file_count INTEGER,
|
|
||||||
create_time DATETIME) ''')
|
|
||||||
c.execute(f'''SELECT COUNT(*)
|
|
||||||
FROM knowledge_base
|
|
||||||
WHERE kb_name="{kb_name}" ''')
|
|
||||||
status = True if c.fetchone()[0] else False
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
return status
|
|
||||||
|
|
||||||
|
|
||||||
def load_kb_from_db(kb_name):
|
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
|
||||||
c = conn.cursor()
|
|
||||||
c.execute('''CREATE TABLE if not exists knowledge_base
|
|
||||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
kb_name TEXT,
|
|
||||||
vs_type TEXT,
|
|
||||||
embed_model TEXT,
|
|
||||||
file_count INTEGER,
|
|
||||||
create_time DATETIME) ''')
|
|
||||||
c.execute(f'''SELECT kb_name, vs_type, embed_model
|
|
||||||
FROM knowledge_base
|
|
||||||
WHERE kb_name="{kb_name}" ''')
|
|
||||||
resp = c.fetchone()
|
|
||||||
if resp:
|
|
||||||
kb_name, vs_type, embed_model = resp
|
|
||||||
else:
|
|
||||||
kb_name, vs_type, embed_model = None, None, None
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
return kb_name, vs_type, embed_model
|
|
||||||
|
|
||||||
|
|
||||||
def delete_kb_from_db(kb_name):
|
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
|
||||||
c = conn.cursor()
|
|
||||||
# delete kb from table knowledge_base
|
|
||||||
c.execute('''CREATE TABLE if not exists knowledge_base
|
|
||||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
kb_name TEXT,
|
|
||||||
vs_type TEXT,
|
|
||||||
embed_model TEXT,
|
|
||||||
file_count INTEGER,
|
|
||||||
create_time DATETIME) ''')
|
|
||||||
c.execute(f'''DELETE
|
|
||||||
FROM knowledge_base
|
|
||||||
WHERE kb_name="{kb_name}" ''')
|
|
||||||
# delete files in kb from table knowledge_files
|
|
||||||
c.execute('''CREATE TABLE if not exists knowledge_files
|
|
||||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
file_name TEXT,
|
|
||||||
file_ext TEXT,
|
|
||||||
kb_name TEXT,
|
|
||||||
document_loader_name TEXT,
|
|
||||||
text_splitter_name TEXT,
|
|
||||||
file_version INTEGER,
|
|
||||||
create_time DATETIME) ''')
|
|
||||||
# Insert a row of data
|
|
||||||
c.execute(f"""DELETE
|
|
||||||
FROM knowledge_files
|
|
||||||
WHERE kb_name="{kb_name}"
|
|
||||||
""")
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def list_docs_from_db(kb_name):
|
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
|
||||||
c = conn.cursor()
|
|
||||||
c.execute('''CREATE TABLE if not exists knowledge_files
|
|
||||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
file_name TEXT,
|
|
||||||
file_ext TEXT,
|
|
||||||
kb_name TEXT,
|
|
||||||
document_loader_name TEXT,
|
|
||||||
text_splitter_name TEXT,
|
|
||||||
file_version INTEGER,
|
|
||||||
create_time DATETIME) ''')
|
|
||||||
c.execute(f'''SELECT file_name
|
|
||||||
FROM knowledge_files
|
|
||||||
WHERE kb_name="{kb_name}" ''')
|
|
||||||
kbs = [i[0] for i in c.fetchall() if i]
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
return kbs
|
|
||||||
|
|
||||||
|
|
||||||
def add_doc_to_db(kb_file: KnowledgeFile):
|
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
|
||||||
c = conn.cursor()
|
|
||||||
# Create table
|
|
||||||
c.execute('''CREATE TABLE if not exists knowledge_files
|
|
||||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
file_name TEXT,
|
|
||||||
file_ext TEXT,
|
|
||||||
kb_name TEXT,
|
|
||||||
document_loader_name TEXT,
|
|
||||||
text_splitter_name TEXT,
|
|
||||||
file_version INTEGER,
|
|
||||||
create_time DATETIME) ''')
|
|
||||||
# Insert a row of data
|
|
||||||
# TODO: 同名文件添加至知识库时,file_version增加
|
|
||||||
c.execute(f"""SELECT 1 FROM knowledge_files WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}" """)
|
|
||||||
record_exist = c.fetchone()
|
|
||||||
if record_exist is not None:
|
|
||||||
c.execute(f"""UPDATE knowledge_files
|
|
||||||
SET file_version = file_version + 1
|
|
||||||
WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}"
|
|
||||||
""")
|
|
||||||
else:
|
|
||||||
c.execute(f"""INSERT INTO knowledge_files
|
|
||||||
(file_name, file_ext, kb_name, document_loader_name, text_splitter_name, file_version, create_time)
|
|
||||||
VALUES
|
|
||||||
('{kb_file.filename}','{kb_file.ext}','{kb_file.kb_name}', '{kb_file.document_loader_name}',
|
|
||||||
'{kb_file.text_splitter_name}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""")
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
|
|
||||||
def delete_file_from_db(kb_file: KnowledgeFile):
|
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
|
||||||
c = conn.cursor()
|
|
||||||
# delete files in kb from table knowledge_files
|
|
||||||
c.execute('''CREATE TABLE if not exists knowledge_files
|
|
||||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
file_name TEXT,
|
|
||||||
file_ext TEXT,
|
|
||||||
kb_name TEXT,
|
|
||||||
document_loader_name TEXT,
|
|
||||||
text_splitter_name TEXT,
|
|
||||||
file_version INTEGER,
|
|
||||||
create_time DATETIME) ''')
|
|
||||||
# Insert a row of data
|
|
||||||
c.execute(f"""DELETE
|
|
||||||
FROM knowledge_files
|
|
||||||
WHERE file_name="{kb_file.filename}"
|
|
||||||
AND kb_name="{kb_file.kb_name}"
|
|
||||||
""")
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def doc_exists(kb_file: KnowledgeFile):
|
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
|
||||||
c = conn.cursor()
|
|
||||||
c.execute('''CREATE TABLE if not exists knowledge_files
|
|
||||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
file_name TEXT,
|
|
||||||
file_ext TEXT,
|
|
||||||
kb_name TEXT,
|
|
||||||
document_loader_name TEXT,
|
|
||||||
text_splitter_name TEXT,
|
|
||||||
file_version INTEGER,
|
|
||||||
create_time DATETIME) ''')
|
|
||||||
c.execute(f'''SELECT COUNT(*)
|
|
||||||
FROM knowledge_files
|
|
||||||
WHERE file_name="{kb_file.filename}"
|
|
||||||
AND kb_name="{kb_file.kb_name}" ''')
|
|
||||||
status = True if c.fetchone()[0] else False
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
return status
|
|
||||||
|
|
||||||
|
|
||||||
def delete_doc_from_faiss(vector_store: FAISS, ids: List[str]):
|
|
||||||
overlapping = set(ids).intersection(vector_store.index_to_docstore_id.values())
|
|
||||||
if not overlapping:
|
|
||||||
raise ValueError("ids do not exist in the current object")
|
|
||||||
_reversed_index = {v: k for k, v in vector_store.index_to_docstore_id.items()}
|
|
||||||
index_to_delete = [_reversed_index[i] for i in ids]
|
|
||||||
vector_store.index.remove_ids(np.array(index_to_delete, dtype=np.int64))
|
|
||||||
for _id in index_to_delete:
|
|
||||||
del vector_store.index_to_docstore_id[_id]
|
|
||||||
# Remove items from docstore.
|
|
||||||
overlapping2 = set(ids).intersection(vector_store.docstore._dict)
|
|
||||||
if not overlapping2:
|
|
||||||
raise ValueError(f"Tried to delete ids that does not exist: {ids}")
|
|
||||||
for _id in ids:
|
|
||||||
vector_store.docstore._dict.pop(_id)
|
|
||||||
return vector_store
|
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeBase:
|
|
||||||
def __init__(self,
|
|
||||||
knowledge_base_name: str,
|
|
||||||
vector_store_type: str = "faiss",
|
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
|
||||||
):
|
|
||||||
self.kb_name = knowledge_base_name
|
|
||||||
if vector_store_type not in SUPPORTED_VS_TYPES:
|
|
||||||
raise ValueError(f"暂未支持向量库类型 {vector_store_type}")
|
|
||||||
self.vs_type = vector_store_type
|
|
||||||
if embed_model not in embedding_model_dict.keys():
|
|
||||||
raise ValueError(f"暂未支持embedding模型 {embed_model}")
|
|
||||||
self.embed_model = embed_model
|
|
||||||
self.kb_path = get_kb_path(self.kb_name)
|
|
||||||
self.doc_path = get_doc_path(self.kb_name)
|
|
||||||
if self.vs_type in ["faiss"]:
|
|
||||||
self.vs_path = get_vs_path(self.kb_name)
|
|
||||||
elif self.vs_type in ["milvus"]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def create(self):
|
|
||||||
if not os.path.exists(self.doc_path):
|
|
||||||
os.makedirs(self.doc_path)
|
|
||||||
if self.vs_type in ["faiss"]:
|
|
||||||
if not os.path.exists(self.vs_path):
|
|
||||||
os.makedirs(self.vs_path)
|
|
||||||
add_kb_to_db(self.kb_name, self.vs_type, self.embed_model)
|
|
||||||
elif self.vs_type in ["milvus"]:
|
|
||||||
# TODO: 创建milvus库
|
|
||||||
pass
|
|
||||||
return True
|
|
||||||
|
|
||||||
def recreate_vs(self):
|
|
||||||
if self.vs_type in ["faiss"]:
|
|
||||||
shutil.rmtree(self.vs_path)
|
|
||||||
self.create()
|
|
||||||
|
|
||||||
def add_doc(self, kb_file: KnowledgeFile):
|
|
||||||
docs = kb_file.file2text()
|
|
||||||
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
|
|
||||||
if self.vs_type in ["faiss"]:
|
|
||||||
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
|
||||||
vector_store = FAISS.load_local(self.vs_path, embeddings)
|
|
||||||
vector_store.add_documents(docs)
|
|
||||||
torch_gc()
|
|
||||||
else:
|
|
||||||
if not os.path.exists(self.vs_path):
|
|
||||||
os.makedirs(self.vs_path)
|
|
||||||
vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表
|
|
||||||
torch_gc()
|
|
||||||
vector_store.save_local(self.vs_path)
|
|
||||||
add_doc_to_db(kb_file)
|
|
||||||
refresh_vs_cache(self.kb_name)
|
|
||||||
elif self.vs_type in ["milvus"]:
|
|
||||||
# TODO: 向milvus库中增加文件
|
|
||||||
pass
|
|
||||||
|
|
||||||
def delete_doc(self, kb_file: KnowledgeFile):
|
|
||||||
if os.path.exists(kb_file.filepath):
|
|
||||||
os.remove(kb_file.filepath)
|
|
||||||
if self.vs_type in ["faiss"]:
|
|
||||||
# TODO: 从FAISS向量库中删除文档
|
|
||||||
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
|
|
||||||
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
|
||||||
vector_store = FAISS.load_local(self.vs_path, embeddings)
|
|
||||||
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
|
||||||
if len(ids) == 0:
|
|
||||||
return None
|
|
||||||
print(len(ids))
|
|
||||||
vector_store = delete_doc_from_faiss(vector_store, ids)
|
|
||||||
vector_store.save_local(self.vs_path)
|
|
||||||
refresh_vs_cache(self.kb_name)
|
|
||||||
delete_file_from_db(kb_file)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def exist_doc(self, file_name: str):
|
|
||||||
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,
|
|
||||||
filename=file_name))
|
|
||||||
|
|
||||||
def list_docs(self):
|
|
||||||
return list_docs_from_db(self.kb_name)
|
|
||||||
|
|
||||||
def search_docs(self,
|
|
||||||
query: str,
|
|
||||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
|
||||||
embedding_device: str = EMBEDDING_DEVICE, ):
|
|
||||||
search_index = load_vector_store(self.kb_name,
|
|
||||||
self.embed_model,
|
|
||||||
embedding_device,
|
|
||||||
_VECTOR_STORE_TICKS.get(self.kb_name))
|
|
||||||
docs = search_index.similarity_search(query, k=top_k)
|
|
||||||
return docs
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def exists(cls,
|
|
||||||
knowledge_base_name: str):
|
|
||||||
return kb_exists(knowledge_base_name)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls,
|
|
||||||
knowledge_base_name: str):
|
|
||||||
kb_name, vs_type, embed_model = load_kb_from_db(knowledge_base_name)
|
|
||||||
return cls(kb_name, vs_type, embed_model)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def delete(cls,
|
|
||||||
knowledge_base_name: str):
|
|
||||||
kb = cls.load(knowledge_base_name)
|
|
||||||
if kb.vs_type in ["faiss"]:
|
|
||||||
shutil.rmtree(kb.kb_path)
|
|
||||||
elif kb.vs_type in ["milvus"]:
|
|
||||||
# TODO: 删除milvus库
|
|
||||||
pass
|
|
||||||
status = delete_kb_from_db(knowledge_base_name)
|
|
||||||
return status
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def list_kbs(cls):
|
|
||||||
return list_kbs_from_db()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# kb = KnowledgeBase("123", "faiss")
|
|
||||||
# kb.create()
|
|
||||||
kb = KnowledgeBase.load(knowledge_base_name="123")
|
|
||||||
kb.delete_doc(KnowledgeFile(knowledge_base_name="123", filename="README.md"))
|
|
||||||
print()
|
|
||||||
|
|
@ -1,5 +1,8 @@
|
||||||
import os.path
|
import os
|
||||||
from configs.model_config import KB_ROOT_PATH
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||||
|
from configs.model_config import (embedding_model_dict, KB_ROOT_PATH)
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
|
||||||
def validate_kb_name(knowledge_base_id: str) -> bool:
|
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||||
# 检查是否包含预期外的字符或路径攻击关键字
|
# 检查是否包含预期外的字符或路径攻击关键字
|
||||||
|
|
@ -17,4 +20,10 @@ def get_vs_path(knowledge_base_name: str):
|
||||||
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
|
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
|
||||||
|
|
||||||
def get_file_path(knowledge_base_name: str, doc_name: str):
|
def get_file_path(knowledge_base_name: str, doc_name: str):
|
||||||
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
|
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
|
||||||
|
|
||||||
|
@lru_cache(1)
|
||||||
|
def load_embeddings(model: str, device: str):
|
||||||
|
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
|
||||||
|
model_kwargs={'device': device})
|
||||||
|
return embeddings
|
||||||
|
|
|
||||||
|
|
@ -7,11 +7,157 @@
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||||
from configs.model_config import LOG_PATH,controller_args,worker_args,server_args,parser
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import re
|
import re
|
||||||
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
LOG_PATH = "./logs/"
|
||||||
|
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
logging.basicConfig(format=LOG_FORMAT)
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
#------multi worker-----------------
|
||||||
|
parser.add_argument('--model-path-address',
|
||||||
|
default="THUDM/chatglm2-6b@localhost@20002",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
help="model path, host, and port, formatted as model-path@host@path")
|
||||||
|
#---------------controller-------------------------
|
||||||
|
|
||||||
|
parser.add_argument("--controller-host", type=str, default="localhost")
|
||||||
|
parser.add_argument("--controller-port", type=int, default=21001)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dispatch-method",
|
||||||
|
type=str,
|
||||||
|
choices=["lottery", "shortest_queue"],
|
||||||
|
default="shortest_queue",
|
||||||
|
)
|
||||||
|
controller_args = ["controller-host","controller-port","dispatch-method"]
|
||||||
|
|
||||||
|
#----------------------worker------------------------------------------
|
||||||
|
|
||||||
|
parser.add_argument("--worker-host", type=str, default="localhost")
|
||||||
|
parser.add_argument("--worker-port", type=int, default=21002)
|
||||||
|
# parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
|
||||||
|
# parser.add_argument(
|
||||||
|
# "--controller-address", type=str, default="http://localhost:21001"
|
||||||
|
# )
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-path",
|
||||||
|
type=str,
|
||||||
|
default="lmsys/vicuna-7b-v1.3",
|
||||||
|
help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--revision",
|
||||||
|
type=str,
|
||||||
|
default="main",
|
||||||
|
help="Hugging Face Hub model revision identifier",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
choices=["cpu", "cuda", "mps", "xpu"],
|
||||||
|
default="cuda",
|
||||||
|
help="The device type",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gpus",
|
||||||
|
type=str,
|
||||||
|
default="0",
|
||||||
|
help="A single GPU like 1 or multiple GPUs like 0,2",
|
||||||
|
)
|
||||||
|
parser.add_argument("--num-gpus", type=int, default=1)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-gpu-memory",
|
||||||
|
type=str,
|
||||||
|
help="The maximum memory per gpu. Use a string like '13Gib'",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--load-8bit", action="store_true", help="Use 8-bit quantization"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cpu-offloading",
|
||||||
|
action="store_true",
|
||||||
|
help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gptq-ckpt",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Load quantized model. The path to the local GPTQ checkpoint.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gptq-wbits",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
choices=[2, 3, 4, 8, 16],
|
||||||
|
help="#bits to use for quantization",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gptq-groupsize",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="Groupsize to use for quantization; default uses full row.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gptq-act-order",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to apply the activation order GPTQ heuristic",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-names",
|
||||||
|
type=lambda s: s.split(","),
|
||||||
|
help="Optional display comma separated names",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--limit-worker-concurrency",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="Limit the model concurrency to prevent OOM.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--stream-interval", type=int, default=2)
|
||||||
|
parser.add_argument("--no-register", action="store_true")
|
||||||
|
|
||||||
|
worker_args = [
|
||||||
|
"worker-host","worker-port",
|
||||||
|
"model-path","revision","device","gpus","num-gpus",
|
||||||
|
"max-gpu-memory","load-8bit","cpu-offloading",
|
||||||
|
"gptq-ckpt","gptq-wbits","gptq-groupsize",
|
||||||
|
"gptq-act-order","model-names","limit-worker-concurrency",
|
||||||
|
"stream-interval","no-register",
|
||||||
|
"controller-address"
|
||||||
|
]
|
||||||
|
#-----------------openai server---------------------------
|
||||||
|
|
||||||
|
parser.add_argument("--server-host", type=str, default="localhost", help="host name")
|
||||||
|
parser.add_argument("--server-port", type=int, default=8001, help="port number")
|
||||||
|
parser.add_argument(
|
||||||
|
"--allow-credentials", action="store_true", help="allow credentials"
|
||||||
|
)
|
||||||
|
# parser.add_argument(
|
||||||
|
# "--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
|
||||||
|
# )
|
||||||
|
# parser.add_argument(
|
||||||
|
# "--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
|
||||||
|
# )
|
||||||
|
# parser.add_argument(
|
||||||
|
# "--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
|
||||||
|
# )
|
||||||
|
parser.add_argument(
|
||||||
|
"--api-keys",
|
||||||
|
type=lambda s: s.split(","),
|
||||||
|
help="Optional list of comma separated API keys",
|
||||||
|
)
|
||||||
|
server_args = ["server-host","server-port","allow-credentials","api-keys",
|
||||||
|
"controller-address"
|
||||||
|
]
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
# 必须要加http//:,否则InvalidSchema: No connection adapters were found
|
# 必须要加http//:,否则InvalidSchema: No connection adapters were found
|
||||||
args = argparse.Namespace(**vars(args),**{"controller-address":f"http://{args.controller_host}:{str(args.controller_port)}"})
|
args = argparse.Namespace(**vars(args),**{"controller-address":f"http://{args.controller_host}:{str(args.controller_port)}"})
|
||||||
|
|
|
||||||
|
|
@ -118,7 +118,7 @@ def dialogue_page(api: ApiRequest):
|
||||||
chat_box.update_msg(text, 0, streaming=False)
|
chat_box.update_msg(text, 0, streaming=False)
|
||||||
|
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
cols[0].download_button(
|
export_btn.download_button(
|
||||||
"Export",
|
"Export",
|
||||||
"".join(chat_box.export2md(cur_chat_name)),
|
"".join(chat_box.export2md(cur_chat_name)),
|
||||||
file_name=f"{now:%Y-%m-%d %H.%M}_{cur_chat_name}.md",
|
file_name=f"{now:%Y-%m-%d %H.%M}_{cur_chat_name}.md",
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,124 @@
|
||||||
|
from pydoc import Helper
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from webui_pages.utils import *
|
from webui_pages.utils import *
|
||||||
|
import streamlit_antd_components as sac
|
||||||
|
from st_aggrid import AgGrid
|
||||||
|
from st_aggrid.grid_options_builder import GridOptionsBuilder
|
||||||
|
import pandas as pd
|
||||||
|
from server.knowledge_base.utils import get_file_path
|
||||||
|
from streamlit_chatbox import *
|
||||||
|
|
||||||
|
|
||||||
|
SENTENCE_SIZE = 100
|
||||||
|
|
||||||
|
|
||||||
def knowledge_base_page(api: ApiRequest):
|
def knowledge_base_page(api: ApiRequest):
|
||||||
st.write(123)
|
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=True)
|
||||||
pass
|
chat_box = ChatBox(session_key="kb_messages")
|
||||||
|
|
||||||
|
kb_list = api.list_knowledge_bases()
|
||||||
|
kb_docs = {}
|
||||||
|
for kb in kb_list:
|
||||||
|
kb_docs[kb] = api.list_kb_docs(kb)
|
||||||
|
|
||||||
|
with st.sidebar:
|
||||||
|
def on_new_kb():
|
||||||
|
if name := st.session_state.get("new_kb_name"):
|
||||||
|
if name in kb_list:
|
||||||
|
st.error(f"名为 {name} 的知识库已经存在!")
|
||||||
|
else:
|
||||||
|
ret = api.create_knowledge_base(name)
|
||||||
|
st.toast(ret["msg"])
|
||||||
|
|
||||||
|
def on_del_kb():
|
||||||
|
if name := st.session_state.get("new_kb_name"):
|
||||||
|
if name in kb_list:
|
||||||
|
ret = api.delete_knowledge_base(name)
|
||||||
|
st.toast(ret["msg"])
|
||||||
|
else:
|
||||||
|
st.error(f"名为 {name} 的知识库不存在!")
|
||||||
|
|
||||||
|
cols = st.columns([2, 1, 1])
|
||||||
|
new_kb_name = cols[0].text_input(
|
||||||
|
"新知识库名称",
|
||||||
|
placeholder="新知识库名称",
|
||||||
|
label_visibility="collapsed",
|
||||||
|
key="new_kb_name",
|
||||||
|
)
|
||||||
|
cols[1].button("新建", on_click=on_new_kb, disabled=not bool(new_kb_name))
|
||||||
|
cols[2].button("删除", on_click=on_del_kb, disabled=not bool(new_kb_name))
|
||||||
|
|
||||||
|
st.write("知识库:")
|
||||||
|
if kb_list:
|
||||||
|
try:
|
||||||
|
index = kb_list.index(st.session_state.get("cur_kb"))
|
||||||
|
except:
|
||||||
|
index = 0
|
||||||
|
kb = sac.buttons(
|
||||||
|
kb_list,
|
||||||
|
index,
|
||||||
|
format_func=lambda x: f"{x} ({len(kb_docs[x])})",
|
||||||
|
)
|
||||||
|
st.session_state["cur_kb"] = kb
|
||||||
|
sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True)
|
||||||
|
files = st.file_uploader("上传知识文件",
|
||||||
|
["docx", "txt", "md", "csv", "xlsx", "pdf"],
|
||||||
|
accept_multiple_files=True,
|
||||||
|
key="files",
|
||||||
|
)
|
||||||
|
if st.button(
|
||||||
|
"添加文件到知识库",
|
||||||
|
help="请先上传文件,再点击添加",
|
||||||
|
use_container_width=True,
|
||||||
|
disabled=len(files)==0,
|
||||||
|
):
|
||||||
|
for f in files:
|
||||||
|
ret = api.upload_kb_doc(f, kb)
|
||||||
|
if ret["code"] == 200:
|
||||||
|
st.toast(ret["msg"], icon="✔")
|
||||||
|
else:
|
||||||
|
st.toast(ret["msg"], icon="❌")
|
||||||
|
st.session_state.files = []
|
||||||
|
|
||||||
|
if st.button(
|
||||||
|
"重建知识库",
|
||||||
|
help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。",
|
||||||
|
use_container_width=True,
|
||||||
|
disabled=True,
|
||||||
|
):
|
||||||
|
progress = st.progress(0.0, "")
|
||||||
|
for d in api.recreate_vector_store(kb):
|
||||||
|
progress.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
|
||||||
|
|
||||||
|
if kb_list:
|
||||||
|
# 知识库详情
|
||||||
|
st.subheader(f"知识库 {kb} 详情")
|
||||||
|
df = pd.DataFrame([[i + 1, x] for i, x in enumerate(kb_docs[kb])], columns=["No", "文档名称"])
|
||||||
|
gb = GridOptionsBuilder.from_dataframe(df)
|
||||||
|
gb.configure_column("No", width=50)
|
||||||
|
gb.configure_selection()
|
||||||
|
|
||||||
|
cols = st.columns([1, 2])
|
||||||
|
|
||||||
|
with cols[0]:
|
||||||
|
docs = AgGrid(df, gb.build())
|
||||||
|
|
||||||
|
with cols[1]:
|
||||||
|
cols = st.columns(3)
|
||||||
|
selected_rows = docs.get("selected_rows", [])
|
||||||
|
|
||||||
|
cols = st.columns([2, 3, 2])
|
||||||
|
if selected_rows:
|
||||||
|
file_name = selected_rows[0]["文档名称"]
|
||||||
|
file_path = get_file_path(kb, file_name)
|
||||||
|
with open(file_path, "rb") as fp:
|
||||||
|
cols[0].download_button("下载选中文档", fp, file_name=file_name)
|
||||||
|
else:
|
||||||
|
cols[0].download_button("下载选中文档", "", disabled=True)
|
||||||
|
if cols[2].button("删除选中文档!", type="primary"):
|
||||||
|
for row in selected_rows:
|
||||||
|
ret = api.delete_kb_doc(kb, row["文档名称"])
|
||||||
|
st.toast(ret["msg"])
|
||||||
|
st.experimental_rerun()
|
||||||
|
|
||||||
|
st.write("本文档包含以下知识条目:(待定内容)")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue