update chat and knowledge base api: unify exception processing and return types

This commit is contained in:
liunux4odoo 2023-08-19 15:14:45 +08:00
parent 62d6f44b28
commit 69627a2fa3
6 changed files with 209 additions and 67 deletions

109
server/chat/github_chat.py Normal file
View File

@ -0,0 +1,109 @@
from langchain.document_loaders.github import GitHubIssuesLoader
from fastapi import Body
from fastapi.responses import StreamingResponse
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE)
from server.chat.utils import wrap_done
from server.utils import BaseResponse
from langchain.chat_models import ChatOpenAI
from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional, Literal
from server.chat.utils import History
from langchain.docstore.document import Document
import json
import os
from functools import lru_cache
from datetime import datetime
GITHUB_PERSONAL_ACCESS_TOKEN = os.environ.get("GITHUB_PERSONAL_ACCESS_TOKEN")
@lru_cache(1)
def load_issues(tick: str):
'''
set tick to a periodic value to refresh cache
'''
loader = GitHubIssuesLoader(
repo="chatchat-space/langchain-chatglm",
access_token=GITHUB_PERSONAL_ACCESS_TOKEN,
include_prs=True,
state="all",
)
docs = loader.load()
return docs
def
def github_chat(query: str = Body(..., description="用户输入", examples=["本项目最新进展"]),
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
include_prs: bool = Body(True, description="是否包含PR"),
state: Literal['open', 'closed', 'all'] = Body(None, description="Issue/PR状态"),
creator: str = Body(None, description="创建者"),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
"content": "介绍一下本项目"},
{"role": "assistant",
"content": "LangChain-Chatchat (原 Langchain-ChatGLM): 基于 Langchain 与 ChatGLM 等大语言模型的本地知识库问答应用实现。"}]]
),
stream: bool = Body(False, description="流式输出"),
):
if GITHUB_PERSONAL_ACCESS_TOKEN is None:
return BaseResponse(code=404, msg=f"使用本功能需要 GITHUB_PERSONAL_ACCESS_TOKEN")
async def chat_iterator(query: str,
search_engine_name: str,
top_k: int,
history: Optional[List[History]],
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL
)
docs = lookup_search_engine(query, search_engine_name, top_k)
context = "\n".join([doc.page_content for doc in docs])
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
chain = LLMChain(prompt=chat_prompt, llm=model)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.acall({"context": context, "question": query}),
callback.done),
)
source_documents = [
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
for inum, doc in enumerate(docs)
]
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps({"answer": token,
"docs": source_documents},
ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps({"answer": token,
"docs": source_documents},
ensure_ascii=False)
await task
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history),
media_type="text/event-stream")

View File

@ -15,7 +15,7 @@ async def list_kbs():
async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
vector_store_type: str = Body("faiss"),
embed_model: str = Body(EMBEDDING_MODEL),
):
) -> BaseResponse:
# Create selected knowledge base
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -27,13 +27,18 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
kb.create_kb()
try:
kb.create_kb()
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"创建知识库出错: {e}")
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
async def delete_kb(
knowledge_base_name: str = Body(..., examples=["samples"])
):
) -> BaseResponse:
# Delete selected knowledge base
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -51,5 +56,6 @@ async def delete_kb(
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"删除知识库时出现意外: {e}")
return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}")

View File

