From 126bd5123213492bc50d24b2d9ca3c679607b6e1 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Mon, 14 Aug 2023 10:35:47 +0800 Subject: [PATCH] fix chat and knowledge_base_chat --- server/chat/chat.py | 14 +++++++++++--- server/chat/knowledge_base_chat.py | 5 +++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/server/chat/chat.py b/server/chat/chat.py index b864609..2bc21db 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -19,6 +19,7 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成 {"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}, {"role": "assistant", "content": "虎头虎脑"}]] ), + stream: bool = Body(False, description="流式输出"), ): history = [History(**h) if isinstance(h, dict) else h for h in history] @@ -46,9 +47,16 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成 callback.done), ) - async for token in callback.aiter(): - # Use server-sent-events to stream the response - yield token + if stream: + async for token in callback.aiter(): + # Use server-sent-events to stream the response + yield token + else: + answer = "" + async for token in callback.aiter(): + answer += token + yield answer + await task return StreamingResponse(chat_iterator(query, history), diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index d1adf59..d6dafe2 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -14,6 +14,7 @@ from typing import List, Optional from server.chat.utils import History from server.knowledge_base.kb_service.base import KBService, KBServiceFactory import json +import os def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]), @@ -64,7 +65,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp ) source_documents = [ - f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n""" + f"""出处 [{inum + 1}] [{os.path.split(doc.metadata["source"])[-1]}] \n\n{doc.page_content}\n\n""" for inum, doc in enumerate(docs) ] @@ -78,7 +79,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp answer = "" async for token in callback.aiter(): answer += token - yield json.dumps({"answer": token, + yield json.dumps({"answer": answer, "docs": source_documents}, ensure_ascii=False)