Merge branch 'dev' of github.com:chatchat-space/Langchain-Chatchat into dev
This commit is contained in:
commit
b40beac1a8
|
|
@ -1,7 +1,7 @@
|
|||
from fastapi.responses import StreamingResponse
|
||||
from typing import List
|
||||
import openai
|
||||
from configs.model_config import llm_model_dict, LLM_MODEL
|
||||
from configs.model_config import llm_model_dict, LLM_MODEL, logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
|
@ -33,19 +33,23 @@ async def openai_chat(msg: OpenAiChatMsgIn):
|
|||
data = msg.dict()
|
||||
data["streaming"] = True
|
||||
data.pop("stream")
|
||||
response = openai.ChatCompletion.create(**data)
|
||||
|
||||
if msg.stream:
|
||||
for chunk in response.choices[0].message.content:
|
||||
print(chunk)
|
||||
yield chunk
|
||||
else:
|
||||
answer = ""
|
||||
for chunk in response.choices[0].message.content:
|
||||
answer += chunk
|
||||
print(answer)
|
||||
yield(answer)
|
||||
|
||||
try:
|
||||
response = openai.ChatCompletion.create(**data)
|
||||
if msg.stream:
|
||||
for chunk in response.choices[0].message.content:
|
||||
print(chunk)
|
||||
yield chunk
|
||||
else:
|
||||
answer = ""
|
||||
for chunk in response.choices[0].message.content:
|
||||
answer += chunk
|
||||
print(answer)
|
||||
yield(answer)
|
||||
except Exception as e:
|
||||
print(type(e))
|
||||
logger.error(e)
|
||||
|
||||
return StreamingResponse(
|
||||
get_response(msg),
|
||||
media_type='text/event-stream',
|
||||
|
|
|
|||
|
|
@ -98,6 +98,9 @@ def dialogue_page(api: ApiRequest):
|
|||
text = ""
|
||||
r = api.chat_chat(prompt, history)
|
||||
for t in r:
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
st.error(error_msg)
|
||||
break
|
||||
text += t
|
||||
chat_box.update_msg(text)
|
||||
chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标
|
||||
|
|
@ -109,6 +112,8 @@ def dialogue_page(api: ApiRequest):
|
|||
])
|
||||
text = ""
|
||||
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, history):
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
st.error(error_msg)
|
||||
text += d["answer"]
|
||||
chat_box.update_msg(text, 0)
|
||||
chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False)
|
||||
|
|
@ -120,6 +125,8 @@ def dialogue_page(api: ApiRequest):
|
|||
])
|
||||
text = ""
|
||||
for d in api.search_engine_chat(prompt, search_engine, se_top_k):
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
st.error(error_msg)
|
||||
text += d["answer"]
|
||||
chat_box.update_msg(text, 0)
|
||||
chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False)
|
||||
|
|
|
|||
|
|
@ -119,8 +119,6 @@ def knowledge_base_page(api: ApiRequest):
|
|||
elif selected_kb:
|
||||
kb = selected_kb["kb_name"]
|
||||
|
||||
|
||||
|
||||
# 上传文件
|
||||
# sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True)
|
||||
files = st.file_uploader("上传知识文件",
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from configs.model_config import (
|
|||
LLM_MODEL,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
SEARCH_ENGINE_TOP_K,
|
||||
logger,
|
||||
)
|
||||
import httpx
|
||||
import asyncio
|
||||
|
|
@ -24,6 +25,7 @@ from configs.model_config import NLTK_DATA_PATH
|
|||
import nltk
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
|
||||
|
||||
def set_httpx_timeout(timeout=60.0):
|
||||
'''
|
||||
设置httpx默认timeout到60秒。
|
||||
|
|
@ -80,7 +82,8 @@ class ApiRequest:
|
|||
return httpx.stream("GET", url, params=params, **kwargs)
|
||||
else:
|
||||
return httpx.get(url, params=params, **kwargs)
|
||||
except:
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
retry -= 1
|
||||
|
||||
async def aget(
|
||||
|
|
@ -100,7 +103,8 @@ class ApiRequest:
|
|||
return await client.stream("GET", url, params=params, **kwargs)
|
||||
else:
|
||||
return await client.get(url, params=params, **kwargs)
|
||||
except:
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
retry -= 1
|
||||
|
||||
def post(
|
||||
|
|
@ -121,7 +125,8 @@ class ApiRequest:
|
|||
return httpx.stream("POST", url, data=data, json=json, **kwargs)
|
||||
else:
|
||||
return httpx.post(url, data=data, json=json, **kwargs)
|
||||
except:
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
retry -= 1
|
||||
|
||||
async def apost(
|
||||
|
|
@ -142,7 +147,8 @@ class ApiRequest:
|
|||
return await client.stream("POST", url, data=data, json=json, **kwargs)
|
||||
else:
|
||||
return await client.post(url, data=data, json=json, **kwargs)
|
||||
except:
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
retry -= 1
|
||||
|
||||
def delete(
|
||||
|
|
@ -162,7 +168,8 @@ class ApiRequest:
|
|||
return httpx.stream("DELETE", url, data=data, json=json, **kwargs)
|
||||
else:
|
||||
return httpx.delete(url, data=data, json=json, **kwargs)
|
||||
except:
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
retry -= 1
|
||||
|
||||
async def adelete(
|
||||
|
|
@ -183,7 +190,8 @@ class ApiRequest:
|
|||
return await client.stream("DELETE", url, data=data, json=json, **kwargs)
|
||||
else:
|
||||
return await client.delete(url, data=data, json=json, **kwargs)
|
||||
except:
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
retry -= 1
|
||||
|
||||
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
|
||||
|
|
@ -195,11 +203,14 @@ class ApiRequest:
|
|||
except:
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
for chunk in iter_over_async(response.body_iterator, loop):
|
||||
if as_json and chunk:
|
||||
yield json.loads(chunk)
|
||||
elif chunk.strip():
|
||||
yield chunk
|
||||
try:
|
||||
for chunk in iter_over_async(response.body_iterator, loop):
|
||||
if as_json and chunk:
|
||||
yield json.loads(chunk)
|
||||
elif chunk.strip():
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def _httpx_stream2generator(
|
||||
self,
|
||||
|
|
@ -209,12 +220,31 @@ class ApiRequest:
|
|||
'''
|
||||
将httpx.stream返回的GeneratorContextManager转化为普通生成器
|
||||
'''
|
||||
with response as r:
|
||||
for chunk in r.iter_text(None):
|
||||
if as_json and chunk:
|
||||
yield json.loads(chunk)
|
||||
elif chunk.strip():
|
||||
yield chunk
|
||||
try:
|
||||
with response as r:
|
||||
for chunk in r.iter_text(None):
|
||||
if not chunk: # openai api server communicating error
|
||||
msg = f"API通信超时,请确认已启动FastChat与API服务(详见README '5. 启动 API 服务或 Web UI')"
|
||||
logger.error(msg)
|
||||
yield {"code": 500, "errorMsg": msg}
|
||||
break
|
||||
if as_json and chunk:
|
||||
yield json.loads(chunk)
|
||||
elif chunk.strip():
|
||||
yield chunk
|
||||
except httpx.ConnectError as e:
|
||||
msg = f"无法连接API服务器,请确认已执行python server\\api.py"
|
||||
logger.error(msg)
|
||||
logger.error(e)
|
||||
yield {"code": 500, "errorMsg": msg}
|
||||
except httpx.ReadTimeout as e:
|
||||
msg = f"API通信超时,请确认已启动FastChat与API服务(详见RADME '5. 启动 API 服务或 Web UI')"
|
||||
logger.error(msg)
|
||||
logger.error(e)
|
||||
yield {"code": 500, "errorMsg": msg}
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
yield {"code": 500, "errorMsg": str(e)}
|
||||
|
||||
# 对话相关操作
|
||||
|
||||
|
|
@ -353,6 +383,21 @@ class ApiRequest:
|
|||
|
||||
# 知识库相关操作
|
||||
|
||||
def _check_httpx_json_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
errorMsg: str = f"无法连接API服务器,请确认已执行python server\\api.py",
|
||||
) -> Dict:
|
||||
'''
|
||||
check whether httpx returns correct data with normal Response.
|
||||
error in api with streaming support was checked in _httpx_stream2enerator
|
||||
'''
|
||||
try:
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return {"code": 500, "errorMsg": errorMsg or str(e)}
|
||||
|
||||
def list_knowledge_bases(
|
||||
self,
|
||||
no_remote_api: bool = None,
|
||||
|
|
@ -369,7 +414,8 @@ class ApiRequest:
|
|||
return response.data
|
||||
else:
|
||||
response = self.get("/knowledge_base/list_knowledge_bases")
|
||||
return response.json().get("data")
|
||||
data = self._check_httpx_json_response(response)
|
||||
return data.get("data", [])
|
||||
|
||||
def create_knowledge_base(
|
||||
self,
|
||||
|
|
@ -399,7 +445,7 @@ class ApiRequest:
|
|||
"/knowledge_base/create_knowledge_base",
|
||||
json=data,
|
||||
)
|
||||
return response.json()
|
||||
return self._check_httpx_json_response(response)
|
||||
|
||||
def delete_knowledge_base(
|
||||
self,
|
||||
|
|
@ -421,7 +467,7 @@ class ApiRequest:
|
|||
"/knowledge_base/delete_knowledge_base",
|
||||
json=f"{knowledge_base_name}",
|
||||
)
|
||||
return response.json()
|
||||
return self._check_httpx_json_response(response)
|
||||
|
||||
def list_kb_docs(
|
||||
self,
|
||||
|
|
@ -443,7 +489,8 @@ class ApiRequest:
|
|||
"/knowledge_base/list_docs",
|
||||
params={"knowledge_base_name": knowledge_base_name}
|
||||
)
|
||||
return response.json().get("data")
|
||||
data = self._check_httpx_json_response(response)
|
||||
return data.get("data", [])
|
||||
|
||||
def upload_kb_doc(
|
||||
self,
|
||||
|
|
@ -487,7 +534,7 @@ class ApiRequest:
|
|||
data={"knowledge_base_name": knowledge_base_name, "override": override},
|
||||
files={"file": (filename, file)},
|
||||
)
|
||||
return response.json()
|
||||
return self._check_httpx_json_response(response)
|
||||
|
||||
def delete_kb_doc(
|
||||
self,
|
||||
|
|
@ -517,7 +564,7 @@ class ApiRequest:
|
|||
"/knowledge_base/delete_doc",
|
||||
json=data,
|
||||
)
|
||||
return response.json()
|
||||
return self._check_httpx_json_response(response)
|
||||
|
||||
def update_kb_doc(
|
||||
self,
|
||||
|
|
@ -540,7 +587,7 @@ class ApiRequest:
|
|||
"/knowledge_base/update_doc",
|
||||
json={"knowledge_base_name": knowledge_base_name, "file_name": file_name},
|
||||
)
|
||||
return response.json()
|
||||
return self._check_httpx_json_response(response)
|
||||
|
||||
def recreate_vector_store(
|
||||
self,
|
||||
|
|
@ -572,10 +619,20 @@ class ApiRequest:
|
|||
"/knowledge_base/recreate_vector_store",
|
||||
json=data,
|
||||
stream=True,
|
||||
timeout=False,
|
||||
)
|
||||
return self._httpx_stream2generator(response, as_json=True)
|
||||
|
||||
|
||||
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
|
||||
'''
|
||||
return error message if error occured when requests API
|
||||
'''
|
||||
if isinstance(data, dict) and key in data:
|
||||
return data[key]
|
||||
return ""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
api = ApiRequest(no_remote_api=True)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue