From 329c24ee73482c1e2d8dee2bb39a2b093a45c5a6 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Thu, 3 Aug 2023 18:22:36 +0800 Subject: [PATCH] 1. change bing_search_chat and duckduckgo_search_chat into search_engine_chat 2. add knowledge_base not found to knowledge_base_chat and add search_engine not found to search_engine_chat --- server/api.py | 12 ++-- server/chat/__init__.py | 3 +- server/chat/duckduckgo_search_chat.py | 64 ------------------- server/chat/knowledge_base_chat.py | 15 ++++- ...g_search_chat.py => search_engine_chat.py} | 55 ++++++++++++---- webui_pages/utils.py | 37 +++-------- 6 files changed, 72 insertions(+), 114 deletions(-) delete mode 100644 server/chat/duckduckgo_search_chat.py rename server/chat/{bing_search_chat.py => search_engine_chat.py} (56%) diff --git a/server/api.py b/server/api.py index d643731..350843f 100644 --- a/server/api.py +++ b/server/api.py @@ -9,7 +9,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from starlette.responses import RedirectResponse from server.chat import (chat, knowledge_base_chat, openai_chat, - bing_search_chat, duckduckgo_search_chat) + search_engine_chat) from server.knowledge_base import (list_kbs, create_kb, delete_kb, list_docs, upload_doc, delete_doc, update_doc) from server.utils import BaseResponse, ListResponse @@ -39,6 +39,7 @@ def create_app(): response_model=BaseResponse, summary="swagger 文档")(document) + # Tag: Chat app.post("/chat/fastchat", tags=["Chat"], summary="与llm模型对话(直接与fastchat api对话)")(openai_chat) @@ -51,14 +52,11 @@ def create_app(): tags=["Chat"], summary="与知识库对话")(knowledge_base_chat) - app.post("/chat/bing_search_chat", + app.post("/chat/search_engine_chat", tags=["Chat"], - summary="与Bing搜索对话")(bing_search_chat) - - app.post("/chat/duckduckgo_search_chat", - tags=["Chat"], - summary="与DuckDuckGo搜索对话")(duckduckgo_search_chat) + summary="与搜索引擎对话")(search_engine_chat) + # Tag: Knowledge Base Management app.get("/knowledge_base/list_knowledge_bases", tags=["Knowledge Base Management"], response_model=ListResponse, diff --git a/server/chat/__init__.py b/server/chat/__init__.py index 728db91..136bad6 100644 --- a/server/chat/__init__.py +++ b/server/chat/__init__.py @@ -1,5 +1,4 @@ from .chat import chat from .knowledge_base_chat import knowledge_base_chat from .openai_chat import openai_chat -from .duckduckgo_search_chat import duckduckgo_search_chat -from .bing_search_chat import bing_search_chat +from .search_engine_chat import search_engine_chat diff --git a/server/chat/duckduckgo_search_chat.py b/server/chat/duckduckgo_search_chat.py deleted file mode 100644 index b75ed25..0000000 --- a/server/chat/duckduckgo_search_chat.py +++ /dev/null @@ -1,64 +0,0 @@ -from langchain.utilities import DuckDuckGoSearchAPIWrapper -from fastapi import Body -from fastapi.responses import StreamingResponse -from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE) -from server.chat.utils import wrap_done -from langchain.chat_models import ChatOpenAI -from langchain import LLMChain -from langchain.callbacks import AsyncIteratorCallbackHandler -from typing import AsyncIterable -import asyncio -from langchain.prompts import PromptTemplate -from langchain.docstore.document import Document - - -def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K): - search = DuckDuckGoSearchAPIWrapper() - return search.results(text, result_len) - - -def search_result2docs(search_results): - docs = [] - for result in search_results: - doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "", - metadata={"source": result["link"] if "link" in result.keys() else "", - "filename": result["title"] if "title" in result.keys() else ""}) - docs.append(doc) - return docs - - -def duckduckgo_search_chat(query: str = Body(..., description="用户输入", example="你好"), - top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"), - ): - async def duckduckgo_search_chat_iterator(query: str, - top_k: int - ) -> AsyncIterable[str]: - callback = AsyncIteratorCallbackHandler() - model = ChatOpenAI( - streaming=True, - verbose=True, - callbacks=[callback], - openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], - openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], - model_name=LLM_MODEL - ) - - results = duckduckgo_search(query, result_len=top_k) - docs = search_result2docs(results) - context = "\n".join([doc.page_content for doc in docs]) - prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"]) - - chain = LLMChain(prompt=prompt, llm=model) - - # Begin a task that runs in the background. - task = asyncio.create_task(wrap_done( - chain.acall({"context": context, "question": query}), - callback.done), - ) - - async for token in callback.aiter(): - # Use server-sent-events to stream the response - yield token - await task - - return StreamingResponse(duckduckgo_search_chat_iterator(query, top_k), media_type="text/event-stream") diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 1351e82..23a380a 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -4,6 +4,9 @@ from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE, CACHED_VS_NUM, VECTOR_SEARCH_TOP_K, embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE) from server.chat.utils import wrap_done +from server.utils import BaseResponse +import os +from server.knowledge_base.utils import get_kb_path from langchain.chat_models import ChatOpenAI from langchain import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler @@ -51,6 +54,10 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp knowledge_base_name: str = Body(..., description="知识库名称", example="samples"), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), ): + kb_path = get_kb_path(knowledge_base_name) + if not os.path.exists(kb_path): + return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") + async def knowledge_base_chat_iterator(query: str, knowledge_base_name: str, top_k: int, @@ -76,9 +83,15 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp callback.done), ) + source_documents = [ + f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n""" + for inum, doc in enumerate(docs) + ] + async for token in callback.aiter(): # Use server-sent-events to stream the response - yield token + yield {"answer": token, + "docs": source_documents} await task return StreamingResponse(knowledge_base_chat_iterator(query, knowledge_base_name, top_k), diff --git a/server/chat/bing_search_chat.py b/server/chat/search_engine_chat.py similarity index 56% rename from server/chat/bing_search_chat.py rename to server/chat/search_engine_chat.py index 9576d6d..ad63183 100644 --- a/server/chat/bing_search_chat.py +++ b/server/chat/search_engine_chat.py @@ -1,9 +1,10 @@ -from langchain.utilities import BingSearchAPIWrapper +from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY from fastapi import Body from fastapi.responses import StreamingResponse from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE) from server.chat.utils import wrap_done +from server.utils import BaseResponse from langchain.chat_models import ChatOpenAI from langchain import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler @@ -12,6 +13,7 @@ import asyncio from langchain.prompts import PromptTemplate from langchain.docstore.document import Document + def bing_search(text, result_len=SEARCH_ENGINE_TOP_K): if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY): return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV", @@ -22,6 +24,16 @@ def bing_search(text, result_len=SEARCH_ENGINE_TOP_K): return search.results(text, result_len) +def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K): + search = DuckDuckGoSearchAPIWrapper() + return search.results(text, result_len) + + +SEARCH_ENGINES = {"bing": bing_search, + "duckduckgo": duckduckgo_search, + } + + def search_result2docs(search_results): docs = [] for result in search_results: @@ -32,12 +44,27 @@ def search_result2docs(search_results): return docs -def bing_search_chat(query: str = Body(..., description="用户输入", example="你好"), - top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"), - ): - async def bing_search_chat_iterator(query: str, - top_k: int, - ) -> AsyncIterable[str]: +def lookup_search_engine( + query: str, + search_engine_name: str, + top_k: int = SEARCH_ENGINE_TOP_K, +): + results = SEARCH_ENGINES[search_engine_name](query, result_len=top_k) + docs = search_result2docs(results) + return docs + + +def search_engine_chat(query: str = Body(..., description="用户输入", example="你好"), + search_engine_name: str = Body(..., description="搜索引擎名称", example="duckduckgo"), + top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"), + ): + if search_engine_name not in SEARCH_ENGINES.keys(): + return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") + + async def search_engine_chat_iterator(query: str, + search_engine_name: str, + top_k: int, + ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() model = ChatOpenAI( streaming=True, @@ -48,8 +75,7 @@ def bing_search_chat(query: str = Body(..., description="用户输入", example= model_name=LLM_MODEL ) - results = bing_search(query, result_len=top_k) - docs = search_result2docs(results) + docs = lookup_search_engine(query, search_engine_name, top_k) context = "\n".join([doc.page_content for doc in docs]) prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"]) @@ -61,9 +87,16 @@ def bing_search_chat(query: str = Body(..., description="用户输入", example= callback.done), ) + source_documents = [ + f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n""" + for inum, doc in enumerate(docs) + ] + async for token in callback.aiter(): # Use server-sent-events to stream the response - yield token + yield {"answer": token, + "docs": source_documents} await task - return StreamingResponse(bing_search_chat_iterator(query, top_k), media_type="text/event-stream") + return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k), + media_type="text/event-stream") diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 2089179..fc1253a 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -275,48 +275,27 @@ class ApiRequest: ) return self._httpx_stream2generator(response) - def duckduckgo_search_chat( + def search_engine_chat( self, query: str, + search_engine_name: str, + top_k: int, no_remote_api: bool = None, ): ''' - 对应api.py/chat/duckduckgo_search_chat接口 + 对应api.py/chat/search_engine_chat接口 ''' if no_remote_api is None: no_remote_api = self.no_remote_api if no_remote_api: - from server.chat.duckduckgo_search_chat import duckduckgo_search_chat - response = duckduckgo_search_chat(query) + from server.chat.search_engine_chat import search_engine_chat + response = search_engine_chat(query, search_engine_name, top_k) return self._fastapi_stream2generator(response) else: response = self.post( - "/chat/duckduckgo_search_chat", - json=f"{query}", - stream=True, - ) - return self._httpx_stream2generator(response) - - def bing_search_chat( - self, - query: str, - no_remote_api: bool = None, - ): - ''' - 对应api.py/chat/bing_search_chat接口 - ''' - if no_remote_api is None: - no_remote_api = self.no_remote_api - - if no_remote_api: - from server.chat.bing_search_chat import bing_search_chat - response = bing_search_chat(query) - return self._fastapi_stream2generator(response) - else: - response = self.post( - "/chat/bing_search_chat", - json=f"{query}", + "/chat/search_engine_chat", + json={"query": query, "search_engine_name": search_engine_name, "top_k": top_k}, stream=True, ) return self._httpx_stream2generator(response)