Merge remote-tracking branch 'origin/dev_fastchat' into dev_fastchat

# Conflicts:
#	server/knowledge_base/kb_service/default_kb_service.py
#	server/knowledge_base/kb_service/milvus_kb_service.py
#	server/knowledge_base/knowledge_base_factory.py
This commit is contained in:
zqt 2023-08-08 14:06:13 +08:00
commit 41fd1acc9c
13 changed files with 347 additions and 446 deletions

View File

@ -15,9 +15,13 @@ starlette~=0.27.0
numpy~=1.24.4 numpy~=1.24.4
pydantic~=1.10.11 pydantic~=1.10.11
unstructured[all-docs] unstructured[all-docs]
python-magic-bin; sys_platform == 'win32'
streamlit>=1.25.0 streamlit>=1.25.0
streamlit-option-menu streamlit-option-menu
streamlit-antd-components streamlit-antd-components
streamlit-chatbox>=1.1.6 streamlit-chatbox>=1.1.6
httpx httpx
faiss-cpu
pymilvus==2.1.3 # requires milvus==2.1.3

View File

@ -4,7 +4,11 @@ from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.schema import ( from langchain.schema import (
BaseMessage, BaseMessage,
messages_from_dict, AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
LLMResult LLMResult
) )
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
@ -16,6 +20,18 @@ from langchain.callbacks.manager import (
from server.model.chat_openai_chain import OpenAiChatMsgDto, OpenAiMessageDto, BaseMessageDto from server.model.chat_openai_chain import OpenAiChatMsgDto, OpenAiMessageDto, BaseMessageDto
def _convert_dict_to_message(_dict: dict) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
return AIMessage(content=_dict["content"])
elif role == "system":
return SystemMessage(content=_dict["content"])
else:
return ChatMessage(content=_dict["content"], role=role)
def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[BaseMessage]: def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[BaseMessage]:
""" """
前端消息传输对象DTO转换为chat消息传输对象DTO 前端消息传输对象DTO转换为chat消息传输对象DTO
@ -25,7 +41,7 @@ def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[Bas
messages = [] messages = []
for message_datum in message_data: for message_datum in message_data:
messages.append(message_datum.dict()) messages.append(message_datum.dict())
return messages_from_dict(messages) return _convert_dict_to_message(messages)
class BaseChatOpenAIChain(Chain, ABC): class BaseChatOpenAIChain(Chain, ABC):

View File

@ -10,7 +10,8 @@ from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable from typing import AsyncIterable
import asyncio import asyncio
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from server.knowledge_base.knowledge_base import KnowledgeBase from server.knowledge_base.knowledge_base_factory import KBServiceFactory
from server.knowledge_base.kb_service.base import KBService
import json import json
@ -18,12 +19,12 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"), knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
): ):
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name): kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
async def knowledge_base_chat_iterator(query: str, async def knowledge_base_chat_iterator(query: str,
kb: KnowledgeBase, kb: KBService,
top_k: int, top_k: int,
) -> AsyncIterable[str]: ) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler() callback = AsyncIteratorCallbackHandler()

View File

@ -1,4 +1,4 @@
from .kb_api import list_kbs, create_kb, delete_kb from .kb_api import list_kbs, create_kb, delete_kb
from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store
from .knowledge_base import KnowledgeBase
from .knowledge_file import KnowledgeFile from .knowledge_file import KnowledgeFile
from .knowledge_base_factory import KBServiceFactory

View File

