Merge remote-tracking branch 'origin/dev_fastchat' into dev_fastchat

# Conflicts:
#	server/knowledge_base/kb_service/default_kb_service.py
#	server/knowledge_base/kb_service/milvus_kb_service.py
#	server/knowledge_base/knowledge_base_factory.py
This commit is contained in:
zqt 2023-08-08 14:06:13 +08:00
commit 41fd1acc9c
13 changed files with 347 additions and 446 deletions

View File

@ -15,9 +15,13 @@ starlette~=0.27.0
numpy~=1.24.4
pydantic~=1.10.11
unstructured[all-docs]
python-magic-bin; sys_platform == 'win32'
streamlit>=1.25.0
streamlit-option-menu
streamlit-antd-components
streamlit-chatbox>=1.1.6
httpx
faiss-cpu
pymilvus==2.1.3 # requires milvus==2.1.3

View File

@ -4,7 +4,11 @@ from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain.schema import (
BaseMessage,
messages_from_dict,
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
LLMResult
)
from langchain.chat_models import ChatOpenAI
@ -16,6 +20,18 @@ from langchain.callbacks.manager import (
from server.model.chat_openai_chain import OpenAiChatMsgDto, OpenAiMessageDto, BaseMessageDto
def _convert_dict_to_message(_dict: dict) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
return AIMessage(content=_dict["content"])
elif role == "system":
return SystemMessage(content=_dict["content"])
else:
return ChatMessage(content=_dict["content"], role=role)
def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[BaseMessage]:
"""
前端消息传输对象DTO转换为chat消息传输对象DTO
@ -25,7 +41,7 @@ def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[Bas
messages = []
for message_datum in message_data:
messages.append(message_datum.dict())
return messages_from_dict(messages)
return _convert_dict_to_message(messages)
class BaseChatOpenAIChain(Chain, ABC):

View File

@ -10,7 +10,8 @@ from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
from langchain.prompts import PromptTemplate
from server.knowledge_base.knowledge_base import KnowledgeBase
from server.knowledge_base.knowledge_base_factory import KBServiceFactory
from server.knowledge_base.kb_service.base import KBService
import json
@ -18,12 +19,12 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
):
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
async def knowledge_base_chat_iterator(query: str,
kb: KnowledgeBase,
kb: KBService,
top_k: int,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()

View File

@ -1,4 +1,4 @@
from .kb_api import list_kbs, create_kb, delete_kb
from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store
from .knowledge_base import KnowledgeBase
from .knowledge_file import KnowledgeFile
from .knowledge_base_factory import KBServiceFactory

View File

@ -1,13 +1,14 @@
import urllib
from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import validate_kb_name
from server.knowledge_base.knowledge_base import KnowledgeBase
from server.knowledge_base.knowledge_base_factory import KBServiceFactory
from server.knowledge_base.kb_service.base import list_kbs_from_db
from configs.model_config import EMBEDDING_MODEL
async def list_kbs():
# Get List of Knowledge Base
return ListResponse(data=KnowledgeBase.list_kbs())
return ListResponse(data=list_kbs_from_db())
async def create_kb(knowledge_base_name: str,
@ -19,11 +20,10 @@ async def create_kb(knowledge_base_name: str,
return BaseResponse(code=403, msg="Don't attack me")
if knowledge_base_name is None or knowledge_base_name.strip() == "":
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
if KnowledgeBase.exists(knowledge_base_name):
kb = KBServiceFactory.get_service(knowledge_base_name, "faiss")
if kb is not None:
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
kb = KnowledgeBase(knowledge_base_name=knowledge_base_name,
vector_store_type=vector_store_type,
embed_model=embed_model)
kb.create()
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
@ -34,10 +34,12 @@ async def delete_kb(knowledge_base_name: str):
return BaseResponse(code=403, msg="Don't attack me")
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
if not KnowledgeBase.exists(knowledge_base_name):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
status = KnowledgeBase.delete(knowledge_base_name)
status = kb.drop_kb()
if status:
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
else:

View File

@ -6,7 +6,9 @@ from server.knowledge_base.utils import (validate_kb_name)
from fastapi.responses import StreamingResponse
import json
from server.knowledge_base.knowledge_file import KnowledgeFile
from server.knowledge_base.knowledge_base import KnowledgeBase
from server.knowledge_base.knowledge_base_factory import KBServiceFactory
from server.knowledge_base.kb_service.base import SupportedVSType, list_docs_from_folder
from server.knowledge_base.kb_service.faiss_kb_service import refresh_vs_cache
async def list_docs(knowledge_base_name: str):
@ -14,10 +16,11 @@ async def list_docs(knowledge_base_name: str):
return ListResponse(code=403, msg="Don't attack me", data=[])
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
else:
all_doc_names = KnowledgeBase.load(knowledge_base_name=knowledge_base_name).list_docs()
all_doc_names = kb.list_docs()
return ListResponse(data=all_doc_names)
@ -28,11 +31,10 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
file_content = await file.read() # 读取上传文件的内容
kb_file = KnowledgeFile(filename=file.filename,
@ -63,10 +65,10 @@ async def delete_doc(knowledge_base_name: str,
return BaseResponse(code=403, msg="Don't attack me")
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
if not kb.exist_doc(doc_name):
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
kb_file = KnowledgeFile(filename=doc_name,
@ -92,21 +94,26 @@ async def recreate_vector_store(knowledge_base_name: str):
recreate vector store from the content.
this is usefull when user can copy files to content folder directly instead of upload through network.
'''
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
async def output(kb: KnowledgeBase):
kb.recreate_vs()
async def output(kb):
kb.clear_vs()
print(f"start to recreate vector store of {kb.kb_name}")
docs = kb.list_docs()
docs = list_docs_from_folder(knowledge_base_name)
print(docs)
for i, filename in enumerate(docs):
yield json.dumps({
"total": len(docs),
"finished": i,
"doc": filename,
})
kb_file = KnowledgeFile(filename=filename,
knowledge_base_name=kb.kb_name)
print(f"processing {kb_file.filepath} to vector store.")
kb.add_doc(kb_file)
yield json.dumps({
"total": len(docs),
"finished": i + 1,
"doc": filename,
})
if kb.vs_type == SupportedVSType.FAISS:
refresh_vs_cache(knowledge_base_name)
return StreamingResponse(output(kb), media_type="text/event-stream")

View File

@ -125,6 +125,10 @@ def list_docs_from_db(kb_name):
conn.close()
return kbs
def list_docs_from_folder(kb_name: str):
doc_path = get_doc_path(kb_name)
return [file for file in os.listdir(doc_path)
if os.path.isfile(os.path.join(doc_path, file))]
def add_doc_to_db(kb_file: KnowledgeFile):
conn = sqlite3.connect(DB_ROOT_PATH)

View File

@ -122,3 +122,4 @@ class FaissKBService(KBService):
def do_clear_vs(self):
shutil.rmtree(self.vs_path)
os.makedirs(self.vs_path)

View File

@ -1,407 +0,0 @@
import os
import sqlite3
import datetime
import shutil
from langchain.vectorstores import FAISS
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from configs.model_config import (embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE,
DB_ROOT_PATH, VECTOR_SEARCH_TOP_K, CACHED_VS_NUM)
from server.utils import torch_gc
from functools import lru_cache
from server.knowledge_base.knowledge_file import KnowledgeFile
from typing import List
import numpy as np
from server.knowledge_base.utils import (get_kb_path, get_doc_path, get_vs_path)
SUPPORTED_VS_TYPES = ["faiss", "milvus"]
_VECTOR_STORE_TICKS = {}
@lru_cache(1)
def load_embeddings(model: str, device: str):
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
model_kwargs={'device': device})
return embeddings
@lru_cache(CACHED_VS_NUM)
def load_vector_store(
knowledge_base_name: str,
embedding_model: str,
embedding_device: str,
tick: int, # tick will be changed by upload_doc etc. and make cache refreshed.
):
print(f"loading vector store in '{knowledge_base_name}' with '{embedding_model}' embeddings.")
embeddings = load_embeddings(embedding_model, embedding_device)
vs_path = get_vs_path(knowledge_base_name)
search_index = FAISS.load_local(vs_path, embeddings)
return search_index
def refresh_vs_cache(kb_name: str):
"""
make vector store cache refreshed when next loading
"""
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
def list_kbs_from_db():
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
c.execute('''CREATE TABLE if not exists knowledge_base
(id INTEGER PRIMARY KEY AUTOINCREMENT,
kb_name TEXT,
vs_type TEXT,
embed_model TEXT,
file_count INTEGER,
create_time DATETIME) ''')
c.execute(f'''SELECT kb_name
FROM knowledge_base
WHERE file_count>0 ''')
kbs = [i[0] for i in c.fetchall() if i]
conn.commit()
conn.close()
return kbs
def add_kb_to_db(kb_name, vs_type, embed_model):
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
# Create table
c.execute('''CREATE TABLE if not exists knowledge_base
(id INTEGER PRIMARY KEY AUTOINCREMENT,
kb_name TEXT,
vs_type TEXT,
embed_model TEXT,
file_count INTEGER,
create_time DATETIME) ''')
# Insert a row of data
c.execute(f"""INSERT INTO knowledge_base
(kb_name, vs_type, embed_model, file_count, create_time)
VALUES
('{kb_name}','{vs_type}','{embed_model}',
0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""")
conn.commit()
conn.close()
def kb_exists(kb_name):
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
c.execute('''CREATE TABLE if not exists knowledge_base
(id INTEGER PRIMARY KEY AUTOINCREMENT,
kb_name TEXT,
vs_type TEXT,
embed_model TEXT,
file_count INTEGER,
create_time DATETIME) ''')
c.execute(f'''SELECT COUNT(*)
FROM knowledge_base
WHERE kb_name="{kb_name}" ''')
status = True if c.fetchone()[0] else False
conn.commit()
conn.close()
return status
def load_kb_from_db(kb_name):
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
c.execute('''CREATE TABLE if not exists knowledge_base
(id INTEGER PRIMARY KEY AUTOINCREMENT,
kb_name TEXT,
vs_type TEXT,
embed_model TEXT,
file_count INTEGER,
create_time DATETIME) ''')
c.execute(f'''SELECT kb_name, vs_type, embed_model
FROM knowledge_base
WHERE kb_name="{kb_name}" ''')
resp = c.fetchone()
if resp:
kb_name, vs_type, embed_model = resp
else:
kb_name, vs_type, embed_model = None, None, None
conn.commit()
conn.close()
return kb_name, vs_type, embed_model
def delete_kb_from_db(kb_name):
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
# delete kb from table knowledge_base
c.execute('''CREATE TABLE if not exists knowledge_base
(id INTEGER PRIMARY KEY AUTOINCREMENT,
kb_name TEXT,
vs_type TEXT,
embed_model TEXT,
file_count INTEGER,
create_time DATETIME) ''')
c.execute(f'''DELETE
FROM knowledge_base
WHERE kb_name="{kb_name}" ''')
# delete files in kb from table knowledge_files
c.execute('''CREATE TABLE if not exists knowledge_files
(id INTEGER PRIMARY KEY AUTOINCREMENT,
file_name TEXT,
file_ext TEXT,
kb_name TEXT,
document_loader_name TEXT,
text_splitter_name TEXT,
file_version INTEGER,
create_time DATETIME) ''')
# Insert a row of data
c.execute(f"""DELETE
FROM knowledge_files
WHERE kb_name="{kb_name}"
""")
conn.commit()
conn.close()
return True
def list_docs_from_db(kb_name):
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
c.execute('''CREATE TABLE if not exists knowledge_files
(id INTEGER PRIMARY KEY AUTOINCREMENT,
file_name TEXT,
file_ext TEXT,
kb_name TEXT,
document_loader_name TEXT,
text_splitter_name TEXT,
file_version INTEGER,
create_time DATETIME) ''')
c.execute(f'''SELECT file_name
FROM knowledge_files
WHERE kb_name="{kb_name}" ''')
kbs = [i[0] for i in c.fetchall() if i]
conn.commit()
conn.close()
return kbs
def add_doc_to_db(kb_file: KnowledgeFile):
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
# Create table
c.execute('''CREATE TABLE if not exists knowledge_files
(id INTEGER PRIMARY KEY AUTOINCREMENT,
file_name TEXT,
file_ext TEXT,
kb_name TEXT,
document_loader_name TEXT,
text_splitter_name TEXT,
file_version INTEGER,
create_time DATETIME) ''')
# Insert a row of data
# TODO: 同名文件添加至知识库时file_version增加
c.execute(f"""SELECT 1 FROM knowledge_files WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}" """)
record_exist = c.fetchone()
if record_exist is not None:
c.execute(f"""UPDATE knowledge_files
SET file_version = file_version + 1
WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}"
""")
else:
c.execute(f"""INSERT INTO knowledge_files
(file_name, file_ext, kb_name, document_loader_name, text_splitter_name, file_version, create_time)
VALUES
('{kb_file.filename}','{kb_file.ext}','{kb_file.kb_name}', '{kb_file.document_loader_name}',
'{kb_file.text_splitter_name}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""")
conn.commit()
conn.close()
def delete_file_from_db(kb_file: KnowledgeFile):
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
# delete files in kb from table knowledge_files
c.execute('''CREATE TABLE if not exists knowledge_files
(id INTEGER PRIMARY KEY AUTOINCREMENT,
file_name TEXT,
file_ext TEXT,
kb_name TEXT,
document_loader_name TEXT,
text_splitter_name TEXT,
file_version INTEGER,
create_time DATETIME) ''')
# Insert a row of data
c.execute(f"""DELETE
FROM knowledge_files
WHERE file_name="{kb_file.filename}"
AND kb_name="{kb_file.kb_name}"
""")
conn.commit()
conn.close()
return True
def doc_exists(kb_file: KnowledgeFile):
conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor()
c.execute('''CREATE TABLE if not exists knowledge_files
(id INTEGER PRIMARY KEY AUTOINCREMENT,
file_name TEXT,
file_ext TEXT,
kb_name TEXT,
document_loader_name TEXT,
text_splitter_name TEXT,
file_version INTEGER,
create_time DATETIME) ''')
c.execute(f'''SELECT COUNT(*)
FROM knowledge_files
WHERE file_name="{kb_file.filename}"
AND kb_name="{kb_file.kb_name}" ''')
status = True if c.fetchone()[0] else False
conn.commit()
conn.close()
return status
def delete_doc_from_faiss(vector_store: FAISS, ids: List[str]):
overlapping = set(ids).intersection(vector_store.index_to_docstore_id.values())
if not overlapping:
raise ValueError("ids do not exist in the current object")
_reversed_index = {v: k for k, v in vector_store.index_to_docstore_id.items()}
index_to_delete = [_reversed_index[i] for i in ids]
vector_store.index.remove_ids(np.array(index_to_delete, dtype=np.int64))
for _id in index_to_delete:
del vector_store.index_to_docstore_id[_id]
# Remove items from docstore.
overlapping2 = set(ids).intersection(vector_store.docstore._dict)
if not overlapping2:
raise ValueError(f"Tried to delete ids that does not exist: {ids}")
for _id in ids:
vector_store.docstore._dict.pop(_id)
return vector_store
class KnowledgeBase:
def __init__(self,
knowledge_base_name: str,
vector_store_type: str = "faiss",
embed_model: str = EMBEDDING_MODEL,
):
self.kb_name = knowledge_base_name
if vector_store_type not in SUPPORTED_VS_TYPES:
raise ValueError(f"暂未支持向量库类型 {vector_store_type}")
self.vs_type = vector_store_type
if embed_model not in embedding_model_dict.keys():
raise ValueError(f"暂未支持embedding模型 {embed_model}")
self.embed_model = embed_model
self.kb_path = get_kb_path(self.kb_name)
self.doc_path = get_doc_path(self.kb_name)
if self.vs_type in ["faiss"]:
self.vs_path = get_vs_path(self.kb_name)
elif self.vs_type in ["milvus"]:
pass
def create(self):
if not os.path.exists(self.doc_path):
os.makedirs(self.doc_path)
if self.vs_type in ["faiss"]:
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
add_kb_to_db(self.kb_name, self.vs_type, self.embed_model)
elif self.vs_type in ["milvus"]:
# TODO: 创建milvus库
pass
return True
def recreate_vs(self):
if self.vs_type in ["faiss"]:
shutil.rmtree(self.vs_path)
self.create()
def add_doc(self, kb_file: KnowledgeFile):
docs = kb_file.file2text()
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
if self.vs_type in ["faiss"]:
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
vector_store = FAISS.load_local(self.vs_path, embeddings)
vector_store.add_documents(docs)
torch_gc()
else:
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表
torch_gc()
vector_store.save_local(self.vs_path)
add_doc_to_db(kb_file)
refresh_vs_cache(self.kb_name)
elif self.vs_type in ["milvus"]:
# TODO: 向milvus库中增加文件
pass
def delete_doc(self, kb_file: KnowledgeFile):
if os.path.exists(kb_file.filepath):
os.remove(kb_file.filepath)
if self.vs_type in ["faiss"]:
# TODO: 从FAISS向量库中删除文档
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
vector_store = FAISS.load_local(self.vs_path, embeddings)
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
if len(ids) == 0:
return None
print(len(ids))
vector_store = delete_doc_from_faiss(vector_store, ids)
vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
delete_file_from_db(kb_file)
return True
def exist_doc(self, file_name: str):
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,
filename=file_name))
def list_docs(self):
return list_docs_from_db(self.kb_name)
def search_docs(self,
query: str,
top_k: int = VECTOR_SEARCH_TOP_K,
embedding_device: str = EMBEDDING_DEVICE, ):
search_index = load_vector_store(self.kb_name,
self.embed_model,
embedding_device,
_VECTOR_STORE_TICKS.get(self.kb_name))
docs = search_index.similarity_search(query, k=top_k)
return docs
@classmethod
def exists(cls,
knowledge_base_name: str):
return kb_exists(knowledge_base_name)
@classmethod
def load(cls,
knowledge_base_name: str):
kb_name, vs_type, embed_model = load_kb_from_db(knowledge_base_name)
return cls(kb_name, vs_type, embed_model)
@classmethod
def delete(cls,
knowledge_base_name: str):
kb = cls.load(knowledge_base_name)
if kb.vs_type in ["faiss"]:
shutil.rmtree(kb.kb_path)
elif kb.vs_type in ["milvus"]:
# TODO: 删除milvus库
pass
status = delete_kb_from_db(knowledge_base_name)
return status
@classmethod
def list_kbs(cls):
return list_kbs_from_db()
if __name__ == "__main__":
# kb = KnowledgeBase("123", "faiss")
# kb.create()
kb = KnowledgeBase.load(knowledge_base_name="123")
kb.delete_doc(KnowledgeFile(knowledge_base_name="123", filename="README.md"))
print()

