update chat and knowledge base api: unify exception processing and return types

This commit is contained in:
liunux4odoo 2023-08-19 15:14:45 +08:00
parent 62d6f44b28
commit 69627a2fa3
6 changed files with 209 additions and 67 deletions

109
server/chat/github_chat.py Normal file
View File

@ -0,0 +1,109 @@
from langchain.document_loaders.github import GitHubIssuesLoader
from fastapi import Body
from fastapi.responses import StreamingResponse
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE)
from server.chat.utils import wrap_done
from server.utils import BaseResponse
from langchain.chat_models import ChatOpenAI
from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional, Literal
from server.chat.utils import History
from langchain.docstore.document import Document
import json
import os
from functools import lru_cache
from datetime import datetime
GITHUB_PERSONAL_ACCESS_TOKEN = os.environ.get("GITHUB_PERSONAL_ACCESS_TOKEN")
@lru_cache(1)
def load_issues(tick: str):
'''
set tick to a periodic value to refresh cache
'''
loader = GitHubIssuesLoader(
repo="chatchat-space/langchain-chatglm",
access_token=GITHUB_PERSONAL_ACCESS_TOKEN,
include_prs=True,
state="all",
)
docs = loader.load()
return docs
def
def github_chat(query: str = Body(..., description="用户输入", examples=["本项目最新进展"]),
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
include_prs: bool = Body(True, description="是否包含PR"),
state: Literal['open', 'closed', 'all'] = Body(None, description="Issue/PR状态"),
creator: str = Body(None, description="创建者"),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
"content": "介绍一下本项目"},
{"role": "assistant",
"content": "LangChain-Chatchat (原 Langchain-ChatGLM): 基于 Langchain 与 ChatGLM 等大语言模型的本地知识库问答应用实现。"}]]
),
stream: bool = Body(False, description="流式输出"),
):
if GITHUB_PERSONAL_ACCESS_TOKEN is None:
return BaseResponse(code=404, msg=f"使用本功能需要 GITHUB_PERSONAL_ACCESS_TOKEN")
async def chat_iterator(query: str,
search_engine_name: str,
top_k: int,
history: Optional[List[History]],
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL
)
docs = lookup_search_engine(query, search_engine_name, top_k)
context = "\n".join([doc.page_content for doc in docs])
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
chain = LLMChain(prompt=chat_prompt, llm=model)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.acall({"context": context, "question": query}),
callback.done),
)
source_documents = [
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
for inum, doc in enumerate(docs)
]
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps({"answer": token,
"docs": source_documents},
ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps({"answer": token,
"docs": source_documents},
ensure_ascii=False)
await task
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history),
media_type="text/event-stream")

View File

@ -15,7 +15,7 @@ async def list_kbs():
async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
vector_store_type: str = Body("faiss"), vector_store_type: str = Body("faiss"),
embed_model: str = Body(EMBEDDING_MODEL), embed_model: str = Body(EMBEDDING_MODEL),
): ) -> BaseResponse:
# Create selected knowledge base # Create selected knowledge base
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me") return BaseResponse(code=403, msg="Don't attack me")
@ -27,13 +27,18 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}") return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model) kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
try:
kb.create_kb() kb.create_kb()
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"创建知识库出错: {e}")
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}") return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
async def delete_kb( async def delete_kb(
knowledge_base_name: str = Body(..., examples=["samples"]) knowledge_base_name: str = Body(..., examples=["samples"])
): ) -> BaseResponse:
# Delete selected knowledge base # Delete selected knowledge base
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me") return BaseResponse(code=403, msg="Don't attack me")
@ -51,5 +56,6 @@ async def delete_kb(
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}") return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
except Exception as e: except Exception as e:
print(e) print(e)
return BaseResponse(code=500, msg=f"删除知识库时出现意外: {e}")
return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}") return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}")

View File