@ -22,7 +22,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
) -> List[DocumentWithScore]:
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}", "docs": []}
return []
docs = kb.search_docs(query, top_k, score_threshold)
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
@ -31,7 +31,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
async def list_docs(
knowledge_base_name: str
):
) -> ListResponse:
if not validate_kb_name(knowledge_base_name):
return ListResponse(code=403, msg="Don't attack me", data=[])
@ -41,13 +41,13 @@ async def list_docs(
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
else:
all_doc_names = kb.list_docs()
return ListResponse(data=all_doc_names)
return ListResponse(data=all_doc_names)
async def upload_doc(file: UploadFile = File(..., description="上传文件"),
knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]),
override: bool = Form(False, description="覆盖已有文件"),
):
) -> BaseResponse:
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -57,31 +57,37 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"),
file_content = await file.read() # 读取上传文件的内容
kb_file = KnowledgeFile(filename=file.filename,
knowledge_base_name=knowledge_base_name)
if (os.path.exists(kb_file.filepath)
and not override
and os.path.getsize(kb_file.filepath) == len(file_content)
):
# TODO: filesize 不同后的处理
file_status = f"文件 {kb_file.filename} 已存在。"
return BaseResponse(code=404, msg=file_status)
try:
kb_file = KnowledgeFile(filename=file.filename,
knowledge_base_name=knowledge_base_name)
if (os.path.exists(kb_file.filepath)
and not override
and os.path.getsize(kb_file.filepath) == len(file_content)
):
# TODO: filesize 不同后的处理
file_status = f"文件 {kb_file.filename} 已存在。"
return BaseResponse(code=404, msg=file_status)
with open(kb_file.filepath, "wb") as f:
f.write(file_content)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}")
kb.add_doc(kb_file)
try:
kb.add_doc(kb_file)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}")
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
doc_name: str = Body(..., examples=["file_name.md"]),
delete_content: bool = Body(False),
):
) -> BaseResponse:
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -92,17 +98,22 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
if not kb.exist_doc(doc_name):
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
kb_file = KnowledgeFile(filename=doc_name,
knowledge_base_name=knowledge_base_name)
kb.delete_doc(kb_file, delete_content)
try:
kb_file = KnowledgeFile(filename=doc_name,
knowledge_base_name=knowledge_base_name)
kb.delete_doc(kb_file, delete_content)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}")
return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功")
# return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败")
async def update_doc(
knowledge_base_name: str = Body(..., examples=["samples"]),
file_name: str = Body(..., examples=["file_name"]),
):
) -> BaseResponse:
'''
更新知识库文档
'''
@ -113,14 +124,17 @@ async def update_doc(
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
try:
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
if os.path.exists(kb_file.filepath):
kb.update_doc(kb_file)
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败,错误信息是:{e}")
if os.path.exists(kb_file.filepath):
kb.update_doc(kb_file)
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
else:
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败")
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败")
async def download_doc(
@ -137,18 +151,20 @@ async def download_doc(
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
if os.path.exists(kb_file.filepath):
return FileResponse(
path=kb_file.filepath,
filename=kb_file.filename,
media_type="multipart/form-data")
else:
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
try:
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
if os.path.exists(kb_file.filepath):
return FileResponse(
path=kb_file.filepath,
filename=kb_file.filename,
media_type="multipart/form-data")
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败,错误信息是:{e}")
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
async def recreate_vector_store(
@ -163,24 +179,33 @@ async def recreate_vector_store(
by default, get_service_by_name only return knowledge base in the info.db and having document files in it.
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
'''
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
async def output(kb):
kb.create_kb()
kb.clear_vs()
docs = list_docs_from_folder(knowledge_base_name)
for i, doc in enumerate(docs):
try:
kb_file = KnowledgeFile(doc, knowledge_base_name)
yield json.dumps({
"total": len(docs),
"finished": i,
"doc": doc,
}, ensure_ascii=False)
kb.add_doc(kb_file)
except Exception as e:
print(e)
async def output():
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else:
kb.create_kb()
kb.clear_vs()
docs = list_docs_from_folder(knowledge_base_name)
for i, doc in enumerate(docs):
try:
kb_file = KnowledgeFile(doc, knowledge_base_name)
yield json.dumps({
"code": 200,
"msg": f"({i + 1} / {len(docs)}): {doc}",
"total": len(docs),
"finished": i,
"doc": doc,
}, ensure_ascii=False)
kb.add_doc(kb_file)
except Exception as e:
print(e)
yield json.dumps({
"code": 500,
"msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。",
})
import asyncio
await asyncio.sleep(5)
return StreamingResponse(output(kb), media_type="text/event-stream")
return StreamingResponse(output(), media_type="text/event-stream")

View File

@ -9,8 +9,8 @@ from typing import Any, Optional
class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description="HTTP status code")
msg: str = pydantic.Field("success", description="HTTP status message")
code: int = pydantic.Field(200, description="API status code")
msg: str = pydantic.Field("success", description="API status message")
class Config:
schema_extra = {

View File

@ -249,12 +249,14 @@ def knowledge_base_page(api: ApiRequest):
use_container_width=True,
type="primary",
):
with st.spinner("向量库重构中"):
with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"):
empty = st.empty()
empty.progress(0.0, "")
for d in api.recreate_vector_store(kb):
print(d)
empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
if msg := check_error_msg(d):
st.toast(msg)
else:
empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
st.experimental_rerun()
if cols[2].button(

View File

@ -229,7 +229,7 @@ class ApiRequest:
elif chunk.strip():
yield chunk
except httpx.ConnectError as e:
msg = f"无法连接API服务器请确认已执行python server\\api.py"
msg = f"无法连接API服务器请确认 api.py 已正常启动。"
logger.error(msg)
logger.error(e)
yield {"code": 500, "msg": msg}