From 745a105baed71563bd88d9447a8033aede0ce0e7 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Sat, 16 Sep 2023 09:09:27 +0800 Subject: [PATCH] feat: support volc fangzhou --- configs/model_config.py.example | 8 + init_database.py | 94 ++++++++--- .../kb_service/milvus_kb_service.py | 6 +- server/knowledge_base/migrate.py | 147 ++++++++---------- server/model_workers/fangzhou.py | 91 +++++++++++ tests/test_migrate.py | 139 +++++++++++++++++ 6 files changed, 377 insertions(+), 108 deletions(-) create mode 100644 server/model_workers/fangzhou.py create mode 100644 tests/test_migrate.py diff --git a/configs/model_config.py.example b/configs/model_config.py.example index dff60f7..c6f1dc4 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -115,6 +115,14 @@ ONLINE_LLM_MODEL = { "secret_key": "", "provider": "QianFanWorker", }, + # 火山方舟 API + "fangzhou-api": { + "version": "chatglm-6b-model", # 当前支持 "chatglm-6b-model", 更多的见文档模型支持列表中方舟部分。 + "version_url": "", # 可以不填写version,直接填写在方舟申请模型发布的API地址 + "api_key": "", + "secret_key": "", + "provider": "FangZhouWorker", + }, } diff --git a/init_database.py b/init_database.py index 42a18c5..9e807a8 100644 --- a/init_database.py +++ b/init_database.py @@ -1,44 +1,92 @@ -from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, recreate_all_vs, list_kbs_from_folder +from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files from configs.model_config import NLTK_DATA_PATH import nltk nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path -from startup import dump_server_info from datetime import datetime +import sys if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser() - parser.formatter_class = argparse.RawTextHelpFormatter + parser = argparse.ArgumentParser(description="please specify only one operate method once time.") parser.add_argument( + "-r", "--recreate-vs", action="store_true", help=(''' - recreate all vector store. + recreate vector store. use this option if you have copied document files to the content folder, but vector store has not been populated or DEFAUL_VS_TYPE/EMBEDDING_MODEL changed. - if your vector store is ready with the configs, just skip this option to fill info to database only. ''' ) ) - args = parser.parse_args() + parser.add_argument( + "-u", + "--update-in-db", + action="store_true", + help=(''' + update vector store for files exist in database. + use this option if you want to recreate vectors for files exist in db and skip files exist in local folder only. + ''' + ) + ) + parser.add_argument( + "-i", + "--increament", + action="store_true", + help=(''' + update vector store for files exist in local folder and not exist in database. + use this option if you want to create vectors increamentally. + ''' + ) + ) + parser.add_argument( + "--prune-db", + action="store_true", + help=(''' + delete docs in database that not existed in local folder. + it is used to delete database docs after user deleted some doc files in file browser + ''' + ) + ) + parser.add_argument( + "--prune-folder", + action="store_true", + help=(''' + delete doc files in local folder that not existed in database. + is is used to free local disk space by delete unused doc files. + ''' + ) + ) + parser.add_argument( + "--kb-name", + type=str, + nargs="+", + default=[], + help=("specify knowledge base names to operate on. default is all folders exist in KB_ROOT_PATH.") + ) - dump_server_info() - - start_time = datetime.now() - - if args.recreate_vs: - reset_tables() - print("database talbes reseted") - print("recreating all vector stores") - recreate_all_vs() + if len(sys.argv) <= 1: + parser.print_help() else: - create_tables() - print("database talbes created") - print("filling kb infos to database") - for kb in list_kbs_from_folder(): - folder2db(kb, "fill_info_only") + args = parser.parse_args() + start_time = datetime.now() - end_time = datetime.now() - print(f"总计用时: {end_time-start_time}") + create_tables() # confirm tables exist + if args.recreate_vs: + reset_tables() + print("database talbes reseted") + print("recreating all vector stores") + folder2db(kb_names=args.kb_name, mode="recreate_vs") + elif args.update_in_db: + folder2db(kb_names=args.kb_name, mode="update_in_db") + elif args.increament: + folder2db(kb_names=args.kb_name, mode="increament") + elif args.prune_db: + prune_db_docs(args.kb_name) + elif args.prune_folder: + prune_folder_files(args.kb_name) + + end_time = datetime.now() + print(f"总计用时: {end_time-start_time}") diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index afb8331..80eac4e 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -22,9 +22,9 @@ class MilvusKBService(KBService): from pymilvus import Collection return Collection(milvus_name) - def save_vector_store(self): - if self.milvus.col: - self.milvus.col.flush() + # def save_vector_store(self): + # if self.milvus.col: + # self.milvus.col.flush() def get_doc_by_id(self, id: str) -> Optional[Document]: if self.milvus.col: diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index 1b7c7ae..b8a95d9 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -3,7 +3,7 @@ from configs import (EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE, from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder, list_files_from_folder,files2docs_in_thread, KnowledgeFile,) -from server.knowledge_base.kb_service.base import KBServiceFactory, SupportedVSType +from server.knowledge_base.kb_service.base import KBServiceFactory from server.db.repository.knowledge_file_repository import add_file_to_db from server.db.base import Base, engine import os @@ -33,8 +33,8 @@ def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]: def folder2db( - kb_name: str, - mode: Literal["recreate_vs", "fill_info_only", "update_in_db", "increament"], + kb_names: List[str], + mode: Literal["recreate_vs", "update_in_db", "increament"], vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE, embed_model: str = EMBEDDING_MODEL, chunk_size: int = -1, @@ -45,21 +45,11 @@ def folder2db( use existed files in local folder to populate database and/or vector store. set parameter `mode` to: recreate_vs: recreate all vector store and fill info to database using existed files in local folder - fill_info_only: do not create vector store, fill info to db using existed files only + fill_info_only(disabled): do not create vector store, fill info to db using existed files only update_in_db: update vector store and database info using local files that existed in database only increament: create vector store and database info for local files that not existed in database only ''' - kb = KBServiceFactory.get_service(kb_name, vs_type, embed_model) - kb.create_kb() - - if mode == "recreate_vs": - files_count = kb.count_files() - print(f"知识库 {kb_name} 中共有 {files_count} 个文档。\n即将清除向量库。") - kb.clear_vs() - files_count = kb.count_files() - print(f"清理后,知识库 {kb_name} 中共有 {files_count} 个文档。") - - kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name)) + def files2vs(kb_name: str, kb_files: List[KnowledgeFile]): for success, result in files2docs_in_thread(kb_files, chunk_size=chunk_size, chunk_overlap=chunk_overlap, @@ -68,84 +58,77 @@ def folder2db( _, filename, docs = result print(f"正在将 {kb_name}/{filename} 添加到向量库,共包含{len(docs)}条文档") kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) - kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True) + kb_file.splited_docs = docs + kb.add_doc(kb_file=kb_file, not_refresh_vs_cache=True) else: print(result) - kb.save_vector_store() - elif mode == "fill_info_only": - files = list_files_from_folder(kb_name) - kb_files = file_to_kbfile(kb_name, files) - for kb_file in kb_files: - add_file_to_db(kb_file) - print(f"已将 {kb_name}/{kb_file.filename} 添加到数据库") - elif mode == "update_in_db": - files = kb.list_files() - kb_files = file_to_kbfile(kb_name, files) + kb_names = kb_names or list_kbs_from_folder() + for kb_name in kb_names: + kb = KBServiceFactory.get_service(kb_name, vs_type, embed_model) + kb.create_kb() - for kb_file in kb_files: - kb.update_doc(kb_file, not_refresh_vs_cache=True) - kb.save_vector_store() - elif mode == "increament": - db_files = kb.list_files() - folder_files = list_files_from_folder(kb_name) - files = list(set(folder_files) - set(db_files)) - kb_files = file_to_kbfile(kb_name, files) - - for success, result in files2docs_in_thread(kb_files, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - zh_title_enhance=zh_title_enhance): - if success: - _, filename, docs = result - print(f"正在将 {kb_name}/{filename} 添加到向量库") - kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) - kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True) - else: - print(result) - kb.save_vector_store() - else: - print(f"unspported migrate mode: {mode}") + # 清除向量库,从本地文件重建 + if mode == "recreate_vs": + kb.clear_vs() + kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name)) + files2vs(kb_name, kb_files) + kb.save_vector_store() + # # 不做文件内容的向量化,仅将文件元信息存到数据库 + # # 由于现在数据库存了很多与文本切分相关的信息,单纯存储文件信息意义不大,该功能取消。 + # elif mode == "fill_info_only": + # files = list_files_from_folder(kb_name) + # kb_files = file_to_kbfile(kb_name, files) + # for kb_file in kb_files: + # add_file_to_db(kb_file) + # print(f"已将 {kb_name}/{kb_file.filename} 添加到数据库") + # 以数据库中文件列表为基准,利用本地文件更新向量库 + elif mode == "update_in_db": + files = kb.list_files() + kb_files = file_to_kbfile(kb_name, files) + files2vs(kb_name, kb_files) + kb.save_vector_store() + # 对比本地目录与数据库中的文件列表,进行增量向量化 + elif mode == "increament": + db_files = kb.list_files() + folder_files = list_files_from_folder(kb_name) + files = list(set(folder_files) - set(db_files)) + kb_files = file_to_kbfile(kb_name, files) + files2vs(kb_name, kb_files) + kb.save_vector_store() + else: + print(f"unspported migrate mode: {mode}") -def recreate_all_vs( - vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE, - embed_mode: str = EMBEDDING_MODEL, - **kwargs: Any, -): +def prune_db_docs(kb_names: List[str]): ''' - used to recreate a vector store or change current vector store to another type or embed_model + delete docs in database that not existed in local folder. + it is used to delete database docs after user deleted some doc files in file browser ''' - for kb_name in list_kbs_from_folder(): - folder2db(kb_name, "recreate_vs", vs_type, embed_mode, **kwargs) + for kb_name in kb_names: + kb = KBServiceFactory.get_service_by_name(kb_name) + if kb and kb.exists(): + files_in_db = kb.list_files() + files_in_folder = list_files_from_folder(kb_name) + files = list(set(files_in_db) - set(files_in_folder)) + kb_files = file_to_kbfile(kb_name, files) + for kb_file in kb_files: + kb.delete_doc(kb_file, not_refresh_vs_cache=True) + print(f"success to delete docs for file: {kb_name}/{kb_file.filename}") + kb.save_vector_store() -def prune_db_files(kb_name: str): - ''' - delete files in database that not existed in local folder. - it is used to delete database files after user deleted some doc files in file browser - ''' - kb = KBServiceFactory.get_service_by_name(kb_name) - if kb.exists(): - files_in_db = kb.list_files() - files_in_folder = list_files_from_folder(kb_name) - files = list(set(files_in_db) - set(files_in_folder)) - kb_files = file_to_kbfile(kb_name, files) - for kb_file in kb_files: - kb.delete_doc(kb_file, not_refresh_vs_cache=True) - kb.save_vector_store() - return kb_files - -def prune_folder_files(kb_name: str): +def prune_folder_files(kb_names: List[str]): ''' delete doc files in local folder that not existed in database. is is used to free local disk space by delete unused doc files. ''' - kb = KBServiceFactory.get_service_by_name(kb_name) - if kb.exists(): - files_in_db = kb.list_files() - files_in_folder = list_files_from_folder(kb_name) - files = list(set(files_in_folder) - set(files_in_db)) - for file in files: - os.remove(get_file_path(kb_name, file)) - return files + for kb_name in kb_names: + kb = KBServiceFactory.get_service_by_name(kb_name) + if kb and kb.exists(): + files_in_db = kb.list_files() + files_in_folder = list_files_from_folder(kb_name) + files = list(set(files_in_folder) - set(files_in_db)) + for file in files: + os.remove(get_file_path(kb_name, file)) + print(f"success to delete file: {kb_name}/{file}") diff --git a/server/model_workers/fangzhou.py b/server/model_workers/fangzhou.py new file mode 100644 index 0000000..08820de --- /dev/null +++ b/server/model_workers/fangzhou.py @@ -0,0 +1,91 @@ +from server.model_workers.base import ApiModelWorker +from configs.model_config import TEMPERATURE +from fastchat import conversation as conv +import sys +import json +from typing import List, Literal, Dict + + +class FangZhouWorker(ApiModelWorker): + """ + 火山方舟 + """ + SUPPORT_MODELS = ["chatglm-6b-model"] + def __init__( + self, + *, + version: Literal["chatglm-6b-model"] = "chatglm-6b-model", + model_names: List[str] = ["fangzhou-api"], + controller_addr: str, + worker_addr: str, + **kwargs, + ): + kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) + kwargs.setdefault("context_len", 16384) + super().__init__(**kwargs) + + config = self.get_config() + self.version = version + self.api_key = config.get("api_key") + self.secret_key = config.get("secret_key") + + from volcengine.maas import ChatRole + self.conv = conv.Conversation( + name=self.model_names[0], + system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。", + messages=[], + roles=[ChatRole.USER, ChatRole.ASSISTANT, ChatRole.SYSTEM], + sep="\n### ", + stop_str="###", + ) + + def generate_stream_gate(self, params): + super().generate_stream_gate(params) + from volcengine.maas import MaasService, MaasException, ChatRole + maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing') + + maas.set_ak(self.api_key) + maas.set_sk(self.secret_key) + + # document: "https://www.volcengine.com/docs/82379/1099475" + req = { + "model": { + "name": self.version, + }, + "parameters": { + # 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明 + # "max_new_tokens": 1000, + "temperature": params.get("temperature", TEMPERATURE), + }, + "messages": [{"role": ChatRole.USER, "content": params["prompt"]}] + } + text = "" + try: + + resps = maas.stream_chat(req) + for resp in resps: + text += resp.choice.message.content + yield json.dumps({"error_code": 0, "text": text}, + ensure_ascii=False).encode() + b"\0" + except MaasException as e: + print(e) + + +def get_embeddings(self, params): + # TODO: 支持embeddings + print("embedding") + print(params) + + +if __name__ == "__main__": + import uvicorn + from server.utils import MakeFastAPIOffline + from fastchat.serve.model_worker import app + + worker = FangZhouWorker( + controller_addr="http://127.0.0.1:20001", + worker_addr="http://127.0.0.1:20006", + ) + sys.modules["fastchat.serve.model_worker"].worker = worker + MakeFastAPIOffline(app) + uvicorn.run(app, port=20006) diff --git a/tests/test_migrate.py b/tests/test_migrate.py new file mode 100644 index 0000000..d694b02 --- /dev/null +++ b/tests/test_migrate.py @@ -0,0 +1,139 @@ +from pathlib import Path +from pprint import pprint +import os +import shutil +import sys +root_path = Path(__file__).parent.parent +sys.path.append(str(root_path)) + +from server.knowledge_base.kb_service.base import KBServiceFactory +from server.knowledge_base.utils import get_kb_path, get_doc_path, KnowledgeFile +from server.knowledge_base.migrate import folder2db, prune_db_docs, prune_folder_files + + +# setup test knowledge base +kb_name = "test_kb_for_migrate" +test_files = { + "faq.md": str(root_path / "docs" / "faq.md"), + "install.md": str(root_path / "docs" / "install.md"), +} + + +kb_path = get_kb_path(kb_name) +doc_path = get_doc_path(kb_name) + +if not os.path.isdir(doc_path): + os.makedirs(doc_path) + +for k, v in test_files.items(): + shutil.copy(v, os.path.join(doc_path, k)) + + +def test_recreate_vs(): + folder2db([kb_name], "recreate_vs") + + kb = KBServiceFactory.get_service_by_name(kb_name) + assert kb.exists() + + files = kb.list_files() + print(files) + for name in test_files: + assert name in files + path = os.path.join(doc_path, name) + + # list docs based on file name + docs = kb.list_docs(file_name=name) + assert len(docs) > 0 + pprint(docs[0]) + for doc in docs: + assert doc.metadata["source"] == path + + # list docs base on metadata + docs = kb.list_docs(metadata={"source": path}) + assert len(docs) > 0 + + for doc in docs: + assert doc.metadata["source"] == path + + +def test_increament(): + kb = KBServiceFactory.get_service_by_name(kb_name) + kb.clear_vs() + assert kb.list_files() == [] + assert kb.list_docs() == [] + + folder2db([kb_name], "increament") + + files = kb.list_files() + print(files) + for f in test_files: + assert f in files + + docs = kb.list_docs(file_name=f) + assert len(docs) > 0 + pprint(docs[0]) + + for doc in docs: + assert doc.metadata["source"] == os.path.join(doc_path, f) + + +def test_prune_db(): + del_file, keep_file = list(test_files)[:2] + os.remove(os.path.join(doc_path, del_file)) + + prune_db_docs([kb_name]) + + kb = KBServiceFactory.get_service_by_name(kb_name) + files = kb.list_files() + print(files) + assert del_file not in files + assert keep_file in files + + docs = kb.list_docs(file_name=del_file) + assert len(docs) == 0 + + docs = kb.list_docs(file_name=keep_file) + assert len(docs) > 0 + pprint(docs[0]) + + shutil.copy(test_files[del_file], os.path.join(doc_path, del_file)) + + +def test_prune_folder(): + del_file, keep_file = list(test_files)[:2] + kb = KBServiceFactory.get_service_by_name(kb_name) + + # delete docs for file + kb.delete_doc(KnowledgeFile(del_file, kb_name)) + files = kb.list_files() + print(files) + assert del_file not in files + assert keep_file in files + + docs = kb.list_docs(file_name=del_file) + assert len(docs) == 0 + + docs = kb.list_docs(file_name=keep_file) + assert len(docs) > 0 + + docs = kb.list_docs(file_name=del_file) + assert len(docs) == 0 + + assert os.path.isfile(os.path.join(doc_path, del_file)) + + # prune folder + prune_folder_files([kb_name]) + + # check result + assert not os.path.isfile(os.path.join(doc_path, del_file)) + assert os.path.isfile(os.path.join(doc_path, keep_file)) + + +def test_drop_kb(): + kb = KBServiceFactory.get_service_by_name(kb_name) + kb.drop_kb() + assert not kb.exists() + assert not os.path.isdir(kb_path) + + kb = KBServiceFactory.get_service_by_name(kb_name) + assert kb is None