@ -1,13 +1,14 @@
import urllib import urllib
from server.utils import BaseResponse, ListResponse from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import validate_kb_name from server.knowledge_base.utils import validate_kb_name
from server.knowledge_base.knowledge_base import KnowledgeBase from server.knowledge_base.knowledge_base_factory import KBServiceFactory
from server.knowledge_base.kb_service.base import list_kbs_from_db
from configs.model_config import EMBEDDING_MODEL from configs.model_config import EMBEDDING_MODEL
async def list_kbs(): async def list_kbs():
# Get List of Knowledge Base # Get List of Knowledge Base
return ListResponse(data=KnowledgeBase.list_kbs()) return ListResponse(data=list_kbs_from_db())
async def create_kb(knowledge_base_name: str, async def create_kb(knowledge_base_name: str,
@ -19,11 +20,10 @@ async def create_kb(knowledge_base_name: str,
return BaseResponse(code=403, msg="Don't attack me") return BaseResponse(code=403, msg="Don't attack me")
if knowledge_base_name is None or knowledge_base_name.strip() == "": if knowledge_base_name is None or knowledge_base_name.strip() == "":
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称") return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
if KnowledgeBase.exists(knowledge_base_name):
kb = KBServiceFactory.get_service(knowledge_base_name, "faiss")
if kb is not None:
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}") return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
kb = KnowledgeBase(knowledge_base_name=knowledge_base_name,
vector_store_type=vector_store_type,
embed_model=embed_model)
kb.create() kb.create()
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}") return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
@ -34,10 +34,12 @@ async def delete_kb(knowledge_base_name: str):
return BaseResponse(code=403, msg="Don't attack me") return BaseResponse(code=403, msg="Don't attack me")
knowledge_base_name = urllib.parse.unquote(knowledge_base_name) knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
if not KnowledgeBase.exists(knowledge_base_name): kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
status = KnowledgeBase.delete(knowledge_base_name) status = kb.drop_kb()
if status: if status:
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}") return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
else: else:

View File

