2023-08-01 14:15:42 +08:00
|
|
|
|
# 该文件包含webui通用工具,可以被不同的webui使用
|
|
|
|
|
|
from typing import *
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
import os
|
|
|
|
|
|
from configs.model_config import (
|
|
|
|
|
|
KB_ROOT_PATH,
|
|
|
|
|
|
LLM_MODEL,
|
|
|
|
|
|
llm_model_dict,
|
2023-08-04 12:49:39 +08:00
|
|
|
|
VECTOR_SEARCH_TOP_K,
|
|
|
|
|
|
SEARCH_ENGINE_TOP_K,
|
2023-08-01 14:15:42 +08:00
|
|
|
|
)
|
|
|
|
|
|
import httpx
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
from server.chat.openai_chat import OpenAiChatMsgIn
|
|
|
|
|
|
from fastapi.responses import StreamingResponse
|
2023-08-03 12:52:49 +08:00
|
|
|
|
import contextlib
|
2023-08-04 12:49:39 +08:00
|
|
|
|
import json
|
2023-08-04 15:53:44 +08:00
|
|
|
|
from io import BytesIO
|
2023-08-08 17:59:41 +08:00
|
|
|
|
from server.knowledge_base.utils import list_kbs_from_folder
|
2023-08-01 14:15:42 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_httpx_timeout(timeout=60.0):
|
|
|
|
|
|
'''
|
|
|
|
|
|
设置httpx默认timeout到60秒。
|
|
|
|
|
|
httpx默认timeout是5秒,在请求LLM回答时不够用。
|
|
|
|
|
|
'''
|
|
|
|
|
|
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
|
|
|
|
|
|
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
|
|
|
|
|
|
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
KB_ROOT_PATH = Path(KB_ROOT_PATH)
|
|
|
|
|
|
set_httpx_timeout()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_async(cor):
|
|
|
|
|
|
'''
|
|
|
|
|
|
在同步环境中运行异步代码.
|
|
|
|
|
|
'''
|
|
|
|
|
|
try:
|
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
|
except:
|
|
|
|
|
|
loop = asyncio.new_event_loop()
|
|
|
|
|
|
return loop.run_until_complete(cor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def iter_over_async(ait, loop):
|
|
|
|
|
|
'''
|
|
|
|
|
|
将异步生成器封装成同步生成器.
|
|
|
|
|
|
'''
|
|
|
|
|
|
ait = ait.__aiter__()
|
|
|
|
|
|
async def get_next():
|
|
|
|
|
|
try:
|
|
|
|
|
|
obj = await ait.__anext__()
|
|
|
|
|
|
return False, obj
|
|
|
|
|
|
except StopAsyncIteration:
|
|
|
|
|
|
return True, None
|
|
|
|
|
|
while True:
|
|
|
|
|
|
done, obj = loop.run_until_complete(get_next())
|
|
|
|
|
|
if done:
|
|
|
|
|
|
break
|
|
|
|
|
|
yield obj
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ApiRequest:
|
|
|
|
|
|
'''
|
|
|
|
|
|
api.py调用的封装,主要实现:
|
|
|
|
|
|
1. 简化api调用方式
|
|
|
|
|
|
2. 实现无api调用(直接运行server.chat.*中的视图函数获取结果),无需启动api.py
|
|
|
|
|
|
'''
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
base_url: str = "http://127.0.0.1:7861",
|
|
|
|
|
|
timeout: float = 60.0,
|
2023-08-03 10:49:57 +08:00
|
|
|
|
no_remote_api: bool = False, # call api view function directly
|
2023-08-01 14:15:42 +08:00
|
|
|
|
):
|
|
|
|
|
|
self.base_url = base_url
|
|
|
|
|
|
self.timeout = timeout
|
2023-08-03 10:49:57 +08:00
|
|
|
|
self.no_remote_api = no_remote_api
|
2023-08-01 14:15:42 +08:00
|
|
|
|
|
|
|
|
|
|
def _parse_url(self, url: str) -> str:
|
|
|
|
|
|
if (not url.startswith("http")
|
|
|
|
|
|
and self.base_url
|
|
|
|
|
|
):
|
|
|
|
|
|
part1 = self.base_url.strip(" /")
|
|
|
|
|
|
part2 = url.strip(" /")
|
|
|
|
|
|
return f"{part1}/{part2}"
|
|
|
|
|
|
else:
|
|
|
|
|
|
return url
|
|
|
|
|
|
|
|
|
|
|
|
def get(
|
|
|
|
|
|
self,
|
|
|
|
|
|
url: str,
|
|
|
|
|
|
params: Union[Dict, List[Tuple], bytes] = None,
|
|
|
|
|
|
retry: int = 3,
|
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
|
) -> Union[httpx.Response, None]:
|
|
|
|
|
|
url = self._parse_url(url)
|
|
|
|
|
|
kwargs.setdefault("timeout", self.timeout)
|
|
|
|
|
|
while retry > 0:
|
|
|
|
|
|
try:
|
2023-08-03 10:49:57 +08:00
|
|
|
|
return httpx.get(url, params=params, **kwargs)
|
2023-08-01 14:15:42 +08:00
|
|
|
|
except:
|
|
|
|
|
|
retry -= 1
|
|
|
|
|
|
|
|
|
|
|
|
async def aget(
|
|
|
|
|
|
self,
|
|
|
|
|
|
url: str,
|
|
|
|
|
|
params: Union[Dict, List[Tuple], bytes] = None,
|
|
|
|
|
|
retry: int = 3,
|
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
|
) -> Union[httpx.Response, None]:
|
|
|
|
|
|
rl = self._parse_url(url)
|
|
|
|
|
|
kwargs.setdefault("timeout", self.timeout)
|
|
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
|
|
|
|
while retry > 0:
|
|
|
|
|
|
try:
|
2023-08-03 10:49:57 +08:00
|
|
|
|
return await client.get(url, params=params, **kwargs)
|
2023-08-01 14:15:42 +08:00
|
|
|
|
except:
|
|
|
|
|
|
retry -= 1
|
|
|
|
|
|
|
|
|
|
|
|
def post(
|
|
|
|
|
|
self,
|
|
|
|
|
|
url: str,
|
|
|
|
|
|
data: Dict = None,
|
|
|
|
|
|
json: Dict = None,
|
|
|
|
|
|
retry: int = 3,
|
|
|
|
|
|
stream: bool = False,
|
|
|
|
|
|
**kwargs: Any
|
|
|
|
|
|
) -> Union[httpx.Response, None]:
|
|
|
|
|
|
url = self._parse_url(url)
|
|
|
|
|
|
kwargs.setdefault("timeout", self.timeout)
|
|
|
|
|
|
while retry > 0:
|
|
|
|
|
|
try:
|
|
|
|
|
|
# return requests.post(url, data=data, json=json, stream=stream, **kwargs)
|
|
|
|
|
|
if stream:
|
|
|
|
|
|
return httpx.stream("POST", url, data=data, json=json, **kwargs)
|
|
|
|
|
|
else:
|
|
|
|
|
|
return httpx.post(url, data=data, json=json, **kwargs)
|
|
|
|
|
|
except:
|
|
|
|
|
|
retry -= 1
|
|
|
|
|
|
|
|
|
|
|
|
async def apost(
|
|
|
|
|
|
self,
|
|
|
|
|
|
url: str,
|
|
|
|
|
|
data: Dict = None,
|
|
|
|
|
|
json: Dict = None,
|
|
|
|
|
|
retry: int = 3,
|
|
|
|
|
|
**kwargs: Any
|
|
|
|
|
|
) -> Union[httpx.Response, None]:
|
|
|
|
|
|
rl = self._parse_url(url)
|
|
|
|
|
|
kwargs.setdefault("timeout", self.timeout)
|
|
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
|
|
|
|
while retry > 0:
|
|
|
|
|
|
try:
|
|
|
|
|
|
return await client.post(url, data=data, json=json, **kwargs)
|
|
|
|
|
|
except:
|
|
|
|
|
|
retry -= 1
|
|
|
|
|
|
|
2023-08-04 12:49:39 +08:00
|
|
|
|
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
|
2023-08-01 14:15:42 +08:00
|
|
|
|
'''
|
|
|
|
|
|
将api.py中视图函数返回的StreamingResponse转化为同步生成器
|
|
|
|
|
|
'''
|
|
|
|
|
|
try:
|
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
|
except:
|
|
|
|
|
|
loop = asyncio.new_event_loop()
|
2023-08-04 12:49:39 +08:00
|
|
|
|
|
|
|
|
|
|
for chunk in iter_over_async(response.body_iterator, loop):
|
|
|
|
|
|
if as_json and chunk:
|
|
|
|
|
|
yield json.loads(chunk)
|
2023-08-09 10:46:01 +08:00
|
|
|
|
elif chunk.strip():
|
2023-08-04 12:49:39 +08:00
|
|
|
|
yield chunk
|
2023-08-01 14:15:42 +08:00
|
|
|
|
|
2023-08-04 12:49:39 +08:00
|
|
|
|
def _httpx_stream2generator(
|
|
|
|
|
|
self,
|
|
|
|
|
|
response: contextlib._GeneratorContextManager,
|
|
|
|
|
|
as_json: bool = False,
|
|
|
|
|
|
):
|
2023-08-03 12:52:49 +08:00
|
|
|
|
'''
|
|
|
|
|
|
将httpx.stream返回的GeneratorContextManager转化为普通生成器
|
|
|
|
|
|
'''
|
|
|
|
|
|
with response as r:
|
|
|
|
|
|
for chunk in r.iter_text(None):
|
2023-08-04 12:49:39 +08:00
|
|
|
|
if as_json and chunk:
|
|
|
|
|
|
yield json.loads(chunk)
|
2023-08-09 10:46:01 +08:00
|
|
|
|
elif chunk.strip():
|
2023-08-04 12:49:39 +08:00
|
|
|
|
yield chunk
|
2023-08-03 12:52:49 +08:00
|
|
|
|
|
2023-08-03 10:49:57 +08:00
|
|
|
|
# 对话相关操作
|
|
|
|
|
|
|
2023-08-01 14:15:42 +08:00
|
|
|
|
def chat_fastchat(
|
|
|
|
|
|
self,
|
|
|
|
|
|
messages: List[Dict],
|
|
|
|
|
|
stream: bool = True,
|
|
|
|
|
|
model: str = LLM_MODEL,
|
|
|
|
|
|
temperature: float = 0.7,
|
|
|
|
|
|
max_tokens: int = 1024, # todo:根据message内容自动计算max_tokens
|
2023-08-03 10:49:57 +08:00
|
|
|
|
no_remote_api: bool = None,
|
2023-08-01 14:15:42 +08:00
|
|
|
|
**kwargs: Any,
|
|
|
|
|
|
):
|
|
|
|
|
|
'''
|
|
|
|
|
|
对应api.py/chat/fastchat接口
|
|
|
|
|
|
'''
|
2023-08-03 10:49:57 +08:00
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
|
no_remote_api = self.no_remote_api
|
2023-08-01 14:15:42 +08:00
|
|
|
|
msg = OpenAiChatMsgIn(**{
|
|
|
|
|
|
"messages": messages,
|
|
|
|
|
|
"stream": stream,
|
|
|
|
|
|
"model": model,
|
|
|
|
|
|
"temperature": temperature,
|
|
|
|
|
|
"max_tokens": max_tokens,
|
|
|
|
|
|
**kwargs,
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
if no_remote_api:
|
|
|
|
|
|
from server.chat.openai_chat import openai_chat
|
|
|
|
|
|
response = openai_chat(msg)
|
2023-08-03 12:52:49 +08:00
|
|
|
|
return self._fastapi_stream2generator(response)
|
2023-08-01 14:15:42 +08:00
|
|
|
|
else:
|
|
|
|
|
|
data = msg.dict(exclude_unset=True, exclude_none=True)
|
|
|
|
|
|
response = self.post(
|
|
|
|
|
|
"/chat/fastchat",
|
|
|
|
|
|
json=data,
|
|
|
|
|
|
stream=stream,
|
|
|
|
|
|
)
|
2023-08-03 12:52:49 +08:00
|
|
|
|
return self._httpx_stream2generator(response)
|
2023-08-01 14:15:42 +08:00
|
|
|
|
|
|
|
|
|
|
def chat_chat(
|
|
|
|
|
|
self,
|
|
|
|
|
|
query: str,
|
2023-08-09 11:00:22 +08:00
|
|
|
|
history: List[Dict] = [],
|
2023-08-03 10:49:57 +08:00
|
|
|
|
no_remote_api: bool = None,
|
2023-08-01 14:15:42 +08:00
|
|
|
|
):
|
|
|
|
|
|
'''
|
|
|
|
|
|
对应api.py/chat/chat接口
|
|
|
|
|
|
'''
|
2023-08-03 10:49:57 +08:00
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
2023-08-01 14:15:42 +08:00
|
|
|
|
if no_remote_api:
|
|
|
|
|
|
from server.chat.chat import chat
|
2023-08-09 11:00:22 +08:00
|
|
|
|
response = chat(query, history)
|
2023-08-03 12:52:49 +08:00
|
|
|
|
return self._fastapi_stream2generator(response)
|
2023-08-01 14:15:42 +08:00
|
|
|
|
else:
|
2023-08-09 11:00:22 +08:00
|
|
|
|
response = self.post("/chat/chat", json={"query": query, "history": history}, stream=True)
|
2023-08-03 12:52:49 +08:00
|
|
|
|
return self._httpx_stream2generator(response)
|
2023-08-01 14:15:42 +08:00
|
|
|
|
|
2023-08-03 09:33:24 +08:00
|
|
|
|
def knowledge_base_chat(
|
|
|
|
|
|
self,
|
|
|
|
|
|
query: str,
|
|
|
|
|
|
knowledge_base_name: str,
|
2023-08-04 12:49:39 +08:00
|
|
|
|
top_k: int = VECTOR_SEARCH_TOP_K,
|
2023-08-09 10:46:01 +08:00
|
|
|
|
history: List[Dict] = [],
|
2023-08-03 10:49:57 +08:00
|
|
|
|
no_remote_api: bool = None,
|
2023-08-03 09:33:24 +08:00
|
|
|
|
):
|
|
|
|
|
|
'''
|
2023-08-03 10:49:57 +08:00
|
|
|
|
对应api.py/chat/knowledge_base_chat接口
|
2023-08-03 09:33:24 +08:00
|
|
|
|
'''
|
2023-08-03 10:49:57 +08:00
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
2023-08-03 09:33:24 +08:00
|
|
|
|
if no_remote_api:
|
|
|
|
|
|
from server.chat.knowledge_base_chat import knowledge_base_chat
|
2023-08-09 10:46:01 +08:00
|
|
|
|
response = knowledge_base_chat(query, knowledge_base_name, top_k, history)
|
2023-08-04 12:49:39 +08:00
|
|
|
|
return self._fastapi_stream2generator(response, as_json=True)
|
2023-08-03 09:33:24 +08:00
|
|
|
|
else:
|
|
|
|
|
|
response = self.post(
|
|
|
|
|
|
"/chat/knowledge_base_chat",
|
2023-08-09 10:46:01 +08:00
|
|
|
|
json={"query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k, "history": history},
|
2023-08-03 09:33:24 +08:00
|
|
|
|
stream=True,
|
|
|
|
|
|
)
|
2023-08-04 12:49:39 +08:00
|
|
|
|
return self._httpx_stream2generator(response, as_json=True)
|
2023-08-03 09:33:24 +08:00
|
|
|
|
|
2023-08-03 18:22:36 +08:00
|
|
|
|
def search_engine_chat(
|
2023-08-03 09:33:24 +08:00
|
|
|
|
self,
|
|
|
|
|
|
query: str,
|
2023-08-03 18:22:36 +08:00
|
|
|
|
search_engine_name: str,
|
2023-08-04 12:49:39 +08:00
|
|
|
|
top_k: int = SEARCH_ENGINE_TOP_K,
|
2023-08-03 10:49:57 +08:00
|
|
|
|
no_remote_api: bool = None,
|
2023-08-03 09:33:24 +08:00
|
|
|
|
):
|
|
|
|
|
|
'''
|
2023-08-03 18:22:36 +08:00
|
|
|
|
对应api.py/chat/search_engine_chat接口
|
2023-08-03 09:33:24 +08:00
|
|
|
|
'''
|
2023-08-03 10:49:57 +08:00
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
2023-08-03 09:33:24 +08:00
|
|
|
|
if no_remote_api:
|
2023-08-03 18:22:36 +08:00
|
|
|
|
from server.chat.search_engine_chat import search_engine_chat
|
|
|
|
|
|
response = search_engine_chat(query, search_engine_name, top_k)
|
2023-08-04 12:49:39 +08:00
|
|
|
|
return self._fastapi_stream2generator(response, as_json=True)
|
2023-08-03 09:33:24 +08:00
|
|
|
|
else:
|
|
|
|
|
|
response = self.post(
|
2023-08-03 18:22:36 +08:00
|
|
|
|
"/chat/search_engine_chat",
|
|
|
|
|
|
json={"query": query, "search_engine_name": search_engine_name, "top_k": top_k},
|
2023-08-03 09:33:24 +08:00
|
|
|
|
stream=True,
|
|
|
|
|
|
)
|
2023-08-04 12:49:39 +08:00
|
|
|
|
return self._httpx_stream2generator(response, as_json=True)
|
2023-08-01 14:15:42 +08:00
|
|
|
|
|
2023-08-03 10:49:57 +08:00
|
|
|
|
# 知识库相关操作
|
|
|
|
|
|
|
|
|
|
|
|
def list_knowledge_bases(
|
|
|
|
|
|
self,
|
|
|
|
|
|
no_remote_api: bool = None,
|
|
|
|
|
|
):
|
|
|
|
|
|
'''
|
|
|
|
|
|
对应api.py/knowledge_base/list_knowledge_bases接口
|
|
|
|
|
|
'''
|
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
|
|
|
|
|
if no_remote_api:
|
|
|
|
|
|
from server.knowledge_base.kb_api import list_kbs
|
|
|
|
|
|
response = run_async(list_kbs())
|
|
|
|
|
|
return response.data
|
|
|
|
|
|
else:
|
|
|
|
|
|
response = self.get("/knowledge_base/list_knowledge_bases")
|
|
|
|
|
|
return response.json().get("data")
|
|
|
|
|
|
|
|
|
|
|
|
def create_knowledge_base(
|
|
|
|
|
|
self,
|
|
|
|
|
|
knowledge_base_name: str,
|
|
|
|
|
|
no_remote_api: bool = None,
|
|
|
|
|
|
):
|
|
|
|
|
|
'''
|
|
|
|
|
|
对应api.py/knowledge_base/create_knowledge_base接口
|
|
|
|
|
|
'''
|
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
|
|
|
|
|
if no_remote_api:
|
|
|
|
|
|
from server.knowledge_base.kb_api import create_kb
|
|
|
|
|
|
response = run_async(create_kb(knowledge_base_name))
|
|
|
|
|
|
return response.dict()
|
|
|
|
|
|
else:
|
|
|
|
|
|
response = self.post(
|
|
|
|
|
|
"/knowledge_base/create_knowledge_base",
|
|
|
|
|
|
json={"knowledge_base_name": knowledge_base_name},
|
|
|
|
|
|
)
|
|
|
|
|
|
return response.json()
|
|
|
|
|
|
|
|
|
|
|
|
def delete_knowledge_base(
|
|
|
|
|
|
self,
|
|
|
|
|
|
knowledge_base_name: str,
|
|
|
|
|
|
no_remote_api: bool = None,
|
|
|
|
|
|
):
|
|
|
|
|
|
'''
|
|
|
|
|
|
对应api.py/knowledge_base/delete_knowledge_base接口
|
|
|
|
|
|
'''
|
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
|
|
|
|
|
if no_remote_api:
|
|
|
|
|
|
from server.knowledge_base.kb_api import delete_kb
|
|
|
|
|
|
response = run_async(delete_kb(knowledge_base_name))
|
|
|
|
|
|
return response.dict()
|
|
|
|
|
|
else:
|
|
|
|
|
|
response = self.delete(
|
|
|
|
|
|
"/knowledge_base/delete_knowledge_base",
|
|
|
|
|
|
json={"knowledge_base_name": knowledge_base_name},
|
|
|
|
|
|
)
|
|
|
|
|
|
return response.json()
|
|
|
|
|
|
|
|
|
|
|
|
def list_kb_docs(
|
|
|
|
|
|
self,
|
|
|
|
|
|
knowledge_base_name: str,
|
|
|
|
|
|
no_remote_api: bool = None,
|
|
|
|
|
|
):
|
|
|
|
|
|
'''
|
|
|
|
|
|
对应api.py/knowledge_base/list_docs接口
|
|
|
|
|
|
'''
|
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
|
|
|
|
|
if no_remote_api:
|
|
|
|
|
|
from server.knowledge_base.kb_doc_api import list_docs
|
|
|
|
|
|
response = run_async(list_docs(knowledge_base_name))
|
|
|
|
|
|
return response.data
|
|
|
|
|
|
else:
|
|
|
|
|
|
response = self.get(
|
|
|
|
|
|
"/knowledge_base/list_docs",
|
|
|
|
|
|
params={"knowledge_base_name": knowledge_base_name}
|
|
|
|
|
|
)
|
|
|
|
|
|
return response.json().get("data")
|
|
|
|
|
|
|
|
|
|
|
|
def upload_kb_doc(
|
|
|
|
|
|
self,
|
2023-08-04 15:53:44 +08:00
|
|
|
|
file: Union[str, Path, bytes],
|
2023-08-03 10:49:57 +08:00
|
|
|
|
knowledge_base_name: str,
|
2023-08-04 15:53:44 +08:00
|
|
|
|
filename: str = None,
|
|
|
|
|
|
override: bool = False,
|
2023-08-03 10:49:57 +08:00
|
|
|
|
no_remote_api: bool = None,
|
|
|
|
|
|
):
|
|
|
|
|
|
'''
|
|
|
|
|
|
对应api.py/knowledge_base/upload_docs接口
|
|
|
|
|
|
'''
|
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
2023-08-04 15:53:44 +08:00
|
|
|
|
if isinstance(file, bytes):
|
|
|
|
|
|
file = BytesIO(file)
|
|
|
|
|
|
else:
|
|
|
|
|
|
file = Path(file).absolute().open("rb")
|
|
|
|
|
|
filename = filename or file.name
|
2023-08-03 10:49:57 +08:00
|
|
|
|
|
|
|
|
|
|
if no_remote_api:
|
|
|
|
|
|
from server.knowledge_base.kb_doc_api import upload_doc
|
|
|
|
|
|
from fastapi import UploadFile
|
|
|
|
|
|
from tempfile import SpooledTemporaryFile
|
|
|
|
|
|
|
|
|
|
|
|
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
|
2023-08-04 15:53:44 +08:00
|
|
|
|
temp_file.write(file.read())
|
2023-08-03 10:49:57 +08:00
|
|
|
|
response = run_async(upload_doc(
|
|
|
|
|
|
UploadFile(temp_file, filename=filename),
|
|
|
|
|
|
knowledge_base_name,
|
2023-08-04 15:53:44 +08:00
|
|
|
|
override,
|
2023-08-03 10:49:57 +08:00
|
|
|
|
))
|
|
|
|
|
|
return response.dict()
|
|
|
|
|
|
else:
|
|
|
|
|
|
response = self.post(
|
|
|
|
|
|
"/knowledge_base/upload_doc",
|
2023-08-04 15:53:44 +08:00
|
|
|
|
data={"knowledge_base_name": knowledge_base_name, "override": override},
|
|
|
|
|
|
files={"file": (filename, file)},
|
2023-08-03 10:49:57 +08:00
|
|
|
|
)
|
|
|
|
|
|
return response.json()
|
|
|
|
|
|
|
|
|
|
|
|
def delete_kb_doc(
|
|
|
|
|
|
self,
|
|
|
|
|
|
knowledge_base_name: str,
|
|
|
|
|
|
doc_name: str,
|
|
|
|
|
|
no_remote_api: bool = None,
|
|
|
|
|
|
):
|
|
|
|
|
|
'''
|
|
|
|
|
|
对应api.py/knowledge_base/delete_doc接口
|
|
|
|
|
|
'''
|
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
|
|
|
|
|
if no_remote_api:
|
|
|
|
|
|
from server.knowledge_base.kb_doc_api import delete_doc
|
|
|
|
|
|
response = run_async(delete_doc(knowledge_base_name, doc_name))
|
|
|
|
|
|
return response.dict()
|
|
|
|
|
|
else:
|
|
|
|
|
|
response = self.delete(
|
|
|
|
|
|
"/knowledge_base/delete_doc",
|
|
|
|
|
|
json={"knowledge_base_name": knowledge_base_name, "doc_name": doc_name},
|
|
|
|
|
|
)
|
|
|
|
|
|
return response.json()
|
|
|
|
|
|
|
2023-08-04 20:26:14 +08:00
|
|
|
|
def recreate_vector_store(
|
|
|
|
|
|
self,
|
|
|
|
|
|
knowledge_base_name: str,
|
|
|
|
|
|
no_remote_api: bool = None,
|
|
|
|
|
|
):
|
|
|
|
|
|
'''
|
|
|
|
|
|
对应api.py/knowledge_base/recreate_vector_store接口
|
|
|
|
|
|
'''
|
|
|
|
|
|
if no_remote_api is None:
|
|
|
|
|
|
no_remote_api = self.no_remote_api
|
|
|
|
|
|
|
|
|
|
|
|
if no_remote_api:
|
|
|
|
|
|
from server.knowledge_base.kb_doc_api import recreate_vector_store
|
|
|
|
|
|
response = run_async(recreate_vector_store(knowledge_base_name))
|
|
|
|
|
|
return self._fastapi_stream2generator(response, as_json=True)
|
|
|
|
|
|
else:
|
|
|
|
|
|
response = self.post(
|
|
|
|
|
|
"/knowledge_base/recreate_vector_store",
|
|
|
|
|
|
json={"knowledge_base_name": knowledge_base_name},
|
|
|
|
|
|
)
|
|
|
|
|
|
return self._httpx_stream2generator(response, as_json=True)
|
|
|
|
|
|
|
2023-08-03 10:49:57 +08:00
|
|
|
|
|
2023-08-01 14:15:42 +08:00
|
|
|
|
if __name__ == "__main__":
|
2023-08-08 17:59:41 +08:00
|
|
|
|
from server.db.base import Base, engine
|
|
|
|
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
|
|
|
2023-08-04 20:26:14 +08:00
|
|
|
|
api = ApiRequest(no_remote_api=True)
|
2023-08-03 10:49:57 +08:00
|
|
|
|
|
2023-08-01 14:15:42 +08:00
|
|
|
|
# print(api.chat_fastchat(
|
|
|
|
|
|
# messages=[{"role": "user", "content": "hello"}]
|
|
|
|
|
|
# ))
|
|
|
|
|
|
|
2023-08-03 10:49:57 +08:00
|
|
|
|
# with api.chat_chat("你好") as r:
|
|
|
|
|
|
# for t in r.iter_text(None):
|
|
|
|
|
|
# print(t)
|
|
|
|
|
|
|
|
|
|
|
|
# r = api.chat_chat("你好", no_remote_api=True)
|
|
|
|
|
|
# for t in r:
|
|
|
|
|
|
# print(t)
|
|
|
|
|
|
|
|
|
|
|
|
# r = api.duckduckgo_search_chat("室温超导最新研究进展", no_remote_api=True)
|
|
|
|
|
|
# for t in r:
|
|
|
|
|
|
# print(t)
|
2023-08-01 14:15:42 +08:00
|
|
|
|
|
2023-08-04 20:26:14 +08:00
|
|
|
|
# print(api.list_knowledge_bases())
|
|
|
|
|
|
|
2023-08-08 17:59:41 +08:00
|
|
|
|
# recreate all vector store
|
|
|
|
|
|
for kb in list_kbs_from_folder():
|
|
|
|
|
|
for t in api.recreate_vector_store(kb):
|
|
|
|
|
|
print(t)
|