feat: support volc fangzhou

This commit is contained in:
liunux4odoo 2023-09-16 09:09:27 +08:00 committed by liunux4odoo
parent 3dde02be28
commit 745a105bae
6 changed files with 377 additions and 108 deletions

View File

@ -115,6 +115,14 @@ ONLINE_LLM_MODEL = {
"secret_key": "",
"provider": "QianFanWorker",
},
# 火山方舟 API
"fangzhou-api": {
"version": "chatglm-6b-model", # 当前支持 "chatglm-6b-model" 更多的见文档模型支持列表中方舟部分。
"version_url": "", # 可以不填写version直接填写在方舟申请模型发布的API地址
"api_key": "",
"secret_key": "",
"provider": "FangZhouWorker",
},
}

View File

@ -1,44 +1,92 @@
from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, recreate_all_vs, list_kbs_from_folder
from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files
from configs.model_config import NLTK_DATA_PATH
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
from startup import dump_server_info
from datetime import datetime
import sys
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.formatter_class = argparse.RawTextHelpFormatter
parser = argparse.ArgumentParser(description="please specify only one operate method once time.")
parser.add_argument(
"-r",
"--recreate-vs",
action="store_true",
help=('''
recreate all vector store.
recreate vector store.
use this option if you have copied document files to the content folder, but vector store has not been populated or DEFAUL_VS_TYPE/EMBEDDING_MODEL changed.
if your vector store is ready with the configs, just skip this option to fill info to database only.
'''
)
)
args = parser.parse_args()
parser.add_argument(
"-u",
"--update-in-db",
action="store_true",
help=('''
update vector store for files exist in database.
use this option if you want to recreate vectors for files exist in db and skip files exist in local folder only.
'''
)
)
parser.add_argument(
"-i",
"--increament",
action="store_true",
help=('''
update vector store for files exist in local folder and not exist in database.
use this option if you want to create vectors increamentally.
'''
)
)
parser.add_argument(
"--prune-db",
action="store_true",
help=('''
delete docs in database that not existed in local folder.
it is used to delete database docs after user deleted some doc files in file browser
'''
)
)
parser.add_argument(
"--prune-folder",
action="store_true",
help=('''
delete doc files in local folder that not existed in database.
is is used to free local disk space by delete unused doc files.
'''
)
)
parser.add_argument(
"--kb-name",
type=str,
nargs="+",
default=[],
help=("specify knowledge base names to operate on. default is all folders exist in KB_ROOT_PATH.")
)
dump_server_info()
start_time = datetime.now()
if args.recreate_vs:
reset_tables()
print("database talbes reseted")
print("recreating all vector stores")
recreate_all_vs()
if len(sys.argv) <= 1:
parser.print_help()
else:
create_tables()
print("database talbes created")
print("filling kb infos to database")
for kb in list_kbs_from_folder():
folder2db(kb, "fill_info_only")
args = parser.parse_args()
start_time = datetime.now()
end_time = datetime.now()
print(f"总计用时: {end_time-start_time}")
create_tables() # confirm tables exist
if args.recreate_vs:
reset_tables()
print("database talbes reseted")
print("recreating all vector stores")
folder2db(kb_names=args.kb_name, mode="recreate_vs")
elif args.update_in_db:
folder2db(kb_names=args.kb_name, mode="update_in_db")
elif args.increament:
folder2db(kb_names=args.kb_name, mode="increament")
elif args.prune_db:
prune_db_docs(args.kb_name)
elif args.prune_folder:
prune_folder_files(args.kb_name)
end_time = datetime.now()
print(f"总计用时: {end_time-start_time}")

View File

@ -22,9 +22,9 @@ class MilvusKBService(KBService):
from pymilvus import Collection
return Collection(milvus_name)
def save_vector_store(self):
if self.milvus.col:
self.milvus.col.flush()
# def save_vector_store(self):
# if self.milvus.col:
# self.milvus.col.flush()
def get_doc_by_id(self, id: str) -> Optional[Document]:
if self.milvus.col:

View File

