1. change bing_search_chat and duckduckgo_search_chat into search_engine_chat
2. add knowledge_base not found to knowledge_base_chat and add search_engine not found to search_engine_chat
This commit is contained in:
parent
b62ea6bd2a
commit
329c24ee73
|
|
@ -9,7 +9,7 @@ from fastapi import FastAPI
|
|||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.responses import RedirectResponse
|
||||
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
||||
bing_search_chat, duckduckgo_search_chat)
|
||||
search_engine_chat)
|
||||
from server.knowledge_base import (list_kbs, create_kb, delete_kb,
|
||||
list_docs, upload_doc, delete_doc, update_doc)
|
||||
from server.utils import BaseResponse, ListResponse
|
||||
|
|
@ -39,6 +39,7 @@ def create_app():
|
|||
response_model=BaseResponse,
|
||||
summary="swagger 文档")(document)
|
||||
|
||||
# Tag: Chat
|
||||
app.post("/chat/fastchat",
|
||||
tags=["Chat"],
|
||||
summary="与llm模型对话(直接与fastchat api对话)")(openai_chat)
|
||||
|
|
@ -51,14 +52,11 @@ def create_app():
|
|||
tags=["Chat"],
|
||||
summary="与知识库对话")(knowledge_base_chat)
|
||||
|
||||
app.post("/chat/bing_search_chat",
|
||||
app.post("/chat/search_engine_chat",
|
||||
tags=["Chat"],
|
||||
summary="与Bing搜索对话")(bing_search_chat)
|
||||
|
||||
app.post("/chat/duckduckgo_search_chat",
|
||||
tags=["Chat"],
|
||||
summary="与DuckDuckGo搜索对话")(duckduckgo_search_chat)
|
||||
summary="与搜索引擎对话")(search_engine_chat)
|
||||
|
||||
# Tag: Knowledge Base Management
|
||||
app.get("/knowledge_base/list_knowledge_bases",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=ListResponse,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from .chat import chat
|
||||
from .knowledge_base_chat import knowledge_base_chat
|
||||
from .openai_chat import openai_chat
|
||||
from .duckduckgo_search_chat import duckduckgo_search_chat
|
||||
from .bing_search_chat import bing_search_chat
|
||||
from .search_engine_chat import search_engine_chat
|
||||
|
|
|
|||
|
|
@ -1,64 +0,0 @@
|
|||
from langchain.utilities import DuckDuckGoSearchAPIWrapper
|
||||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE)
|
||||
from server.chat.utils import wrap_done
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable
|
||||
import asyncio
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
||||
def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K):
|
||||
search = DuckDuckGoSearchAPIWrapper()
|
||||
return search.results(text, result_len)
|
||||
|
||||
|
||||
def search_result2docs(search_results):
|
||||
docs = []
|
||||
for result in search_results:
|
||||
doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "",
|
||||
metadata={"source": result["link"] if "link" in result.keys() else "",
|
||||
"filename": result["title"] if "title" in result.keys() else ""})
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
|
||||
def duckduckgo_search_chat(query: str = Body(..., description="用户输入", example="你好"),
|
||||
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
||||
):
|
||||
async def duckduckgo_search_chat_iterator(query: str,
|
||||
top_k: int
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
callbacks=[callback],
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL
|
||||
)
|
||||
|
||||
results = duckduckgo_search(query, result_len=top_k)
|
||||
docs = search_result2docs(results)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
|
||||
|
||||
chain = LLMChain(prompt=prompt, llm=model)
|
||||
|
||||
# Begin a task that runs in the background.
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({"context": context, "question": query}),
|
||||
callback.done),
|
||||
)
|
||||
|
||||
async for token in callback.aiter():
|
||||
# Use server-sent-events to stream the response
|
||||
yield token
|
||||
await task
|
||||
|
||||
return StreamingResponse(duckduckgo_search_chat_iterator(query, top_k), media_type="text/event-stream")
|
||||
|
|
@ -4,6 +4,9 @@ from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
|||
CACHED_VS_NUM, VECTOR_SEARCH_TOP_K,
|
||||
embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE)
|
||||
from server.chat.utils import wrap_done
|
||||
from server.utils import BaseResponse
|
||||
import os
|
||||
from server.knowledge_base.utils import get_kb_path
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
|
|
@ -51,6 +54,10 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
|||
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
):
|
||||
kb_path = get_kb_path(knowledge_base_name)
|
||||
if not os.path.exists(kb_path):
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
async def knowledge_base_chat_iterator(query: str,
|
||||
knowledge_base_name: str,
|
||||
top_k: int,
|
||||
|
|
@ -76,9 +83,15 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
|||
callback.done),
|
||||
)
|
||||
|
||||
source_documents = [
|
||||
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
|
||||
for inum, doc in enumerate(docs)
|
||||
]
|
||||
|
||||
async for token in callback.aiter():
|
||||
# Use server-sent-events to stream the response
|
||||
yield token
|
||||
yield {"answer": token,
|
||||
"docs": source_documents}
|
||||
await task
|
||||
|
||||
return StreamingResponse(knowledge_base_chat_iterator(query, knowledge_base_name, top_k),
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from langchain.utilities import BingSearchAPIWrapper
|
||||
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
|
||||
from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY
|
||||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE)
|
||||
from server.chat.utils import wrap_done
|
||||
from server.utils import BaseResponse
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
|
|
@ -12,6 +13,7 @@ import asyncio
|
|||
from langchain.prompts import PromptTemplate
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
||||
def bing_search(text, result_len=SEARCH_ENGINE_TOP_K):
|
||||
if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
|
||||
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
|
||||
|
|
@ -22,6 +24,16 @@ def bing_search(text, result_len=SEARCH_ENGINE_TOP_K):
|
|||
return search.results(text, result_len)
|
||||
|
||||
|
||||
def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K):
|
||||
search = DuckDuckGoSearchAPIWrapper()
|
||||
return search.results(text, result_len)
|
||||
|
||||
|
||||
SEARCH_ENGINES = {"bing": bing_search,
|
||||
"duckduckgo": duckduckgo_search,
|
||||
}
|
||||
|
||||
|
||||
def search_result2docs(search_results):
|
||||
docs = []
|
||||
for result in search_results:
|
||||
|
|
@ -32,12 +44,27 @@ def search_result2docs(search_results):
|
|||
return docs
|
||||
|
||||
|
||||
def bing_search_chat(query: str = Body(..., description="用户输入", example="你好"),
|
||||
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
||||
):
|
||||
async def bing_search_chat_iterator(query: str,
|
||||
top_k: int,
|
||||
) -> AsyncIterable[str]:
|
||||
def lookup_search_engine(
|
||||
query: str,
|
||||
search_engine_name: str,
|
||||
top_k: int = SEARCH_ENGINE_TOP_K,
|
||||
):
|
||||
results = SEARCH_ENGINES[search_engine_name](query, result_len=top_k)
|
||||
docs = search_result2docs(results)
|
||||
return docs
|
||||
|
||||
|
||||
def search_engine_chat(query: str = Body(..., description="用户输入", example="你好"),
|
||||
search_engine_name: str = Body(..., description="搜索引擎名称", example="duckduckgo"),
|
||||
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
||||
):
|
||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||||
|
||||
async def search_engine_chat_iterator(query: str,
|
||||
search_engine_name: str,
|
||||
top_k: int,
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
|
|
@ -48,8 +75,7 @@ def bing_search_chat(query: str = Body(..., description="用户输入", example=
|
|||
model_name=LLM_MODEL
|
||||
)
|
||||
|
||||
results = bing_search(query, result_len=top_k)
|
||||
docs = search_result2docs(results)
|
||||
docs = lookup_search_engine(query, search_engine_name, top_k)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
|
||||
|
||||
|
|
@ -61,9 +87,16 @@ def bing_search_chat(query: str = Body(..., description="用户输入", example=
|
|||
callback.done),
|
||||
)
|
||||
|
||||
source_documents = [
|
||||
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
|
||||
for inum, doc in enumerate(docs)
|
||||
]
|
||||
|
||||
async for token in callback.aiter():
|
||||
# Use server-sent-events to stream the response
|
||||
yield token
|
||||
yield {"answer": token,
|
||||
"docs": source_documents}
|
||||
await task
|
||||
|
||||
return StreamingResponse(bing_search_chat_iterator(query, top_k), media_type="text/event-stream")
|
||||
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k),
|
||||
media_type="text/event-stream")
|
||||
|
|
@ -275,48 +275,27 @@ class ApiRequest:
|
|||
)
|
||||
return self._httpx_stream2generator(response)
|
||||
|
||||
def duckduckgo_search_chat(
|
||||
def search_engine_chat(
|
||||
self,
|
||||
query: str,
|
||||
search_engine_name: str,
|
||||
top_k: int,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
对应api.py/chat/duckduckgo_search_chat接口
|
||||
对应api.py/chat/search_engine_chat接口
|
||||
'''
|
||||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
if no_remote_api:
|
||||
from server.chat.duckduckgo_search_chat import duckduckgo_search_chat
|
||||
response = duckduckgo_search_chat(query)
|
||||
from server.chat.search_engine_chat import search_engine_chat
|
||||
response = search_engine_chat(query, search_engine_name, top_k)
|
||||
return self._fastapi_stream2generator(response)
|
||||
else:
|
||||
response = self.post(
|
||||
"/chat/duckduckgo_search_chat",
|
||||
json=f"{query}",
|
||||
stream=True,
|
||||
)
|
||||
return self._httpx_stream2generator(response)
|
||||
|
||||
def bing_search_chat(
|
||||
self,
|
||||
query: str,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
对应api.py/chat/bing_search_chat接口
|
||||
'''
|
||||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
if no_remote_api:
|
||||
from server.chat.bing_search_chat import bing_search_chat
|
||||
response = bing_search_chat(query)
|
||||
return self._fastapi_stream2generator(response)
|
||||
else:
|
||||
response = self.post(
|
||||
"/chat/bing_search_chat",
|
||||
json=f"{query}",
|
||||
"/chat/search_engine_chat",
|
||||
json={"query": query, "search_engine_name": search_engine_name, "top_k": top_k},
|
||||
stream=True,
|
||||
)
|
||||
return self._httpx_stream2generator(response)
|
||||
|
|
|
|||
Loading…
Reference in New Issue