From dfcd6f610274bcb22b54cba3d2a4e63b51d804c6 Mon Sep 17 00:00:00 2001 From: halfss Date: Tue, 16 May 2023 23:51:34 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0bing=E6=90=9C=E7=B4=A2agent?= =?UTF-8?q?=20(#378)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 1: 基于langchain的bing search agent添加bing搜索支持(百度不靠谱,Google不可用) 2: 调整输入框交互模式,对话/知识库/搜索,三选一 * fixed bug of message have no text --------- Co-authored-by: root --- agent/__init__.py | 3 +- agent/bing_search.py | 20 ++++++++ api.py | 21 ++++++++ views/src/api/chat.ts | 7 +++ views/src/views/chat/index.vue | 88 ++++++++++++++++++++++++++++------ 5 files changed, 124 insertions(+), 15 deletions(-) create mode 100644 agent/bing_search.py diff --git a/agent/__init__.py b/agent/__init__.py index 933c16e..473a5ea 100644 --- a/agent/__init__.py +++ b/agent/__init__.py @@ -1 +1,2 @@ -from .chatglm_with_shared_memory_openai_llm import * \ No newline at end of file +#from .chatglm_with_shared_memory_openai_llm import * +from agent.bing_search import search as bing_search diff --git a/agent/bing_search.py b/agent/bing_search.py new file mode 100644 index 0000000..d5ba766 --- /dev/null +++ b/agent/bing_search.py @@ -0,0 +1,20 @@ +#coding=utf8 + +import os +from langchain.utilities import BingSearchAPIWrapper + + +env_bing_key = os.environ.get("BING_SUBSCRIPTION_KEY") +env_bing_url = os.environ.get("BING_SEARCH_URL") + + +def search(text, result_len=3): + if not (env_bing_key and env_bing_url): + return [{"snippet":"please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV", + "title": "env inof not fould", "link":"https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}] + search = BingSearchAPIWrapper() + return search.results(text, result_len) + + +if __name__ == "__main__": + r = search('python') diff --git a/api.py b/api.py index b65c084..558baac 100644 --- a/api.py +++ b/api.py @@ -13,10 +13,12 @@ from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing_extensions import Annotated from starlette.responses import RedirectResponse + from chains.local_doc_qa import LocalDocQA from configs.model_config import (VS_ROOT_PATH, UPLOAD_ROOT_PATH, EMBEDDING_DEVICE, EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH, VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN, OPEN_CROSS_DOMAIN) +from agent import bing_search as agent_bing_search nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path @@ -314,6 +316,23 @@ async def document(): return RedirectResponse(url="/docs") +async def bing_search( + search_text: str = Query(default=None, description="text you want to search", example="langchain") +): + results = agent_bing_search(search_text) + result_str = '' + for result in results: + for k, v in result.items(): + result_str += "%s: %s\n" % (k, v) + result_str += '\n' + + return ChatMessage( + question=search_text, + response=result_str, + history=[], + source_documents=[], + ) + def api_start(host, port): global app global local_doc_qa @@ -342,6 +361,8 @@ def api_start(host, port): app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs) app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_docs) + app.get("/bing_search", response_model=ChatMessage)(bing_search) + local_doc_qa = LocalDocQA() local_doc_qa.init_cfg( llm_model=LLM_MODEL, diff --git a/views/src/api/chat.ts b/views/src/api/chat.ts index 349edd7..5adbb3b 100644 --- a/views/src/api/chat.ts +++ b/views/src/api/chat.ts @@ -24,7 +24,14 @@ export const getfilelist = (knowledge_base_id: any) => { }) } +export const bing_search = (search_text: any) => { + return api({ + url: '/bing_search', + method: 'get', + params: { search_text }, + }) +} export const deletefile = (params: any) => { return api({ url: '/local_doc_qa/delete_file', diff --git a/views/src/views/chat/index.vue b/views/src/views/chat/index.vue index bb5e831..5afb648 100644 --- a/views/src/views/chat/index.vue +++ b/views/src/views/chat/index.vue @@ -3,7 +3,7 @@ import type { Ref } from 'vue' import { computed, onMounted, onUnmounted, ref } from 'vue' import { useRoute } from 'vue-router' import { storeToRefs } from 'pinia' -import { NAutoComplete, NButton, NInput, NSwitch, useDialog, useMessage } from 'naive-ui' +import { NAutoComplete, NButton, NInput, NRadioButton, NRadioGroup, useDialog, useMessage } from 'naive-ui' import html2canvas from 'html2canvas' import { Message } from './components' import { useScroll } from './hooks/useScroll' @@ -14,7 +14,7 @@ import { HoverButton, SvgIcon } from '@/components/common' import { useBasicLayout } from '@/hooks/useBasicLayout' import { useChatStore, usePromptStore } from '@/store' import { t } from '@/locales' -import { chat, chatfile } from '@/api/chat' +import { bing_search, chat, chatfile } from '@/api/chat' import { idStore } from '@/store/modules/knowledgebaseid/id' let controller = new AbortController() @@ -39,6 +39,7 @@ const conversationList = computed(() => dataSources.value.filter(item => (!item. const prompt = ref('') const loading = ref(false) const inputRef = ref(null) +const search = ref('对话') // 添加PromptStore const promptStore = usePromptStore() @@ -55,15 +56,70 @@ dataSources.value.forEach((item, index) => { updateChatSome(+uuid, index, { loading: false }) }) -function handleSubmit() { - onConversation() +async function handleSubmit() { + if (search.value == 'Bing搜索') { + loading.value = true + const options: Chat.ConversationRequest = {} + const lastText = '' + const message = prompt.value + + addChat( + +uuid, + { + dateTime: new Date().toLocaleString(), + text: message, + inversion: true, + error: false, + conversationOptions: null, + requestOptions: { prompt: message, options: null }, + }, + ) + scrollToBottom() + const res = await bing_search(prompt.value) + + const result = active.value ? `${res.data.response}\n\n数据来源:\n\n>${res.data.source_documents.join('>')}` : res.data.response + addChat( + +uuid, + { + dateTime: new Date().toLocaleString(), + text: '', + loading: true, + inversion: false, + error: false, + conversationOptions: null, + requestOptions: { prompt: message, options: { ...options } }, + }, + ) + scrollToBottom() + updateChat( + +uuid, + dataSources.value.length - 1, + { + dateTime: new Date().toLocaleString(), + text: lastText + (result ?? ''), + inversion: false, + error: false, + loading: false, + conversationOptions: null, + requestOptions: { prompt: message, options: { ...options } }, + }, + ) + prompt.value = '' + scrollToBottomIfAtBottom() + loading.value = false + } + else { + onConversation() + } } async function onConversation() { const message = prompt.value if (usingContext.value) { - for (let i = 0; i < dataSources.value.length; i = i + 2) - history.value.push([dataSources.value[i].text, dataSources.value[i + 1].text.split('\n\n数据来源:\n\n>')[0]]) + for (let i = 0; i < dataSources.value.length; i = i + 2) { + if (!i) + history.value.push([dataSources.value[i].text, dataSources.value[i + 1].text.split('\n\n数据来源:\n\n>')[0]]) + } } else { history.value.length = 0 } @@ -480,6 +536,13 @@ onUnmounted(() => { if (loading.value) controller.abort() }) +function searchfun() { + if (search.value == '知识库') + active.value = true + + else + active.value = false +}