View File

@ -1,5 +1,8 @@
import os.path
from configs.model_config import KB_ROOT_PATH
import os
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from configs.model_config import (embedding_model_dict, KB_ROOT_PATH)
from functools import lru_cache
def validate_kb_name(knowledge_base_id: str) -> bool:
# 检查是否包含预期外的字符或路径攻击关键字
@ -18,3 +21,9 @@ def get_vs_path(knowledge_base_name: str):
def get_file_path(knowledge_base_name: str, doc_name: str):
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
@lru_cache(1)
def load_embeddings(model: str, device: str):
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
model_kwargs={'device': device})
return embeddings

View File

@ -7,11 +7,157 @@
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import LOG_PATH,controller_args,worker_args,server_args,parser
import subprocess
import re
import logging
import argparse
LOG_PATH = "./logs/"
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.basicConfig(format=LOG_FORMAT)
parser = argparse.ArgumentParser()
#------multi worker-----------------
parser.add_argument('--model-path-address',
default="THUDM/chatglm2-6b@localhost@20002",
nargs="+",
type=str,
help="model path, host, and port, formatted as model-path@host@path")
#---------------controller-------------------------
parser.add_argument("--controller-host", type=str, default="localhost")
parser.add_argument("--controller-port", type=int, default=21001)
parser.add_argument(
"--dispatch-method",
type=str,
choices=["lottery", "shortest_queue"],
default="shortest_queue",
)
controller_args = ["controller-host","controller-port","dispatch-method"]
#----------------------worker------------------------------------------
parser.add_argument("--worker-host", type=str, default="localhost")
parser.add_argument("--worker-port", type=int, default=21002)
# parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
# parser.add_argument(
# "--controller-address", type=str, default="http://localhost:21001"
# )
parser.add_argument(
"--model-path",
type=str,
default="lmsys/vicuna-7b-v1.3",
help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
)
parser.add_argument(
"--revision",
type=str,
default="main",
help="Hugging Face Hub model revision identifier",
)
parser.add_argument(
"--device",
type=str,
choices=["cpu", "cuda", "mps", "xpu"],
default="cuda",
help="The device type",
)
parser.add_argument(
"--gpus",
type=str,
default="0",
help="A single GPU like 1 or multiple GPUs like 0,2",
)
parser.add_argument("--num-gpus", type=int, default=1)
parser.add_argument(
"--max-gpu-memory",
type=str,
help="The maximum memory per gpu. Use a string like '13Gib'",
)
parser.add_argument(
"--load-8bit", action="store_true", help="Use 8-bit quantization"
)
parser.add_argument(
"--cpu-offloading",
action="store_true",
help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU",
)
parser.add_argument(
"--gptq-ckpt",
type=str,
default=None,
help="Load quantized model. The path to the local GPTQ checkpoint.",
)
parser.add_argument(
"--gptq-wbits",
type=int,
default=16,
choices=[2, 3, 4, 8, 16],
help="#bits to use for quantization",
)
parser.add_argument(
"--gptq-groupsize",
type=int,
default=-1,
help="Groupsize to use for quantization; default uses full row.",
)
parser.add_argument(
"--gptq-act-order",
action="store_true",
help="Whether to apply the activation order GPTQ heuristic",
)
parser.add_argument(
"--model-names",
type=lambda s: s.split(","),
help="Optional display comma separated names",
)
parser.add_argument(
"--limit-worker-concurrency",
type=int,
default=5,
help="Limit the model concurrency to prevent OOM.",
)
parser.add_argument("--stream-interval", type=int, default=2)
parser.add_argument("--no-register", action="store_true")
worker_args = [
"worker-host","worker-port",
"model-path","revision","device","gpus","num-gpus",
"max-gpu-memory","load-8bit","cpu-offloading",
"gptq-ckpt","gptq-wbits","gptq-groupsize",
"gptq-act-order","model-names","limit-worker-concurrency",
"stream-interval","no-register",
"controller-address"
]
#-----------------openai server---------------------------
parser.add_argument("--server-host", type=str, default="localhost", help="host name")
parser.add_argument("--server-port", type=int, default=8001, help="port number")
parser.add_argument(
"--allow-credentials", action="store_true", help="allow credentials"
)
# parser.add_argument(
# "--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
# )
# parser.add_argument(
# "--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
# )
# parser.add_argument(
# "--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
# )
parser.add_argument(
"--api-keys",
type=lambda s: s.split(","),
help="Optional list of comma separated API keys",
)
server_args = ["server-host","server-port","allow-credentials","api-keys",
"controller-address"
]
args = parser.parse_args()
# 必须要加http//:否则InvalidSchema: No connection adapters were found
args = argparse.Namespace(**vars(args),**{"controller-address":f"http://{args.controller_host}:{str(args.controller_port)}"})

