From eb620ddefc7f2b92047cc7404308e08d428c7992 Mon Sep 17 00:00:00 2001 From: hzg0601 Date: Mon, 12 Jun 2023 16:22:07 +0800 Subject: [PATCH 01/12] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20bing=5Fsearch.py?= =?UTF-8?q?=E7=9A=84typo;=E6=9B=B4=E6=96=B0model=5Fconfig.py=E4=B8=ADBing?= =?UTF-8?q?=20Subscription=20Key=E7=94=B3=E8=AF=B7=E6=96=B9=E5=BC=8F?= =?UTF-8?q?=E5=8F=8A=E6=B3=A8=E6=84=8F=E4=BA=8B=E9=A1=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent/bing_search.py | 2 +- configs/model_config.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/agent/bing_search.py b/agent/bing_search.py index 2ff7749..962090c 100644 --- a/agent/bing_search.py +++ b/agent/bing_search.py @@ -7,7 +7,7 @@ from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY def bing_search(text, result_len=3): if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY): return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV", - "title": "env inof not fould", + "title": "env info is not found", "link": "https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}] search = BingSearchAPIWrapper(bing_subscription_key=BING_SUBSCRIPTION_KEY, bing_search_url=BING_SEARCH_URL) diff --git a/configs/model_config.py b/configs/model_config.py index e18f69b..9552ee0 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -165,7 +165,14 @@ flagging username: {FLAG_USER_NAME} OPEN_CROSS_DOMAIN = False # Bing 搜索必备变量 -# 使用 Bing 搜索需要使用 Bing Subscription Key -# 具体申请方式请见 https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/quickstarts/rest/python +# 使用 Bing 搜索需要使用 Bing Subscription Key,需要在azure port中申请试用bing search +# 具体申请方式请见 +# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/create-bing-search-service-resource +# 使用python创建bing api 搜索实例详见: +# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/quickstarts/rest/python BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search" +# 注意不是bing Webmaster Tools的api key, + +# 此外,如果是在服务器上,报Failed to establish a new connection: [Errno 110] Connection timed out +# 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG BING_SUBSCRIPTION_KEY = "" \ No newline at end of file From 4054e46cab77aef21604b4ad40fa626fe360cf38 Mon Sep 17 00:00:00 2001 From: hzg0601 Date: Mon, 12 Jun 2023 16:28:40 +0800 Subject: [PATCH 02/12] =?UTF-8?q?=E6=9B=B4=E6=96=B0FAQ=EF=BC=8C=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E4=BA=86[Errno=20110]=20Connection=20timed=20out?= =?UTF-8?q?=E7=9A=84=E5=8E=9F=E5=9B=A0=E4=B8=8E=E8=A7=A3=E5=86=B3=E6=96=B9?= =?UTF-8?q?=E6=A1=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/FAQ.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/FAQ.md b/docs/FAQ.md index b43483b..f712477 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -146,7 +146,6 @@ $ pip install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple/ $ pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple/ ``` - Q12 启动api.py时upload_file接口抛出 `partially initialized module 'charset_normalizer' has no attribute 'md__mypyc' (most likely due to a circular import)` 这是由于 charset_normalizer模块版本过高导致的,需要降低低charset_normalizer的版本,测试在charset_normalizer==2.1.0上可用。 @@ -174,3 +173,7 @@ download_with_progressbar(url, tmp_path) 然后按照给定网址,如"https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar"手动下载文件,上传到对应的文件夹中,如“.paddleocr/whl/rec/ch/ch_PP-OCRv3_rec_infer/ch_PP-OCRv3_rec_infer.tar”. --- + +Q14 调用api中的 `bing_search_chat`接口时,报出 `Failed to establish a new connection: [Errno 110] Connection timed out` + +这是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG--! From f7e7d318d8dfac2bf46bd8db8471b9926385620f Mon Sep 17 00:00:00 2001 From: hzg0601 Date: Tue, 13 Jun 2023 23:30:10 +0800 Subject: [PATCH 03/12] =?UTF-8?q?=E4=BF=AE=E6=94=B9loader.py=E4=B8=ADload?= =?UTF-8?q?=5Fin=5F8bit=E5=A4=B1=E8=B4=A5=E7=9A=84=E5=8E=9F=E5=9B=A0?= =?UTF-8?q?=E5=92=8C=E8=AF=A6=E7=BB=86=E8=A7=A3=E5=86=B3=E6=96=B9=E6=A1=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/loader/loader.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/models/loader/loader.py b/models/loader/loader.py index c68e71a..415048b 100644 --- a/models/loader/loader.py +++ b/models/loader/loader.py @@ -30,6 +30,15 @@ class LoaderCheckPoint: ptuning_dir: str = None use_ptuning_v2: bool = False # 如果开启了8bit量化加载,项目无法启动,参考此位置,选择合适的cuda版本,https://github.com/TimDettmers/bitsandbytes/issues/156 + # 原因主要是由于bitsandbytes安装时选择了系统环境变量里不匹配的cuda版本, + # 例如PATH下存在cuda10.2和cuda11.2,bitsandbytes安装时选择了10.2,而torch等安装依赖的版本是11.2 + # 因此主要的解决思路是清理环境变量里PATH下的不匹配的cuda版本,一劳永逸的方法是: + # 0. 在终端执行`pip uninstall bitsandbytes` + # 1. 删除.bashrc文件下关于PATH的条目 + # 2. 在终端执行 `echo $PATH >> .bashrc` + # 3. 在终端执行`source .bashrc` + # 4. 再执行`pip install bitsandbytes` + load_in_8bit: bool = False is_llamacpp: bool = False bf16: bool = False @@ -99,6 +108,8 @@ class LoaderCheckPoint: LoaderClass = AutoModelForCausalLM # Load the model in simple 16-bit mode by default + # 如果加载没问题,但在推理时报错RuntimeError: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle)` + # 那还是因为显存不够,此时只能考虑--load-in-8bit,或者配置默认模型为`chatglm-6b-int8` if not any([self.llm_device.lower() == "cpu", self.load_in_8bit, self.is_llamacpp]): From b63c74212f77824e303f4d8721254128aa86382b Mon Sep 17 00:00:00 2001 From: hzg0601 Date: Wed, 14 Jun 2023 08:22:05 +0800 Subject: [PATCH 04/12] update loader.py --- models/loader/loader.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/models/loader/loader.py b/models/loader/loader.py index 415048b..f315e6c 100644 --- a/models/loader/loader.py +++ b/models/loader/loader.py @@ -30,14 +30,15 @@ class LoaderCheckPoint: ptuning_dir: str = None use_ptuning_v2: bool = False # 如果开启了8bit量化加载,项目无法启动,参考此位置,选择合适的cuda版本,https://github.com/TimDettmers/bitsandbytes/issues/156 - # 原因主要是由于bitsandbytes安装时选择了系统环境变量里不匹配的cuda版本, + # 另一个原因可能是由于bitsandbytes安装时选择了系统环境变量里不匹配的cuda版本, # 例如PATH下存在cuda10.2和cuda11.2,bitsandbytes安装时选择了10.2,而torch等安装依赖的版本是11.2 # 因此主要的解决思路是清理环境变量里PATH下的不匹配的cuda版本,一劳永逸的方法是: # 0. 在终端执行`pip uninstall bitsandbytes` # 1. 删除.bashrc文件下关于PATH的条目 # 2. 在终端执行 `echo $PATH >> .bashrc` - # 3. 在终端执行`source .bashrc` - # 4. 再执行`pip install bitsandbytes` + # 3. 删除.bashrc文件下PATH中关于不匹配的cuda版本路径 + # 4. 在终端执行`source .bashrc` + # 5. 再执行`pip install bitsandbytes` load_in_8bit: bool = False is_llamacpp: bool = False From 987f5518f23dbf21cd6bd4310a133c0eb2965d80 Mon Sep 17 00:00:00 2001 From: hzg0601 Date: Wed, 14 Jun 2023 08:55:27 +0800 Subject: [PATCH 05/12] stream_chat_bing --- api.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/api.py b/api.py index 7926fb2..dc53f69 100644 --- a/api.py +++ b/api.py @@ -346,6 +346,41 @@ async def stream_chat(websocket: WebSocket, knowledge_base_id: str): ) turn += 1 +async def stream_chat_bing(websocket: WebSocket): + """ + 基于bing搜索的流式问答 + """ + await websocket.accept() + turn = 1 + while True: + input_json = await websocket.receive_json() + question, history = input_json["question"], input_json["history"] + + await websocket.send_json({"question": question, "turn": turn, "flag": "start"}) + + last_print_len = 0 + for resp, history in local_doc_qa.get_search_result_based_answer(question, chat_history=history, streaming=True): + await websocket.send_text(resp["result"][last_print_len:]) + last_print_len = len(resp["result"]) + + source_documents = [ + f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" + f"""相关度:{doc.metadata['score']}\n\n""" + for inum, doc in enumerate(resp["source_documents"]) + ] + + await websocket.send_text( + json.dumps( + { + "question": question, + "turn": turn, + "flag": "end", + "sources_documents": source_documents, + }, + ensure_ascii=False, + ) + ) + turn += 1 async def document(): return RedirectResponse(url="/docs") @@ -377,6 +412,9 @@ def api_start(host, port): app.get("/", response_model=BaseResponse)(document) + # # 增加基于bing搜索的流式问答 + app.websocket("/local_doc_qa/stream_chat_bing")(stream_chat_bing) + app.post("/chat", response_model=ChatMessage)(chat) app.post("/local_doc_qa/upload_file", response_model=BaseResponse)(upload_file) From 660f8c6715923f2511710856d9732fe95ab322f6 Mon Sep 17 00:00:00 2001 From: hzg0601 Date: Wed, 14 Jun 2023 11:17:21 +0800 Subject: [PATCH 06/12] =?UTF-8?q?=E4=BF=AE=E6=94=B9stream=5Fchat=E7=9A=84?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3=EF=BC=8C=E5=9C=A8=E8=AF=B7=E6=B1=82=E4=BD=93?= =?UTF-8?q?=E4=B8=AD=E9=80=89=E6=8B=A9knowledge=5Fbase=5Fid;=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0stream=5Fchat=5Fbing=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/api.py b/api.py index dc53f69..77aa1cb 100644 --- a/api.py +++ b/api.py @@ -305,7 +305,7 @@ async def chat( ) -async def stream_chat(websocket: WebSocket, knowledge_base_id: str): +async def stream_chat(websocket: WebSocket): await websocket.accept() turn = 1 while True: @@ -408,11 +408,15 @@ def api_start(host, port): allow_methods=["*"], allow_headers=["*"], ) - app.websocket("/local_doc_qa/stream-chat/{knowledge_base_id}")(stream_chat) + # 修改了stream_chat的接口,直接通过ws://localhost:7861/local_doc_qa/stream_chat建立连接,在请求体中选择knowledge_base_id + app.websocket("/local_doc_qa/stream_chat")(stream_chat) app.get("/", response_model=BaseResponse)(document) - # # 增加基于bing搜索的流式问答 + # 增加基于bing搜索的流式问答 + # 需要说明的是,如果想测试websocket的流式问答,需要使用支持websocket的测试工具,如postman,insomnia + # 强烈推荐开源的insomnia + # 在测试时选择new websocket request,并将url的协议改为ws,如ws://localhost:7861/local_doc_qa/stream_chat_bing app.websocket("/local_doc_qa/stream_chat_bing")(stream_chat_bing) app.post("/chat", response_model=ChatMessage)(chat) From b2626122482e6fd0a2e0fa04b21bab8eb34c0ca5 Mon Sep 17 00:00:00 2001 From: hzg0601 Date: Thu, 15 Jun 2023 13:15:00 +0800 Subject: [PATCH 07/12] =?UTF-8?q?=E4=BC=98=E5=8C=96cli=5Fdemo.py=E7=9A=84?= =?UTF-8?q?=E9=80=BB=E8=BE=91=EF=BC=9A=E6=94=AF=E6=8C=81=20=E8=BE=93?= =?UTF-8?q?=E5=85=A5=E6=8F=90=E7=A4=BA=EF=BC=9B=E5=A4=9A=E8=BE=93=E5=85=A5?= =?UTF-8?q?=EF=BC=9B=E9=87=8D=E6=96=B0=E8=BE=93=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- chains/local_doc_qa.py | 4 +++- cli_demo.py | 22 +++++++++++++++++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index dac13bb..f7eea83 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -196,7 +196,9 @@ class LocalDocQA: return vs_path, loaded_files else: logger.info("文件均未成功加载,请检查依赖包或替换为其他文件再次上传。") - return None, loaded_files + # 若len(docs) !> 0,必然是所有文件均未加载成功,loaded_files必然为[],返回没有实际意义 + # 而若只返回None,可以跟上文的异常返回值保持一致,更便于下游任务判断 + return None def one_knowledge_add(self, vs_path, one_title, one_conent, one_content_segmentation, sentence_size): try: diff --git a/cli_demo.py b/cli_demo.py index 938ebb3..a7bc732 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -23,11 +23,31 @@ def main(): top_k=VECTOR_SEARCH_TOP_K) vs_path = None while not vs_path: + print("注意输入的路径是完整的文件路径,例如content/`knowledge_base_id`/file.md,多个路径用英文逗号分割") filepath = input("Input your local knowledge file path 请输入本地知识文件路径:") + # 判断 filepath 是否为空,如果为空的话,重新让用户输入,防止用户误触回车 if not filepath: continue - vs_path, _ = local_doc_qa.init_knowledge_vector_store(filepath) + + # 支持加载多个文件 + filepath = filepath.split(",") + # filepath错误的返回为None, 如果直接用原先的vs_path,_ = local_doc_qa.init_knowledge_vector_store(filepath) + # 会直接导致TypeError: cannot unpack non-iterable NoneType object而使得程序直接退出 + # 因此需要先加一层判断,保证程序能继续运行 + temp = local_doc_qa.init_knowledge_vector_store(filepath) + if temp is not None: + vs_path,loaded_files = temp + if len(loaded_files) != len(filepath): + reload_flag = eval(input("部分文件加载失败,若提示路径不存在,可重新加载,是否重新加载,输入True或False: ")) + if reload_flag: + vs_path = None + continue + + print(f"the loaded vs_path is 加载的vs_path为: {vs_path}") + else: + print("load file failed, re-input your local knowledge file path 请重新输入本地知识文件路径") + history = [] while True: query = input("Input your question 请输入问题:") From c42c0cbf87afcc82137f0a4c7cd9fbf4b41e4efb Mon Sep 17 00:00:00 2001 From: hzg0601 Date: Thu, 15 Jun 2023 13:46:11 +0800 Subject: [PATCH 08/12] update cli_demo.py --- cli_demo.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cli_demo.py b/cli_demo.py index a7bc732..c9ca558 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -38,6 +38,8 @@ def main(): temp = local_doc_qa.init_knowledge_vector_store(filepath) if temp is not None: vs_path,loaded_files = temp + # 如果loaded_files和len(filepath)不一致,则说明部分文件没有加载成功 + # 如果是路径错误,则应该支持重新加载 if len(loaded_files) != len(filepath): reload_flag = eval(input("部分文件加载失败,若提示路径不存在,可重新加载,是否重新加载,输入True或False: ")) if reload_flag: From 43fae98bb51438be38919689f79c8c82456c0ed9 Mon Sep 17 00:00:00 2001 From: hzg0601 Date: Sun, 2 Jul 2023 20:47:27 +0800 Subject: [PATCH 09/12] =?UTF-8?q?=E6=8C=89=E7=85=A7review=E5=BB=BA?= =?UTF-8?q?=E8=AE=AE=E8=BF=9B=E8=A1=8C=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- chains/local_doc_qa.py | 5 ++--- cli_demo.py | 6 +++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index 7755ef1..cf4281e 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -197,9 +197,8 @@ class LocalDocQA: return vs_path, loaded_files else: logger.info("文件均未成功加载,请检查依赖包或替换为其他文件再次上传。") - # 若len(docs) !> 0,必然是所有文件均未加载成功,loaded_files必然为[],返回没有实际意义 - # 而若只返回None,可以跟上文的异常返回值保持一致,更便于下游任务判断 - return None + + return None, loaded_files def one_knowledge_add(self, vs_path, one_title, one_conent, one_content_segmentation, sentence_size): try: diff --git a/cli_demo.py b/cli_demo.py index c9ca558..a445e14 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -23,7 +23,7 @@ def main(): top_k=VECTOR_SEARCH_TOP_K) vs_path = None while not vs_path: - print("注意输入的路径是完整的文件路径,例如content/`knowledge_base_id`/file.md,多个路径用英文逗号分割") + print("注意输入的路径是完整的文件路径,例如knowledge_base/`knowledge_base_id`/content/file.md,多个路径用英文逗号分割") filepath = input("Input your local knowledge file path 请输入本地知识文件路径:") # 判断 filepath 是否为空,如果为空的话,重新让用户输入,防止用户误触回车 @@ -35,9 +35,9 @@ def main(): # filepath错误的返回为None, 如果直接用原先的vs_path,_ = local_doc_qa.init_knowledge_vector_store(filepath) # 会直接导致TypeError: cannot unpack non-iterable NoneType object而使得程序直接退出 # 因此需要先加一层判断,保证程序能继续运行 - temp = local_doc_qa.init_knowledge_vector_store(filepath) + temp,loaded_files = local_doc_qa.init_knowledge_vector_store(filepath) if temp is not None: - vs_path,loaded_files = temp + vs_path = temp # 如果loaded_files和len(filepath)不一致,则说明部分文件没有加载成功 # 如果是路径错误,则应该支持重新加载 if len(loaded_files) != len(filepath): From 300d287d61b315802d7d9fd930cc438cb973101a Mon Sep 17 00:00:00 2001 From: hzg0601 Date: Thu, 6 Jul 2023 04:47:44 +0800 Subject: [PATCH 10/12] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E7=9A=84=E5=A4=9A=E5=8D=A1=E9=83=A8=E7=BD=B2=E6=96=B9=E6=A1=88?= =?UTF-8?q?=EF=BC=8C=E5=9F=BA=E6=9C=AC=E4=BF=9D=E8=AF=81=E9=92=88=E5=AF=B9?= =?UTF-8?q?=E6=96=B0=E6=A8=A1=E5=9E=8B=E4=B9=9F=E4=B8=8D=E4=BC=9A=E5=A4=B1?= =?UTF-8?q?=E8=B4=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/loader/loader.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/models/loader/loader.py b/models/loader/loader.py index f315e6c..9fe5ce9 100644 --- a/models/loader/loader.py +++ b/models/loader/loader.py @@ -140,7 +140,17 @@ class LoaderCheckPoint: elif 'moss' in model_name.lower(): self.device_map = self.moss_auto_configure_device_map(num_gpus, model_name) else: - self.device_map = self.chatglm_auto_configure_device_map(num_gpus) + # 基于如下方式作为默认的多卡加载方案针对新模型基本不会失败 + # 在chatglm2-6b,bloom-3b,blooz-7b1上进行了测试,GPU负载也相对均衡 + from accelerate.utils import get_balanced_memory + max_memory = get_balanced_memory(model, + dtype=torch.int8 if self.load_in_8bit else None, + low_zero=False, + no_split_module_classes=model._no_split_modules) + self.device_map = infer_auto_device_map(model, + dtype=torch.float16 if not self.load_in_8bit else torch.int8, + max_memory=max_memory, + no_split_module_classes=model._no_split_modules) model = dispatch_model(model, device_map=self.device_map) else: From 1940da0adf920d44c48cc15f3f2cfaf610080fd1 Mon Sep 17 00:00:00 2001 From: hzg0601 Date: Tue, 11 Jul 2023 23:35:24 +0800 Subject: [PATCH 11/12] =?UTF-8?q?=E6=B5=8B=E8=AF=95openai=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E6=88=90=E5=8A=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/model_config.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/configs/model_config.py b/configs/model_config.py index 856a00b..db52851 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -160,6 +160,14 @@ llm_model_dict = { "api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url" "api_key": "EMPTY" }, + # 调用chatgpt时如果报出: urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443): + # Max retries exceeded with url: /v1/chat/completions + # 则需要将urllib3版本修改为1.25.11 + + # 如果报出:raise NewConnectionError( + # urllib3.exceptions.NewConnectionError: : + # Failed to establish a new connection: [WinError 10060] + # 则是因为内地和香港的IP都被OPENAI封了,需要挂切换为日本、新加坡等地 "openai-chatgpt-3.5": { "name": "gpt-3.5-turbo", "pretrained_model_name": "gpt-3.5-turbo", From f68d347c25b4bdd07f293c65a6e44a673a11f614 Mon Sep 17 00:00:00 2001 From: hzg0601 Date: Thu, 13 Jul 2023 21:22:35 +0800 Subject: [PATCH 12/12] add ptuning-v2 dir --- configs/model_config.py | 7 +++++++ models/loader/args.py | 2 ++ 2 files changed, 9 insertions(+) diff --git a/configs/model_config.py b/configs/model_config.py index db52851..21b1980 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -87,6 +87,12 @@ llm_model_dict = { "local_model_path": None, "provides": "MOSSLLM" }, + "moss-int4": { + "name": "moss", + "pretrained_model_name": "fnlp/moss-moon-003-sft-int4", + "local_model_path": None, + "provides": "MOSSLLM" + }, "vicuna-13b-hf": { "name": "vicuna-13b-hf", "pretrained_model_name": "vicuna-13b-hf", @@ -197,6 +203,7 @@ STREAMING = True # Use p-tuning-v2 PrefixEncoder USE_PTUNING_V2 = False +PTUNING_DIR='./ptuing-v2' # LLM running device LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" diff --git a/models/loader/args.py b/models/loader/args.py index b15ad5e..92105e2 100644 --- a/models/loader/args.py +++ b/models/loader/args.py @@ -43,6 +43,8 @@ parser.add_argument('--no-remote-model', action='store_true', help='remote in th parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.') parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.') parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras") +parser.add_argument('--use-ptuning-v2',type=str,default=False,help="whether use ptuning-v2 checkpoint") +parser.add_argument("--ptuning-dir",type=str,default=PTUNING_DIR,help="the dir of ptuning-v2 checkpoint") # Accelerate/transformers parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT,