diff --git a/requirements.txt b/requirements.txt index 765cf9f..257a62d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,9 +15,13 @@ starlette~=0.27.0 numpy~=1.24.4 pydantic~=1.10.11 unstructured[all-docs] +python-magic-bin; sys_platform == 'win32' streamlit>=1.25.0 streamlit-option-menu streamlit-antd-components streamlit-chatbox>=1.1.6 httpx + +faiss-cpu +pymilvus==2.1.3 # requires milvus==2.1.3 diff --git a/server/chat/chat_openai_chain/chat_openai_chain.py b/server/chat/chat_openai_chain/chat_openai_chain.py index 17f1866..7757f98 100644 --- a/server/chat/chat_openai_chain/chat_openai_chain.py +++ b/server/chat/chat_openai_chain/chat_openai_chain.py @@ -4,7 +4,11 @@ from typing import Any, Dict, List, Optional from langchain.chains.base import Chain from langchain.schema import ( BaseMessage, - messages_from_dict, + AIMessage, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, LLMResult ) from langchain.chat_models import ChatOpenAI @@ -16,6 +20,18 @@ from langchain.callbacks.manager import ( from server.model.chat_openai_chain import OpenAiChatMsgDto, OpenAiMessageDto, BaseMessageDto +def _convert_dict_to_message(_dict: dict) -> BaseMessage: + role = _dict["role"] + if role == "user": + return HumanMessage(content=_dict["content"]) + elif role == "assistant": + return AIMessage(content=_dict["content"]) + elif role == "system": + return SystemMessage(content=_dict["content"]) + else: + return ChatMessage(content=_dict["content"], role=role) + + def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[BaseMessage]: """ 前端消息传输对象DTO转换为chat消息传输对象DTO @@ -25,7 +41,7 @@ def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[Bas messages = [] for message_datum in message_data: messages.append(message_datum.dict()) - return messages_from_dict(messages) + return _convert_dict_to_message(messages) class BaseChatOpenAIChain(Chain, ABC): diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 4eaf3fd..0a85df3 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -10,7 +10,8 @@ from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable import asyncio from langchain.prompts import PromptTemplate -from server.knowledge_base.knowledge_base import KnowledgeBase +from server.knowledge_base.knowledge_base_factory import KBServiceFactory +from server.knowledge_base.kb_service.base import KBService import json @@ -18,12 +19,12 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp knowledge_base_name: str = Body(..., description="知识库名称", example="samples"), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), ): - if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name): + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name) async def knowledge_base_chat_iterator(query: str, - kb: KnowledgeBase, + kb: KBService, top_k: int, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() diff --git a/server/knowledge_base/__init__.py b/server/knowledge_base/__init__.py index 14b7898..b64e921 100644 --- a/server/knowledge_base/__init__.py +++ b/server/knowledge_base/__init__.py @@ -1,4 +1,4 @@ from .kb_api import list_kbs, create_kb, delete_kb from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store -from .knowledge_base import KnowledgeBase from .knowledge_file import KnowledgeFile +from .knowledge_base_factory import KBServiceFactory diff --git a/server/knowledge_base/kb_api.py b/server/knowledge_base/kb_api.py index d5ef57b..9a92ea6 100644 --- a/server/knowledge_base/kb_api.py +++ b/server/knowledge_base/kb_api.py @@ -1,13 +1,14 @@ import urllib from server.utils import BaseResponse, ListResponse from server.knowledge_base.utils import validate_kb_name -from server.knowledge_base.knowledge_base import KnowledgeBase +from server.knowledge_base.knowledge_base_factory import KBServiceFactory +from server.knowledge_base.kb_service.base import list_kbs_from_db from configs.model_config import EMBEDDING_MODEL async def list_kbs(): # Get List of Knowledge Base - return ListResponse(data=KnowledgeBase.list_kbs()) + return ListResponse(data=list_kbs_from_db()) async def create_kb(knowledge_base_name: str, @@ -19,11 +20,10 @@ async def create_kb(knowledge_base_name: str, return BaseResponse(code=403, msg="Don't attack me") if knowledge_base_name is None or knowledge_base_name.strip() == "": return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称") - if KnowledgeBase.exists(knowledge_base_name): + + kb = KBServiceFactory.get_service(knowledge_base_name, "faiss") + if kb is not None: return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}") - kb = KnowledgeBase(knowledge_base_name=knowledge_base_name, - vector_store_type=vector_store_type, - embed_model=embed_model) kb.create() return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}") @@ -34,10 +34,12 @@ async def delete_kb(knowledge_base_name: str): return BaseResponse(code=403, msg="Don't attack me") knowledge_base_name = urllib.parse.unquote(knowledge_base_name) - if not KnowledgeBase.exists(knowledge_base_name): + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + + if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - status = KnowledgeBase.delete(knowledge_base_name) + status = kb.drop_kb() if status: return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}") else: diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 62de2d9..2ebe828 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -6,7 +6,9 @@ from server.knowledge_base.utils import (validate_kb_name) from fastapi.responses import StreamingResponse import json from server.knowledge_base.knowledge_file import KnowledgeFile -from server.knowledge_base.knowledge_base import KnowledgeBase +from server.knowledge_base.knowledge_base_factory import KBServiceFactory +from server.knowledge_base.kb_service.base import SupportedVSType, list_docs_from_folder +from server.knowledge_base.kb_service.faiss_kb_service import refresh_vs_cache async def list_docs(knowledge_base_name: str): @@ -14,10 +16,11 @@ async def list_docs(knowledge_base_name: str): return ListResponse(code=403, msg="Don't attack me", data=[]) knowledge_base_name = urllib.parse.unquote(knowledge_base_name) - if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name): + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[]) else: - all_doc_names = KnowledgeBase.load(knowledge_base_name=knowledge_base_name).list_docs() + all_doc_names = kb.list_docs() return ListResponse(data=all_doc_names) @@ -28,11 +31,10 @@ async def upload_doc(file: UploadFile = File(description="上传文件"), if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") - if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name): + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name) - file_content = await file.read() # 读取上传文件的内容 kb_file = KnowledgeFile(filename=file.filename, @@ -63,10 +65,10 @@ async def delete_doc(knowledge_base_name: str, return BaseResponse(code=403, msg="Don't attack me") knowledge_base_name = urllib.parse.unquote(knowledge_base_name) - if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name): + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name) if not kb.exist_doc(doc_name): return BaseResponse(code=404, msg=f"未找到文件 {doc_name}") kb_file = KnowledgeFile(filename=doc_name, @@ -92,21 +94,26 @@ async def recreate_vector_store(knowledge_base_name: str): recreate vector store from the content. this is usefull when user can copy files to content folder directly instead of upload through network. ''' - kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name) + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: + return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - async def output(kb: KnowledgeBase): - kb.recreate_vs() + async def output(kb): + kb.clear_vs() print(f"start to recreate vector store of {kb.kb_name}") - docs = kb.list_docs() + docs = list_docs_from_folder(knowledge_base_name) + print(docs) for i, filename in enumerate(docs): + yield json.dumps({ + "total": len(docs), + "finished": i, + "doc": filename, + }) kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb.kb_name) print(f"processing {kb_file.filepath} to vector store.") kb.add_doc(kb_file) - yield json.dumps({ - "total": len(docs), - "finished": i + 1, - "doc": filename, - }) + if kb.vs_type == SupportedVSType.FAISS: + refresh_vs_cache(knowledge_base_name) return StreamingResponse(output(kb), media_type="text/event-stream") diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index c7fb24e..89f3c3e 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -125,6 +125,10 @@ def list_docs_from_db(kb_name): conn.close() return kbs +def list_docs_from_folder(kb_name: str): + doc_path = get_doc_path(kb_name) + return [file for file in os.listdir(doc_path) + if os.path.isfile(os.path.join(doc_path, file))] def add_doc_to_db(kb_file: KnowledgeFile): conn = sqlite3.connect(DB_ROOT_PATH) diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 64ed9ed..5e9e8b7 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -122,3 +122,4 @@ class FaissKBService(KBService): def do_clear_vs(self): shutil.rmtree(self.vs_path) + os.makedirs(self.vs_path) diff --git a/server/knowledge_base/knowledge_base.py b/server/knowledge_base/knowledge_base.py deleted file mode 100644 index 10ccd9d..0000000 --- a/server/knowledge_base/knowledge_base.py +++ /dev/null @@ -1,407 +0,0 @@ -import os -import sqlite3 -import datetime -import shutil -from langchain.vectorstores import FAISS -from langchain.embeddings.huggingface import HuggingFaceEmbeddings -from configs.model_config import (embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE, - DB_ROOT_PATH, VECTOR_SEARCH_TOP_K, CACHED_VS_NUM) -from server.utils import torch_gc -from functools import lru_cache -from server.knowledge_base.knowledge_file import KnowledgeFile -from typing import List -import numpy as np -from server.knowledge_base.utils import (get_kb_path, get_doc_path, get_vs_path) - -SUPPORTED_VS_TYPES = ["faiss", "milvus"] - -_VECTOR_STORE_TICKS = {} - -@lru_cache(1) -def load_embeddings(model: str, device: str): - embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], - model_kwargs={'device': device}) - return embeddings - - -@lru_cache(CACHED_VS_NUM) -def load_vector_store( - knowledge_base_name: str, - embedding_model: str, - embedding_device: str, - tick: int, # tick will be changed by upload_doc etc. and make cache refreshed. -): - print(f"loading vector store in '{knowledge_base_name}' with '{embedding_model}' embeddings.") - embeddings = load_embeddings(embedding_model, embedding_device) - vs_path = get_vs_path(knowledge_base_name) - search_index = FAISS.load_local(vs_path, embeddings) - return search_index - - -def refresh_vs_cache(kb_name: str): - """ - make vector store cache refreshed when next loading - """ - _VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1 - - -def list_kbs_from_db(): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - c.execute('''CREATE TABLE if not exists knowledge_base - (id INTEGER PRIMARY KEY AUTOINCREMENT, - kb_name TEXT, - vs_type TEXT, - embed_model TEXT, - file_count INTEGER, - create_time DATETIME) ''') - c.execute(f'''SELECT kb_name - FROM knowledge_base - WHERE file_count>0 ''') - kbs = [i[0] for i in c.fetchall() if i] - conn.commit() - conn.close() - return kbs - - -def add_kb_to_db(kb_name, vs_type, embed_model): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - # Create table - c.execute('''CREATE TABLE if not exists knowledge_base - (id INTEGER PRIMARY KEY AUTOINCREMENT, - kb_name TEXT, - vs_type TEXT, - embed_model TEXT, - file_count INTEGER, - create_time DATETIME) ''') - # Insert a row of data - c.execute(f"""INSERT INTO knowledge_base - (kb_name, vs_type, embed_model, file_count, create_time) - VALUES - ('{kb_name}','{vs_type}','{embed_model}', - 0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""") - conn.commit() - conn.close() - - -def kb_exists(kb_name): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - c.execute('''CREATE TABLE if not exists knowledge_base - (id INTEGER PRIMARY KEY AUTOINCREMENT, - kb_name TEXT, - vs_type TEXT, - embed_model TEXT, - file_count INTEGER, - create_time DATETIME) ''') - c.execute(f'''SELECT COUNT(*) - FROM knowledge_base - WHERE kb_name="{kb_name}" ''') - status = True if c.fetchone()[0] else False - conn.commit() - conn.close() - return status - - -def load_kb_from_db(kb_name): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - c.execute('''CREATE TABLE if not exists knowledge_base - (id INTEGER PRIMARY KEY AUTOINCREMENT, - kb_name TEXT, - vs_type TEXT, - embed_model TEXT, - file_count INTEGER, - create_time DATETIME) ''') - c.execute(f'''SELECT kb_name, vs_type, embed_model - FROM knowledge_base - WHERE kb_name="{kb_name}" ''') - resp = c.fetchone() - if resp: - kb_name, vs_type, embed_model = resp - else: - kb_name, vs_type, embed_model = None, None, None - conn.commit() - conn.close() - return kb_name, vs_type, embed_model - - -def delete_kb_from_db(kb_name): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - # delete kb from table knowledge_base - c.execute('''CREATE TABLE if not exists knowledge_base - (id INTEGER PRIMARY KEY AUTOINCREMENT, - kb_name TEXT, - vs_type TEXT, - embed_model TEXT, - file_count INTEGER, - create_time DATETIME) ''') - c.execute(f'''DELETE - FROM knowledge_base - WHERE kb_name="{kb_name}" ''') - # delete files in kb from table knowledge_files - c.execute('''CREATE TABLE if not exists knowledge_files - (id INTEGER PRIMARY KEY AUTOINCREMENT, - file_name TEXT, - file_ext TEXT, - kb_name TEXT, - document_loader_name TEXT, - text_splitter_name TEXT, - file_version INTEGER, - create_time DATETIME) ''') - # Insert a row of data - c.execute(f"""DELETE - FROM knowledge_files - WHERE kb_name="{kb_name}" - """) - conn.commit() - conn.close() - return True - - -def list_docs_from_db(kb_name): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - c.execute('''CREATE TABLE if not exists knowledge_files - (id INTEGER PRIMARY KEY AUTOINCREMENT, - file_name TEXT, - file_ext TEXT, - kb_name TEXT, - document_loader_name TEXT, - text_splitter_name TEXT, - file_version INTEGER, - create_time DATETIME) ''') - c.execute(f'''SELECT file_name - FROM knowledge_files - WHERE kb_name="{kb_name}" ''') - kbs = [i[0] for i in c.fetchall() if i] - conn.commit() - conn.close() - return kbs - - -def add_doc_to_db(kb_file: KnowledgeFile): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - # Create table - c.execute('''CREATE TABLE if not exists knowledge_files - (id INTEGER PRIMARY KEY AUTOINCREMENT, - file_name TEXT, - file_ext TEXT, - kb_name TEXT, - document_loader_name TEXT, - text_splitter_name TEXT, - file_version INTEGER, - create_time DATETIME) ''') - # Insert a row of data - # TODO: 同名文件添加至知识库时,file_version增加 - c.execute(f"""SELECT 1 FROM knowledge_files WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}" """) - record_exist = c.fetchone() - if record_exist is not None: - c.execute(f"""UPDATE knowledge_files - SET file_version = file_version + 1 - WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}" - """) - else: - c.execute(f"""INSERT INTO knowledge_files - (file_name, file_ext, kb_name, document_loader_name, text_splitter_name, file_version, create_time) - VALUES - ('{kb_file.filename}','{kb_file.ext}','{kb_file.kb_name}', '{kb_file.document_loader_name}', - '{kb_file.text_splitter_name}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""") - conn.commit() - conn.close() - - -def delete_file_from_db(kb_file: KnowledgeFile): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - # delete files in kb from table knowledge_files - c.execute('''CREATE TABLE if not exists knowledge_files - (id INTEGER PRIMARY KEY AUTOINCREMENT, - file_name TEXT, - file_ext TEXT, - kb_name TEXT, - document_loader_name TEXT, - text_splitter_name TEXT, - file_version INTEGER, - create_time DATETIME) ''') - # Insert a row of data - c.execute(f"""DELETE - FROM knowledge_files - WHERE file_name="{kb_file.filename}" - AND kb_name="{kb_file.kb_name}" - """) - conn.commit() - conn.close() - return True - - -def doc_exists(kb_file: KnowledgeFile): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - c.execute('''CREATE TABLE if not exists knowledge_files - (id INTEGER PRIMARY KEY AUTOINCREMENT, - file_name TEXT, - file_ext TEXT, - kb_name TEXT, - document_loader_name TEXT, - text_splitter_name TEXT, - file_version INTEGER, - create_time DATETIME) ''') - c.execute(f'''SELECT COUNT(*) - FROM knowledge_files - WHERE file_name="{kb_file.filename}" - AND kb_name="{kb_file.kb_name}" ''') - status = True if c.fetchone()[0] else False - conn.commit() - conn.close() - return status - - -def delete_doc_from_faiss(vector_store: FAISS, ids: List[str]): - overlapping = set(ids).intersection(vector_store.index_to_docstore_id.values()) - if not overlapping: - raise ValueError("ids do not exist in the current object") - _reversed_index = {v: k for k, v in vector_store.index_to_docstore_id.items()} - index_to_delete = [_reversed_index[i] for i in ids] - vector_store.index.remove_ids(np.array(index_to_delete, dtype=np.int64)) - for _id in index_to_delete: - del vector_store.index_to_docstore_id[_id] - # Remove items from docstore. - overlapping2 = set(ids).intersection(vector_store.docstore._dict) - if not overlapping2: - raise ValueError(f"Tried to delete ids that does not exist: {ids}") - for _id in ids: - vector_store.docstore._dict.pop(_id) - return vector_store - - -class KnowledgeBase: - def __init__(self, - knowledge_base_name: str, - vector_store_type: str = "faiss", - embed_model: str = EMBEDDING_MODEL, - ): - self.kb_name = knowledge_base_name - if vector_store_type not in SUPPORTED_VS_TYPES: - raise ValueError(f"暂未支持向量库类型 {vector_store_type}") - self.vs_type = vector_store_type - if embed_model not in embedding_model_dict.keys(): - raise ValueError(f"暂未支持embedding模型 {embed_model}") - self.embed_model = embed_model - self.kb_path = get_kb_path(self.kb_name) - self.doc_path = get_doc_path(self.kb_name) - if self.vs_type in ["faiss"]: - self.vs_path = get_vs_path(self.kb_name) - elif self.vs_type in ["milvus"]: - pass - - def create(self): - if not os.path.exists(self.doc_path): - os.makedirs(self.doc_path) - if self.vs_type in ["faiss"]: - if not os.path.exists(self.vs_path): - os.makedirs(self.vs_path) - add_kb_to_db(self.kb_name, self.vs_type, self.embed_model) - elif self.vs_type in ["milvus"]: - # TODO: 创建milvus库 - pass - return True - - def recreate_vs(self): - if self.vs_type in ["faiss"]: - shutil.rmtree(self.vs_path) - self.create() - - def add_doc(self, kb_file: KnowledgeFile): - docs = kb_file.file2text() - embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE) - if self.vs_type in ["faiss"]: - if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path): - vector_store = FAISS.load_local(self.vs_path, embeddings) - vector_store.add_documents(docs) - torch_gc() - else: - if not os.path.exists(self.vs_path): - os.makedirs(self.vs_path) - vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表 - torch_gc() - vector_store.save_local(self.vs_path) - add_doc_to_db(kb_file) - refresh_vs_cache(self.kb_name) - elif self.vs_type in ["milvus"]: - # TODO: 向milvus库中增加文件 - pass - - def delete_doc(self, kb_file: KnowledgeFile): - if os.path.exists(kb_file.filepath): - os.remove(kb_file.filepath) - if self.vs_type in ["faiss"]: - # TODO: 从FAISS向量库中删除文档 - embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE) - if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path): - vector_store = FAISS.load_local(self.vs_path, embeddings) - ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] - if len(ids) == 0: - return None - print(len(ids)) - vector_store = delete_doc_from_faiss(vector_store, ids) - vector_store.save_local(self.vs_path) - refresh_vs_cache(self.kb_name) - delete_file_from_db(kb_file) - return True - - def exist_doc(self, file_name: str): - return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name, - filename=file_name)) - - def list_docs(self): - return list_docs_from_db(self.kb_name) - - def search_docs(self, - query: str, - top_k: int = VECTOR_SEARCH_TOP_K, - embedding_device: str = EMBEDDING_DEVICE, ): - search_index = load_vector_store(self.kb_name, - self.embed_model, - embedding_device, - _VECTOR_STORE_TICKS.get(self.kb_name)) - docs = search_index.similarity_search(query, k=top_k) - return docs - - @classmethod - def exists(cls, - knowledge_base_name: str): - return kb_exists(knowledge_base_name) - - @classmethod - def load(cls, - knowledge_base_name: str): - kb_name, vs_type, embed_model = load_kb_from_db(knowledge_base_name) - return cls(kb_name, vs_type, embed_model) - - @classmethod - def delete(cls, - knowledge_base_name: str): - kb = cls.load(knowledge_base_name) - if kb.vs_type in ["faiss"]: - shutil.rmtree(kb.kb_path) - elif kb.vs_type in ["milvus"]: - # TODO: 删除milvus库 - pass - status = delete_kb_from_db(knowledge_base_name) - return status - - @classmethod - def list_kbs(cls): - return list_kbs_from_db() - - -if __name__ == "__main__": - # kb = KnowledgeBase("123", "faiss") - # kb.create() - kb = KnowledgeBase.load(knowledge_base_name="123") - kb.delete_doc(KnowledgeFile(knowledge_base_name="123", filename="README.md")) - print() diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 215ab36..fa6f132 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -1,5 +1,8 @@ -import os.path -from configs.model_config import KB_ROOT_PATH +import os +from langchain.embeddings.huggingface import HuggingFaceEmbeddings +from configs.model_config import (embedding_model_dict, KB_ROOT_PATH) +from functools import lru_cache + def validate_kb_name(knowledge_base_id: str) -> bool: # 检查是否包含预期外的字符或路径攻击关键字 @@ -17,4 +20,10 @@ def get_vs_path(knowledge_base_name: str): return os.path.join(get_kb_path(knowledge_base_name), "vector_store") def get_file_path(knowledge_base_name: str, doc_name: str): - return os.path.join(get_doc_path(knowledge_base_name), doc_name) \ No newline at end of file + return os.path.join(get_doc_path(knowledge_base_name), doc_name) + +@lru_cache(1) +def load_embeddings(model: str, device: str): + embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], + model_kwargs={'device': device}) + return embeddings diff --git a/server/llm_api_sh.py b/server/llm_api_sh.py index 904ac71..3a8e880 100644 --- a/server/llm_api_sh.py +++ b/server/llm_api_sh.py @@ -7,11 +7,157 @@ import sys import os sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from configs.model_config import LOG_PATH,controller_args,worker_args,server_args,parser + import subprocess import re +import logging import argparse +LOG_PATH = "./logs/" +LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" +logger = logging.getLogger() +logger.setLevel(logging.INFO) +logging.basicConfig(format=LOG_FORMAT) + + +parser = argparse.ArgumentParser() +#------multi worker----------------- +parser.add_argument('--model-path-address', + default="THUDM/chatglm2-6b@localhost@20002", + nargs="+", + type=str, + help="model path, host, and port, formatted as model-path@host@path") +#---------------controller------------------------- + +parser.add_argument("--controller-host", type=str, default="localhost") +parser.add_argument("--controller-port", type=int, default=21001) +parser.add_argument( + "--dispatch-method", + type=str, + choices=["lottery", "shortest_queue"], + default="shortest_queue", +) +controller_args = ["controller-host","controller-port","dispatch-method"] + +#----------------------worker------------------------------------------ + +parser.add_argument("--worker-host", type=str, default="localhost") +parser.add_argument("--worker-port", type=int, default=21002) +# parser.add_argument("--worker-address", type=str, default="http://localhost:21002") +# parser.add_argument( +# "--controller-address", type=str, default="http://localhost:21001" +# ) +parser.add_argument( + "--model-path", + type=str, + default="lmsys/vicuna-7b-v1.3", + help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", +) +parser.add_argument( + "--revision", + type=str, + default="main", + help="Hugging Face Hub model revision identifier", +) +parser.add_argument( + "--device", + type=str, + choices=["cpu", "cuda", "mps", "xpu"], + default="cuda", + help="The device type", +) +parser.add_argument( + "--gpus", + type=str, + default="0", + help="A single GPU like 1 or multiple GPUs like 0,2", +) +parser.add_argument("--num-gpus", type=int, default=1) +parser.add_argument( + "--max-gpu-memory", + type=str, + help="The maximum memory per gpu. Use a string like '13Gib'", +) +parser.add_argument( + "--load-8bit", action="store_true", help="Use 8-bit quantization" +) +parser.add_argument( + "--cpu-offloading", + action="store_true", + help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU", +) +parser.add_argument( + "--gptq-ckpt", + type=str, + default=None, + help="Load quantized model. The path to the local GPTQ checkpoint.", +) +parser.add_argument( + "--gptq-wbits", + type=int, + default=16, + choices=[2, 3, 4, 8, 16], + help="#bits to use for quantization", +) +parser.add_argument( + "--gptq-groupsize", + type=int, + default=-1, + help="Groupsize to use for quantization; default uses full row.", +) +parser.add_argument( + "--gptq-act-order", + action="store_true", + help="Whether to apply the activation order GPTQ heuristic", +) +parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", +) +parser.add_argument( + "--limit-worker-concurrency", + type=int, + default=5, + help="Limit the model concurrency to prevent OOM.", +) +parser.add_argument("--stream-interval", type=int, default=2) +parser.add_argument("--no-register", action="store_true") + +worker_args = [ + "worker-host","worker-port", + "model-path","revision","device","gpus","num-gpus", + "max-gpu-memory","load-8bit","cpu-offloading", + "gptq-ckpt","gptq-wbits","gptq-groupsize", + "gptq-act-order","model-names","limit-worker-concurrency", + "stream-interval","no-register", + "controller-address" + ] +#-----------------openai server--------------------------- + +parser.add_argument("--server-host", type=str, default="localhost", help="host name") +parser.add_argument("--server-port", type=int, default=8001, help="port number") +parser.add_argument( + "--allow-credentials", action="store_true", help="allow credentials" +) +# parser.add_argument( +# "--allowed-origins", type=json.loads, default=["*"], help="allowed origins" +# ) +# parser.add_argument( +# "--allowed-methods", type=json.loads, default=["*"], help="allowed methods" +# ) +# parser.add_argument( +# "--allowed-headers", type=json.loads, default=["*"], help="allowed headers" +# ) +parser.add_argument( + "--api-keys", + type=lambda s: s.split(","), + help="Optional list of comma separated API keys", +) +server_args = ["server-host","server-port","allow-credentials","api-keys", + "controller-address" + ] + args = parser.parse_args() # 必须要加http//:,否则InvalidSchema: No connection adapters were found args = argparse.Namespace(**vars(args),**{"controller-address":f"http://{args.controller_host}:{str(args.controller_port)}"}) diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index b421fb6..e03a539 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -118,7 +118,7 @@ def dialogue_page(api: ApiRequest): chat_box.update_msg(text, 0, streaming=False) now = datetime.now() - cols[0].download_button( + export_btn.download_button( "Export", "".join(chat_box.export2md(cur_chat_name)), file_name=f"{now:%Y-%m-%d %H.%M}_{cur_chat_name}.md", diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 119a4a2..f36bae4 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -1,6 +1,124 @@ +from pydoc import Helper import streamlit as st from webui_pages.utils import * +import streamlit_antd_components as sac +from st_aggrid import AgGrid +from st_aggrid.grid_options_builder import GridOptionsBuilder +import pandas as pd +from server.knowledge_base.utils import get_file_path +from streamlit_chatbox import * + + +SENTENCE_SIZE = 100 + def knowledge_base_page(api: ApiRequest): - st.write(123) - pass \ No newline at end of file + api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=True) + chat_box = ChatBox(session_key="kb_messages") + + kb_list = api.list_knowledge_bases() + kb_docs = {} + for kb in kb_list: + kb_docs[kb] = api.list_kb_docs(kb) + + with st.sidebar: + def on_new_kb(): + if name := st.session_state.get("new_kb_name"): + if name in kb_list: + st.error(f"名为 {name} 的知识库已经存在!") + else: + ret = api.create_knowledge_base(name) + st.toast(ret["msg"]) + + def on_del_kb(): + if name := st.session_state.get("new_kb_name"): + if name in kb_list: + ret = api.delete_knowledge_base(name) + st.toast(ret["msg"]) + else: + st.error(f"名为 {name} 的知识库不存在!") + + cols = st.columns([2, 1, 1]) + new_kb_name = cols[0].text_input( + "新知识库名称", + placeholder="新知识库名称", + label_visibility="collapsed", + key="new_kb_name", + ) + cols[1].button("新建", on_click=on_new_kb, disabled=not bool(new_kb_name)) + cols[2].button("删除", on_click=on_del_kb, disabled=not bool(new_kb_name)) + + st.write("知识库:") + if kb_list: + try: + index = kb_list.index(st.session_state.get("cur_kb")) + except: + index = 0 + kb = sac.buttons( + kb_list, + index, + format_func=lambda x: f"{x} ({len(kb_docs[x])})", + ) + st.session_state["cur_kb"] = kb + sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True) + files = st.file_uploader("上传知识文件", + ["docx", "txt", "md", "csv", "xlsx", "pdf"], + accept_multiple_files=True, + key="files", + ) + if st.button( + "添加文件到知识库", + help="请先上传文件,再点击添加", + use_container_width=True, + disabled=len(files)==0, + ): + for f in files: + ret = api.upload_kb_doc(f, kb) + if ret["code"] == 200: + st.toast(ret["msg"], icon="✔") + else: + st.toast(ret["msg"], icon="❌") + st.session_state.files = [] + + if st.button( + "重建知识库", + help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。", + use_container_width=True, + disabled=True, + ): + progress = st.progress(0.0, "") + for d in api.recreate_vector_store(kb): + progress.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}") + + if kb_list: + # 知识库详情 + st.subheader(f"知识库 {kb} 详情") + df = pd.DataFrame([[i + 1, x] for i, x in enumerate(kb_docs[kb])], columns=["No", "文档名称"]) + gb = GridOptionsBuilder.from_dataframe(df) + gb.configure_column("No", width=50) + gb.configure_selection() + + cols = st.columns([1, 2]) + + with cols[0]: + docs = AgGrid(df, gb.build()) + + with cols[1]: + cols = st.columns(3) + selected_rows = docs.get("selected_rows", []) + + cols = st.columns([2, 3, 2]) + if selected_rows: + file_name = selected_rows[0]["文档名称"] + file_path = get_file_path(kb, file_name) + with open(file_path, "rb") as fp: + cols[0].download_button("下载选中文档", fp, file_name=file_name) + else: + cols[0].download_button("下载选中文档", "", disabled=True) + if cols[2].button("删除选中文档!", type="primary"): + for row in selected_rows: + ret = api.delete_kb_doc(kb, row["文档名称"]) + st.toast(ret["msg"]) + st.experimental_rerun() + + st.write("本文档包含以下知识条目:(待定内容)")