diff --git a/api.py b/api.py index 0413c95..4965cbf 100644 --- a/api.py +++ b/api.py @@ -362,7 +362,7 @@ async def chat( ) -async def stream_chat(websocket: WebSocket, knowledge_base_id: str): +async def stream_chat(websocket: WebSocket): await websocket.accept() turn = 1 while True: @@ -404,6 +404,41 @@ async def stream_chat(websocket: WebSocket, knowledge_base_id: str): ) turn += 1 +async def stream_chat_bing(websocket: WebSocket): + """ + 基于bing搜索的流式问答 + """ + await websocket.accept() + turn = 1 + while True: + input_json = await websocket.receive_json() + question, history = input_json["question"], input_json["history"] + + await websocket.send_json({"question": question, "turn": turn, "flag": "start"}) + + last_print_len = 0 + for resp, history in local_doc_qa.get_search_result_based_answer(question, chat_history=history, streaming=True): + await websocket.send_text(resp["result"][last_print_len:]) + last_print_len = len(resp["result"]) + + source_documents = [ + f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" + f"""相关度:{doc.metadata['score']}\n\n""" + for inum, doc in enumerate(resp["source_documents"]) + ] + + await websocket.send_text( + json.dumps( + { + "question": question, + "turn": turn, + "flag": "end", + "sources_documents": source_documents, + }, + ensure_ascii=False, + ) + ) + turn += 1 async def document(): return RedirectResponse(url="/docs") @@ -428,10 +463,17 @@ def api_start(host, port): allow_methods=["*"], allow_headers=["*"], ) - app.websocket("/local_doc_qa/stream-chat/{knowledge_base_id}")(stream_chat) + # 修改了stream_chat的接口,直接通过ws://localhost:7861/local_doc_qa/stream_chat建立连接,在请求体中选择knowledge_base_id + app.websocket("/local_doc_qa/stream_chat")(stream_chat) app.get("/", response_model=BaseResponse)(document) + # 增加基于bing搜索的流式问答 + # 需要说明的是,如果想测试websocket的流式问答,需要使用支持websocket的测试工具,如postman,insomnia + # 强烈推荐开源的insomnia + # 在测试时选择new websocket request,并将url的协议改为ws,如ws://localhost:7861/local_doc_qa/stream_chat_bing + app.websocket("/local_doc_qa/stream_chat_bing")(stream_chat_bing) + app.post("/chat", response_model=ChatMessage)(chat) app.post("/local_doc_qa/upload_file", response_model=BaseResponse)(upload_file) diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index fe70066..9ec3db5 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -200,6 +200,7 @@ class LocalDocQA: return vs_path, loaded_files else: logger.info("文件均未成功加载,请检查依赖包或替换为其他文件再次上传。") + return None, loaded_files def one_knowledge_add(self, vs_path, one_title, one_conent, one_content_segmentation, sentence_size): diff --git a/cli_demo.py b/cli_demo.py index 938ebb3..a445e14 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -23,11 +23,33 @@ def main(): top_k=VECTOR_SEARCH_TOP_K) vs_path = None while not vs_path: + print("注意输入的路径是完整的文件路径,例如knowledge_base/`knowledge_base_id`/content/file.md,多个路径用英文逗号分割") filepath = input("Input your local knowledge file path 请输入本地知识文件路径:") + # 判断 filepath 是否为空,如果为空的话,重新让用户输入,防止用户误触回车 if not filepath: continue - vs_path, _ = local_doc_qa.init_knowledge_vector_store(filepath) + + # 支持加载多个文件 + filepath = filepath.split(",") + # filepath错误的返回为None, 如果直接用原先的vs_path,_ = local_doc_qa.init_knowledge_vector_store(filepath) + # 会直接导致TypeError: cannot unpack non-iterable NoneType object而使得程序直接退出 + # 因此需要先加一层判断,保证程序能继续运行 + temp,loaded_files = local_doc_qa.init_knowledge_vector_store(filepath) + if temp is not None: + vs_path = temp + # 如果loaded_files和len(filepath)不一致,则说明部分文件没有加载成功 + # 如果是路径错误,则应该支持重新加载 + if len(loaded_files) != len(filepath): + reload_flag = eval(input("部分文件加载失败,若提示路径不存在,可重新加载,是否重新加载,输入True或False: ")) + if reload_flag: + vs_path = None + continue + + print(f"the loaded vs_path is 加载的vs_path为: {vs_path}") + else: + print("load file failed, re-input your local knowledge file path 请重新输入本地知识文件路径") + history = [] while True: query = input("Input your question 请输入问题:")