1. change bing_search_chat and duckduckgo_search_chat into search_engine_chat

2. add knowledge_base not found to knowledge_base_chat and add search_engine not found to search_engine_chat
This commit is contained in:
imClumsyPanda 2023-08-03 18:22:36 +08:00
parent b62ea6bd2a
commit 329c24ee73
6 changed files with 72 additions and 114 deletions

View File

@ -9,7 +9,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse
from server.chat import (chat, knowledge_base_chat, openai_chat,
bing_search_chat, duckduckgo_search_chat)
search_engine_chat)
from server.knowledge_base import (list_kbs, create_kb, delete_kb,
list_docs, upload_doc, delete_doc, update_doc)
from server.utils import BaseResponse, ListResponse
@ -39,6 +39,7 @@ def create_app():
response_model=BaseResponse,
summary="swagger 文档")(document)
# Tag: Chat
app.post("/chat/fastchat",
tags=["Chat"],
summary="与llm模型对话(直接与fastchat api对话)")(openai_chat)
@ -51,14 +52,11 @@ def create_app():
tags=["Chat"],
summary="与知识库对话")(knowledge_base_chat)
app.post("/chat/bing_search_chat",
app.post("/chat/search_engine_chat",
tags=["Chat"],
summary="与Bing搜索对话")(bing_search_chat)
app.post("/chat/duckduckgo_search_chat",
tags=["Chat"],
summary="与DuckDuckGo搜索对话")(duckduckgo_search_chat)
summary="与搜索引擎对话")(search_engine_chat)
# Tag: Knowledge Base Management
app.get("/knowledge_base/list_knowledge_bases",
tags=["Knowledge Base Management"],
response_model=ListResponse,

View File

@ -1,5 +1,4 @@
from .chat import chat
from .knowledge_base_chat import knowledge_base_chat
from .openai_chat import openai_chat
from .duckduckgo_search_chat import duckduckgo_search_chat
from .bing_search_chat import bing_search_chat
from .search_engine_chat import search_engine_chat

View File

@ -1,64 +0,0 @@
from langchain.utilities import DuckDuckGoSearchAPIWrapper
from fastapi import Body
from fastapi.responses import StreamingResponse
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE)
from server.chat.utils import wrap_done
from langchain.chat_models import ChatOpenAI
from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
from langchain.prompts import PromptTemplate
from langchain.docstore.document import Document
def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K):
search = DuckDuckGoSearchAPIWrapper()
return search.results(text, result_len)
def search_result2docs(search_results):
docs = []
for result in search_results:
doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "",
metadata={"source": result["link"] if "link" in result.keys() else "",
"filename": result["title"] if "title" in result.keys() else ""})
docs.append(doc)
return docs
def duckduckgo_search_chat(query: str = Body(..., description="用户输入", example="你好"),
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
):
async def duckduckgo_search_chat_iterator(query: str,
top_k: int
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL
)
results = duckduckgo_search(query, result_len=top_k)
docs = search_result2docs(results)
context = "\n".join([doc.page_content for doc in docs])
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
chain = LLMChain(prompt=prompt, llm=model)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.acall({"context": context, "question": query}),
callback.done),
)
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield token
await task
return StreamingResponse(duckduckgo_search_chat_iterator(query, top_k), media_type="text/event-stream")

View File