@ -3,7 +3,7 @@ from configs import (EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder,
list_files_from_folder,files2docs_in_thread,
KnowledgeFile,)
from server.knowledge_base.kb_service.base import KBServiceFactory, SupportedVSType
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_file_repository import add_file_to_db
from server.db.base import Base, engine
import os
@ -33,8 +33,8 @@ def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]:
def folder2db(
kb_name: str,
mode: Literal["recreate_vs", "fill_info_only", "update_in_db", "increament"],
kb_names: List[str],
mode: Literal["recreate_vs", "update_in_db", "increament"],
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL,
chunk_size: int = -1,
@ -45,21 +45,11 @@ def folder2db(
use existed files in local folder to populate database and/or vector store.
set parameter `mode` to:
recreate_vs: recreate all vector store and fill info to database using existed files in local folder
fill_info_only: do not create vector store, fill info to db using existed files only
fill_info_only(disabled): do not create vector store, fill info to db using existed files only
update_in_db: update vector store and database info using local files that existed in database only
increament: create vector store and database info for local files that not existed in database only
'''
kb = KBServiceFactory.get_service(kb_name, vs_type, embed_model)
kb.create_kb()
if mode == "recreate_vs":
files_count = kb.count_files()
print(f"知识库 {kb_name} 中共有 {files_count} 个文档。\n即将清除向量库。")
kb.clear_vs()
files_count = kb.count_files()
print(f"清理后,知识库 {kb_name} 中共有 {files_count} 个文档。")
kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name))
def files2vs(kb_name: str, kb_files: List[KnowledgeFile]):
for success, result in files2docs_in_thread(kb_files,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
@ -68,84 +58,77 @@ def folder2db(
_, filename, docs = result
print(f"正在将 {kb_name}/{filename} 添加到向量库,共包含{len(docs)}条文档")
kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
kb_file.splited_docs = docs
kb.add_doc(kb_file=kb_file, not_refresh_vs_cache=True)
else:
print(result)
kb.save_vector_store()
elif mode == "fill_info_only":
files = list_files_from_folder(kb_name)
kb_files = file_to_kbfile(kb_name, files)
for kb_file in kb_files:
add_file_to_db(kb_file)
print(f"已将 {kb_name}/{kb_file.filename} 添加到数据库")
elif mode == "update_in_db":
files = kb.list_files()
kb_files = file_to_kbfile(kb_name, files)
kb_names = kb_names or list_kbs_from_folder()
for kb_name in kb_names:
kb = KBServiceFactory.get_service(kb_name, vs_type, embed_model)
kb.create_kb()
for kb_file in kb_files:
kb.update_doc(kb_file, not_refresh_vs_cache=True)
kb.save_vector_store()
elif mode == "increament":
db_files = kb.list_files()
folder_files = list_files_from_folder(kb_name)
files = list(set(folder_files) - set(db_files))
kb_files = file_to_kbfile(kb_name, files)
for success, result in files2docs_in_thread(kb_files,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance):
if success:
_, filename, docs = result
print(f"正在将 {kb_name}/{filename} 添加到向量库")
kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
else:
print(result)
kb.save_vector_store()
else:
print(f"unspported migrate mode: {mode}")
# 清除向量库,从本地文件重建
if mode == "recreate_vs":
kb.clear_vs()
kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name))
files2vs(kb_name, kb_files)
kb.save_vector_store()
# # 不做文件内容的向量化,仅将文件元信息存到数据库
# # 由于现在数据库存了很多与文本切分相关的信息,单纯存储文件信息意义不大,该功能取消。
# elif mode == "fill_info_only":
# files = list_files_from_folder(kb_name)
# kb_files = file_to_kbfile(kb_name, files)
# for kb_file in kb_files:
# add_file_to_db(kb_file)
# print(f"已将 {kb_name}/{kb_file.filename} 添加到数据库")
# 以数据库中文件列表为基准,利用本地文件更新向量库
elif mode == "update_in_db":
files = kb.list_files()
kb_files = file_to_kbfile(kb_name, files)
files2vs(kb_name, kb_files)
kb.save_vector_store()
# 对比本地目录与数据库中的文件列表,进行增量向量化
elif mode == "increament":
db_files = kb.list_files()
folder_files = list_files_from_folder(kb_name)
files = list(set(folder_files) - set(db_files))
kb_files = file_to_kbfile(kb_name, files)
files2vs(kb_name, kb_files)
kb.save_vector_store()
else:
print(f"unspported migrate mode: {mode}")
def recreate_all_vs(
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
embed_mode: str = EMBEDDING_MODEL,
**kwargs: Any,
):
def prune_db_docs(kb_names: List[str]):
'''
used to recreate a vector store or change current vector store to another type or embed_model
delete docs in database that not existed in local folder.
it is used to delete database docs after user deleted some doc files in file browser
'''
for kb_name in list_kbs_from_folder():
folder2db(kb_name, "recreate_vs", vs_type, embed_mode, **kwargs)
for kb_name in kb_names:
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb and kb.exists():
files_in_db = kb.list_files()
files_in_folder = list_files_from_folder(kb_name)
files = list(set(files_in_db) - set(files_in_folder))
kb_files = file_to_kbfile(kb_name, files)
for kb_file in kb_files:
kb.delete_doc(kb_file, not_refresh_vs_cache=True)
print(f"success to delete docs for file: {kb_name}/{kb_file.filename}")
kb.save_vector_store()
def prune_db_files(kb_name: str):
'''
delete files in database that not existed in local folder.
it is used to delete database files after user deleted some doc files in file browser
'''
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb.exists():
files_in_db = kb.list_files()
files_in_folder = list_files_from_folder(kb_name)
files = list(set(files_in_db) - set(files_in_folder))
kb_files = file_to_kbfile(kb_name, files)
for kb_file in kb_files:
kb.delete_doc(kb_file, not_refresh_vs_cache=True)
kb.save_vector_store()
return kb_files
def prune_folder_files(kb_name: str):
def prune_folder_files(kb_names: List[str]):
'''
delete doc files in local folder that not existed in database.
is is used to free local disk space by delete unused doc files.
'''
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb.exists():
files_in_db = kb.list_files()
files_in_folder = list_files_from_folder(kb_name)
files = list(set(files_in_folder) - set(files_in_db))
for file in files:
os.remove(get_file_path(kb_name, file))
return files
for kb_name in kb_names:
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb and kb.exists():
files_in_db = kb.list_files()
files_in_folder = list_files_from_folder(kb_name)
files = list(set(files_in_folder) - set(files_in_db))
for file in files:
os.remove(get_file_path(kb_name, file))
print(f"success to delete file: {kb_name}/{file}")

View File

@ -0,0 +1,91 @@
from server.model_workers.base import ApiModelWorker
from configs.model_config import TEMPERATURE
from fastchat import conversation as conv
import sys
import json
from typing import List, Literal, Dict
class FangZhouWorker(ApiModelWorker):
"""
火山方舟
"""
SUPPORT_MODELS = ["chatglm-6b-model"]
def __init__(
self,
*,
version: Literal["chatglm-6b-model"] = "chatglm-6b-model",
model_names: List[str] = ["fangzhou-api"],
controller_addr: str,
worker_addr: str,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs)
config = self.get_config()
self.version = version
self.api_key = config.get("api_key")
self.secret_key = config.get("secret_key")
from volcengine.maas import ChatRole
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=[ChatRole.USER, ChatRole.ASSISTANT, ChatRole.SYSTEM],
sep="\n### ",
stop_str="###",
)
def generate_stream_gate(self, params):
super().generate_stream_gate(params)
from volcengine.maas import MaasService, MaasException, ChatRole
maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing')
maas.set_ak(self.api_key)
maas.set_sk(self.secret_key)
# document: "https://www.volcengine.com/docs/82379/1099475"
req = {
"model": {
"name": self.version,
},
"parameters": {
# 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明
# "max_new_tokens": 1000,
"temperature": params.get("temperature", TEMPERATURE),
},
"messages": [{"role": ChatRole.USER, "content": params["prompt"]}]
}
text = ""
try:
resps = maas.stream_chat(req)
for resp in resps:
text += resp.choice.message.content
yield json.dumps({"error_code": 0, "text": text},
ensure_ascii=False).encode() + b"\0"
except MaasException as e:
print(e)
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = FangZhouWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20006",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=20006)

139
tests/test_migrate.py Normal file
View File

@ -0,0 +1,139 @@
from pathlib import Path
from pprint import pprint
import os
import shutil
import sys
root_path = Path(__file__).parent.parent
sys.path.append(str(root_path))
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.knowledge_base.utils import get_kb_path, get_doc_path, KnowledgeFile
from server.knowledge_base.migrate import folder2db, prune_db_docs, prune_folder_files
# setup test knowledge base
kb_name = "test_kb_for_migrate"
test_files = {
"faq.md": str(root_path / "docs" / "faq.md"),
"install.md": str(root_path / "docs" / "install.md"),
}
kb_path = get_kb_path(kb_name)
doc_path = get_doc_path(kb_name)
if not os.path.isdir(doc_path):
os.makedirs(doc_path)
for k, v in test_files.items():
shutil.copy(v, os.path.join(doc_path, k))
def test_recreate_vs():
folder2db([kb_name], "recreate_vs")
kb = KBServiceFactory.get_service_by_name(kb_name)
assert kb.exists()
files = kb.list_files()
print(files)
for name in test_files:
assert name in files
path = os.path.join(doc_path, name)
# list docs based on file name
docs = kb.list_docs(file_name=name)
assert len(docs) > 0
pprint(docs[0])
for doc in docs:
assert doc.metadata["source"] == path
# list docs base on metadata
docs = kb.list_docs(metadata={"source": path})
assert len(docs) > 0
for doc in docs:
assert doc.metadata["source"] == path
def test_increament():
kb = KBServiceFactory.get_service_by_name(kb_name)
kb.clear_vs()
assert kb.list_files() == []
assert kb.list_docs() == []
folder2db([kb_name], "increament")
files = kb.list_files()
print(files)
for f in test_files:
assert f in files
docs = kb.list_docs(file_name=f)
assert len(docs) > 0
pprint(docs[0])
for doc in docs:
assert doc.metadata["source"] == os.path.join(doc_path, f)
def test_prune_db():
del_file, keep_file = list(test_files)[:2]
os.remove(os.path.join(doc_path, del_file))
prune_db_docs([kb_name])
kb = KBServiceFactory.get_service_by_name(kb_name)
files = kb.list_files()
print(files)
assert del_file not in files
assert keep_file in files
docs = kb.list_docs(file_name=del_file)
assert len(docs) == 0
docs = kb.list_docs(file_name=keep_file)
assert len(docs) > 0
pprint(docs[0])
shutil.copy(test_files[del_file], os.path.join(doc_path, del_file))
def test_prune_folder():
del_file, keep_file = list(test_files)[:2]
kb = KBServiceFactory.get_service_by_name(kb_name)
# delete docs for file
kb.delete_doc(KnowledgeFile(del_file, kb_name))
files = kb.list_files()
print(files)
assert del_file not in files
assert keep_file in files
docs = kb.list_docs(file_name=del_file)
assert len(docs) == 0
docs = kb.list_docs(file_name=keep_file)
assert len(docs) > 0
docs = kb.list_docs(file_name=del_file)
assert len(docs) == 0
assert os.path.isfile(os.path.join(doc_path, del_file))
# prune folder
prune_folder_files([kb_name])
# check result
assert not os.path.isfile(os.path.join(doc_path, del_file))
assert os.path.isfile(os.path.join(doc_path, keep_file))
def test_drop_kb():
kb = KBServiceFactory.get_service_by_name(kb_name)
kb.drop_kb()
assert not kb.exists()
assert not os.path.isdir(kb_path)
kb = KBServiceFactory.get_service_by_name(kb_name)
assert kb is None