make api and webui support top_k paramter
This commit is contained in:
parent
df61f31c8e
commit
5c804aac75
|
|
@ -52,6 +52,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
||||||
):
|
):
|
||||||
async def knowledge_base_chat_iterator(query: str,
|
async def knowledge_base_chat_iterator(query: str,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
|
top_k: int,
|
||||||
) -> AsyncIterable[str]:
|
) -> AsyncIterable[str]:
|
||||||
callback = AsyncIteratorCallbackHandler()
|
callback = AsyncIteratorCallbackHandler()
|
||||||
model = ChatOpenAI(
|
model = ChatOpenAI(
|
||||||
|
|
@ -62,7 +63,6 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
||||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||||
model_name=LLM_MODEL
|
model_name=LLM_MODEL
|
||||||
)
|
)
|
||||||
|
|
||||||
docs = lookup_vs(query, knowledge_base_name, top_k)
|
docs = lookup_vs(query, knowledge_base_name, top_k)
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
context = "\n".join([doc.page_content for doc in docs])
|
||||||
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
|
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
|
||||||
|
|
@ -80,4 +80,4 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
||||||
yield token
|
yield token
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return StreamingResponse(knowledge_base_chat_iterator(query, knowledge_base_name), media_type="text/event-stream")
|
return StreamingResponse(knowledge_base_chat_iterator(query, knowledge_base_name, top_k), media_type="text/event-stream")
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ chat_box = ChatBox(
|
||||||
)
|
)
|
||||||
|
|
||||||
def dialogue_page(api: ApiRequest):
|
def dialogue_page(api: ApiRequest):
|
||||||
|
chat_box.init_session()
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
def on_mode_change():
|
def on_mode_change():
|
||||||
mode = st.session_state.dialogue_mode
|
mode = st.session_state.dialogue_mode
|
||||||
|
|
|
||||||
|
|
@ -254,6 +254,7 @@ class ApiRequest:
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
|
top_k: int = 3,
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
|
|
@ -264,12 +265,12 @@ class ApiRequest:
|
||||||
|
|
||||||
if no_remote_api:
|
if no_remote_api:
|
||||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||||
response = knowledge_base_chat(query, knowledge_base_name)
|
response = knowledge_base_chat(query, knowledge_base_name, top_k)
|
||||||
return self._fastapi_stream2generator(response)
|
return self._fastapi_stream2generator(response)
|
||||||
else:
|
else:
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/chat/knowledge_base_chat",
|
"/chat/knowledge_base_chat",
|
||||||
json={"query": query, "knowledge_base_name": knowledge_base_name},
|
json={"query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k},
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
return self._httpx_stream2generator(response)
|
return self._httpx_stream2generator(response)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue