diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 58ee7e0..68d99e4 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -13,6 +13,7 @@ from typing import AsyncIterable import asyncio from langchain.prompts import PromptTemplate from server.knowledge_base.utils import lookup_vs +import json def knowledge_base_chat(query: str = Body(..., description="用户输入", example="你好"), @@ -55,8 +56,8 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp async for token in callback.aiter(): # Use server-sent-events to stream the response - yield {"answer": token, - "docs": source_documents} + yield json.dumps({"answer": token, + "docs": source_documents}) await task return StreamingResponse(knowledge_base_chat_iterator(query, knowledge_base_name, top_k), diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index ad63183..4689d58 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -12,6 +12,7 @@ from typing import AsyncIterable import asyncio from langchain.prompts import PromptTemplate from langchain.docstore.document import Document +import json def bing_search(text, result_len=SEARCH_ENGINE_TOP_K): @@ -94,8 +95,8 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl async for token in callback.aiter(): # Use server-sent-events to stream the response - yield {"answer": token, - "docs": source_documents} + yield json.dumps({"answer": token, + "docs": source_documents}) await task return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k), diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 383c89d..0504a85 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -72,7 +72,7 @@ def lookup_vs( embedding_model: str = EMBEDDING_MODEL, embedding_device: str = EMBEDDING_DEVICE, ): - search_index = load_vector_store(knowledge_base_name, embedding_model, embedding_device) + search_index = load_vector_store(knowledge_base_name, embedding_model, embedding_device, _VECTOR_STORE_TICKS.get(knowledge_base_name)) docs = search_index.similarity_search(query, k=top_k) return docs