@ -6,7 +6,9 @@ from server.knowledge_base.utils import (validate_kb_name)
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
import json import json
from server.knowledge_base.knowledge_file import KnowledgeFile from server.knowledge_base.knowledge_file import KnowledgeFile
from server.knowledge_base.knowledge_base import KnowledgeBase from server.knowledge_base.knowledge_base_factory import KBServiceFactory
from server.knowledge_base.kb_service.base import SupportedVSType, list_docs_from_folder
from server.knowledge_base.kb_service.faiss_kb_service import refresh_vs_cache
async def list_docs(knowledge_base_name: str): async def list_docs(knowledge_base_name: str):
@ -14,10 +16,11 @@ async def list_docs(knowledge_base_name: str):
return ListResponse(code=403, msg="Don't attack me", data=[]) return ListResponse(code=403, msg="Don't attack me", data=[])
knowledge_base_name = urllib.parse.unquote(knowledge_base_name) knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name): kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[]) return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
else: else:
all_doc_names = KnowledgeBase.load(knowledge_base_name=knowledge_base_name).list_docs() all_doc_names = kb.list_docs()
return ListResponse(data=all_doc_names) return ListResponse(data=all_doc_names)
@ -28,11 +31,10 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me") return BaseResponse(code=403, msg="Don't attack me")
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name): kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
file_content = await file.read() # 读取上传文件的内容 file_content = await file.read() # 读取上传文件的内容
kb_file = KnowledgeFile(filename=file.filename, kb_file = KnowledgeFile(filename=file.filename,
@ -63,10 +65,10 @@ async def delete_doc(knowledge_base_name: str,
return BaseResponse(code=403, msg="Don't attack me") return BaseResponse(code=403, msg="Don't attack me")
knowledge_base_name = urllib.parse.unquote(knowledge_base_name) knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name): kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
if not kb.exist_doc(doc_name): if not kb.exist_doc(doc_name):
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}") return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
kb_file = KnowledgeFile(filename=doc_name, kb_file = KnowledgeFile(filename=doc_name,
@ -92,21 +94,26 @@ async def recreate_vector_store(knowledge_base_name: str):
recreate vector store from the content. recreate vector store from the content.
this is usefull when user can copy files to content folder directly instead of upload through network. this is usefull when user can copy files to content folder directly instead of upload through network.
''' '''
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name) kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
async def output(kb: KnowledgeBase): async def output(kb):
kb.recreate_vs() kb.clear_vs()
print(f"start to recreate vector store of {kb.kb_name}") print(f"start to recreate vector store of {kb.kb_name}")
docs = kb.list_docs() docs = list_docs_from_folder(knowledge_base_name)
print(docs)
for i, filename in enumerate(docs): for i, filename in enumerate(docs):
yield json.dumps({
"total": len(docs),
"finished": i,
"doc": filename,
})
kb_file = KnowledgeFile(filename=filename, kb_file = KnowledgeFile(filename=filename,
knowledge_base_name=kb.kb_name) knowledge_base_name=kb.kb_name)
print(f"processing {kb_file.filepath} to vector store.") print(f"processing {kb_file.filepath} to vector store.")
kb.add_doc(kb_file) kb.add_doc(kb_file)
yield json.dumps({ if kb.vs_type == SupportedVSType.FAISS:
"total": len(docs), refresh_vs_cache(knowledge_base_name)
"finished": i + 1,
"doc": filename,
})
return StreamingResponse(output(kb), media_type="text/event-stream") return StreamingResponse(output(kb), media_type="text/event-stream")

View File

@ -125,6 +125,10 @@ def list_docs_from_db(kb_name):
conn.close() conn.close()
return kbs return kbs
def list_docs_from_folder(kb_name: str):
doc_path = get_doc_path(kb_name)
return [file for file in os.listdir(doc_path)
if os.path.isfile(os.path.join(doc_path, file))]
def add_doc_to_db(kb_file: KnowledgeFile): def add_doc_to_db(kb_file: KnowledgeFile):
conn = sqlite3.connect(DB_ROOT_PATH) conn = sqlite3.connect(DB_ROOT_PATH)

View File

@ -122,3 +122,4 @@ class FaissKBService(KBService):
def do_clear_vs(self): def do_clear_vs(self):
shutil.rmtree(self.vs_path) shutil.rmtree(self.vs_path)
os.makedirs(self.vs_path)

View File

@ -1,407 +0,0 @@
import os
import sqlite3
import datetime
import shutil
from langchain.vectorstores import FAISS
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from configs.model_config import (embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE,
DB_ROOT_PATH, VECTOR_SEARCH_TOP_K, CACHED_VS_NUM)
from server.utils import torch_gc
from functools import lru_cache
from server.knowledge_base.knowledge_file import KnowledgeFile
from typing import List
import numpy as np
from server.knowledge_base.utils import (get_kb_path, get_doc_path, get_vs_path)
SUPPORTED_VS_TYPES = ["faiss", "milvus"]
_VECTOR_STORE_TICKS = {}
@lru_cache(1)
def load_embeddings(model: str, device: str):
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
model_kwargs={'device': device})
return embeddings
@lru_cache(CACHED_VS_NUM)
def load_vector_store(
knowledge_base_name: str,
embedding_model: str,
embedding_device: str,
tick: int, # tick will be changed by upload_doc etc. and make cache refreshed.
):
print(f"loading vector store in '{knowledge_base_name}' with '{embedding_model}' embeddings.")
embeddings = load_embeddings(embedding_model, embedding_device)
vs_path = get_vs_path(knowledge_base_name)
search_index = FAISS.load_local(vs_path, embeddings)
return search_index
def refresh_vs_cache(kb_name: str):
"""
make vector store cache refreshed when next loading
"""
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
def list_kbs_from_db():
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
c.execute('''CREATE TABLE if not exists knowledge_base
(id INTEGER PRIMARY KEY AUTOINCREMENT,
kb_name TEXT,
vs_type TEXT,
embed_model TEXT,
file_count INTEGER,
create_time DATETIME) ''')
c.execute(f'''SELECT kb_name
FROM knowledge_base
WHERE file_count>0 ''')
kbs = [i[0] for i in c.fetchall() if i]
conn.commit()
conn.close()
return kbs
def add_kb_to_db(kb_name, vs_type, embed_model):
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
# Create table
c.execute('''CREATE TABLE if not exists knowledge_base
(id INTEGER PRIMARY KEY AUTOINCREMENT,
kb_name TEXT,
vs_type TEXT,
embed_model TEXT,
file_count INTEGER,
create_time DATETIME) ''')
# Insert a row of data
c.execute(f"""INSERT INTO knowledge_base
(kb_name, vs_type, embed_model, file_count, create_time)
VALUES
('{kb_name}','{vs_type}','{embed_model}',
0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""")
conn.commit()
conn.close()
def kb_exists(kb_name):
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
c.execute('''CREATE TABLE if not exists knowledge_base
(id INTEGER PRIMARY KEY AUTOINCREMENT,
kb_name TEXT,
vs_type TEXT,
embed_model TEXT,
file_count INTEGER,
create_time DATETIME) ''')
c.execute(f'''SELECT COUNT(*)
FROM knowledge_base
WHERE kb_name="{kb_name}" ''')
status = True if c.fetchone()[0] else False
conn.commit()
conn.close()
return status
def load_kb_from_db(kb_name):
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
c.execute('''CREATE TABLE if not exists knowledge_base
(id INTEGER PRIMARY KEY AUTOINCREMENT,
kb_name TEXT,
vs_type TEXT,
embed_model TEXT,
file_count INTEGER,
create_time DATETIME) ''')
c.execute(f'''SELECT kb_name, vs_type, embed_model
FROM knowledge_base
WHERE kb_name="{kb_name}" ''')
resp = c.fetchone()
if resp:
kb_name, vs_type, embed_model = resp
else:
kb_name, vs_type, embed_model = None, None, None
conn.commit()
conn.close()
return kb_name, vs_type, embed_model
def delete_kb_from_db(kb_name):
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
# delete kb from table knowledge_base
c.execute('''CREATE TABLE if not exists knowledge_base
(id INTEGER PRIMARY KEY AUTOINCREMENT,
kb_name TEXT,
vs_type TEXT,
embed_model TEXT,
file_count INTEGER,
create_time DATETIME) ''')
c.execute(f'''DELETE
FROM knowledge_base
WHERE kb_name="{kb_name}" ''')
# delete files in kb from table knowledge_files
c.execute('''CREATE TABLE if not exists knowledge_files
(id INTEGER PRIMARY KEY AUTOINCREMENT,
file_name TEXT,
file_ext TEXT,
kb_name TEXT,
document_loader_name TEXT,
text_splitter_name TEXT,
file_version INTEGER,
create_time DATETIME) ''')
# Insert a row of data
c.execute(f"""DELETE
FROM knowledge_files
WHERE kb_name="{kb_name}"
""")
conn.commit()
conn.close()
return True
def list_docs_from_db(kb_name):
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
c.execute('''CREATE TABLE if not exists knowledge_files
(id INTEGER PRIMARY KEY AUTOINCREMENT,
file_name TEXT,
file_ext TEXT,
kb_name TEXT,
document_loader_name TEXT,
text_splitter_name TEXT,
file_version INTEGER,
create_time DATETIME) ''')
c.execute(f'''SELECT file_name
FROM knowledge_files
WHERE kb_name="{kb_name}" ''')
kbs = [i[0] for i in c.fetchall() if i]
conn.commit()
conn.close()
return kbs
def add_doc_to_db(kb_file: KnowledgeFile):
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
# Create table
c.execute('''CREATE TABLE if not exists knowledge_files
(id INTEGER PRIMARY KEY AUTOINCREMENT,
file_name TEXT,
file_ext TEXT,
kb_name TEXT,
document_loader_name TEXT,
text_splitter_name TEXT,
file_version INTEGER,
create_time DATETIME) ''')
# Insert a row of data
# TODO: 同名文件添加至知识库时file_version增加
c.execute(f"""SELECT 1 FROM knowledge_files WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}" """)
record_exist = c.fetchone()
if record_exist is not None:
c.execute(f"""UPDATE knowledge_files
SET file_version = file_version + 1
WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}"
""")
else:
c.execute(f"""INSERT INTO knowledge_files
(file_name, file_ext, kb_name, document_loader_name, text_splitter_name, file_version, create_time)
VALUES
('{kb_file.filename}','{kb_file.ext}','{kb_file.kb_name}', '{kb_file.document_loader_name}',
'{kb_file.text_splitter_name}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""")
conn.commit()
conn.close()
def delete_file_from_db(kb_file: KnowledgeFile):
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
# delete files in kb from table knowledge_files
c.execute('''CREATE TABLE if not exists knowledge_files
(id INTEGER PRIMARY KEY AUTOINCREMENT,
file_name TEXT,
file_ext TEXT,
kb_name TEXT,
document_loader_name TEXT,
text_splitter_name TEXT,
file_version INTEGER,
create_time DATETIME) ''')
# Insert a row of data
c.execute(f"""DELETE
FROM knowledge_files
WHERE file_name="{kb_file.filename}"
AND kb_name="{kb_file.kb_name}"
""")
conn.commit()
conn.close()
return True
def doc_exists(kb_file: KnowledgeFile):
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
c.execute('''CREATE TABLE if not exists knowledge_files
(id INTEGER PRIMARY KEY AUTOINCREMENT,
file_name TEXT,
file_ext TEXT,
kb_name TEXT,
document_loader_name TEXT,
text_splitter_name TEXT,
file_version INTEGER,
create_time DATETIME) ''')
c.execute(f'''SELECT COUNT(*)
FROM knowledge_files
WHERE file_name="{kb_file.filename}"
AND kb_name="{kb_file.kb_name}" ''')
status = True if c.fetchone()[0] else False
conn.commit()
conn.close()
return status
def delete_doc_from_faiss(vector_store: FAISS, ids: List[str]):
overlapping = set(ids).intersection(vector_store.index_to_docstore_id.values())
if not overlapping:
raise ValueError("ids do not exist in the current object")
_reversed_index = {v: k for k, v in vector_store.index_to_docstore_id.items()}
index_to_delete = [_reversed_index[i] for i in ids]
vector_store.index.remove_ids(np.array(index_to_delete, dtype=np.int64))
for _id in index_to_delete:
del vector_store.index_to_docstore_id[_id]
# Remove items from docstore.
overlapping2 = set(ids).intersection(vector_store.docstore._dict)
if not overlapping2:
raise ValueError(f"Tried to delete ids that does not exist: {ids}")
for _id in ids:
vector_store.docstore._dict.pop(_id)
return vector_store
class KnowledgeBase:
def __init__(self,
knowledge_base_name: str,
vector_store_type: str = "faiss",
embed_model: str = EMBEDDING_MODEL,
):
self.kb_name = knowledge_base_name
if vector_store_type not in SUPPORTED_VS_TYPES:
raise ValueError(f"暂未支持向量库类型 {vector_store_type}")
self.vs_type = vector_store_type
if embed_model not in embedding_model_dict.keys():
raise ValueError(f"暂未支持embedding模型 {embed_model}")
self.embed_model = embed_model
self.kb_path = get_kb_path(self.kb_name)
self.doc_path = get_doc_path(self.kb_name)
if self.vs_type in ["faiss"]:
self.vs_path = get_vs_path(self.kb_name)
elif self.vs_type in ["milvus"]:
pass
def create(self):
if not os.path.exists(self.doc_path):
os.makedirs(self.doc_path)
if self.vs_type in ["faiss"]:
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
add_kb_to_db(self.kb_name, self.vs_type, self.embed_model)
elif self.vs_type in ["milvus"]:
# TODO: 创建milvus库
pass
return True
def recreate_vs(self):
if self.vs_type in ["faiss"]:
shutil.rmtree(self.vs_path)
self.create()
def add_doc(self, kb_file: KnowledgeFile):
docs = kb_file.file2text()
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
if self.vs_type in ["faiss"]:
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
vector_store = FAISS.load_local(self.vs_path, embeddings)
vector_store.add_documents(docs)
torch_gc()
else:
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表
torch_gc()
vector_store.save_local(self.vs_path)
add_doc_to_db(kb_file)
refresh_vs_cache(self.kb_name)
elif self.vs_type in ["milvus"]:
# TODO: 向milvus库中增加文件
pass
def delete_doc(self, kb_file: KnowledgeFile):
if os.path.exists(kb_file.filepath):
os.remove(kb_file.filepath)
if self.vs_type in ["faiss"]:
# TODO: 从FAISS向量库中删除文档
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
vector_store = FAISS.load_local(self.vs_path, embeddings)
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
if len(ids) == 0:
return None
print(len(ids))
vector_store = delete_doc_from_faiss(vector_store, ids)
vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
delete_file_from_db(kb_file)
return True
def exist_doc(self, file_name: str):
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,
filename=file_name))
def list_docs(self):
return list_docs_from_db(self.kb_name)
def search_docs(self,
query: str,
top_k: int = VECTOR_SEARCH_TOP_K,
embedding_device: str = EMBEDDING_DEVICE, ):
search_index = load_vector_store(self.kb_name,
self.embed_model,
embedding_device,
_VECTOR_STORE_TICKS.get(self.kb_name))
docs = search_index.similarity_search(query, k=top_k)
return docs
@classmethod
def exists(cls,
knowledge_base_name: str):
return kb_exists(knowledge_base_name)
@classmethod
def load(cls,
knowledge_base_name: str):
kb_name, vs_type, embed_model = load_kb_from_db(knowledge_base_name)
return cls(kb_name, vs_type, embed_model)
@classmethod
def delete(cls,
knowledge_base_name: str):
kb = cls.load(knowledge_base_name)
if kb.vs_type in ["faiss"]:
shutil.rmtree(kb.kb_path)
elif kb.vs_type in ["milvus"]:
# TODO: 删除milvus库
pass
status = delete_kb_from_db(knowledge_base_name)
return status
@classmethod
def list_kbs(cls):
return list_kbs_from_db()
if __name__ == "__main__":
# kb = KnowledgeBase("123", "faiss")
# kb.create()
kb = KnowledgeBase.load(knowledge_base_name="123")
kb.delete_doc(KnowledgeFile(knowledge_base_name="123", filename="README.md"))
print()

