From 3c2b8a85d68bf5a8c5369984adc54cb056724084 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Thu, 3 Aug 2023 10:49:57 +0800 Subject: [PATCH] add knowledge base api methods --- webui_pages/utils.py | 200 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 186 insertions(+), 14 deletions(-) diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 89e99d4..72dd251 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -1,5 +1,6 @@ # 该文件包含webui通用工具,可以被不同的webui使用 +import tempfile from typing import * from pathlib import Path import os @@ -87,9 +88,11 @@ class ApiRequest: self, base_url: str = "http://127.0.0.1:7861", timeout: float = 60.0, + no_remote_api: bool = False, # call api view function directly ): self.base_url = base_url self.timeout = timeout + self.no_remote_api = no_remote_api def _parse_url(self, url: str) -> str: if (not url.startswith("http") @@ -112,7 +115,7 @@ class ApiRequest: kwargs.setdefault("timeout", self.timeout) while retry > 0: try: - return httpx.get(url, params, **kwargs) + return httpx.get(url, params=params, **kwargs) except: retry -= 1 @@ -128,7 +131,7 @@ class ApiRequest: async with httpx.AsyncClient() as client: while retry > 0: try: - return await client.get(url, params, **kwargs) + return await client.get(url, params=params, **kwargs) except: retry -= 1 @@ -180,6 +183,8 @@ class ApiRequest: loop = asyncio.new_event_loop() return iter_over_async(response.body_iterator, loop) + # 对话相关操作 + def chat_fastchat( self, messages: List[Dict], @@ -187,12 +192,14 @@ class ApiRequest: model: str = LLM_MODEL, temperature: float = 0.7, max_tokens: int = 1024, # todo:根据message内容自动计算max_tokens - no_remote_api=False, # all api view function directly + no_remote_api: bool = None, **kwargs: Any, ): ''' 对应api.py/chat/fastchat接口 ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api msg = OpenAiChatMsgIn(**{ "messages": messages, "stream": stream, @@ -218,11 +225,14 @@ class ApiRequest: def chat_chat( self, query: str, - no_remote_api: bool = False, + no_remote_api: bool = None, ): ''' 对应api.py/chat/chat接口 ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + if no_remote_api: from server.chat.chat import chat response = chat(query) @@ -235,11 +245,14 @@ class ApiRequest: self, query: str, knowledge_base_name: str, - no_remote_api: bool = False, + no_remote_api: bool = None, ): ''' - 对应/chat/knowledge_base_chat接口 + 对应api.py/chat/knowledge_base_chat接口 ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + if no_remote_api: from server.chat.knowledge_base_chat import knowledge_base_chat response = knowledge_base_chat(query, knowledge_base_name) @@ -255,11 +268,14 @@ class ApiRequest: def duckduckgo_search_chat( self, query: str, - no_remote_api: bool = False, + no_remote_api: bool = None, ): ''' 对应api.py/chat/duckduckgo_search_chat接口 ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + if no_remote_api: from server.chat.duckduckgo_search_chat import duckduckgo_search_chat response = duckduckgo_search_chat(query) @@ -275,11 +291,14 @@ class ApiRequest: def bing_search_chat( self, query: str, - no_remote_api: bool = False, + no_remote_api: bool = None, ): ''' 对应api.py/chat/bing_search_chat接口 ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + if no_remote_api: from server.chat.bing_search_chat import bing_search_chat response = bing_search_chat(query) @@ -292,16 +311,169 @@ class ApiRequest: ) return response + # 知识库相关操作 + + def list_knowledge_bases( + self, + no_remote_api: bool = None, + ): + ''' + 对应api.py/knowledge_base/list_knowledge_bases接口 + ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + + if no_remote_api: + from server.knowledge_base.kb_api import list_kbs + response = run_async(list_kbs()) + return response.data + else: + response = self.get("/knowledge_base/list_knowledge_bases") + return response.json().get("data") + + def create_knowledge_base( + self, + knowledge_base_name: str, + no_remote_api: bool = None, + ): + ''' + 对应api.py/knowledge_base/create_knowledge_base接口 + ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + + if no_remote_api: + from server.knowledge_base.kb_api import create_kb + response = run_async(create_kb(knowledge_base_name)) + return response.dict() + else: + response = self.post( + "/knowledge_base/create_knowledge_base", + json={"knowledge_base_name": knowledge_base_name}, + ) + return response.json() + + def delete_knowledge_base( + self, + knowledge_base_name: str, + no_remote_api: bool = None, + ): + ''' + 对应api.py/knowledge_base/delete_knowledge_base接口 + ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + + if no_remote_api: + from server.knowledge_base.kb_api import delete_kb + response = run_async(delete_kb(knowledge_base_name)) + return response.dict() + else: + response = self.delete( + "/knowledge_base/delete_knowledge_base", + json={"knowledge_base_name": knowledge_base_name}, + ) + return response.json() + + def list_kb_docs( + self, + knowledge_base_name: str, + no_remote_api: bool = None, + ): + ''' + 对应api.py/knowledge_base/list_docs接口 + ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + + if no_remote_api: + from server.knowledge_base.kb_doc_api import list_docs + response = run_async(list_docs(knowledge_base_name)) + return response.data + else: + response = self.get( + "/knowledge_base/list_docs", + params={"knowledge_base_name": knowledge_base_name} + ) + return response.json().get("data") + + def upload_kb_doc( + self, + file: Union[str, Path], + knowledge_base_name: str, + no_remote_api: bool = None, + ): + ''' + 对应api.py/knowledge_base/upload_docs接口 + ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + + file = Path(file).absolute() + filename = file.name + + if no_remote_api: + from server.knowledge_base.kb_doc_api import upload_doc + from fastapi import UploadFile + from tempfile import SpooledTemporaryFile + + temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024) + with file.open("rb") as fp: + temp_file.write(fp.read()) + response = run_async(upload_doc( + UploadFile(temp_file, filename=filename), + knowledge_base_name, + )) + return response.dict() + else: + response = self.post( + "/knowledge_base/upload_doc", + data={"knowledge_base_name": knowledge_base_name}, + files={"file": (filename, file.open("rb"))}, + ) + return response.json() + + def delete_kb_doc( + self, + knowledge_base_name: str, + doc_name: str, + no_remote_api: bool = None, + ): + ''' + 对应api.py/knowledge_base/delete_doc接口 + ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + + if no_remote_api: + from server.knowledge_base.kb_doc_api import delete_doc + response = run_async(delete_doc(knowledge_base_name, doc_name)) + return response.dict() + else: + response = self.delete( + "/knowledge_base/delete_doc", + json={"knowledge_base_name": knowledge_base_name, "doc_name": doc_name}, + ) + return response.json() + + if __name__ == "__main__": api = ApiRequest() + # print(api.chat_fastchat( # messages=[{"role": "user", "content": "hello"}] # )) - with api.chat_chat("你好") as r: - for t in r.iter_text(None): - print(t) + # with api.chat_chat("你好") as r: + # for t in r.iter_text(None): + # print(t) - r = api.chat_chat("你好", no_remote_api=True) - for t in r: - print(t) \ No newline at end of file + # r = api.chat_chat("你好", no_remote_api=True) + # for t in r: + # print(t) + + # r = api.duckduckgo_search_chat("室温超导最新研究进展", no_remote_api=True) + # for t in r: + # print(t) + + print(api.list_knowledge_bases())