Merge branch 'dev' of github.com:chatchat-space/Langchain-Chatchat into dev

This commit is contained in:
hzg0601 2023-08-15 15:34:50 +08:00
commit b40beac1a8
4 changed files with 105 additions and 39 deletions

View File

@ -1,7 +1,7 @@
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from typing import List from typing import List
import openai import openai
from configs.model_config import llm_model_dict, LLM_MODEL from configs.model_config import llm_model_dict, LLM_MODEL, logger
from pydantic import BaseModel from pydantic import BaseModel
@ -33,18 +33,22 @@ async def openai_chat(msg: OpenAiChatMsgIn):
data = msg.dict() data = msg.dict()
data["streaming"] = True data["streaming"] = True
data.pop("stream") data.pop("stream")
response = openai.ChatCompletion.create(**data)
if msg.stream: try:
for chunk in response.choices[0].message.content: response = openai.ChatCompletion.create(**data)
print(chunk) if msg.stream:
yield chunk for chunk in response.choices[0].message.content:
else: print(chunk)
answer = "" yield chunk
for chunk in response.choices[0].message.content: else:
answer += chunk answer = ""
print(answer) for chunk in response.choices[0].message.content:
yield(answer) answer += chunk
print(answer)
yield(answer)
except Exception as e:
print(type(e))
logger.error(e)
return StreamingResponse( return StreamingResponse(
get_response(msg), get_response(msg),

View File

@ -98,6 +98,9 @@ def dialogue_page(api: ApiRequest):
text = "" text = ""
r = api.chat_chat(prompt, history) r = api.chat_chat(prompt, history)
for t in r: for t in r:
if error_msg := check_error_msg(t): # check whether error occured
st.error(error_msg)
break
text += t text += t
chat_box.update_msg(text) chat_box.update_msg(text)
chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标 chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标
@ -109,6 +112,8 @@ def dialogue_page(api: ApiRequest):
]) ])
text = "" text = ""
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, history): for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, history):
if error_msg := check_error_msg(t): # check whether error occured
st.error(error_msg)
text += d["answer"] text += d["answer"]
chat_box.update_msg(text, 0) chat_box.update_msg(text, 0)
chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False) chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False)
@ -120,6 +125,8 @@ def dialogue_page(api: ApiRequest):
]) ])
text = "" text = ""
for d in api.search_engine_chat(prompt, search_engine, se_top_k): for d in api.search_engine_chat(prompt, search_engine, se_top_k):
if error_msg := check_error_msg(t): # check whether error occured
st.error(error_msg)
text += d["answer"] text += d["answer"]
chat_box.update_msg(text, 0) chat_box.update_msg(text, 0)
chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False) chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False)

View File