View File

@ -1,5 +1,8 @@
import os.path import os
from configs.model_config import KB_ROOT_PATH from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from configs.model_config import (embedding_model_dict, KB_ROOT_PATH)
from functools import lru_cache
def validate_kb_name(knowledge_base_id: str) -> bool: def validate_kb_name(knowledge_base_id: str) -> bool:
# 检查是否包含预期外的字符或路径攻击关键字 # 检查是否包含预期外的字符或路径攻击关键字
@ -17,4 +20,10 @@ def get_vs_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "vector_store") return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
def get_file_path(knowledge_base_name: str, doc_name: str): def get_file_path(knowledge_base_name: str, doc_name: str):
return os.path.join(get_doc_path(knowledge_base_name), doc_name) return os.path.join(get_doc_path(knowledge_base_name), doc_name)
@lru_cache(1)
def load_embeddings(model: str, device: str):
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
model_kwargs={'device': device})
return embeddings

View File

@ -7,11 +7,157 @@
import sys import sys
import os import os
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import LOG_PATH,controller_args,worker_args,server_args,parser
import subprocess import subprocess
import re import re
import logging
import argparse import argparse
LOG_PATH = "./logs/"
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.basicConfig(format=LOG_FORMAT)
parser = argparse.ArgumentParser()
#------multi worker-----------------
parser.add_argument('--model-path-address',
default="THUDM/chatglm2-6b@localhost@20002",
nargs="+",
type=str,
help="model path, host, and port, formatted as model-path@host@path")
#---------------controller-------------------------
parser.add_argument("--controller-host", type=str, default="localhost")
parser.add_argument("--controller-port", type=int, default=21001)
parser.add_argument(
"--dispatch-method",
type=str,
choices=["lottery", "shortest_queue"],
default="shortest_queue",
)
controller_args = ["controller-host","controller-port","dispatch-method"]
#----------------------worker------------------------------------------
parser.add_argument("--worker-host", type=str, default="localhost")
parser.add_argument("--worker-port", type=int, default=21002)
# parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
# parser.add_argument(
# "--controller-address", type=str, default="http://localhost:21001"
# )
parser.add_argument(
"--model-path",
type=str,
default="lmsys/vicuna-7b-v1.3",
help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
)
parser.add_argument(
"--revision",
type=str,
default="main",
help="Hugging Face Hub model revision identifier",
)
parser.add_argument(
"--device",
type=str,
choices=["cpu", "cuda", "mps", "xpu"],
default="cuda",
help="The device type",
)
parser.add_argument(
"--gpus",
type=str,
default="0",
help="A single GPU like 1 or multiple GPUs like 0,2",
)
parser.add_argument("--num-gpus", type=int, default=1)
parser.add_argument(
"--max-gpu-memory",
type=str,
help="The maximum memory per gpu. Use a string like '13Gib'",
)
parser.add_argument(
"--load-8bit", action="store_true", help="Use 8-bit quantization"
)
parser.add_argument(
"--cpu-offloading",
action="store_true",
help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU",
)
parser.add_argument(
"--gptq-ckpt",
type=str,
default=None,
help="Load quantized model. The path to the local GPTQ checkpoint.",
)
parser.add_argument(
"--gptq-wbits",
type=int,
default=16,
choices=[2, 3, 4, 8, 16],
help="#bits to use for quantization",
)
parser.add_argument(
"--gptq-groupsize",
type=int,
default=-1,
help="Groupsize to use for quantization; default uses full row.",
)
parser.add_argument(
"--gptq-act-order",
action="store_true",
help="Whether to apply the activation order GPTQ heuristic",
)
parser.add_argument(
"--model-names",
type=lambda s: s.split(","),
help="Optional display comma separated names",
)
parser.add_argument(
"--limit-worker-concurrency",
type=int,
default=5,
help="Limit the model concurrency to prevent OOM.",
)
parser.add_argument("--stream-interval", type=int, default=2)
parser.add_argument("--no-register", action="store_true")
worker_args = [
"worker-host","worker-port",
"model-path","revision","device","gpus","num-gpus",
"max-gpu-memory","load-8bit","cpu-offloading",
"gptq-ckpt","gptq-wbits","gptq-groupsize",
"gptq-act-order","model-names","limit-worker-concurrency",
"stream-interval","no-register",
"controller-address"
]
#-----------------openai server---------------------------
parser.add_argument("--server-host", type=str, default="localhost", help="host name")
parser.add_argument("--server-port", type=int, default=8001, help="port number")
parser.add_argument(
"--allow-credentials", action="store_true", help="allow credentials"
)
# parser.add_argument(
# "--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
# )
# parser.add_argument(
# "--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
# )
# parser.add_argument(
# "--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
# )
parser.add_argument(
"--api-keys",
type=lambda s: s.split(","),
help="Optional list of comma separated API keys",
)
server_args = ["server-host","server-port","allow-credentials","api-keys",
"controller-address"
]
args = parser.parse_args() args = parser.parse_args()
# 必须要加http//:否则InvalidSchema: No connection adapters were found # 必须要加http//:否则InvalidSchema: No connection adapters were found
args = argparse.Namespace(**vars(args),**{"controller-address":f"http://{args.controller_host}:{str(args.controller_port)}"}) args = argparse.Namespace(**vars(args),**{"controller-address":f"http://{args.controller_host}:{str(args.controller_port)}"})

