update chat and knowledge base api: unify exception processing and return types
This commit is contained in:
parent
62d6f44b28
commit
69627a2fa3
|
|
@ -0,0 +1,109 @@
|
|||
from langchain.document_loaders.github import GitHubIssuesLoader
|
||||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE)
|
||||
from server.chat.utils import wrap_done
|
||||
from server.utils import BaseResponse
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable
|
||||
import asyncio
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from typing import List, Optional, Literal
|
||||
from server.chat.utils import History
|
||||
from langchain.docstore.document import Document
|
||||
import json
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
GITHUB_PERSONAL_ACCESS_TOKEN = os.environ.get("GITHUB_PERSONAL_ACCESS_TOKEN")
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def load_issues(tick: str):
|
||||
'''
|
||||
set tick to a periodic value to refresh cache
|
||||
'''
|
||||
loader = GitHubIssuesLoader(
|
||||
repo="chatchat-space/langchain-chatglm",
|
||||
access_token=GITHUB_PERSONAL_ACCESS_TOKEN,
|
||||
include_prs=True,
|
||||
state="all",
|
||||
)
|
||||
docs = loader.load()
|
||||
return docs
|
||||
|
||||
|
||||
def
|
||||
def github_chat(query: str = Body(..., description="用户输入", examples=["本项目最新进展"]),
|
||||
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
||||
include_prs: bool = Body(True, description="是否包含PR"),
|
||||
state: Literal['open', 'closed', 'all'] = Body(None, description="Issue/PR状态"),
|
||||
creator: str = Body(None, description="创建者"),
|
||||
history: List[History] = Body([],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user",
|
||||
"content": "介绍一下本项目"},
|
||||
{"role": "assistant",
|
||||
"content": "LangChain-Chatchat (原 Langchain-ChatGLM): 基于 Langchain 与 ChatGLM 等大语言模型的本地知识库问答应用实现。"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
):
|
||||
if GITHUB_PERSONAL_ACCESS_TOKEN is None:
|
||||
return BaseResponse(code=404, msg=f"使用本功能需要 GITHUB_PERSONAL_ACCESS_TOKEN")
|
||||
|
||||
async def chat_iterator(query: str,
|
||||
search_engine_name: str,
|
||||
top_k: int,
|
||||
history: Optional[List[History]],
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
callbacks=[callback],
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL
|
||||
)
|
||||
|
||||
docs = lookup_search_engine(query, search_engine_name, top_k)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
|
||||
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
|
||||
# Begin a task that runs in the background.
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({"context": context, "question": query}),
|
||||
callback.done),
|
||||
)
|
||||
|
||||
source_documents = [
|
||||
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
|
||||
for inum, doc in enumerate(docs)
|
||||
]
|
||||
|
||||
if stream:
|
||||
async for token in callback.aiter():
|
||||
# Use server-sent-events to stream the response
|
||||
yield json.dumps({"answer": token,
|
||||
"docs": source_documents},
|
||||
ensure_ascii=False)
|
||||
else:
|
||||
answer = ""
|
||||
async for token in callback.aiter():
|
||||
answer += token
|
||||
yield json.dumps({"answer": token,
|
||||
"docs": source_documents},
|
||||
ensure_ascii=False)
|
||||
await task
|
||||
|
||||
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history),
|
||||
media_type="text/event-stream")
|
||||
|
|
@ -15,7 +15,7 @@ async def list_kbs():
|
|||
async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
vector_store_type: str = Body("faiss"),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
):
|
||||
) -> BaseResponse:
|
||||
# Create selected knowledge base
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
|
@ -27,13 +27,18 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
|||
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
|
||||
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
|
||||
kb.create_kb()
|
||||
try:
|
||||
kb.create_kb()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"创建知识库出错: {e}")
|
||||
|
||||
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
|
||||
|
||||
|
||||
async def delete_kb(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"])
|
||||
):
|
||||
) -> BaseResponse:
|
||||
# Delete selected knowledge base
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
|
@ -51,5 +56,6 @@ async def delete_kb(
|
|||
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"删除知识库时出现意外: {e}")
|
||||
|
||||
return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}")
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
|
|||
) -> List[DocumentWithScore]:
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}", "docs": []}
|
||||
return []
|
||||
docs = kb.search_docs(query, top_k, score_threshold)
|
||||
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
|
|||
|
||||
async def list_docs(
|
||||
knowledge_base_name: str
|
||||
):
|
||||
) -> ListResponse:
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return ListResponse(code=403, msg="Don't attack me", data=[])
|
||||
|
||||
|
|
@ -41,13 +41,13 @@ async def list_docs(
|
|||
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
|
||||
else:
|
||||
all_doc_names = kb.list_docs()
|
||||
return ListResponse(data=all_doc_names)
|
||||
return ListResponse(data=all_doc_names)
|
||||
|
||||
|
||||
async def upload_doc(file: UploadFile = File(..., description="上传文件"),
|
||||
knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]),
|
||||
override: bool = Form(False, description="覆盖已有文件"),
|
||||
):
|
||||
) -> BaseResponse:
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
|
|
@ -57,31 +57,37 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"),
|
|||
|
||||
file_content = await file.read() # 读取上传文件的内容
|
||||
|
||||
kb_file = KnowledgeFile(filename=file.filename,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
|
||||
if (os.path.exists(kb_file.filepath)
|
||||
and not override
|
||||
and os.path.getsize(kb_file.filepath) == len(file_content)
|
||||
):
|
||||
# TODO: filesize 不同后的处理
|
||||
file_status = f"文件 {kb_file.filename} 已存在。"
|
||||
return BaseResponse(code=404, msg=file_status)
|
||||
|
||||
try:
|
||||
kb_file = KnowledgeFile(filename=file.filename,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
|
||||
if (os.path.exists(kb_file.filepath)
|
||||
and not override
|
||||
and os.path.getsize(kb_file.filepath) == len(file_content)
|
||||
):
|
||||
# TODO: filesize 不同后的处理
|
||||
file_status = f"文件 {kb_file.filename} 已存在。"
|
||||
return BaseResponse(code=404, msg=file_status)
|
||||
|
||||
with open(kb_file.filepath, "wb") as f:
|
||||
f.write(file_content)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}")
|
||||
|
||||
kb.add_doc(kb_file)
|
||||
try:
|
||||
kb.add_doc(kb_file)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}")
|
||||
|
||||
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
|
||||
|
||||
|
||||
async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
doc_name: str = Body(..., examples=["file_name.md"]),
|
||||
delete_content: bool = Body(False),
|
||||
):
|
||||
) -> BaseResponse:
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
|
|
@ -92,17 +98,22 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
|
|||
|
||||
if not kb.exist_doc(doc_name):
|
||||
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
|
||||
kb_file = KnowledgeFile(filename=doc_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
kb.delete_doc(kb_file, delete_content)
|
||||
|
||||
try:
|
||||
kb_file = KnowledgeFile(filename=doc_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
kb.delete_doc(kb_file, delete_content)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}")
|
||||
|
||||
return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功")
|
||||
# return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败")
|
||||
|
||||
|
||||
async def update_doc(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
file_name: str = Body(..., examples=["file_name"]),
|
||||
):
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
更新知识库文档
|
||||
'''
|
||||
|
|
@ -113,14 +124,17 @@ async def update_doc(
|
|||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
kb_file = KnowledgeFile(filename=file_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
try:
|
||||
kb_file = KnowledgeFile(filename=file_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
if os.path.exists(kb_file.filepath):
|
||||
kb.update_doc(kb_file)
|
||||
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败,错误信息是:{e}")
|
||||
|
||||
if os.path.exists(kb_file.filepath):
|
||||
kb.update_doc(kb_file)
|
||||
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
|
||||
else:
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败")
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败")
|
||||
|
||||
|
||||
async def download_doc(
|
||||
|
|
@ -137,18 +151,20 @@ async def download_doc(
|
|||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
kb_file = KnowledgeFile(filename=file_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
|
||||
if os.path.exists(kb_file.filepath):
|
||||
return FileResponse(
|
||||
path=kb_file.filepath,
|
||||
filename=kb_file.filename,
|
||||
media_type="multipart/form-data")
|
||||
else:
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
|
||||
try:
|
||||
kb_file = KnowledgeFile(filename=file_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
|
||||
if os.path.exists(kb_file.filepath):
|
||||
return FileResponse(
|
||||
path=kb_file.filepath,
|
||||
filename=kb_file.filename,
|
||||
media_type="multipart/form-data")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败,错误信息是:{e}")
|
||||
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
|
||||
|
||||
|
||||
async def recreate_vector_store(
|
||||
|
|
@ -163,24 +179,33 @@ async def recreate_vector_store(
|
|||
by default, get_service_by_name only return knowledge base in the info.db and having document files in it.
|
||||
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
|
||||
'''
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
|
||||
if not kb.exists() and not allow_empty_kb:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
async def output(kb):
|
||||
kb.create_kb()
|
||||
kb.clear_vs()
|
||||
docs = list_docs_from_folder(knowledge_base_name)
|
||||
for i, doc in enumerate(docs):
|
||||
try:
|
||||
kb_file = KnowledgeFile(doc, knowledge_base_name)
|
||||
yield json.dumps({
|
||||
"total": len(docs),
|
||||
"finished": i,
|
||||
"doc": doc,
|
||||
}, ensure_ascii=False)
|
||||
kb.add_doc(kb_file)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
async def output():
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
|
||||
if not kb.exists() and not allow_empty_kb:
|
||||
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
||||
else:
|
||||
kb.create_kb()
|
||||
kb.clear_vs()
|
||||
docs = list_docs_from_folder(knowledge_base_name)
|
||||
for i, doc in enumerate(docs):
|
||||
try:
|
||||
kb_file = KnowledgeFile(doc, knowledge_base_name)
|
||||
yield json.dumps({
|
||||
"code": 200,
|
||||
"msg": f"({i + 1} / {len(docs)}): {doc}",
|
||||
"total": len(docs),
|
||||
"finished": i,
|
||||
"doc": doc,
|
||||
}, ensure_ascii=False)
|
||||
kb.add_doc(kb_file)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
yield json.dumps({
|
||||
"code": 500,
|
||||
"msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。",
|
||||
})
|
||||
import asyncio
|
||||
await asyncio.sleep(5)
|
||||
|
||||
return StreamingResponse(output(kb), media_type="text/event-stream")
|
||||
return StreamingResponse(output(), media_type="text/event-stream")
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ from typing import Any, Optional
|
|||
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
code: int = pydantic.Field(200, description="HTTP status code")
|
||||
msg: str = pydantic.Field("success", description="HTTP status message")
|
||||
code: int = pydantic.Field(200, description="API status code")
|
||||
msg: str = pydantic.Field("success", description="API status message")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
|
|
|
|||
|
|
@ -249,12 +249,14 @@ def knowledge_base_page(api: ApiRequest):
|
|||
use_container_width=True,
|
||||
type="primary",
|
||||
):
|
||||
with st.spinner("向量库重构中"):
|
||||
with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"):
|
||||
empty = st.empty()
|
||||
empty.progress(0.0, "")
|
||||
for d in api.recreate_vector_store(kb):
|
||||
print(d)
|
||||
empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
|
||||
if msg := check_error_msg(d):
|
||||
st.toast(msg)
|
||||
else:
|
||||
empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
|
||||
st.experimental_rerun()
|
||||
|
||||
if cols[2].button(
|
||||
|
|
|
|||
|
|
@ -229,7 +229,7 @@ class ApiRequest:
|
|||
elif chunk.strip():
|
||||
yield chunk
|
||||
except httpx.ConnectError as e:
|
||||
msg = f"无法连接API服务器,请确认已执行python server\\api.py"
|
||||
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。"
|
||||
logger.error(msg)
|
||||
logger.error(e)
|
||||
yield {"code": 500, "msg": msg}
|
||||
|
|
|
|||
Loading…
Reference in New Issue