From 745a105baed71563bd88d9447a8033aede0ce0e7 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Sat, 16 Sep 2023 09:09:27 +0800 Subject: [PATCH 1/2] feat: support volc fangzhou --- configs/model_config.py.example | 8 + init_database.py | 94 ++++++++--- .../kb_service/milvus_kb_service.py | 6 +- server/knowledge_base/migrate.py | 147 ++++++++---------- server/model_workers/fangzhou.py | 91 +++++++++++ tests/test_migrate.py | 139 +++++++++++++++++ 6 files changed, 377 insertions(+), 108 deletions(-) create mode 100644 server/model_workers/fangzhou.py create mode 100644 tests/test_migrate.py diff --git a/configs/model_config.py.example b/configs/model_config.py.example index dff60f7..c6f1dc4 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -115,6 +115,14 @@ ONLINE_LLM_MODEL = { "secret_key": "", "provider": "QianFanWorker", }, + # 火山方舟 API + "fangzhou-api": { + "version": "chatglm-6b-model", # 当前支持 "chatglm-6b-model", 更多的见文档模型支持列表中方舟部分。 + "version_url": "", # 可以不填写version,直接填写在方舟申请模型发布的API地址 + "api_key": "", + "secret_key": "", + "provider": "FangZhouWorker", + }, } diff --git a/init_database.py b/init_database.py index 42a18c5..9e807a8 100644 --- a/init_database.py +++ b/init_database.py @@ -1,44 +1,92 @@ -from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, recreate_all_vs, list_kbs_from_folder +from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files from configs.model_config import NLTK_DATA_PATH import nltk nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path -from startup import dump_server_info from datetime import datetime +import sys if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser() - parser.formatter_class = argparse.RawTextHelpFormatter + parser = argparse.ArgumentParser(description="please specify only one operate method once time.") parser.add_argument( + "-r", "--recreate-vs", action="store_true", help=(''' - recreate all vector store. + recreate vector store. use this option if you have copied document files to the content folder, but vector store has not been populated or DEFAUL_VS_TYPE/EMBEDDING_MODEL changed. - if your vector store is ready with the configs, just skip this option to fill info to database only. ''' ) ) - args = parser.parse_args() + parser.add_argument( + "-u", + "--update-in-db", + action="store_true", + help=(''' + update vector store for files exist in database. + use this option if you want to recreate vectors for files exist in db and skip files exist in local folder only. + ''' + ) + ) + parser.add_argument( + "-i", + "--increament", + action="store_true", + help=(''' + update vector store for files exist in local folder and not exist in database. + use this option if you want to create vectors increamentally. + ''' + ) + ) + parser.add_argument( + "--prune-db", + action="store_true", + help=(''' + delete docs in database that not existed in local folder. + it is used to delete database docs after user deleted some doc files in file browser + ''' + ) + ) + parser.add_argument( + "--prune-folder", + action="store_true", + help=(''' + delete doc files in local folder that not existed in database. + is is used to free local disk space by delete unused doc files. + ''' + ) + ) + parser.add_argument( + "--kb-name", + type=str, + nargs="+", + default=[], + help=("specify knowledge base names to operate on. default is all folders exist in KB_ROOT_PATH.") + ) - dump_server_info() - - start_time = datetime.now() - - if args.recreate_vs: - reset_tables() - print("database talbes reseted") - print("recreating all vector stores") - recreate_all_vs() + if len(sys.argv) <= 1: + parser.print_help() else: - create_tables() - print("database talbes created") - print("filling kb infos to database") - for kb in list_kbs_from_folder(): - folder2db(kb, "fill_info_only") + args = parser.parse_args() + start_time = datetime.now() - end_time = datetime.now() - print(f"总计用时: {end_time-start_time}") + create_tables() # confirm tables exist + if args.recreate_vs: + reset_tables() + print("database talbes reseted") + print("recreating all vector stores") + folder2db(kb_names=args.kb_name, mode="recreate_vs") + elif args.update_in_db: + folder2db(kb_names=args.kb_name, mode="update_in_db") + elif args.increament: + folder2db(kb_names=args.kb_name, mode="increament") + elif args.prune_db: + prune_db_docs(args.kb_name) + elif args.prune_folder: + prune_folder_files(args.kb_name) + + end_time = datetime.now() + print(f"总计用时: {end_time-start_time}") diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index afb8331..80eac4e 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -22,9 +22,9 @@ class MilvusKBService(KBService): from pymilvus import Collection return Collection(milvus_name) - def save_vector_store(self): - if self.milvus.col: - self.milvus.col.flush() + # def save_vector_store(self): + # if self.milvus.col: + # self.milvus.col.flush() def get_doc_by_id(self, id: str) -> Optional[Document]: if self.milvus.col: diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index 1b7c7ae..b8a95d9 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -3,7 +3,7 @@ from configs import (EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE, from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder, list_files_from_folder,files2docs_in_thread, KnowledgeFile,) -from server.knowledge_base.kb_service.base import KBServiceFactory, SupportedVSType +from server.knowledge_base.kb_service.base import KBServiceFactory from server.db.repository.knowledge_file_repository import add_file_to_db from server.db.base import Base, engine import os @@ -33,8 +33,8 @@ def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]: def folder2db( - kb_name: str, - mode: Literal["recreate_vs", "fill_info_only", "update_in_db", "increament"], + kb_names: List[str], + mode: Literal["recreate_vs", "update_in_db", "increament"], vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE, embed_model: str = EMBEDDING_MODEL, chunk_size: int = -1, @@ -45,21 +45,11 @@ def folder2db( use existed files in local folder to populate database and/or vector store. set parameter `mode` to: recreate_vs: recreate all vector store and fill info to database using existed files in local folder - fill_info_only: do not create vector store, fill info to db using existed files only + fill_info_only(disabled): do not create vector store, fill info to db using existed files only update_in_db: update vector store and database info using local files that existed in database only increament: create vector store and database info for local files that not existed in database only ''' - kb = KBServiceFactory.get_service(kb_name, vs_type, embed_model) - kb.create_kb() - - if mode == "recreate_vs": - files_count = kb.count_files() - print(f"知识库 {kb_name} 中共有 {files_count} 个文档。\n即将清除向量库。") - kb.clear_vs() - files_count = kb.count_files() - print(f"清理后,知识库 {kb_name} 中共有 {files_count} 个文档。") - - kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name)) + def files2vs(kb_name: str, kb_files: List[KnowledgeFile]): for success, result in files2docs_in_thread(kb_files, chunk_size=chunk_size, chunk_overlap=chunk_overlap, @@ -68,84 +58,77 @@ def folder2db( _, filename, docs = result print(f"正在将 {kb_name}/{filename} 添加到向量库,共包含{len(docs)}条文档") kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) - kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True) + kb_file.splited_docs = docs + kb.add_doc(kb_file=kb_file, not_refresh_vs_cache=True) else: print(result) - kb.save_vector_store() - elif mode == "fill_info_only": - files = list_files_from_folder(kb_name) - kb_files = file_to_kbfile(kb_name, files) - for kb_file in kb_files: - add_file_to_db(kb_file) - print(f"已将 {kb_name}/{kb_file.filename} 添加到数据库") - elif mode == "update_in_db": - files = kb.list_files() - kb_files = file_to_kbfile(kb_name, files) + kb_names = kb_names or list_kbs_from_folder() + for kb_name in kb_names: + kb = KBServiceFactory.get_service(kb_name, vs_type, embed_model) + kb.create_kb() - for kb_file in kb_files: - kb.update_doc(kb_file, not_refresh_vs_cache=True) - kb.save_vector_store() - elif mode == "increament": - db_files = kb.list_files() - folder_files = list_files_from_folder(kb_name) - files = list(set(folder_files) - set(db_files)) - kb_files = file_to_kbfile(kb_name, files) - - for success, result in files2docs_in_thread(kb_files, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - zh_title_enhance=zh_title_enhance): - if success: - _, filename, docs = result - print(f"正在将 {kb_name}/{filename} 添加到向量库") - kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) - kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True) - else: - print(result) - kb.save_vector_store() - else: - print(f"unspported migrate mode: {mode}") + # 清除向量库,从本地文件重建 + if mode == "recreate_vs": + kb.clear_vs() + kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name)) + files2vs(kb_name, kb_files) + kb.save_vector_store() + # # 不做文件内容的向量化,仅将文件元信息存到数据库 + # # 由于现在数据库存了很多与文本切分相关的信息,单纯存储文件信息意义不大,该功能取消。 + # elif mode == "fill_info_only": + # files = list_files_from_folder(kb_name) + # kb_files = file_to_kbfile(kb_name, files) + # for kb_file in kb_files: + # add_file_to_db(kb_file) + # print(f"已将 {kb_name}/{kb_file.filename} 添加到数据库") + # 以数据库中文件列表为基准,利用本地文件更新向量库 + elif mode == "update_in_db": + files = kb.list_files() + kb_files = file_to_kbfile(kb_name, files) + files2vs(kb_name, kb_files) + kb.save_vector_store() + # 对比本地目录与数据库中的文件列表,进行增量向量化 + elif mode == "increament": + db_files = kb.list_files() + folder_files = list_files_from_folder(kb_name) + files = list(set(folder_files) - set(db_files)) + kb_files = file_to_kbfile(kb_name, files) + files2vs(kb_name, kb_files) + kb.save_vector_store() + else: + print(f"unspported migrate mode: {mode}") -def recreate_all_vs( - vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE, - embed_mode: str = EMBEDDING_MODEL, - **kwargs: Any, -): +def prune_db_docs(kb_names: List[str]): ''' - used to recreate a vector store or change current vector store to another type or embed_model + delete docs in database that not existed in local folder. + it is used to delete database docs after user deleted some doc files in file browser ''' - for kb_name in list_kbs_from_folder(): - folder2db(kb_name, "recreate_vs", vs_type, embed_mode, **kwargs) + for kb_name in kb_names: + kb = KBServiceFactory.get_service_by_name(kb_name) + if kb and kb.exists(): + files_in_db = kb.list_files() + files_in_folder = list_files_from_folder(kb_name) + files = list(set(files_in_db) - set(files_in_folder)) + kb_files = file_to_kbfile(kb_name, files) + for kb_file in kb_files: + kb.delete_doc(kb_file, not_refresh_vs_cache=True) + print(f"success to delete docs for file: {kb_name}/{kb_file.filename}") + kb.save_vector_store() -def prune_db_files(kb_name: str): - ''' - delete files in database that not existed in local folder. - it is used to delete database files after user deleted some doc files in file browser - ''' - kb = KBServiceFactory.get_service_by_name(kb_name) - if kb.exists(): - files_in_db = kb.list_files() - files_in_folder = list_files_from_folder(kb_name) - files = list(set(files_in_db) - set(files_in_folder)) - kb_files = file_to_kbfile(kb_name, files) - for kb_file in kb_files: - kb.delete_doc(kb_file, not_refresh_vs_cache=True) - kb.save_vector_store() - return kb_files - -def prune_folder_files(kb_name: str): +def prune_folder_files(kb_names: List[str]): ''' delete doc files in local folder that not existed in database. is is used to free local disk space by delete unused doc files. ''' - kb = KBServiceFactory.get_service_by_name(kb_name) - if kb.exists(): - files_in_db = kb.list_files() - files_in_folder = list_files_from_folder(kb_name) - files = list(set(files_in_folder) - set(files_in_db)) - for file in files: - os.remove(get_file_path(kb_name, file)) - return files + for kb_name in kb_names: + kb = KBServiceFactory.get_service_by_name(kb_name) + if kb and kb.exists(): + files_in_db = kb.list_files() + files_in_folder = list_files_from_folder(kb_name) + files = list(set(files_in_folder) - set(files_in_db)) + for file in files: + os.remove(get_file_path(kb_name, file)) + print(f"success to delete file: {kb_name}/{file}") diff --git a/server/model_workers/fangzhou.py b/server/model_workers/fangzhou.py new file mode 100644 index 0000000..08820de --- /dev/null +++ b/server/model_workers/fangzhou.py @@ -0,0 +1,91 @@ +from server.model_workers.base import ApiModelWorker +from configs.model_config import TEMPERATURE +from fastchat import conversation as conv +import sys +import json +from typing import List, Literal, Dict + + +class FangZhouWorker(ApiModelWorker): + """ + 火山方舟 + """ + SUPPORT_MODELS = ["chatglm-6b-model"] + def __init__( + self, + *, + version: Literal["chatglm-6b-model"] = "chatglm-6b-model", + model_names: List[str] = ["fangzhou-api"], + controller_addr: str, + worker_addr: str, + **kwargs, + ): + kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) + kwargs.setdefault("context_len", 16384) + super().__init__(**kwargs) + + config = self.get_config() + self.version = version + self.api_key = config.get("api_key") + self.secret_key = config.get("secret_key") + + from volcengine.maas import ChatRole + self.conv = conv.Conversation( + name=self.model_names[0], + system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。", + messages=[], + roles=[ChatRole.USER, ChatRole.ASSISTANT, ChatRole.SYSTEM], + sep="\n### ", + stop_str="###", + ) + + def generate_stream_gate(self, params): + super().generate_stream_gate(params) + from volcengine.maas import MaasService, MaasException, ChatRole + maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing') + + maas.set_ak(self.api_key) + maas.set_sk(self.secret_key) + + # document: "https://www.volcengine.com/docs/82379/1099475" + req = { + "model": { + "name": self.version, + }, + "parameters": { + # 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明 + # "max_new_tokens": 1000, + "temperature": params.get("temperature", TEMPERATURE), + }, + "messages": [{"role": ChatRole.USER, "content": params["prompt"]}] + } + text = "" + try: + + resps = maas.stream_chat(req) + for resp in resps: + text += resp.choice.message.content + yield json.dumps({"error_code": 0, "text": text}, + ensure_ascii=False).encode() + b"\0" + except MaasException as e: + print(e) + + +def get_embeddings(self, params): + # TODO: 支持embeddings + print("embedding") + print(params) + + +if __name__ == "__main__": + import uvicorn + from server.utils import MakeFastAPIOffline + from fastchat.serve.model_worker import app + + worker = FangZhouWorker( + controller_addr="http://127.0.0.1:20001", + worker_addr="http://127.0.0.1:20006", + ) + sys.modules["fastchat.serve.model_worker"].worker = worker + MakeFastAPIOffline(app) + uvicorn.run(app, port=20006) diff --git a/tests/test_migrate.py b/tests/test_migrate.py new file mode 100644 index 0000000..d694b02 --- /dev/null +++ b/tests/test_migrate.py @@ -0,0 +1,139 @@ +from pathlib import Path +from pprint import pprint +import os +import shutil +import sys +root_path = Path(__file__).parent.parent +sys.path.append(str(root_path)) + +from server.knowledge_base.kb_service.base import KBServiceFactory +from server.knowledge_base.utils import get_kb_path, get_doc_path, KnowledgeFile +from server.knowledge_base.migrate import folder2db, prune_db_docs, prune_folder_files + + +# setup test knowledge base +kb_name = "test_kb_for_migrate" +test_files = { + "faq.md": str(root_path / "docs" / "faq.md"), + "install.md": str(root_path / "docs" / "install.md"), +} + + +kb_path = get_kb_path(kb_name) +doc_path = get_doc_path(kb_name) + +if not os.path.isdir(doc_path): + os.makedirs(doc_path) + +for k, v in test_files.items(): + shutil.copy(v, os.path.join(doc_path, k)) + + +def test_recreate_vs(): + folder2db([kb_name], "recreate_vs") + + kb = KBServiceFactory.get_service_by_name(kb_name) + assert kb.exists() + + files = kb.list_files() + print(files) + for name in test_files: + assert name in files + path = os.path.join(doc_path, name) + + # list docs based on file name + docs = kb.list_docs(file_name=name) + assert len(docs) > 0 + pprint(docs[0]) + for doc in docs: + assert doc.metadata["source"] == path + + # list docs base on metadata + docs = kb.list_docs(metadata={"source": path}) + assert len(docs) > 0 + + for doc in docs: + assert doc.metadata["source"] == path + + +def test_increament(): + kb = KBServiceFactory.get_service_by_name(kb_name) + kb.clear_vs() + assert kb.list_files() == [] + assert kb.list_docs() == [] + + folder2db([kb_name], "increament") + + files = kb.list_files() + print(files) + for f in test_files: + assert f in files + + docs = kb.list_docs(file_name=f) + assert len(docs) > 0 + pprint(docs[0]) + + for doc in docs: + assert doc.metadata["source"] == os.path.join(doc_path, f) + + +def test_prune_db(): + del_file, keep_file = list(test_files)[:2] + os.remove(os.path.join(doc_path, del_file)) + + prune_db_docs([kb_name]) + + kb = KBServiceFactory.get_service_by_name(kb_name) + files = kb.list_files() + print(files) + assert del_file not in files + assert keep_file in files + + docs = kb.list_docs(file_name=del_file) + assert len(docs) == 0 + + docs = kb.list_docs(file_name=keep_file) + assert len(docs) > 0 + pprint(docs[0]) + + shutil.copy(test_files[del_file], os.path.join(doc_path, del_file)) + + +def test_prune_folder(): + del_file, keep_file = list(test_files)[:2] + kb = KBServiceFactory.get_service_by_name(kb_name) + + # delete docs for file + kb.delete_doc(KnowledgeFile(del_file, kb_name)) + files = kb.list_files() + print(files) + assert del_file not in files + assert keep_file in files + + docs = kb.list_docs(file_name=del_file) + assert len(docs) == 0 + + docs = kb.list_docs(file_name=keep_file) + assert len(docs) > 0 + + docs = kb.list_docs(file_name=del_file) + assert len(docs) == 0 + + assert os.path.isfile(os.path.join(doc_path, del_file)) + + # prune folder + prune_folder_files([kb_name]) + + # check result + assert not os.path.isfile(os.path.join(doc_path, del_file)) + assert os.path.isfile(os.path.join(doc_path, keep_file)) + + +def test_drop_kb(): + kb = KBServiceFactory.get_service_by_name(kb_name) + kb.drop_kb() + assert not kb.exists() + assert not os.path.isdir(kb_path) + + kb = KBServiceFactory.get_service_by_name(kb_name) + assert kb is None From 9a7beef270e27ecea8c18cf264b9e5b7bee899d5 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Sun, 17 Sep 2023 00:21:13 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E4=BD=BF=E7=81=AB=E5=B1=B1=E6=96=B9?= =?UTF-8?q?=E8=88=9F=E6=AD=A3=E5=B8=B8=E5=B7=A5=E4=BD=9C=EF=BC=8C=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E9=94=99=E8=AF=AF=E5=A4=84=E7=90=86=E5=92=8C=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/model_workers/__init__.py | 1 + server/model_workers/fangzhou.py | 96 ++++++++++++++++++++----------- server/model_workers/minimax.py | 4 +- server/model_workers/qianfan.py | 4 +- server/model_workers/xinghuo.py | 4 +- server/model_workers/zhipu.py | 4 +- tests/online_api/test_fangzhou.py | 22 +++++++ tests/online_api/test_qianfan.py | 2 +- 8 files changed, 95 insertions(+), 42 deletions(-) create mode 100644 tests/online_api/test_fangzhou.py diff --git a/server/model_workers/__init__.py b/server/model_workers/__init__.py index a3a162f..611d94e 100644 --- a/server/model_workers/__init__.py +++ b/server/model_workers/__init__.py @@ -2,3 +2,4 @@ from .zhipu import ChatGLMWorker from .minimax import MiniMaxWorker from .xinghuo import XingHuoWorker from .qianfan import QianFanWorker +from .fangzhou import FangZhouWorker diff --git a/server/model_workers/fangzhou.py b/server/model_workers/fangzhou.py index 08820de..5207fdb 100644 --- a/server/model_workers/fangzhou.py +++ b/server/model_workers/fangzhou.py @@ -3,9 +3,52 @@ from configs.model_config import TEMPERATURE from fastchat import conversation as conv import sys import json +from pprint import pprint +from server.utils import get_model_worker_config from typing import List, Literal, Dict +def request_volc_api( + messages: List[Dict], + model_name: str = "fangzhou-api", + version: str = "chatglm-6b-model", + temperature: float = TEMPERATURE, + api_key: str = None, + secret_key: str = None, +): + from volcengine.maas import MaasService, MaasException, ChatRole + + maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing') + config = get_model_worker_config(model_name) + version = version or config.get("version") + version_url = config.get("version_url") + api_key = api_key or config.get("api_key") + secret_key = secret_key or config.get("secret_key") + + maas.set_ak(api_key) + maas.set_sk(secret_key) + + # document: "https://www.volcengine.com/docs/82379/1099475" + req = { + "model": { + "name": version, + }, + "parameters": { + # 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明 + "max_new_tokens": 1000, + "temperature": temperature, + }, + "messages": messages, + } + + try: + resps = maas.stream_chat(req) + for resp in resps: + yield resp + except MaasException as e: + print(e) + + class FangZhouWorker(ApiModelWorker): """ 火山方舟 @@ -21,7 +64,7 @@ class FangZhouWorker(ApiModelWorker): **kwargs, ): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) - kwargs.setdefault("context_len", 16384) + kwargs.setdefault("context_len", 16384) # TODO: 不同的模型有不同的大小 super().__init__(**kwargs) config = self.get_config() @@ -41,40 +84,27 @@ class FangZhouWorker(ApiModelWorker): def generate_stream_gate(self, params): super().generate_stream_gate(params) - from volcengine.maas import MaasService, MaasException, ChatRole - maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing') - maas.set_ak(self.api_key) - maas.set_sk(self.secret_key) - - # document: "https://www.volcengine.com/docs/82379/1099475" - req = { - "model": { - "name": self.version, - }, - "parameters": { - # 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明 - # "max_new_tokens": 1000, - "temperature": params.get("temperature", TEMPERATURE), - }, - "messages": [{"role": ChatRole.USER, "content": params["prompt"]}] - } + messages = self.prompt_to_messages(params["prompt"]) text = "" - try: - resps = maas.stream_chat(req) - for resp in resps: - text += resp.choice.message.content - yield json.dumps({"error_code": 0, "text": text}, - ensure_ascii=False).encode() + b"\0" - except MaasException as e: - print(e) + for resp in request_volc_api(messages=messages, + model_name=self.model_names[0], + version=self.version, + temperature=params.get("temperature", TEMPERATURE), + ): + error = resp.error + if error.code_n > 0: + data = {"error_code": error.code_n, "text": error.message} + elif chunk := resp.choice.message.content: + text += chunk + data = {"error_code": 0, "text": text} + yield json.dumps(data, ensure_ascii=False).encode() + b"\0" - -def get_embeddings(self, params): - # TODO: 支持embeddings - print("embedding") - print(params) + def get_embeddings(self, params): + # TODO: 支持embeddings + print("embedding") + print(params) if __name__ == "__main__": @@ -84,8 +114,8 @@ if __name__ == "__main__": worker = FangZhouWorker( controller_addr="http://127.0.0.1:20001", - worker_addr="http://127.0.0.1:20006", + worker_addr="http://127.0.0.1:21005", ) sys.modules["fastchat.serve.model_worker"].worker = worker MakeFastAPIOffline(app) - uvicorn.run(app, port=20006) + uvicorn.run(app, port=21005) diff --git a/server/model_workers/minimax.py b/server/model_workers/minimax.py index c772c0d..39ff293 100644 --- a/server/model_workers/minimax.py +++ b/server/model_workers/minimax.py @@ -93,8 +93,8 @@ if __name__ == "__main__": worker = MiniMaxWorker( controller_addr="http://127.0.0.1:20001", - worker_addr="http://127.0.0.1:20004", + worker_addr="http://127.0.0.1:21002", ) sys.modules["fastchat.serve.model_worker"].worker = worker MakeFastAPIOffline(app) - uvicorn.run(app, port=20003) + uvicorn.run(app, port=21002) diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py index f7a5161..387d4b7 100644 --- a/server/model_workers/qianfan.py +++ b/server/model_workers/qianfan.py @@ -168,8 +168,8 @@ if __name__ == "__main__": worker = QianFanWorker( controller_addr="http://127.0.0.1:20001", - worker_addr="http://127.0.0.1:20006", + worker_addr="http://127.0.0.1:21004" ) sys.modules["fastchat.serve.model_worker"].worker = worker MakeFastAPIOffline(app) - uvicorn.run(app, port=20006) \ No newline at end of file + uvicorn.run(app, port=21004) \ No newline at end of file diff --git a/server/model_workers/xinghuo.py b/server/model_workers/xinghuo.py index 499e8bc..bc98a9c 100644 --- a/server/model_workers/xinghuo.py +++ b/server/model_workers/xinghuo.py @@ -94,8 +94,8 @@ if __name__ == "__main__": worker = XingHuoWorker( controller_addr="http://127.0.0.1:20001", - worker_addr="http://127.0.0.1:20005", + worker_addr="http://127.0.0.1:21003", ) sys.modules["fastchat.serve.model_worker"].worker = worker MakeFastAPIOffline(app) - uvicorn.run(app, port=20005) + uvicorn.run(app, port=21003) diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py index f835ac0..18cec5b 100644 --- a/server/model_workers/zhipu.py +++ b/server/model_workers/zhipu.py @@ -67,8 +67,8 @@ if __name__ == "__main__": worker = ChatGLMWorker( controller_addr="http://127.0.0.1:20001", - worker_addr="http://127.0.0.1:20003", + worker_addr="http://127.0.0.1:21001", ) sys.modules["fastchat.serve.model_worker"].worker = worker MakeFastAPIOffline(app) - uvicorn.run(app, port=20003) + uvicorn.run(app, port=21001) diff --git a/tests/online_api/test_fangzhou.py b/tests/online_api/test_fangzhou.py new file mode 100644 index 0000000..1157537 --- /dev/null +++ b/tests/online_api/test_fangzhou.py @@ -0,0 +1,22 @@ +import sys +from pathlib import Path +root_path = Path(__file__).parent.parent.parent +sys.path.append(str(root_path)) + +from server.model_workers.fangzhou import request_volc_api +from pprint import pprint +import pytest + + +@pytest.mark.parametrize("version", ["chatglm-6b-model"]) +def test_qianfan(version): + messages = [{"role": "user", "content": "hello"}] + print("\n" + version + "\n") + i = 1 + for x in request_volc_api(messages, version=version): + print(type(x)) + pprint(x) + if chunk := x.choice.message.content: + print(chunk) + assert x.choice.message + i += 1 diff --git a/tests/online_api/test_qianfan.py b/tests/online_api/test_qianfan.py index 0e8a948..b4b9b15 100644 --- a/tests/online_api/test_qianfan.py +++ b/tests/online_api/test_qianfan.py @@ -8,7 +8,7 @@ from pprint import pprint import pytest -@pytest.mark.parametrize("version", MODEL_VERSIONS.keys()) +@pytest.mark.parametrize("version", list(MODEL_VERSIONS.keys())[:2]) def test_qianfan(version): messages = [{"role": "user", "content": "你好"}] print("\n" + version + "\n")