View File

@ -118,7 +118,7 @@ def dialogue_page(api: ApiRequest):
chat_box.update_msg(text, 0, streaming=False) chat_box.update_msg(text, 0, streaming=False)
now = datetime.now() now = datetime.now()
cols[0].download_button( export_btn.download_button(
"Export", "Export",
"".join(chat_box.export2md(cur_chat_name)), "".join(chat_box.export2md(cur_chat_name)),
file_name=f"{now:%Y-%m-%d %H.%M}_{cur_chat_name}.md", file_name=f"{now:%Y-%m-%d %H.%M}_{cur_chat_name}.md",

View File

@ -1,6 +1,124 @@
from pydoc import Helper
import streamlit as st import streamlit as st
from webui_pages.utils import * from webui_pages.utils import *
import streamlit_antd_components as sac
from st_aggrid import AgGrid
from st_aggrid.grid_options_builder import GridOptionsBuilder
import pandas as pd
from server.knowledge_base.utils import get_file_path
from streamlit_chatbox import *
SENTENCE_SIZE = 100
def knowledge_base_page(api: ApiRequest): def knowledge_base_page(api: ApiRequest):
st.write(123) api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=True)
pass chat_box = ChatBox(session_key="kb_messages")
kb_list = api.list_knowledge_bases()
kb_docs = {}
for kb in kb_list:
kb_docs[kb] = api.list_kb_docs(kb)
with st.sidebar:
def on_new_kb():
if name := st.session_state.get("new_kb_name"):
if name in kb_list:
st.error(f"名为 {name} 的知识库已经存在!")
else:
ret = api.create_knowledge_base(name)
st.toast(ret["msg"])
def on_del_kb():
if name := st.session_state.get("new_kb_name"):
if name in kb_list:
ret = api.delete_knowledge_base(name)
st.toast(ret["msg"])
else:
st.error(f"名为 {name} 的知识库不存在!")
cols = st.columns([2, 1, 1])
new_kb_name = cols[0].text_input(
"新知识库名称",
placeholder="新知识库名称",
label_visibility="collapsed",
key="new_kb_name",
)
cols[1].button("新建", on_click=on_new_kb, disabled=not bool(new_kb_name))
cols[2].button("删除", on_click=on_del_kb, disabled=not bool(new_kb_name))
st.write("知识库:")
if kb_list:
try:
index = kb_list.index(st.session_state.get("cur_kb"))
except:
index = 0
kb = sac.buttons(
kb_list,
index,
format_func=lambda x: f"{x} ({len(kb_docs[x])})",
)
st.session_state["cur_kb"] = kb
sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True)
files = st.file_uploader("上传知识文件",
["docx", "txt", "md", "csv", "xlsx", "pdf"],
accept_multiple_files=True,
key="files",
)
if st.button(
"添加文件到知识库",
help="请先上传文件,再点击添加",
use_container_width=True,
disabled=len(files)==0,
):
for f in files:
ret = api.upload_kb_doc(f, kb)
if ret["code"] == 200:
st.toast(ret["msg"], icon="")
else:
st.toast(ret["msg"], icon="")
st.session_state.files = []
if st.button(
"重建知识库",
help="无需上传文件通过其它方式将文档拷贝到对应知识库content目录下点击本按钮即可重建知识库。",
use_container_width=True,
disabled=True,
):
progress = st.progress(0.0, "")
for d in api.recreate_vector_store(kb):
progress.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
if kb_list:
# 知识库详情
st.subheader(f"知识库 {kb} 详情")
df = pd.DataFrame([[i + 1, x] for i, x in enumerate(kb_docs[kb])], columns=["No", "文档名称"])
gb = GridOptionsBuilder.from_dataframe(df)
gb.configure_column("No", width=50)
gb.configure_selection()
cols = st.columns([1, 2])
with cols[0]:
docs = AgGrid(df, gb.build())
with cols[1]:
cols = st.columns(3)
selected_rows = docs.get("selected_rows", [])
cols = st.columns([2, 3, 2])
if selected_rows:
file_name = selected_rows[0]["文档名称"]
file_path = get_file_path(kb, file_name)
with open(file_path, "rb") as fp:
cols[0].download_button("下载选中文档", fp, file_name=file_name)
else:
cols[0].download_button("下载选中文档", "", disabled=True)
if cols[2].button("删除选中文档!", type="primary"):
for row in selected_rows:
ret = api.delete_kb_doc(kb, row["文档名称"])
st.toast(ret["msg"])
st.experimental_rerun()
st.write("本文档包含以下知识条目:(待定内容)")