From c52308d6050d914a3f77740bf4f95b79d3cb6687 Mon Sep 17 00:00:00 2001 From: Zhi-guo Huang Date: Sun, 2 Jul 2023 21:54:12 +0800 Subject: [PATCH] =?UTF-8?q?1.=20=E4=BF=AE=E6=94=B9stream=5Fchat=E7=9A=84?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3=EF=BC=8C=E5=9C=A8=E8=AF=B7=E6=B1=82=E4=BD=93?= =?UTF-8?q?=E4=B8=AD=E9=80=89=E6=8B=A9knowledge=5Fbase=5Fid;2.=20=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0stream=5Fchat=5Fbing=E6=8E=A5=E5=8F=A3=EF=BC=9B3.=20?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=86=E8=B0=83=E7=94=A8=E6=B5=81=E5=BC=8F?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3=E7=9A=84=E6=96=B9=E6=B3=95=E8=AF=B4=E6=98=8E?= =?UTF-8?q?=EF=BC=9B4.=20=E4=BC=98=E5=8C=96cli=5Fdemo.py=E7=9A=84=E9=80=BB?= =?UTF-8?q?=E8=BE=91=EF=BC=9A=E6=94=AF=E6=8C=81=20=E8=BE=93=E5=85=A5?= =?UTF-8?q?=E6=8F=90=E7=A4=BA=EF=BC=9B=E5=A4=9A=E8=BE=93=E5=85=A5=EF=BC=9B?= =?UTF-8?q?=E9=87=8D=E6=96=B0=E8=BE=93=E5=85=A5=20(#630)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 修复 bing_search.py的typo;更新model_config.py中Bing Subscription Key申请方式及注意事项 * 更新FAQ,增加了[Errno 110] Connection timed out的原因与解决方案 * 修改loader.py中load_in_8bit失败的原因和详细解决方案 * update loader.py * stream_chat_bing * 修改stream_chat的接口,在请求体中选择knowledge_base_id;增加stream_chat_bing接口 * 优化cli_demo.py的逻辑:支持 输入提示;多输入;重新输入 * update cli_demo.py * 按照review建议进行修改 --- api.py | 46 ++++++++++++++++++++++++++++++++++++++++-- chains/local_doc_qa.py | 1 + cli_demo.py | 24 +++++++++++++++++++++- 3 files changed, 68 insertions(+), 3 deletions(-) diff --git a/api.py b/api.py index 0413c95..4965cbf 100644 --- a/api.py +++ b/api.py @@ -362,7 +362,7 @@ async def chat( ) -async def stream_chat(websocket: WebSocket, knowledge_base_id: str): +async def stream_chat(websocket: WebSocket): await websocket.accept() turn = 1 while True: @@ -404,6 +404,41 @@ async def stream_chat(websocket: WebSocket, knowledge_base_id: str): ) turn += 1 +async def stream_chat_bing(websocket: WebSocket): + """ + 基于bing搜索的流式问答 + """ + await websocket.accept() + turn = 1 + while True: + input_json = await websocket.receive_json() + question, history = input_json["question"], input_json["history"] + + await websocket.send_json({"question": question, "turn": turn, "flag": "start"}) + + last_print_len = 0 + for resp, history in local_doc_qa.get_search_result_based_answer(question, chat_history=history, streaming=True): + await websocket.send_text(resp["result"][last_print_len:]) + last_print_len = len(resp["result"]) + + source_documents = [ + f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" + f"""相关度:{doc.metadata['score']}\n\n""" + for inum, doc in enumerate(resp["source_documents"]) + ] + + await websocket.send_text( + json.dumps( + { + "question": question, + "turn": turn, + "flag": "end", + "sources_documents": source_documents, + }, + ensure_ascii=False, + ) + ) + turn += 1 async def document(): return RedirectResponse(url="/docs") @@ -428,10 +463,17 @@ def api_start(host, port): allow_methods=["*"], allow_headers=["*"], ) - app.websocket("/local_doc_qa/stream-chat/{knowledge_base_id}")(stream_chat) + # 修改了stream_chat的接口,直接通过ws://localhost:7861/local_doc_qa/stream_chat建立连接,在请求体中选择knowledge_base_id + app.websocket("/local_doc_qa/stream_chat")(stream_chat) app.get("/", response_model=BaseResponse)(document) + # 增加基于bing搜索的流式问答 + # 需要说明的是,如果想测试websocket的流式问答,需要使用支持websocket的测试工具,如postman,insomnia + # 强烈推荐开源的insomnia + # 在测试时选择new websocket request,并将url的协议改为ws,如ws://localhost:7861/local_doc_qa/stream_chat_bing + app.websocket("/local_doc_qa/stream_chat_bing")(stream_chat_bing) + app.post("/chat", response_model=ChatMessage)(chat) app.post("/local_doc_qa/upload_file", response_model=BaseResponse)(upload_file) diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index fe70066..9ec3db5 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -200,6 +200,7 @@ class LocalDocQA: return vs_path, loaded_files else: logger.info("文件均未成功加载,请检查依赖包或替换为其他文件再次上传。") + return None, loaded_files def one_knowledge_add(self, vs_path, one_title, one_conent, one_content_segmentation, sentence_size): diff --git a/cli_demo.py b/cli_demo.py index 938ebb3..a445e14 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -23,11 +23,33 @@ def main(): top_k=VECTOR_SEARCH_TOP_K) vs_path = None while not vs_path: + print("注意输入的路径是完整的文件路径,例如knowledge_base/`knowledge_base_id`/content/file.md,多个路径用英文逗号分割") filepath = input("Input your local knowledge file path 请输入本地知识文件路径:") + # 判断 filepath 是否为空,如果为空的话,重新让用户输入,防止用户误触回车 if not filepath: continue - vs_path, _ = local_doc_qa.init_knowledge_vector_store(filepath) + + # 支持加载多个文件 + filepath = filepath.split(",") + # filepath错误的返回为None, 如果直接用原先的vs_path,_ = local_doc_qa.init_knowledge_vector_store(filepath) + # 会直接导致TypeError: cannot unpack non-iterable NoneType object而使得程序直接退出 + # 因此需要先加一层判断,保证程序能继续运行 + temp,loaded_files = local_doc_qa.init_knowledge_vector_store(filepath) + if temp is not None: + vs_path = temp + # 如果loaded_files和len(filepath)不一致,则说明部分文件没有加载成功 + # 如果是路径错误,则应该支持重新加载 + if len(loaded_files) != len(filepath): + reload_flag = eval(input("部分文件加载失败,若提示路径不存在,可重新加载,是否重新加载,输入True或False: ")) + if reload_flag: + vs_path = None + continue + + print(f"the loaded vs_path is 加载的vs_path为: {vs_path}") + else: + print("load file failed, re-input your local knowledge file path 请重新输入本地知识文件路径") + history = [] while True: query = input("Input your question 请输入问题:")