@ -119,8 +119,6 @@ def knowledge_base_page(api: ApiRequest):
elif selected_kb: elif selected_kb:
kb = selected_kb["kb_name"] kb = selected_kb["kb_name"]
# 上传文件 # 上传文件
# sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True) # sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True)
files = st.file_uploader("上传知识文件", files = st.file_uploader("上传知识文件",

View File

@ -8,6 +8,7 @@ from configs.model_config import (
LLM_MODEL, LLM_MODEL,
VECTOR_SEARCH_TOP_K, VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K, SEARCH_ENGINE_TOP_K,
logger,
) )
import httpx import httpx
import asyncio import asyncio
@ -24,6 +25,7 @@ from configs.model_config import NLTK_DATA_PATH
import nltk import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
def set_httpx_timeout(timeout=60.0): def set_httpx_timeout(timeout=60.0):
''' '''
设置httpx默认timeout到60秒 设置httpx默认timeout到60秒
@ -80,7 +82,8 @@ class ApiRequest:
return httpx.stream("GET", url, params=params, **kwargs) return httpx.stream("GET", url, params=params, **kwargs)
else: else:
return httpx.get(url, params=params, **kwargs) return httpx.get(url, params=params, **kwargs)
except: except Exception as e:
logger.error(e)
retry -= 1 retry -= 1
async def aget( async def aget(
@ -100,7 +103,8 @@ class ApiRequest:
return await client.stream("GET", url, params=params, **kwargs) return await client.stream("GET", url, params=params, **kwargs)
else: else:
return await client.get(url, params=params, **kwargs) return await client.get(url, params=params, **kwargs)
except: except Exception as e:
logger.error(e)
retry -= 1 retry -= 1
def post( def post(
@ -121,7 +125,8 @@ class ApiRequest:
return httpx.stream("POST", url, data=data, json=json, **kwargs) return httpx.stream("POST", url, data=data, json=json, **kwargs)
else: else:
return httpx.post(url, data=data, json=json, **kwargs) return httpx.post(url, data=data, json=json, **kwargs)
except: except Exception as e:
logger.error(e)
retry -= 1 retry -= 1
async def apost( async def apost(
@ -142,7 +147,8 @@ class ApiRequest:
return await client.stream("POST", url, data=data, json=json, **kwargs) return await client.stream("POST", url, data=data, json=json, **kwargs)
else: else:
return await client.post(url, data=data, json=json, **kwargs) return await client.post(url, data=data, json=json, **kwargs)
except: except Exception as e:
logger.error(e)
retry -= 1 retry -= 1
def delete( def delete(
@ -162,7 +168,8 @@ class ApiRequest:
return httpx.stream("DELETE", url, data=data, json=json, **kwargs) return httpx.stream("DELETE", url, data=data, json=json, **kwargs)
else: else:
return httpx.delete(url, data=data, json=json, **kwargs) return httpx.delete(url, data=data, json=json, **kwargs)
except: except Exception as e:
logger.error(e)
retry -= 1 retry -= 1
async def adelete( async def adelete(
@ -183,7 +190,8 @@ class ApiRequest:
return await client.stream("DELETE", url, data=data, json=json, **kwargs) return await client.stream("DELETE", url, data=data, json=json, **kwargs)
else: else:
return await client.delete(url, data=data, json=json, **kwargs) return await client.delete(url, data=data, json=json, **kwargs)
except: except Exception as e:
logger.error(e)
retry -= 1 retry -= 1
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False): def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
@ -195,11 +203,14 @@ class ApiRequest:
except: except:
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
for chunk in iter_over_async(response.body_iterator, loop): try:
if as_json and chunk: for chunk in iter_over_async(response.body_iterator, loop):
yield json.loads(chunk) if as_json and chunk:
elif chunk.strip(): yield json.loads(chunk)
yield chunk elif chunk.strip():
yield chunk
except Exception as e:
logger.error(e)
def _httpx_stream2generator( def _httpx_stream2generator(
self, self,
@ -209,12 +220,31 @@ class ApiRequest:
''' '''
将httpx.stream返回的GeneratorContextManager转化为普通生成器 将httpx.stream返回的GeneratorContextManager转化为普通生成器
''' '''
with response as r: try:
for chunk in r.iter_text(None): with response as r:
if as_json and chunk: for chunk in r.iter_text(None):
yield json.loads(chunk) if not chunk: # openai api server communicating error
elif chunk.strip(): msg = f"API通信超时请确认已启动FastChat与API服务详见README '5. 启动 API 服务或 Web UI'"
yield chunk logger.error(msg)
yield {"code": 500, "errorMsg": msg}
break
if as_json and chunk:
yield json.loads(chunk)
elif chunk.strip():
yield chunk
except httpx.ConnectError as e:
msg = f"无法连接API服务器请确认已执行python server\\api.py"
logger.error(msg)
logger.error(e)
yield {"code": 500, "errorMsg": msg}
except httpx.ReadTimeout as e:
msg = f"API通信超时请确认已启动FastChat与API服务详见RADME '5. 启动 API 服务或 Web UI'"
logger.error(msg)
logger.error(e)
yield {"code": 500, "errorMsg": msg}
except Exception as e:
logger.error(e)
yield {"code": 500, "errorMsg": str(e)}
# 对话相关操作 # 对话相关操作
@ -353,6 +383,21 @@ class ApiRequest:
# 知识库相关操作 # 知识库相关操作
def _check_httpx_json_response(
self,
response: httpx.Response,
errorMsg: str = f"无法连接API服务器请确认已执行python server\\api.py",
) -> Dict:
'''
check whether httpx returns correct data with normal Response.
error in api with streaming support was checked in _httpx_stream2enerator
'''
try:
return response.json()
except Exception as e:
logger.error(e)
return {"code": 500, "errorMsg": errorMsg or str(e)}
def list_knowledge_bases( def list_knowledge_bases(
self, self,
no_remote_api: bool = None, no_remote_api: bool = None,
@ -369,7 +414,8 @@ class ApiRequest:
return response.data return response.data
else: else:
response = self.get("/knowledge_base/list_knowledge_bases") response = self.get("/knowledge_base/list_knowledge_bases")
return response.json().get("data") data = self._check_httpx_json_response(response)
return data.get("data", [])
def create_knowledge_base( def create_knowledge_base(
self, self,
@ -399,7 +445,7 @@ class ApiRequest:
"/knowledge_base/create_knowledge_base", "/knowledge_base/create_knowledge_base",
json=data, json=data,
) )
return response.json() return self._check_httpx_json_response(response)
def delete_knowledge_base( def delete_knowledge_base(
self, self,
@ -421,7 +467,7 @@ class ApiRequest:
"/knowledge_base/delete_knowledge_base", "/knowledge_base/delete_knowledge_base",
json=f"{knowledge_base_name}", json=f"{knowledge_base_name}",
) )
return response.json() return self._check_httpx_json_response(response)
def list_kb_docs( def list_kb_docs(
self, self,
@ -443,7 +489,8 @@ class ApiRequest:
"/knowledge_base/list_docs", "/knowledge_base/list_docs",
params={"knowledge_base_name": knowledge_base_name} params={"knowledge_base_name": knowledge_base_name}
) )
return response.json().get("data") data = self._check_httpx_json_response(response)
return data.get("data", [])
def upload_kb_doc( def upload_kb_doc(
self, self,
@ -487,7 +534,7 @@ class ApiRequest:
data={"knowledge_base_name": knowledge_base_name, "override": override}, data={"knowledge_base_name": knowledge_base_name, "override": override},
files={"file": (filename, file)}, files={"file": (filename, file)},
) )
return response.json() return self._check_httpx_json_response(response)
def delete_kb_doc( def delete_kb_doc(
self, self,
@ -517,7 +564,7 @@ class ApiRequest:
"/knowledge_base/delete_doc", "/knowledge_base/delete_doc",
json=data, json=data,
) )
return response.json() return self._check_httpx_json_response(response)
def update_kb_doc( def update_kb_doc(
self, self,
@ -540,7 +587,7 @@ class ApiRequest:
"/knowledge_base/update_doc", "/knowledge_base/update_doc",
json={"knowledge_base_name": knowledge_base_name, "file_name": file_name}, json={"knowledge_base_name": knowledge_base_name, "file_name": file_name},
) )
return response.json() return self._check_httpx_json_response(response)
def recreate_vector_store( def recreate_vector_store(
self, self,
@ -572,10 +619,20 @@ class ApiRequest:
"/knowledge_base/recreate_vector_store", "/knowledge_base/recreate_vector_store",
json=data, json=data,
stream=True, stream=True,
timeout=False,
) )
return self._httpx_stream2generator(response, as_json=True) return self._httpx_stream2generator(response, as_json=True)
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
'''
return error message if error occured when requests API
'''
if isinstance(data, dict) and key in data:
return data[key]
return ""
if __name__ == "__main__": if __name__ == "__main__":
api = ApiRequest(no_remote_api=True) api = ApiRequest(no_remote_api=True)