2023-04-25 20:14:33 +08:00
|
|
|
|
import os
|
2023-05-08 18:29:09 +08:00
|
|
|
|
import logging
|
2023-07-27 23:22:07 +08:00
|
|
|
|
import torch
|
2023-08-01 17:59:20 +08:00
|
|
|
|
import argparse
|
|
|
|
|
|
import json
|
2023-07-27 23:22:07 +08:00
|
|
|
|
# 日志格式
|
2023-08-01 17:59:20 +08:00
|
|
|
|
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
2023-05-08 18:29:09 +08:00
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
logging.basicConfig(format=LOG_FORMAT)
|
2023-04-13 23:01:52 +08:00
|
|
|
|
|
2023-08-01 17:59:20 +08:00
|
|
|
|
import argparse
|
|
|
|
|
|
import json
|
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
|
#------multi worker-----------------
|
|
|
|
|
|
parser.add_argument('--model-path-address',
|
|
|
|
|
|
default="THUDM/chatglm2-6b@localhost@20002",
|
|
|
|
|
|
nargs="+",
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
help="model path, host, and port, formatted as model-path@host@path")
|
|
|
|
|
|
#---------------controller-------------------------
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--controller-host", type=str, default="localhost")
|
|
|
|
|
|
parser.add_argument("--controller-port", type=int, default=21001)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--dispatch-method",
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
choices=["lottery", "shortest_queue"],
|
|
|
|
|
|
default="shortest_queue",
|
|
|
|
|
|
)
|
|
|
|
|
|
controller_args = ["controller-host","controller-port","dispatch-method"]
|
|
|
|
|
|
|
|
|
|
|
|
#----------------------worker------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--worker-host", type=str, default="localhost")
|
|
|
|
|
|
parser.add_argument("--worker-port", type=int, default=21002)
|
|
|
|
|
|
# parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
|
|
|
|
|
|
# parser.add_argument(
|
|
|
|
|
|
# "--controller-address", type=str, default="http://localhost:21001"
|
|
|
|
|
|
# )
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--model-path",
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
default="lmsys/vicuna-7b-v1.3",
|
|
|
|
|
|
help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--revision",
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
default="main",
|
|
|
|
|
|
help="Hugging Face Hub model revision identifier",
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--device",
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
choices=["cpu", "cuda", "mps", "xpu"],
|
|
|
|
|
|
default="cuda",
|
|
|
|
|
|
help="The device type",
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--gpus",
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
default="0",
|
|
|
|
|
|
help="A single GPU like 1 or multiple GPUs like 0,2",
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument("--num-gpus", type=int, default=1)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--max-gpu-memory",
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
help="The maximum memory per gpu. Use a string like '13Gib'",
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--load-8bit", action="store_true", help="Use 8-bit quantization"
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--cpu-offloading",
|
|
|
|
|
|
action="store_true",
|
|
|
|
|
|
help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU",
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--gptq-ckpt",
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
default=None,
|
|
|
|
|
|
help="Load quantized model. The path to the local GPTQ checkpoint.",
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--gptq-wbits",
|
|
|
|
|
|
type=int,
|
|
|
|
|
|
default=16,
|
|
|
|
|
|
choices=[2, 3, 4, 8, 16],
|
|
|
|
|
|
help="#bits to use for quantization",
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--gptq-groupsize",
|
|
|
|
|
|
type=int,
|
|
|
|
|
|
default=-1,
|
|
|
|
|
|
help="Groupsize to use for quantization; default uses full row.",
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--gptq-act-order",
|
|
|
|
|
|
action="store_true",
|
|
|
|
|
|
help="Whether to apply the activation order GPTQ heuristic",
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--model-names",
|
|
|
|
|
|
type=lambda s: s.split(","),
|
|
|
|
|
|
help="Optional display comma separated names",
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--limit-worker-concurrency",
|
|
|
|
|
|
type=int,
|
|
|
|
|
|
default=5,
|
|
|
|
|
|
help="Limit the model concurrency to prevent OOM.",
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument("--stream-interval", type=int, default=2)
|
|
|
|
|
|
parser.add_argument("--no-register", action="store_true")
|
|
|
|
|
|
|
|
|
|
|
|
worker_args = [
|
2023-08-01 22:07:05 +08:00
|
|
|
|
"worker-host","worker-port",
|
2023-08-01 17:59:20 +08:00
|
|
|
|
"model-path","revision","device","gpus","num-gpus",
|
|
|
|
|
|
"max-gpu-memory","load-8bit","cpu-offloading",
|
|
|
|
|
|
"gptq-ckpt","gptq-wbits","gptq-groupsize",
|
|
|
|
|
|
"gptq-act-order","model-names","limit-worker-concurrency",
|
|
|
|
|
|
"stream-interval","no-register",
|
|
|
|
|
|
"controller-address"
|
|
|
|
|
|
]
|
|
|
|
|
|
#-----------------openai server---------------------------
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--server-host", type=str, default="localhost", help="host name")
|
|
|
|
|
|
parser.add_argument("--server-port", type=int, default=8001, help="port number")
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--allow-credentials", action="store_true", help="allow credentials"
|
|
|
|
|
|
)
|
|
|
|
|
|
# parser.add_argument(
|
|
|
|
|
|
# "--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
|
|
|
|
|
|
# )
|
|
|
|
|
|
# parser.add_argument(
|
|
|
|
|
|
# "--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
|
|
|
|
|
|
# )
|
|
|
|
|
|
# parser.add_argument(
|
|
|
|
|
|
# "--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
|
|
|
|
|
|
# )
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--api-keys",
|
|
|
|
|
|
type=lambda s: s.split(","),
|
|
|
|
|
|
help="Optional list of comma separated API keys",
|
|
|
|
|
|
)
|
|
|
|
|
|
server_args = ["server-host","server-port","allow-credentials","api-keys",
|
|
|
|
|
|
"controller-address"
|
|
|
|
|
|
]
|
|
|
|
|
|
#-------------------似乎也可以在这里把所有可配置的项目做成命令行-----------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2023-07-27 23:22:07 +08:00
|
|
|
|
|
2023-06-12 00:10:47 +08:00
|
|
|
|
# 在以下字典中修改属性值,以指定本地embedding模型存储位置
|
|
|
|
|
|
# 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese"
|
|
|
|
|
|
# 此处请写绝对路径
|
2023-04-13 23:01:52 +08:00
|
|
|
|
embedding_model_dict = {
|
|
|
|
|
|
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
|
|
|
|
|
|
"ernie-base": "nghuyong/ernie-3.0-base-zh",
|
2023-04-27 07:40:57 +08:00
|
|
|
|
"text2vec-base": "shibing624/text2vec-base-chinese",
|
2023-07-27 23:28:33 +08:00
|
|
|
|
"text2vec": "GanymedeNil/text2vec-large-chinese",
|
2023-07-27 23:22:07 +08:00
|
|
|
|
"text2vec-paraphrase": "shibing624/text2vec-base-chinese-paraphrase",
|
|
|
|
|
|
"text2vec-sentence": "shibing624/text2vec-base-chinese-sentence",
|
|
|
|
|
|
"text2vec-multilingual": "shibing624/text2vec-base-multilingual",
|
2023-06-09 19:04:34 +08:00
|
|
|
|
"m3e-small": "moka-ai/m3e-small",
|
2023-07-27 23:28:33 +08:00
|
|
|
|
"m3e-base": "moka-ai/m3e-base",
|
2023-07-27 23:22:07 +08:00
|
|
|
|
"m3e-large": "moka-ai/m3e-large",
|
2023-04-13 23:01:52 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2023-07-27 23:22:07 +08:00
|
|
|
|
# 选用的 Embedding 名称
|
2023-08-01 14:12:28 +08:00
|
|
|
|
EMBEDDING_MODEL = "m3e-base"
|
2023-04-13 23:01:52 +08:00
|
|
|
|
|
2023-07-27 23:22:07 +08:00
|
|
|
|
# Embedding 模型运行设备
|
2023-04-13 23:01:52 +08:00
|
|
|
|
EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
|
|
|
|
|
|
2023-07-27 23:22:07 +08:00
|
|
|
|
|
2023-04-13 23:01:52 +08:00
|
|
|
|
llm_model_dict = {
|
2023-05-18 22:54:41 +08:00
|
|
|
|
"chatglm-6b": {
|
2023-07-28 16:12:57 +08:00
|
|
|
|
"local_model_path": "THUDM/chatglm-6b",
|
2023-07-27 23:22:07 +08:00
|
|
|
|
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
|
2023-07-11 19:36:50 +08:00
|
|
|
|
"api_key": "EMPTY"
|
2023-07-19 10:00:23 +08:00
|
|
|
|
},
|
2023-07-27 23:22:07 +08:00
|
|
|
|
|
|
|
|
|
|
"chatglm-6b-int4": {
|
2023-07-28 16:12:57 +08:00
|
|
|
|
"local_model_path": "THUDM/chatglm-6b-int4",
|
2023-07-19 10:00:23 +08:00
|
|
|
|
"api_base_url": "http://localhost:8001/v1", # "name"修改为fastchat服务中的"api_base_url"
|
|
|
|
|
|
"api_key": "EMPTY"
|
2023-05-31 22:11:28 +08:00
|
|
|
|
},
|
2023-07-27 23:22:07 +08:00
|
|
|
|
|
|
|
|
|
|
"chatglm2-6b": {
|
2023-07-28 16:12:57 +08:00
|
|
|
|
"local_model_path": "THUDM/chatglm2-6b",
|
2023-07-27 23:22:07 +08:00
|
|
|
|
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
|
|
|
|
|
|
"api_key": "EMPTY"
|
2023-06-25 17:04:16 +08:00
|
|
|
|
},
|
2023-05-31 22:11:28 +08:00
|
|
|
|
|
2023-08-01 14:12:28 +08:00
|
|
|
|
"chatglm2-6b-32k": {
|
|
|
|
|
|
"local_model_path": "THUDM/chatglm2-6b-32k", # "THUDM/chatglm2-6b-32k",
|
|
|
|
|
|
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
|
|
|
|
|
|
"api_key": "EMPTY"
|
|
|
|
|
|
},
|
|
|
|
|
|
|
2023-07-27 23:22:07 +08:00
|
|
|
|
"vicuna-13b-hf": {
|
|
|
|
|
|
"local_model_path": "",
|
2023-07-11 19:36:50 +08:00
|
|
|
|
"api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url"
|
|
|
|
|
|
"api_key": "EMPTY"
|
2023-05-31 22:11:28 +08:00
|
|
|
|
},
|
2023-07-27 23:22:07 +08:00
|
|
|
|
|
2023-07-11 23:40:49 +08:00
|
|
|
|
# 调用chatgpt时如果报出: urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
|
|
|
|
|
|
# Max retries exceeded with url: /v1/chat/completions
|
|
|
|
|
|
# 则需要将urllib3版本修改为1.25.11
|
2023-07-21 09:29:43 +08:00
|
|
|
|
# 如果依然报urllib3.exceptions.MaxRetryError: HTTPSConnectionPool,则将https改为http
|
|
|
|
|
|
# 参考https://zhuanlan.zhihu.com/p/350015032
|
2023-07-11 23:40:49 +08:00
|
|
|
|
|
|
|
|
|
|
# 如果报出:raise NewConnectionError(
|
2023-07-12 23:09:28 +08:00
|
|
|
|
# urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPSConnection object at 0x000001FE4BDB85E0>:
|
2023-07-11 23:40:49 +08:00
|
|
|
|
# Failed to establish a new connection: [WinError 10060]
|
2023-07-19 10:57:09 +08:00
|
|
|
|
# 则是因为内地和香港的IP都被OPENAI封了,需要切换为日本、新加坡等地
|
2023-07-11 20:19:50 +08:00
|
|
|
|
"openai-chatgpt-3.5": {
|
2023-07-28 16:12:57 +08:00
|
|
|
|
"local_model_path": "gpt-3.5-turbo",
|
2023-07-27 23:22:07 +08:00
|
|
|
|
"api_base_url": "https://api.openapi.com/v1",
|
2023-07-28 16:51:58 +08:00
|
|
|
|
"api_key": os.environ.get("OPENAI_API_KEY")
|
2023-07-11 19:36:50 +08:00
|
|
|
|
},
|
2023-04-13 23:01:52 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2023-05-23 23:10:31 +08:00
|
|
|
|
# LLM 名称
|
2023-07-27 23:22:07 +08:00
|
|
|
|
LLM_MODEL = "chatglm2-6b"
|
2023-05-01 17:50:16 +08:00
|
|
|
|
|
2023-07-27 23:22:07 +08:00
|
|
|
|
# LLM 运行设备
|
2023-04-13 23:01:52 +08:00
|
|
|
|
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
|
|
|
|
|
|
2023-07-27 23:22:07 +08:00
|
|
|
|
# 日志存储路径
|
|
|
|
|
|
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
|
2023-07-28 16:12:57 +08:00
|
|
|
|
if not os.path.exists(LOG_PATH):
|
|
|
|
|
|
os.mkdir(LOG_PATH)
|
2023-07-27 23:22:07 +08:00
|
|
|
|
|
2023-06-13 23:54:29 +08:00
|
|
|
|
# 知识库默认存储路径
|
|
|
|
|
|
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
|
2023-04-19 23:02:47 +08:00
|
|
|
|
|
2023-08-05 22:57:19 +08:00
|
|
|
|
# 数据库默认存储路径
|
|
|
|
|
|
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
|
|
|
|
|
|
|
2023-08-04 09:16:28 +08:00
|
|
|
|
# 缓存向量库数量
|
|
|
|
|
|
CACHED_VS_NUM = 1
|
|
|
|
|
|
|
|
|
|
|
|
# 知识库匹配向量数量
|
|
|
|
|
|
VECTOR_SEARCH_TOP_K = 5
|
|
|
|
|
|
|
|
|
|
|
|
# 搜索引擎匹配结题数量
|
|
|
|
|
|
SEARCH_ENGINE_TOP_K = 5
|
|
|
|
|
|
|
2023-07-27 23:22:07 +08:00
|
|
|
|
# nltk 模型存储路径
|
2023-05-05 18:44:37 +08:00
|
|
|
|
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
|
2023-05-08 18:29:09 +08:00
|
|
|
|
|
2023-07-27 23:22:07 +08:00
|
|
|
|
# 基于本地知识问答的提示词模版
|
|
|
|
|
|
PROMPT_TEMPLATE = """【指令】根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。
|
2023-05-11 09:32:58 +08:00
|
|
|
|
|
2023-07-27 23:22:07 +08:00
|
|
|
|
【已知信息】{context}
|
2023-06-14 21:31:26 +08:00
|
|
|
|
|
2023-07-27 23:22:07 +08:00
|
|
|
|
【问题】{question}"""
|
2023-06-18 21:45:06 +08:00
|
|
|
|
|
2023-07-27 23:22:07 +08:00
|
|
|
|
# API 是否开启跨域,默认为False,如果需要开启,请设置为True
|
|
|
|
|
|
# is open cross domain
|
2023-08-01 16:39:17 +08:00
|
|
|
|
OPEN_CROSS_DOMAIN = False
|
|
|
|
|
|
|
|
|
|
|
|
# Bing 搜索必备变量
|
|
|
|
|
|
# 使用 Bing 搜索需要使用 Bing Subscription Key,需要在azure port中申请试用bing search
|
|
|
|
|
|
# 具体申请方式请见
|
|
|
|
|
|
# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/create-bing-search-service-resource
|
|
|
|
|
|
# 使用python创建bing api 搜索实例详见:
|
|
|
|
|
|
# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/quickstarts/rest/python
|
|
|
|
|
|
BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search"
|
|
|
|
|
|
# 注意不是bing Webmaster Tools的api key,
|
|
|
|
|
|
|
|
|
|
|
|
# 此外,如果是在服务器上,报Failed to establish a new connection: [Errno 110] Connection timed out
|
|
|
|
|
|
# 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG
|
2023-08-01 18:02:52 +08:00
|
|
|
|
BING_SUBSCRIPTION_KEY = ""
|
2023-08-06 23:43:54 +08:00
|
|
|
|
|