@ -22,7 +22,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
) -> List[DocumentWithScore]: ) -> List[DocumentWithScore]:
kb = KBServiceFactory.get_service_by_name(knowledge_base_name) kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None: if kb is None:
return {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}", "docs": []} return []
docs = kb.search_docs(query, top_k, score_threshold) docs = kb.search_docs(query, top_k, score_threshold)
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs] data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
@ -31,7 +31,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
async def list_docs( async def list_docs(
knowledge_base_name: str knowledge_base_name: str
): ) -> ListResponse:
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):
return ListResponse(code=403, msg="Don't attack me", data=[]) return ListResponse(code=403, msg="Don't attack me", data=[])
@ -47,7 +47,7 @@ async def list_docs(
async def upload_doc(file: UploadFile = File(..., description="上传文件"), async def upload_doc(file: UploadFile = File(..., description="上传文件"),
knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]), knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]),
override: bool = Form(False, description="覆盖已有文件"), override: bool = Form(False, description="覆盖已有文件"),
): ) -> BaseResponse:
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me") return BaseResponse(code=403, msg="Don't attack me")
@ -57,6 +57,7 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"),
file_content = await file.read() # 读取上传文件的内容 file_content = await file.read() # 读取上传文件的内容
try:
kb_file = KnowledgeFile(filename=file.filename, kb_file = KnowledgeFile(filename=file.filename,
knowledge_base_name=knowledge_base_name) knowledge_base_name=knowledge_base_name)
@ -68,20 +69,25 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"),
file_status = f"文件 {kb_file.filename} 已存在。" file_status = f"文件 {kb_file.filename} 已存在。"
return BaseResponse(code=404, msg=file_status) return BaseResponse(code=404, msg=file_status)
try:
with open(kb_file.filepath, "wb") as f: with open(kb_file.filepath, "wb") as f:
f.write(file_content) f.write(file_content)
except Exception as e: except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}") return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}")
try:
kb.add_doc(kb_file) kb.add_doc(kb_file)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}")
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}") return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]), async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
doc_name: str = Body(..., examples=["file_name.md"]), doc_name: str = Body(..., examples=["file_name.md"]),
delete_content: bool = Body(False), delete_content: bool = Body(False),
): ) -> BaseResponse:
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me") return BaseResponse(code=403, msg="Don't attack me")
@ -92,17 +98,22 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
if not kb.exist_doc(doc_name): if not kb.exist_doc(doc_name):
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}") return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
try:
kb_file = KnowledgeFile(filename=doc_name, kb_file = KnowledgeFile(filename=doc_name,
knowledge_base_name=knowledge_base_name) knowledge_base_name=knowledge_base_name)
kb.delete_doc(kb_file, delete_content) kb.delete_doc(kb_file, delete_content)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}")
return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功") return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功")
# return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败")
async def update_doc( async def update_doc(
knowledge_base_name: str = Body(..., examples=["samples"]), knowledge_base_name: str = Body(..., examples=["samples"]),
file_name: str = Body(..., examples=["file_name"]), file_name: str = Body(..., examples=["file_name"]),
): ) -> BaseResponse:
''' '''
更新知识库文档 更新知识库文档
''' '''
@ -113,13 +124,16 @@ async def update_doc(
if kb is None: if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
try:
kb_file = KnowledgeFile(filename=file_name, kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name) knowledge_base_name=knowledge_base_name)
if os.path.exists(kb_file.filepath): if os.path.exists(kb_file.filepath):
kb.update_doc(kb_file) kb.update_doc(kb_file)
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}") return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
else: except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败,错误信息是:{e}")
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败") return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败")
@ -137,6 +151,7 @@ async def download_doc(
if kb is None: if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
try:
kb_file = KnowledgeFile(filename=file_name, kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name) knowledge_base_name=knowledge_base_name)
@ -145,12 +160,13 @@ async def download_doc(
path=kb_file.filepath, path=kb_file.filepath,
filename=kb_file.filename, filename=kb_file.filename,
media_type="multipart/form-data") media_type="multipart/form-data")
else: except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败,错误信息是:{e}")
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败") return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
async def recreate_vector_store( async def recreate_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]), knowledge_base_name: str = Body(..., examples=["samples"]),
allow_empty_kb: bool = Body(True), allow_empty_kb: bool = Body(True),
@ -163,11 +179,12 @@ async def recreate_vector_store(
by default, get_service_by_name only return knowledge base in the info.db and having document files in it. by default, get_service_by_name only return knowledge base in the info.db and having document files in it.
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents. set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
''' '''
async def output():
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model) kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb: if not kb.exists() and not allow_empty_kb:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else:
async def output(kb):
kb.create_kb() kb.create_kb()
kb.clear_vs() kb.clear_vs()
docs = list_docs_from_folder(knowledge_base_name) docs = list_docs_from_folder(knowledge_base_name)
@ -175,6 +192,8 @@ async def recreate_vector_store(
try: try:
kb_file = KnowledgeFile(doc, knowledge_base_name) kb_file = KnowledgeFile(doc, knowledge_base_name)
yield json.dumps({ yield json.dumps({
"code": 200,
"msg": f"({i + 1} / {len(docs)}): {doc}",
"total": len(docs), "total": len(docs),
"finished": i, "finished": i,
"doc": doc, "doc": doc,
@ -182,5 +201,11 @@ async def recreate_vector_store(
kb.add_doc(kb_file) kb.add_doc(kb_file)
except Exception as e: except Exception as e:
print(e) print(e)
yield json.dumps({
"code": 500,
"msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。",
})
import asyncio
await asyncio.sleep(5)
return StreamingResponse(output(kb), media_type="text/event-stream") return StreamingResponse(output(), media_type="text/event-stream")

View File

@ -9,8 +9,8 @@ from typing import Any, Optional
class BaseResponse(BaseModel): class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description="HTTP status code") code: int = pydantic.Field(200, description="API status code")
msg: str = pydantic.Field("success", description="HTTP status message") msg: str = pydantic.Field("success", description="API status message")
class Config: class Config:
schema_extra = { schema_extra = {

View File

@ -249,11 +249,13 @@ def knowledge_base_page(api: ApiRequest):
use_container_width=True, use_container_width=True,
type="primary", type="primary",
): ):
with st.spinner("向量库重构中"): with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"):
empty = st.empty() empty = st.empty()
empty.progress(0.0, "") empty.progress(0.0, "")
for d in api.recreate_vector_store(kb): for d in api.recreate_vector_store(kb):
print(d) if msg := check_error_msg(d):
st.toast(msg)
else:
empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}") empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
st.experimental_rerun() st.experimental_rerun()

View File

@ -229,7 +229,7 @@ class ApiRequest:
elif chunk.strip(): elif chunk.strip():
yield chunk yield chunk
except httpx.ConnectError as e: except httpx.ConnectError as e:
msg = f"无法连接API服务器请确认已执行python server\\api.py" msg = f"无法连接API服务器请确认 api.py 已正常启动。"
logger.error(msg) logger.error(msg)
logger.error(e) logger.error(e)
yield {"code": 500, "msg": msg} yield {"code": 500, "msg": msg}