@ -4,6 +4,9 @@ from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
CACHED_VS_NUM, VECTOR_SEARCH_TOP_K,
embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE)
from server.chat.utils import wrap_done
from server.utils import BaseResponse
import os
from server.knowledge_base.utils import get_kb_path
from langchain.chat_models import ChatOpenAI
from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
@ -51,6 +54,10 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
):
kb_path = get_kb_path(knowledge_base_name)
if not os.path.exists(kb_path):
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
async def knowledge_base_chat_iterator(query: str,
knowledge_base_name: str,
top_k: int,
@ -76,9 +83,15 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
callback.done),
)
source_documents = [
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
for inum, doc in enumerate(docs)
]
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield token
yield {"answer": token,
"docs": source_documents}
await task
return StreamingResponse(knowledge_base_chat_iterator(query, knowledge_base_name, top_k),

View File

@ -1,9 +1,10 @@
from langchain.utilities import BingSearchAPIWrapper
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY
from fastapi import Body
from fastapi.responses import StreamingResponse
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE)
from server.chat.utils import wrap_done
from server.utils import BaseResponse
from langchain.chat_models import ChatOpenAI
from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
@ -12,6 +13,7 @@ import asyncio
from langchain.prompts import PromptTemplate
from langchain.docstore.document import Document
def bing_search(text, result_len=SEARCH_ENGINE_TOP_K):
if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
@ -22,6 +24,16 @@ def bing_search(text, result_len=SEARCH_ENGINE_TOP_K):
return search.results(text, result_len)
def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K):
search = DuckDuckGoSearchAPIWrapper()
return search.results(text, result_len)
SEARCH_ENGINES = {"bing": bing_search,
"duckduckgo": duckduckgo_search,
}
def search_result2docs(search_results):
docs = []
for result in search_results:
@ -32,12 +44,27 @@ def search_result2docs(search_results):
return docs
def bing_search_chat(query: str = Body(..., description="用户输入", example="你好"),
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
):
async def bing_search_chat_iterator(query: str,
top_k: int,
) -> AsyncIterable[str]:
def lookup_search_engine(
query: str,
search_engine_name: str,
top_k: int = SEARCH_ENGINE_TOP_K,
):
results = SEARCH_ENGINES[search_engine_name](query, result_len=top_k)
docs = search_result2docs(results)
return docs
def search_engine_chat(query: str = Body(..., description="用户输入", example="你好"),
search_engine_name: str = Body(..., description="搜索引擎名称", example="duckduckgo"),
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
):
if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
async def search_engine_chat_iterator(query: str,
search_engine_name: str,
top_k: int,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
streaming=True,
@ -48,8 +75,7 @@ def bing_search_chat(query: str = Body(..., description="用户输入", example=
model_name=LLM_MODEL
)
results = bing_search(query, result_len=top_k)
docs = search_result2docs(results)
docs = lookup_search_engine(query, search_engine_name, top_k)
context = "\n".join([doc.page_content for doc in docs])
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
@ -61,9 +87,16 @@ def bing_search_chat(query: str = Body(..., description="用户输入", example=
callback.done),
)
source_documents = [
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
for inum, doc in enumerate(docs)
]
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield token
yield {"answer": token,
"docs": source_documents}
await task
return StreamingResponse(bing_search_chat_iterator(query, top_k), media_type="text/event-stream")
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k),
media_type="text/event-stream")

View File

@ -275,48 +275,27 @@ class ApiRequest:
)
return self._httpx_stream2generator(response)
def duckduckgo_search_chat(
def search_engine_chat(
self,
query: str,
search_engine_name: str,
top_k: int,
no_remote_api: bool = None,
):
'''
对应api.py/chat/duckduckgo_search_chat接口
对应api.py/chat/search_engine_chat接口
'''
if no_remote_api is None:
no_remote_api = self.no_remote_api
if no_remote_api:
from server.chat.duckduckgo_search_chat import duckduckgo_search_chat
response = duckduckgo_search_chat(query)
from server.chat.search_engine_chat import search_engine_chat
response = search_engine_chat(query, search_engine_name, top_k)
return self._fastapi_stream2generator(response)
else:
response = self.post(
"/chat/duckduckgo_search_chat",
json=f"{query}",
stream=True,
)
return self._httpx_stream2generator(response)
def bing_search_chat(
self,
query: str,
no_remote_api: bool = None,
):
'''
对应api.py/chat/bing_search_chat接口
'''
if no_remote_api is None:
no_remote_api = self.no_remote_api
if no_remote_api:
from server.chat.bing_search_chat import bing_search_chat
response = bing_search_chat(query)
return self._fastapi_stream2generator(response)
else:
response = self.post(
"/chat/bing_search_chat",
json=f"{query}",
"/chat/search_engine_chat",
json={"query": query, "search_engine_name": search_engine_name, "top_k": top_k},
stream=True,
)
return self._httpx_stream2generator(response)