1. change bing_search_chat and duckduckgo_search_chat into search_engine_chat
2. add knowledge_base not found to knowledge_base_chat and add search_engine not found to search_engine_chat
This commit is contained in:
parent
b62ea6bd2a
commit
329c24ee73
|
|
@ -9,7 +9,7 @@ from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from starlette.responses import RedirectResponse
|
from starlette.responses import RedirectResponse
|
||||||
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
||||||
bing_search_chat, duckduckgo_search_chat)
|
search_engine_chat)
|
||||||
from server.knowledge_base import (list_kbs, create_kb, delete_kb,
|
from server.knowledge_base import (list_kbs, create_kb, delete_kb,
|
||||||
list_docs, upload_doc, delete_doc, update_doc)
|
list_docs, upload_doc, delete_doc, update_doc)
|
||||||
from server.utils import BaseResponse, ListResponse
|
from server.utils import BaseResponse, ListResponse
|
||||||
|
|
@ -39,6 +39,7 @@ def create_app():
|
||||||
response_model=BaseResponse,
|
response_model=BaseResponse,
|
||||||
summary="swagger 文档")(document)
|
summary="swagger 文档")(document)
|
||||||
|
|
||||||
|
# Tag: Chat
|
||||||
app.post("/chat/fastchat",
|
app.post("/chat/fastchat",
|
||||||
tags=["Chat"],
|
tags=["Chat"],
|
||||||
summary="与llm模型对话(直接与fastchat api对话)")(openai_chat)
|
summary="与llm模型对话(直接与fastchat api对话)")(openai_chat)
|
||||||
|
|
@ -51,14 +52,11 @@ def create_app():
|
||||||
tags=["Chat"],
|
tags=["Chat"],
|
||||||
summary="与知识库对话")(knowledge_base_chat)
|
summary="与知识库对话")(knowledge_base_chat)
|
||||||
|
|
||||||
app.post("/chat/bing_search_chat",
|
app.post("/chat/search_engine_chat",
|
||||||
tags=["Chat"],
|
tags=["Chat"],
|
||||||
summary="与Bing搜索对话")(bing_search_chat)
|
summary="与搜索引擎对话")(search_engine_chat)
|
||||||
|
|
||||||
app.post("/chat/duckduckgo_search_chat",
|
|
||||||
tags=["Chat"],
|
|
||||||
summary="与DuckDuckGo搜索对话")(duckduckgo_search_chat)
|
|
||||||
|
|
||||||
|
# Tag: Knowledge Base Management
|
||||||
app.get("/knowledge_base/list_knowledge_bases",
|
app.get("/knowledge_base/list_knowledge_bases",
|
||||||
tags=["Knowledge Base Management"],
|
tags=["Knowledge Base Management"],
|
||||||
response_model=ListResponse,
|
response_model=ListResponse,
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
from .chat import chat
|
from .chat import chat
|
||||||
from .knowledge_base_chat import knowledge_base_chat
|
from .knowledge_base_chat import knowledge_base_chat
|
||||||
from .openai_chat import openai_chat
|
from .openai_chat import openai_chat
|
||||||
from .duckduckgo_search_chat import duckduckgo_search_chat
|
from .search_engine_chat import search_engine_chat
|
||||||
from .bing_search_chat import bing_search_chat
|
|
||||||
|
|
|
||||||
|
|
@ -1,64 +0,0 @@
|
||||||
from langchain.utilities import DuckDuckGoSearchAPIWrapper
|
|
||||||
from fastapi import Body
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE)
|
|
||||||
from server.chat.utils import wrap_done
|
|
||||||
from langchain.chat_models import ChatOpenAI
|
|
||||||
from langchain import LLMChain
|
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
|
||||||
from typing import AsyncIterable
|
|
||||||
import asyncio
|
|
||||||
from langchain.prompts import PromptTemplate
|
|
||||||
from langchain.docstore.document import Document
|
|
||||||
|
|
||||||
|
|
||||||
def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K):
|
|
||||||
search = DuckDuckGoSearchAPIWrapper()
|
|
||||||
return search.results(text, result_len)
|
|
||||||
|
|
||||||
|
|
||||||
def search_result2docs(search_results):
|
|
||||||
docs = []
|
|
||||||
for result in search_results:
|
|
||||||
doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "",
|
|
||||||
metadata={"source": result["link"] if "link" in result.keys() else "",
|
|
||||||
"filename": result["title"] if "title" in result.keys() else ""})
|
|
||||||
docs.append(doc)
|
|
||||||
return docs
|
|
||||||
|
|
||||||
|
|
||||||
def duckduckgo_search_chat(query: str = Body(..., description="用户输入", example="你好"),
|
|
||||||
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
|
||||||
):
|
|
||||||
async def duckduckgo_search_chat_iterator(query: str,
|
|
||||||
top_k: int
|
|
||||||
) -> AsyncIterable[str]:
|
|
||||||
callback = AsyncIteratorCallbackHandler()
|
|
||||||
model = ChatOpenAI(
|
|
||||||
streaming=True,
|
|
||||||
verbose=True,
|
|
||||||
callbacks=[callback],
|
|
||||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
|
||||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
|
||||||
model_name=LLM_MODEL
|
|
||||||
)
|
|
||||||
|
|
||||||
results = duckduckgo_search(query, result_len=top_k)
|
|
||||||
docs = search_result2docs(results)
|
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
|
||||||
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
|
|
||||||
|
|
||||||
chain = LLMChain(prompt=prompt, llm=model)
|
|
||||||
|
|
||||||
# Begin a task that runs in the background.
|
|
||||||
task = asyncio.create_task(wrap_done(
|
|
||||||
chain.acall({"context": context, "question": query}),
|
|
||||||
callback.done),
|
|
||||||
)
|
|
||||||
|
|
||||||
async for token in callback.aiter():
|
|
||||||
# Use server-sent-events to stream the response
|
|
||||||
yield token
|
|
||||||
await task
|
|
||||||
|
|
||||||
return StreamingResponse(duckduckgo_search_chat_iterator(query, top_k), media_type="text/event-stream")
|
|
||||||
|
|
@ -4,6 +4,9 @@ from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
||||||
CACHED_VS_NUM, VECTOR_SEARCH_TOP_K,
|
CACHED_VS_NUM, VECTOR_SEARCH_TOP_K,
|
||||||
embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE)
|
embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE)
|
||||||
from server.chat.utils import wrap_done
|
from server.chat.utils import wrap_done
|
||||||
|
from server.utils import BaseResponse
|
||||||
|
import os
|
||||||
|
from server.knowledge_base.utils import get_kb_path
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
from langchain import LLMChain
|
from langchain import LLMChain
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
|
|
@ -51,6 +54,10 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
||||||
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
|
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
|
||||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||||
):
|
):
|
||||||
|
kb_path = get_kb_path(knowledge_base_name)
|
||||||
|
if not os.path.exists(kb_path):
|
||||||
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
async def knowledge_base_chat_iterator(query: str,
|
async def knowledge_base_chat_iterator(query: str,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
|
|
@ -76,9 +83,15 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
||||||
callback.done),
|
callback.done),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
source_documents = [
|
||||||
|
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
|
||||||
|
for inum, doc in enumerate(docs)
|
||||||
|
]
|
||||||
|
|
||||||
async for token in callback.aiter():
|
async for token in callback.aiter():
|
||||||
# Use server-sent-events to stream the response
|
# Use server-sent-events to stream the response
|
||||||
yield token
|
yield {"answer": token,
|
||||||
|
"docs": source_documents}
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return StreamingResponse(knowledge_base_chat_iterator(query, knowledge_base_name, top_k),
|
return StreamingResponse(knowledge_base_chat_iterator(query, knowledge_base_name, top_k),
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
from langchain.utilities import BingSearchAPIWrapper
|
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
|
||||||
from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY
|
from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE)
|
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE)
|
||||||
from server.chat.utils import wrap_done
|
from server.chat.utils import wrap_done
|
||||||
|
from server.utils import BaseResponse
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
from langchain import LLMChain
|
from langchain import LLMChain
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
|
|
@ -12,6 +13,7 @@ import asyncio
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
|
|
||||||
def bing_search(text, result_len=SEARCH_ENGINE_TOP_K):
|
def bing_search(text, result_len=SEARCH_ENGINE_TOP_K):
|
||||||
if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
|
if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
|
||||||
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
|
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
|
||||||
|
|
@ -22,6 +24,16 @@ def bing_search(text, result_len=SEARCH_ENGINE_TOP_K):
|
||||||
return search.results(text, result_len)
|
return search.results(text, result_len)
|
||||||
|
|
||||||
|
|
||||||
|
def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K):
|
||||||
|
search = DuckDuckGoSearchAPIWrapper()
|
||||||
|
return search.results(text, result_len)
|
||||||
|
|
||||||
|
|
||||||
|
SEARCH_ENGINES = {"bing": bing_search,
|
||||||
|
"duckduckgo": duckduckgo_search,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def search_result2docs(search_results):
|
def search_result2docs(search_results):
|
||||||
docs = []
|
docs = []
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
|
|
@ -32,10 +44,25 @@ def search_result2docs(search_results):
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
||||||
def bing_search_chat(query: str = Body(..., description="用户输入", example="你好"),
|
def lookup_search_engine(
|
||||||
|
query: str,
|
||||||
|
search_engine_name: str,
|
||||||
|
top_k: int = SEARCH_ENGINE_TOP_K,
|
||||||
|
):
|
||||||
|
results = SEARCH_ENGINES[search_engine_name](query, result_len=top_k)
|
||||||
|
docs = search_result2docs(results)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
def search_engine_chat(query: str = Body(..., description="用户输入", example="你好"),
|
||||||
|
search_engine_name: str = Body(..., description="搜索引擎名称", example="duckduckgo"),
|
||||||
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
||||||
):
|
):
|
||||||
async def bing_search_chat_iterator(query: str,
|
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||||
|
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||||||
|
|
||||||
|
async def search_engine_chat_iterator(query: str,
|
||||||
|
search_engine_name: str,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
) -> AsyncIterable[str]:
|
) -> AsyncIterable[str]:
|
||||||
callback = AsyncIteratorCallbackHandler()
|
callback = AsyncIteratorCallbackHandler()
|
||||||
|
|
@ -48,8 +75,7 @@ def bing_search_chat(query: str = Body(..., description="用户输入", example=
|
||||||
model_name=LLM_MODEL
|
model_name=LLM_MODEL
|
||||||
)
|
)
|
||||||
|
|
||||||
results = bing_search(query, result_len=top_k)
|
docs = lookup_search_engine(query, search_engine_name, top_k)
|
||||||
docs = search_result2docs(results)
|
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
context = "\n".join([doc.page_content for doc in docs])
|
||||||
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
|
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
|
||||||
|
|
||||||
|
|
@ -61,9 +87,16 @@ def bing_search_chat(query: str = Body(..., description="用户输入", example=
|
||||||
callback.done),
|
callback.done),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
source_documents = [
|
||||||
|
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
|
||||||
|
for inum, doc in enumerate(docs)
|
||||||
|
]
|
||||||
|
|
||||||
async for token in callback.aiter():
|
async for token in callback.aiter():
|
||||||
# Use server-sent-events to stream the response
|
# Use server-sent-events to stream the response
|
||||||
yield token
|
yield {"answer": token,
|
||||||
|
"docs": source_documents}
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return StreamingResponse(bing_search_chat_iterator(query, top_k), media_type="text/event-stream")
|
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k),
|
||||||
|
media_type="text/event-stream")
|
||||||
|
|
@ -275,48 +275,27 @@ class ApiRequest:
|
||||||
)
|
)
|
||||||
return self._httpx_stream2generator(response)
|
return self._httpx_stream2generator(response)
|
||||||
|
|
||||||
def duckduckgo_search_chat(
|
def search_engine_chat(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
search_engine_name: str,
|
||||||
|
top_k: int,
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/chat/duckduckgo_search_chat接口
|
对应api.py/chat/search_engine_chat接口
|
||||||
'''
|
'''
|
||||||
if no_remote_api is None:
|
if no_remote_api is None:
|
||||||
no_remote_api = self.no_remote_api
|
no_remote_api = self.no_remote_api
|
||||||
|
|
||||||
if no_remote_api:
|
if no_remote_api:
|
||||||
from server.chat.duckduckgo_search_chat import duckduckgo_search_chat
|
from server.chat.search_engine_chat import search_engine_chat
|
||||||
response = duckduckgo_search_chat(query)
|
response = search_engine_chat(query, search_engine_name, top_k)
|
||||||
return self._fastapi_stream2generator(response)
|
return self._fastapi_stream2generator(response)
|
||||||
else:
|
else:
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/chat/duckduckgo_search_chat",
|
"/chat/search_engine_chat",
|
||||||
json=f"{query}",
|
json={"query": query, "search_engine_name": search_engine_name, "top_k": top_k},
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
return self._httpx_stream2generator(response)
|
|
||||||
|
|
||||||
def bing_search_chat(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
no_remote_api: bool = None,
|
|
||||||
):
|
|
||||||
'''
|
|
||||||
对应api.py/chat/bing_search_chat接口
|
|
||||||
'''
|
|
||||||
if no_remote_api is None:
|
|
||||||
no_remote_api = self.no_remote_api
|
|
||||||
|
|
||||||
if no_remote_api:
|
|
||||||
from server.chat.bing_search_chat import bing_search_chat
|
|
||||||
response = bing_search_chat(query)
|
|
||||||
return self._fastapi_stream2generator(response)
|
|
||||||
else:
|
|
||||||
response = self.post(
|
|
||||||
"/chat/bing_search_chat",
|
|
||||||
json=f"{query}",
|
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
return self._httpx_stream2generator(response)
|
return self._httpx_stream2generator(response)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue