fix chat and knowledge_base_chat

This commit is contained in:
liunux4odoo 2023-08-14 10:35:47 +08:00
parent 84e4981cc1
commit 126bd51232
2 changed files with 14 additions and 5 deletions

View File

@ -19,6 +19,7 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}, {"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant", "content": "虎头虎脑"}]] {"role": "assistant", "content": "虎头虎脑"}]]
), ),
stream: bool = Body(False, description="流式输出"),
): ):
history = [History(**h) if isinstance(h, dict) else h for h in history] history = [History(**h) if isinstance(h, dict) else h for h in history]
@ -46,9 +47,16 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
callback.done), callback.done),
) )
async for token in callback.aiter(): if stream:
# Use server-sent-events to stream the response async for token in callback.aiter():
yield token # Use server-sent-events to stream the response
yield token
else:
answer = ""
async for token in callback.aiter():
answer += token
yield answer
await task await task
return StreamingResponse(chat_iterator(query, history), return StreamingResponse(chat_iterator(query, history),

View File

@ -14,6 +14,7 @@ from typing import List, Optional
from server.chat.utils import History from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBService, KBServiceFactory from server.knowledge_base.kb_service.base import KBService, KBServiceFactory
import json import json
import os
def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]), def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
@ -64,7 +65,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
) )
source_documents = [ source_documents = [
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n""" f"""出处 [{inum + 1}] [{os.path.split(doc.metadata["source"])[-1]}] \n\n{doc.page_content}\n\n"""
for inum, doc in enumerate(docs) for inum, doc in enumerate(docs)
] ]
@ -78,7 +79,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
answer = "" answer = ""
async for token in callback.aiter(): async for token in callback.aiter():
answer += token answer += token
yield json.dumps({"answer": token, yield json.dumps({"answer": answer,
"docs": source_documents}, "docs": source_documents},
ensure_ascii=False) ensure_ascii=False)