Merge remote-tracking branch 'origin/dev_fastchat' into dev_fastchat
# Conflicts: # server/knowledge_base/kb_service/default_kb_service.py # server/knowledge_base/kb_service/milvus_kb_service.py # server/knowledge_base/knowledge_base_factory.py
This commit is contained in:
commit
41fd1acc9c
|
|
@ -15,9 +15,13 @@ starlette~=0.27.0
|
|||
numpy~=1.24.4
|
||||
pydantic~=1.10.11
|
||||
unstructured[all-docs]
|
||||
python-magic-bin; sys_platform == 'win32'
|
||||
|
||||
streamlit>=1.25.0
|
||||
streamlit-option-menu
|
||||
streamlit-antd-components
|
||||
streamlit-chatbox>=1.1.6
|
||||
httpx
|
||||
|
||||
faiss-cpu
|
||||
pymilvus==2.1.3 # requires milvus==2.1.3
|
||||
|
|
|
|||
|
|
@ -4,7 +4,11 @@ from typing import Any, Dict, List, Optional
|
|||
from langchain.chains.base import Chain
|
||||
from langchain.schema import (
|
||||
BaseMessage,
|
||||
messages_from_dict,
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
LLMResult
|
||||
)
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
|
|
@ -16,6 +20,18 @@ from langchain.callbacks.manager import (
|
|||
from server.model.chat_openai_chain import OpenAiChatMsgDto, OpenAiMessageDto, BaseMessageDto
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: dict) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["content"])
|
||||
elif role == "assistant":
|
||||
return AIMessage(content=_dict["content"])
|
||||
elif role == "system":
|
||||
return SystemMessage(content=_dict["content"])
|
||||
else:
|
||||
return ChatMessage(content=_dict["content"], role=role)
|
||||
|
||||
|
||||
def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[BaseMessage]:
|
||||
"""
|
||||
前端消息传输对象DTO转换为chat消息传输对象DTO
|
||||
|
|
@ -25,7 +41,7 @@ def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[Bas
|
|||
messages = []
|
||||
for message_datum in message_data:
|
||||
messages.append(message_datum.dict())
|
||||
return messages_from_dict(messages)
|
||||
return _convert_dict_to_message(messages)
|
||||
|
||||
|
||||
class BaseChatOpenAIChain(Chain, ABC):
|
||||
|
|
|
|||
|
|
@ -10,7 +10,8 @@ from langchain.callbacks import AsyncIteratorCallbackHandler
|
|||
from typing import AsyncIterable
|
||||
import asyncio
|
||||
from langchain.prompts import PromptTemplate
|
||||
from server.knowledge_base.knowledge_base import KnowledgeBase
|
||||
from server.knowledge_base.knowledge_base_factory import KBServiceFactory
|
||||
from server.knowledge_base.kb_service.base import KBService
|
||||
import json
|
||||
|
||||
|
||||
|
|
@ -18,12 +19,12 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
|||
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
):
|
||||
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
|
||||
|
||||
async def knowledge_base_chat_iterator(query: str,
|
||||
kb: KnowledgeBase,
|
||||
kb: KBService,
|
||||
top_k: int,
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from .kb_api import list_kbs, create_kb, delete_kb
|
||||
from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store
|
||||
from .knowledge_base import KnowledgeBase
|
||||
from .knowledge_file import KnowledgeFile
|
||||
from .knowledge_base_factory import KBServiceFactory
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
import urllib
|
||||
from server.utils import BaseResponse, ListResponse
|
||||
from server.knowledge_base.utils import validate_kb_name
|
||||
from server.knowledge_base.knowledge_base import KnowledgeBase
|
||||
from server.knowledge_base.knowledge_base_factory import KBServiceFactory
|
||||
from server.knowledge_base.kb_service.base import list_kbs_from_db
|
||||
from configs.model_config import EMBEDDING_MODEL
|
||||
|
||||
|
||||
async def list_kbs():
|
||||
# Get List of Knowledge Base
|
||||
return ListResponse(data=KnowledgeBase.list_kbs())
|
||||
return ListResponse(data=list_kbs_from_db())
|
||||
|
||||
|
||||
async def create_kb(knowledge_base_name: str,
|
||||
|
|
@ -19,11 +20,10 @@ async def create_kb(knowledge_base_name: str,
|
|||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
if knowledge_base_name is None or knowledge_base_name.strip() == "":
|
||||
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
|
||||
if KnowledgeBase.exists(knowledge_base_name):
|
||||
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, "faiss")
|
||||
if kb is not None:
|
||||
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
|
||||
kb = KnowledgeBase(knowledge_base_name=knowledge_base_name,
|
||||
vector_store_type=vector_store_type,
|
||||
embed_model=embed_model)
|
||||
kb.create()
|
||||
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
|
||||
|
||||
|
|
@ -34,10 +34,12 @@ async def delete_kb(knowledge_base_name: str):
|
|||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
||||
|
||||
if not KnowledgeBase.exists(knowledge_base_name):
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
status = KnowledgeBase.delete(knowledge_base_name)
|
||||
status = kb.drop_kb()
|
||||
if status:
|
||||
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,9 @@ from server.knowledge_base.utils import (validate_kb_name)
|
|||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
from server.knowledge_base.knowledge_file import KnowledgeFile
|
||||
from server.knowledge_base.knowledge_base import KnowledgeBase
|
||||
from server.knowledge_base.knowledge_base_factory import KBServiceFactory
|
||||
from server.knowledge_base.kb_service.base import SupportedVSType, list_docs_from_folder
|
||||
from server.knowledge_base.kb_service.faiss_kb_service import refresh_vs_cache
|
||||
|
||||
|
||||
async def list_docs(knowledge_base_name: str):
|
||||
|
|
@ -14,10 +16,11 @@ async def list_docs(knowledge_base_name: str):
|
|||
return ListResponse(code=403, msg="Don't attack me", data=[])
|
||||
|
||||
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
||||
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
|
||||
else:
|
||||
all_doc_names = KnowledgeBase.load(knowledge_base_name=knowledge_base_name).list_docs()
|
||||
all_doc_names = kb.list_docs()
|
||||
return ListResponse(data=all_doc_names)
|
||||
|
||||
|
||||
|
|
@ -28,11 +31,10 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
|
|||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
|
||||
|
||||
file_content = await file.read() # 读取上传文件的内容
|
||||
|
||||
kb_file = KnowledgeFile(filename=file.filename,
|
||||
|
|
@ -63,10 +65,10 @@ async def delete_doc(knowledge_base_name: str,
|
|||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
||||
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
|
||||
if not kb.exist_doc(doc_name):
|
||||
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
|
||||
kb_file = KnowledgeFile(filename=doc_name,
|
||||
|
|
@ -92,21 +94,26 @@ async def recreate_vector_store(knowledge_base_name: str):
|
|||
recreate vector store from the content.
|
||||
this is usefull when user can copy files to content folder directly instead of upload through network.
|
||||
'''
|
||||
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
async def output(kb: KnowledgeBase):
|
||||
kb.recreate_vs()
|
||||
async def output(kb):
|
||||
kb.clear_vs()
|
||||
print(f"start to recreate vector store of {kb.kb_name}")
|
||||
docs = kb.list_docs()
|
||||
docs = list_docs_from_folder(knowledge_base_name)
|
||||
print(docs)
|
||||
for i, filename in enumerate(docs):
|
||||
yield json.dumps({
|
||||
"total": len(docs),
|
||||
"finished": i,
|
||||
"doc": filename,
|
||||
})
|
||||
kb_file = KnowledgeFile(filename=filename,
|
||||
knowledge_base_name=kb.kb_name)
|
||||
print(f"processing {kb_file.filepath} to vector store.")
|
||||
kb.add_doc(kb_file)
|
||||
yield json.dumps({
|
||||
"total": len(docs),
|
||||
"finished": i + 1,
|
||||
"doc": filename,
|
||||
})
|
||||
if kb.vs_type == SupportedVSType.FAISS:
|
||||
refresh_vs_cache(knowledge_base_name)
|
||||
|
||||
return StreamingResponse(output(kb), media_type="text/event-stream")
|
||||
|
|
|
|||
|
|
@ -125,6 +125,10 @@ def list_docs_from_db(kb_name):
|
|||
conn.close()
|
||||
return kbs
|
||||
|
||||
def list_docs_from_folder(kb_name: str):
|
||||
doc_path = get_doc_path(kb_name)
|
||||
return [file for file in os.listdir(doc_path)
|
||||
if os.path.isfile(os.path.join(doc_path, file))]
|
||||
|
||||
def add_doc_to_db(kb_file: KnowledgeFile):
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
|
|
|
|||
|
|
@ -122,3 +122,4 @@ class FaissKBService(KBService):
|
|||
|
||||
def do_clear_vs(self):
|
||||
shutil.rmtree(self.vs_path)
|
||||
os.makedirs(self.vs_path)
|
||||
|
|
|
|||
|
|
@ -1,407 +0,0 @@
|
|||
import os
|
||||
import sqlite3
|
||||
import datetime
|
||||
import shutil
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from configs.model_config import (embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE,
|
||||
DB_ROOT_PATH, VECTOR_SEARCH_TOP_K, CACHED_VS_NUM)
|
||||
from server.utils import torch_gc
|
||||
from functools import lru_cache
|
||||
from server.knowledge_base.knowledge_file import KnowledgeFile
|
||||
from typing import List
|
||||
import numpy as np
|
||||
from server.knowledge_base.utils import (get_kb_path, get_doc_path, get_vs_path)
|
||||
|
||||
SUPPORTED_VS_TYPES = ["faiss", "milvus"]
|
||||
|
||||
_VECTOR_STORE_TICKS = {}
|
||||
|
||||
@lru_cache(1)
|
||||
def load_embeddings(model: str, device: str):
|
||||
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
|
||||
model_kwargs={'device': device})
|
||||
return embeddings
|
||||
|
||||
|
||||
@lru_cache(CACHED_VS_NUM)
|
||||
def load_vector_store(
|
||||
knowledge_base_name: str,
|
||||
embedding_model: str,
|
||||
embedding_device: str,
|
||||
tick: int, # tick will be changed by upload_doc etc. and make cache refreshed.
|
||||
):
|
||||
print(f"loading vector store in '{knowledge_base_name}' with '{embedding_model}' embeddings.")
|
||||
embeddings = load_embeddings(embedding_model, embedding_device)
|
||||
vs_path = get_vs_path(knowledge_base_name)
|
||||
search_index = FAISS.load_local(vs_path, embeddings)
|
||||
return search_index
|
||||
|
||||
|
||||
def refresh_vs_cache(kb_name: str):
|
||||
"""
|
||||
make vector store cache refreshed when next loading
|
||||
"""
|
||||
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
|
||||
|
||||
|
||||
def list_kbs_from_db():
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
c.execute('''CREATE TABLE if not exists knowledge_base
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
kb_name TEXT,
|
||||
vs_type TEXT,
|
||||
embed_model TEXT,
|
||||
file_count INTEGER,
|
||||
create_time DATETIME) ''')
|
||||
c.execute(f'''SELECT kb_name
|
||||
FROM knowledge_base
|
||||
WHERE file_count>0 ''')
|
||||
kbs = [i[0] for i in c.fetchall() if i]
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return kbs
|
||||
|
||||
|
||||
def add_kb_to_db(kb_name, vs_type, embed_model):
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
# Create table
|
||||
c.execute('''CREATE TABLE if not exists knowledge_base
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
kb_name TEXT,
|
||||
vs_type TEXT,
|
||||
embed_model TEXT,
|
||||
file_count INTEGER,
|
||||
create_time DATETIME) ''')
|
||||
# Insert a row of data
|
||||
c.execute(f"""INSERT INTO knowledge_base
|
||||
(kb_name, vs_type, embed_model, file_count, create_time)
|
||||
VALUES
|
||||
('{kb_name}','{vs_type}','{embed_model}',
|
||||
0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def kb_exists(kb_name):
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
c.execute('''CREATE TABLE if not exists knowledge_base
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
kb_name TEXT,
|
||||
vs_type TEXT,
|
||||
embed_model TEXT,
|
||||
file_count INTEGER,
|
||||
create_time DATETIME) ''')
|
||||
c.execute(f'''SELECT COUNT(*)
|
||||
FROM knowledge_base
|
||||
WHERE kb_name="{kb_name}" ''')
|
||||
status = True if c.fetchone()[0] else False
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return status
|
||||
|
||||
|
||||
def load_kb_from_db(kb_name):
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
c.execute('''CREATE TABLE if not exists knowledge_base
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
kb_name TEXT,
|
||||
vs_type TEXT,
|
||||
embed_model TEXT,
|
||||
file_count INTEGER,
|
||||
create_time DATETIME) ''')
|
||||
c.execute(f'''SELECT kb_name, vs_type, embed_model
|
||||
FROM knowledge_base
|
||||
WHERE kb_name="{kb_name}" ''')
|
||||
resp = c.fetchone()
|
||||
if resp:
|
||||
kb_name, vs_type, embed_model = resp
|
||||
else:
|
||||
kb_name, vs_type, embed_model = None, None, None
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return kb_name, vs_type, embed_model
|
||||
|
||||
|
||||
def delete_kb_from_db(kb_name):
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
# delete kb from table knowledge_base
|
||||
c.execute('''CREATE TABLE if not exists knowledge_base
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
kb_name TEXT,
|
||||
vs_type TEXT,
|
||||
embed_model TEXT,
|
||||
file_count INTEGER,
|
||||
create_time DATETIME) ''')
|
||||
c.execute(f'''DELETE
|
||||
FROM knowledge_base
|
||||
WHERE kb_name="{kb_name}" ''')
|
||||
# delete files in kb from table knowledge_files
|
||||
c.execute('''CREATE TABLE if not exists knowledge_files
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_name TEXT,
|
||||
file_ext TEXT,
|
||||
kb_name TEXT,
|
||||
document_loader_name TEXT,
|
||||
text_splitter_name TEXT,
|
||||
file_version INTEGER,
|
||||
create_time DATETIME) ''')
|
||||
# Insert a row of data
|
||||
c.execute(f"""DELETE
|
||||
FROM knowledge_files
|
||||
WHERE kb_name="{kb_name}"
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return True
|
||||
|
||||
|
||||
def list_docs_from_db(kb_name):
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
c.execute('''CREATE TABLE if not exists knowledge_files
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_name TEXT,
|
||||
file_ext TEXT,
|
||||
kb_name TEXT,
|
||||
document_loader_name TEXT,
|
||||
text_splitter_name TEXT,
|
||||
file_version INTEGER,
|
||||
create_time DATETIME) ''')
|
||||
c.execute(f'''SELECT file_name
|
||||
FROM knowledge_files
|
||||
WHERE kb_name="{kb_name}" ''')
|
||||
kbs = [i[0] for i in c.fetchall() if i]
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return kbs
|
||||
|
||||
|
||||
def add_doc_to_db(kb_file: KnowledgeFile):
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
# Create table
|
||||
c.execute('''CREATE TABLE if not exists knowledge_files
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_name TEXT,
|
||||
file_ext TEXT,
|
||||
kb_name TEXT,
|
||||
document_loader_name TEXT,
|
||||
text_splitter_name TEXT,
|
||||
file_version INTEGER,
|
||||
create_time DATETIME) ''')
|
||||
# Insert a row of data
|
||||
# TODO: 同名文件添加至知识库时,file_version增加
|
||||
c.execute(f"""SELECT 1 FROM knowledge_files WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}" """)
|
||||
record_exist = c.fetchone()
|
||||
if record_exist is not None:
|
||||
c.execute(f"""UPDATE knowledge_files
|
||||
SET file_version = file_version + 1
|
||||
WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}"
|
||||
""")
|
||||
else:
|
||||
c.execute(f"""INSERT INTO knowledge_files
|
||||
(file_name, file_ext, kb_name, document_loader_name, text_splitter_name, file_version, create_time)
|
||||
VALUES
|
||||
('{kb_file.filename}','{kb_file.ext}','{kb_file.kb_name}', '{kb_file.document_loader_name}',
|
||||
'{kb_file.text_splitter_name}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def delete_file_from_db(kb_file: KnowledgeFile):
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
# delete files in kb from table knowledge_files
|
||||
c.execute('''CREATE TABLE if not exists knowledge_files
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_name TEXT,
|
||||
file_ext TEXT,
|
||||
kb_name TEXT,
|
||||
document_loader_name TEXT,
|
||||
text_splitter_name TEXT,
|
||||
file_version INTEGER,
|
||||
create_time DATETIME) ''')
|
||||
# Insert a row of data
|
||||
c.execute(f"""DELETE
|
||||
FROM knowledge_files
|
||||
WHERE file_name="{kb_file.filename}"
|
||||
AND kb_name="{kb_file.kb_name}"
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return True
|
||||
|
||||
|
||||
def doc_exists(kb_file: KnowledgeFile):
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
c.execute('''CREATE TABLE if not exists knowledge_files
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_name TEXT,
|
||||
file_ext TEXT,
|
||||
kb_name TEXT,
|
||||
document_loader_name TEXT,
|
||||
text_splitter_name TEXT,
|
||||
file_version INTEGER,
|
||||
create_time DATETIME) ''')
|
||||
c.execute(f'''SELECT COUNT(*)
|
||||
FROM knowledge_files
|
||||
WHERE file_name="{kb_file.filename}"
|
||||
AND kb_name="{kb_file.kb_name}" ''')
|
||||
status = True if c.fetchone()[0] else False
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return status
|
||||
|
||||
|
||||
def delete_doc_from_faiss(vector_store: FAISS, ids: List[str]):
|
||||
overlapping = set(ids).intersection(vector_store.index_to_docstore_id.values())
|
||||
if not overlapping:
|
||||
raise ValueError("ids do not exist in the current object")
|
||||
_reversed_index = {v: k for k, v in vector_store.index_to_docstore_id.items()}
|
||||
index_to_delete = [_reversed_index[i] for i in ids]
|
||||
vector_store.index.remove_ids(np.array(index_to_delete, dtype=np.int64))
|
||||
for _id in index_to_delete:
|
||||
del vector_store.index_to_docstore_id[_id]
|
||||
# Remove items from docstore.
|
||||
overlapping2 = set(ids).intersection(vector_store.docstore._dict)
|
||||
if not overlapping2:
|
||||
raise ValueError(f"Tried to delete ids that does not exist: {ids}")
|
||||
for _id in ids:
|
||||
vector_store.docstore._dict.pop(_id)
|
||||
return vector_store
|
||||
|
||||
|
||||
class KnowledgeBase:
|
||||
def __init__(self,
|
||||
knowledge_base_name: str,
|
||||
vector_store_type: str = "faiss",
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
):
|
||||
self.kb_name = knowledge_base_name
|
||||
if vector_store_type not in SUPPORTED_VS_TYPES:
|
||||
raise ValueError(f"暂未支持向量库类型 {vector_store_type}")
|
||||
self.vs_type = vector_store_type
|
||||
if embed_model not in embedding_model_dict.keys():
|
||||
raise ValueError(f"暂未支持embedding模型 {embed_model}")
|
||||
self.embed_model = embed_model
|
||||
self.kb_path = get_kb_path(self.kb_name)
|
||||
self.doc_path = get_doc_path(self.kb_name)
|
||||
if self.vs_type in ["faiss"]:
|
||||
self.vs_path = get_vs_path(self.kb_name)
|
||||
elif self.vs_type in ["milvus"]:
|
||||
pass
|
||||
|
||||
def create(self):
|
||||
if not os.path.exists(self.doc_path):
|
||||
os.makedirs(self.doc_path)
|
||||
if self.vs_type in ["faiss"]:
|
||||
if not os.path.exists(self.vs_path):
|
||||
os.makedirs(self.vs_path)
|
||||
add_kb_to_db(self.kb_name, self.vs_type, self.embed_model)
|
||||
elif self.vs_type in ["milvus"]:
|
||||
# TODO: 创建milvus库
|
||||
pass
|
||||
return True
|
||||
|
||||
def recreate_vs(self):
|
||||
if self.vs_type in ["faiss"]:
|
||||
shutil.rmtree(self.vs_path)
|
||||
self.create()
|
||||
|
||||
def add_doc(self, kb_file: KnowledgeFile):
|
||||
docs = kb_file.file2text()
|
||||
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
|
||||
if self.vs_type in ["faiss"]:
|
||||
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
||||
vector_store = FAISS.load_local(self.vs_path, embeddings)
|
||||
vector_store.add_documents(docs)
|
||||
torch_gc()
|
||||
else:
|
||||
if not os.path.exists(self.vs_path):
|
||||
os.makedirs(self.vs_path)
|
||||
vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表
|
||||
torch_gc()
|
||||
vector_store.save_local(self.vs_path)
|
||||
add_doc_to_db(kb_file)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
elif self.vs_type in ["milvus"]:
|
||||
# TODO: 向milvus库中增加文件
|
||||
pass
|
||||
|
||||
def delete_doc(self, kb_file: KnowledgeFile):
|
||||
if os.path.exists(kb_file.filepath):
|
||||
os.remove(kb_file.filepath)
|
||||
if self.vs_type in ["faiss"]:
|
||||
# TODO: 从FAISS向量库中删除文档
|
||||
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
|
||||
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
||||
vector_store = FAISS.load_local(self.vs_path, embeddings)
|
||||
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
||||
if len(ids) == 0:
|
||||
return None
|
||||
print(len(ids))
|
||||
vector_store = delete_doc_from_faiss(vector_store, ids)
|
||||
vector_store.save_local(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
delete_file_from_db(kb_file)
|
||||
return True
|
||||
|
||||
def exist_doc(self, file_name: str):
|
||||
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,
|
||||
filename=file_name))
|
||||
|
||||
def list_docs(self):
|
||||
return list_docs_from_db(self.kb_name)
|
||||
|
||||
def search_docs(self,
|
||||
query: str,
|
||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||
embedding_device: str = EMBEDDING_DEVICE, ):
|
||||
search_index = load_vector_store(self.kb_name,
|
||||
self.embed_model,
|
||||
embedding_device,
|
||||
_VECTOR_STORE_TICKS.get(self.kb_name))
|
||||
docs = search_index.similarity_search(query, k=top_k)
|
||||
return docs
|
||||
|
||||
@classmethod
|
||||
def exists(cls,
|
||||
knowledge_base_name: str):
|
||||
return kb_exists(knowledge_base_name)
|
||||
|
||||
@classmethod
|
||||
def load(cls,
|
||||
knowledge_base_name: str):
|
||||
kb_name, vs_type, embed_model = load_kb_from_db(knowledge_base_name)
|
||||
return cls(kb_name, vs_type, embed_model)
|
||||
|
||||
@classmethod
|
||||
def delete(cls,
|
||||
knowledge_base_name: str):
|
||||
kb = cls.load(knowledge_base_name)
|
||||
if kb.vs_type in ["faiss"]:
|
||||
shutil.rmtree(kb.kb_path)
|
||||
elif kb.vs_type in ["milvus"]:
|
||||
# TODO: 删除milvus库
|
||||
pass
|
||||
status = delete_kb_from_db(knowledge_base_name)
|
||||
return status
|
||||
|
||||
@classmethod
|
||||
def list_kbs(cls):
|
||||
return list_kbs_from_db()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# kb = KnowledgeBase("123", "faiss")
|
||||
# kb.create()
|
||||
kb = KnowledgeBase.load(knowledge_base_name="123")
|
||||
kb.delete_doc(KnowledgeFile(knowledge_base_name="123", filename="README.md"))
|
||||
print()
|
||||
|
|
@ -1,5 +1,8 @@
|
|||
import os.path
|
||||
from configs.model_config import KB_ROOT_PATH
|
||||
import os
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from configs.model_config import (embedding_model_dict, KB_ROOT_PATH)
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||
# 检查是否包含预期外的字符或路径攻击关键字
|
||||
|
|
@ -17,4 +20,10 @@ def get_vs_path(knowledge_base_name: str):
|
|||
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
|
||||
|
||||
def get_file_path(knowledge_base_name: str, doc_name: str):
|
||||
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
|
||||
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
|
||||
|
||||
@lru_cache(1)
|
||||
def load_embeddings(model: str, device: str):
|
||||
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
|
||||
model_kwargs={'device': device})
|
||||
return embeddings
|
||||
|
|
|
|||
|
|
@ -7,11 +7,157 @@
|
|||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
from configs.model_config import LOG_PATH,controller_args,worker_args,server_args,parser
|
||||
|
||||
import subprocess
|
||||
import re
|
||||
import logging
|
||||
import argparse
|
||||
|
||||
LOG_PATH = "./logs/"
|
||||
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
logging.basicConfig(format=LOG_FORMAT)
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
#------multi worker-----------------
|
||||
parser.add_argument('--model-path-address',
|
||||
default="THUDM/chatglm2-6b@localhost@20002",
|
||||
nargs="+",
|
||||
type=str,
|
||||
help="model path, host, and port, formatted as model-path@host@path")
|
||||
#---------------controller-------------------------
|
||||
|
||||
parser.add_argument("--controller-host", type=str, default="localhost")
|
||||
parser.add_argument("--controller-port", type=int, default=21001)
|
||||
parser.add_argument(
|
||||
"--dispatch-method",
|
||||
type=str,
|
||||
choices=["lottery", "shortest_queue"],
|
||||
default="shortest_queue",
|
||||
)
|
||||
controller_args = ["controller-host","controller-port","dispatch-method"]
|
||||
|
||||
#----------------------worker------------------------------------------
|
||||
|
||||
parser.add_argument("--worker-host", type=str, default="localhost")
|
||||
parser.add_argument("--worker-port", type=int, default=21002)
|
||||
# parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
|
||||
# parser.add_argument(
|
||||
# "--controller-address", type=str, default="http://localhost:21001"
|
||||
# )
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="lmsys/vicuna-7b-v1.3",
|
||||
help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default="main",
|
||||
help="Hugging Face Hub model revision identifier",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
choices=["cpu", "cuda", "mps", "xpu"],
|
||||
default="cuda",
|
||||
help="The device type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpus",
|
||||
type=str,
|
||||
default="0",
|
||||
help="A single GPU like 1 or multiple GPUs like 0,2",
|
||||
)
|
||||
parser.add_argument("--num-gpus", type=int, default=1)
|
||||
parser.add_argument(
|
||||
"--max-gpu-memory",
|
||||
type=str,
|
||||
help="The maximum memory per gpu. Use a string like '13Gib'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-8bit", action="store_true", help="Use 8-bit quantization"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpu-offloading",
|
||||
action="store_true",
|
||||
help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gptq-ckpt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Load quantized model. The path to the local GPTQ checkpoint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gptq-wbits",
|
||||
type=int,
|
||||
default=16,
|
||||
choices=[2, 3, 4, 8, 16],
|
||||
help="#bits to use for quantization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gptq-groupsize",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Groupsize to use for quantization; default uses full row.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gptq-act-order",
|
||||
action="store_true",
|
||||
help="Whether to apply the activation order GPTQ heuristic",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-names",
|
||||
type=lambda s: s.split(","),
|
||||
help="Optional display comma separated names",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit-worker-concurrency",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Limit the model concurrency to prevent OOM.",
|
||||
)
|
||||
parser.add_argument("--stream-interval", type=int, default=2)
|
||||
parser.add_argument("--no-register", action="store_true")
|
||||
|
||||
worker_args = [
|
||||
"worker-host","worker-port",
|
||||
"model-path","revision","device","gpus","num-gpus",
|
||||
"max-gpu-memory","load-8bit","cpu-offloading",
|
||||
"gptq-ckpt","gptq-wbits","gptq-groupsize",
|
||||
"gptq-act-order","model-names","limit-worker-concurrency",
|
||||
"stream-interval","no-register",
|
||||
"controller-address"
|
||||
]
|
||||
#-----------------openai server---------------------------
|
||||
|
||||
parser.add_argument("--server-host", type=str, default="localhost", help="host name")
|
||||
parser.add_argument("--server-port", type=int, default=8001, help="port number")
|
||||
parser.add_argument(
|
||||
"--allow-credentials", action="store_true", help="allow credentials"
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
|
||||
# )
|
||||
parser.add_argument(
|
||||
"--api-keys",
|
||||
type=lambda s: s.split(","),
|
||||
help="Optional list of comma separated API keys",
|
||||
)
|
||||
server_args = ["server-host","server-port","allow-credentials","api-keys",
|
||||
"controller-address"
|
||||
]
|
||||
|
||||
args = parser.parse_args()
|
||||
# 必须要加http//:,否则InvalidSchema: No connection adapters were found
|
||||
args = argparse.Namespace(**vars(args),**{"controller-address":f"http://{args.controller_host}:{str(args.controller_port)}"})
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ def dialogue_page(api: ApiRequest):
|
|||
chat_box.update_msg(text, 0, streaming=False)
|
||||
|
||||
now = datetime.now()
|
||||
cols[0].download_button(
|
||||
export_btn.download_button(
|
||||
"Export",
|
||||
"".join(chat_box.export2md(cur_chat_name)),
|
||||
file_name=f"{now:%Y-%m-%d %H.%M}_{cur_chat_name}.md",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,124 @@
|
|||
from pydoc import Helper
|
||||
import streamlit as st
|
||||
from webui_pages.utils import *
|
||||
import streamlit_antd_components as sac
|
||||
from st_aggrid import AgGrid
|
||||
from st_aggrid.grid_options_builder import GridOptionsBuilder
|
||||
import pandas as pd
|
||||
from server.knowledge_base.utils import get_file_path
|
||||
from streamlit_chatbox import *
|
||||
|
||||
|
||||
SENTENCE_SIZE = 100
|
||||
|
||||
|
||||
def knowledge_base_page(api: ApiRequest):
|
||||
st.write(123)
|
||||
pass
|
||||
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=True)
|
||||
chat_box = ChatBox(session_key="kb_messages")
|
||||
|
||||
kb_list = api.list_knowledge_bases()
|
||||
kb_docs = {}
|
||||
for kb in kb_list:
|
||||
kb_docs[kb] = api.list_kb_docs(kb)
|
||||
|
||||
with st.sidebar:
|
||||
def on_new_kb():
|
||||
if name := st.session_state.get("new_kb_name"):
|
||||
if name in kb_list:
|
||||
st.error(f"名为 {name} 的知识库已经存在!")
|
||||
else:
|
||||
ret = api.create_knowledge_base(name)
|
||||
st.toast(ret["msg"])
|
||||
|
||||
def on_del_kb():
|
||||
if name := st.session_state.get("new_kb_name"):
|
||||
if name in kb_list:
|
||||
ret = api.delete_knowledge_base(name)
|
||||
st.toast(ret["msg"])
|
||||
else:
|
||||
st.error(f"名为 {name} 的知识库不存在!")
|
||||
|
||||
cols = st.columns([2, 1, 1])
|
||||
new_kb_name = cols[0].text_input(
|
||||
"新知识库名称",
|
||||
placeholder="新知识库名称",
|
||||
label_visibility="collapsed",
|
||||
key="new_kb_name",
|
||||
)
|
||||
cols[1].button("新建", on_click=on_new_kb, disabled=not bool(new_kb_name))
|
||||
cols[2].button("删除", on_click=on_del_kb, disabled=not bool(new_kb_name))
|
||||
|
||||
st.write("知识库:")
|
||||
if kb_list:
|
||||
try:
|
||||
index = kb_list.index(st.session_state.get("cur_kb"))
|
||||
except:
|
||||
index = 0
|
||||
kb = sac.buttons(
|
||||
kb_list,
|
||||
index,
|
||||
format_func=lambda x: f"{x} ({len(kb_docs[x])})",
|
||||
)
|
||||
st.session_state["cur_kb"] = kb
|
||||
sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True)
|
||||
files = st.file_uploader("上传知识文件",
|
||||
["docx", "txt", "md", "csv", "xlsx", "pdf"],
|
||||
accept_multiple_files=True,
|
||||
key="files",
|
||||
)
|
||||
if st.button(
|
||||
"添加文件到知识库",
|
||||
help="请先上传文件,再点击添加",
|
||||
use_container_width=True,
|
||||
disabled=len(files)==0,
|
||||
):
|
||||
for f in files:
|
||||
ret = api.upload_kb_doc(f, kb)
|
||||
if ret["code"] == 200:
|
||||
st.toast(ret["msg"], icon="✔")
|
||||
else:
|
||||
st.toast(ret["msg"], icon="❌")
|
||||
st.session_state.files = []
|
||||
|
||||
if st.button(
|
||||
"重建知识库",
|
||||
help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。",
|
||||
use_container_width=True,
|
||||
disabled=True,
|
||||
):
|
||||
progress = st.progress(0.0, "")
|
||||
for d in api.recreate_vector_store(kb):
|
||||
progress.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
|
||||
|
||||
if kb_list:
|
||||
# 知识库详情
|
||||
st.subheader(f"知识库 {kb} 详情")
|
||||
df = pd.DataFrame([[i + 1, x] for i, x in enumerate(kb_docs[kb])], columns=["No", "文档名称"])
|
||||
gb = GridOptionsBuilder.from_dataframe(df)
|
||||
gb.configure_column("No", width=50)
|
||||
gb.configure_selection()
|
||||
|
||||
cols = st.columns([1, 2])
|
||||
|
||||
with cols[0]:
|
||||
docs = AgGrid(df, gb.build())
|
||||
|
||||
with cols[1]:
|
||||
cols = st.columns(3)
|
||||
selected_rows = docs.get("selected_rows", [])
|
||||
|
||||
cols = st.columns([2, 3, 2])
|
||||
if selected_rows:
|
||||
file_name = selected_rows[0]["文档名称"]
|
||||
file_path = get_file_path(kb, file_name)
|
||||
with open(file_path, "rb") as fp:
|
||||
cols[0].download_button("下载选中文档", fp, file_name=file_name)
|
||||
else:
|
||||
cols[0].download_button("下载选中文档", "", disabled=True)
|
||||
if cols[2].button("删除选中文档!", type="primary"):
|
||||
for row in selected_rows:
|
||||
ret = api.delete_kb_doc(kb, row["文档名称"])
|
||||
st.toast(ret["msg"])
|
||||
st.experimental_rerun()
|
||||
|
||||
st.write("本文档包含以下知识条目:(待定内容)")
|
||||
|
|
|
|||
Loading…
Reference in New Issue