View File

@ -118,7 +118,7 @@ def dialogue_page(api: ApiRequest):
chat_box.update_msg(text, 0, streaming=False)
now = datetime.now()
cols[0].download_button(
export_btn.download_button(
"Export",
"".join(chat_box.export2md(cur_chat_name)),
file_name=f"{now:%Y-%m-%d %H.%M}_{cur_chat_name}.md",

View File

@ -1,6 +1,124 @@
from pydoc import Helper
import streamlit as st
from webui_pages.utils import *
import streamlit_antd_components as sac
from st_aggrid import AgGrid
from st_aggrid.grid_options_builder import GridOptionsBuilder
import pandas as pd
from server.knowledge_base.utils import get_file_path
from streamlit_chatbox import *
SENTENCE_SIZE = 100
def knowledge_base_page(api: ApiRequest):
st.write(123)
pass
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=True)
chat_box = ChatBox(session_key="kb_messages")
kb_list = api.list_knowledge_bases()
kb_docs = {}
for kb in kb_list:
kb_docs[kb] = api.list_kb_docs(kb)
with st.sidebar:
def on_new_kb():
if name := st.session_state.get("new_kb_name"):
if name in kb_list:
st.error(f"名为 {name} 的知识库已经存在!")
else:
ret = api.create_knowledge_base(name)
st.toast(ret["msg"])
def on_del_kb():
if name := st.session_state.get("new_kb_name"):
if name in kb_list:
ret = api.delete_knowledge_base(name)
st.toast(ret["msg"])
else:
st.error(f"名为 {name} 的知识库不存在!")
cols = st.columns([2, 1, 1])
new_kb_name = cols[0].text_input(
"新知识库名称",
placeholder="新知识库名称",
label_visibility="collapsed",
key="new_kb_name",
)
cols[1].button("新建", on_click=on_new_kb, disabled=not bool(new_kb_name))
cols[2].button("删除", on_click=on_del_kb, disabled=not bool(new_kb_name))
st.write("知识库:")
if kb_list:
try:
index = kb_list.index(st.session_state.get("cur_kb"))
except:
index = 0
kb = sac.buttons(
kb_list,
index,
format_func=lambda x: f"{x} ({len(kb_docs[x])})",
)
st.session_state["cur_kb"] = kb
sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True)
files = st.file_uploader("上传知识文件",
["docx", "txt", "md", "csv", "xlsx", "pdf"],
accept_multiple_files=True,
key="files",
)
if st.button(
"添加文件到知识库",
help="请先上传文件,再点击添加",
use_container_width=True,
disabled=len(files)==0,
):
for f in files:
ret = api.upload_kb_doc(f, kb)
if ret["code"] == 200:
st.toast(ret["msg"], icon="")
else:
st.toast(ret["msg"], icon="")
st.session_state.files = []
if st.button(
"重建知识库",
help="无需上传文件通过其它方式将文档拷贝到对应知识库content目录下点击本按钮即可重建知识库。",
use_container_width=True,
disabled=True,
):
progress = st.progress(0.0, "")
for d in api.recreate_vector_store(kb):
progress.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
if kb_list:
# 知识库详情
st.subheader(f"知识库 {kb} 详情")
df = pd.DataFrame([[i + 1, x] for i, x in enumerate(kb_docs[kb])], columns=["No", "文档名称"])
gb = GridOptionsBuilder.from_dataframe(df)
gb.configure_column("No", width=50)
gb.configure_selection()
cols = st.columns([1, 2])
with cols[0]:
docs = AgGrid(df, gb.build())
with cols[1]:
cols = st.columns(3)
selected_rows = docs.get("selected_rows", [])
cols = st.columns([2, 3, 2])
if selected_rows:
file_name = selected_rows[0]["文档名称"]
file_path = get_file_path(kb, file_name)
with open(file_path, "rb") as fp:
cols[0].download_button("下载选中文档", fp, file_name=file_name)
else:
cols[0].download_button("下载选中文档", "", disabled=True)
if cols[2].button("删除选中文档!", type="primary"):
for row in selected_rows:
ret = api.delete_kb_doc(kb, row["文档名称"])
st.toast(ret["msg"])
st.experimental_rerun()
st.write("本文档包含以下知识条目:(待定内容)")