From 69627a2fa34717660974612d2b94e14ac76152de Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Sat, 19 Aug 2023 15:14:45 +0800 Subject: [PATCH] update chat and knowledge base api: unify exception processing and return types --- server/chat/github_chat.py | 109 ++++++++++++++ server/knowledge_base/kb_api.py | 12 +- server/knowledge_base/kb_doc_api.py | 141 +++++++++++-------- server/utils.py | 4 +- webui_pages/knowledge_base/knowledge_base.py | 8 +- webui_pages/utils.py | 2 +- 6 files changed, 209 insertions(+), 67 deletions(-) create mode 100644 server/chat/github_chat.py diff --git a/server/chat/github_chat.py b/server/chat/github_chat.py new file mode 100644 index 0000000..b161548 --- /dev/null +++ b/server/chat/github_chat.py @@ -0,0 +1,109 @@ +from langchain.document_loaders.github import GitHubIssuesLoader +from fastapi import Body +from fastapi.responses import StreamingResponse +from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE) +from server.chat.utils import wrap_done +from server.utils import BaseResponse +from langchain.chat_models import ChatOpenAI +from langchain import LLMChain +from langchain.callbacks import AsyncIteratorCallbackHandler +from typing import AsyncIterable +import asyncio +from langchain.prompts.chat import ChatPromptTemplate +from typing import List, Optional, Literal +from server.chat.utils import History +from langchain.docstore.document import Document +import json +import os +from functools import lru_cache +from datetime import datetime + + +GITHUB_PERSONAL_ACCESS_TOKEN = os.environ.get("GITHUB_PERSONAL_ACCESS_TOKEN") + + +@lru_cache(1) +def load_issues(tick: str): + ''' + set tick to a periodic value to refresh cache + ''' + loader = GitHubIssuesLoader( + repo="chatchat-space/langchain-chatglm", + access_token=GITHUB_PERSONAL_ACCESS_TOKEN, + include_prs=True, + state="all", + ) + docs = loader.load() + return docs + + +def +def github_chat(query: str = Body(..., description="用户输入", examples=["本项目最新进展"]), + top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"), + include_prs: bool = Body(True, description="是否包含PR"), + state: Literal['open', 'closed', 'all'] = Body(None, description="Issue/PR状态"), + creator: str = Body(None, description="创建者"), + history: List[History] = Body([], + description="历史对话", + examples=[[ + {"role": "user", + "content": "介绍一下本项目"}, + {"role": "assistant", + "content": "LangChain-Chatchat (原 Langchain-ChatGLM): 基于 Langchain 与 ChatGLM 等大语言模型的本地知识库问答应用实现。"}]] + ), + stream: bool = Body(False, description="流式输出"), + ): + if GITHUB_PERSONAL_ACCESS_TOKEN is None: + return BaseResponse(code=404, msg=f"使用本功能需要 GITHUB_PERSONAL_ACCESS_TOKEN") + + async def chat_iterator(query: str, + search_engine_name: str, + top_k: int, + history: Optional[List[History]], + ) -> AsyncIterable[str]: + callback = AsyncIteratorCallbackHandler() + model = ChatOpenAI( + streaming=True, + verbose=True, + callbacks=[callback], + openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], + openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], + model_name=LLM_MODEL + ) + + docs = lookup_search_engine(query, search_engine_name, top_k) + context = "\n".join([doc.page_content for doc in docs]) + + chat_prompt = ChatPromptTemplate.from_messages( + [i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)]) + + chain = LLMChain(prompt=chat_prompt, llm=model) + + # Begin a task that runs in the background. + task = asyncio.create_task(wrap_done( + chain.acall({"context": context, "question": query}), + callback.done), + ) + + source_documents = [ + f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n""" + for inum, doc in enumerate(docs) + ] + + if stream: + async for token in callback.aiter(): + # Use server-sent-events to stream the response + yield json.dumps({"answer": token, + "docs": source_documents}, + ensure_ascii=False) + else: + answer = "" + async for token in callback.aiter(): + answer += token + yield json.dumps({"answer": token, + "docs": source_documents}, + ensure_ascii=False) + await task + + return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history), + media_type="text/event-stream") diff --git a/server/knowledge_base/kb_api.py b/server/knowledge_base/kb_api.py index 4753ba4..b9151b8 100644 --- a/server/knowledge_base/kb_api.py +++ b/server/knowledge_base/kb_api.py @@ -15,7 +15,7 @@ async def list_kbs(): async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), vector_store_type: str = Body("faiss"), embed_model: str = Body(EMBEDDING_MODEL), - ): + ) -> BaseResponse: # Create selected knowledge base if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -27,13 +27,18 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}") kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model) - kb.create_kb() + try: + kb.create_kb() + except Exception as e: + print(e) + return BaseResponse(code=500, msg=f"创建知识库出错: {e}") + return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}") async def delete_kb( knowledge_base_name: str = Body(..., examples=["samples"]) - ): + ) -> BaseResponse: # Delete selected knowledge base if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -51,5 +56,6 @@ async def delete_kb( return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}") except Exception as e: print(e) + return BaseResponse(code=500, msg=f"删除知识库时出现意外: {e}") return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}") diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 0bf2cb7..0d74fd6 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -22,7 +22,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=[" ) -> List[DocumentWithScore]: kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: - return {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}", "docs": []} + return [] docs = kb.search_docs(query, top_k, score_threshold) data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs] @@ -31,7 +31,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=[" async def list_docs( knowledge_base_name: str -): +) -> ListResponse: if not validate_kb_name(knowledge_base_name): return ListResponse(code=403, msg="Don't attack me", data=[]) @@ -41,13 +41,13 @@ async def list_docs( return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[]) else: all_doc_names = kb.list_docs() - return ListResponse(data=all_doc_names) + return ListResponse(data=all_doc_names) async def upload_doc(file: UploadFile = File(..., description="上传文件"), knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]), override: bool = Form(False, description="覆盖已有文件"), - ): + ) -> BaseResponse: if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -57,31 +57,37 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"), file_content = await file.read() # 读取上传文件的内容 - kb_file = KnowledgeFile(filename=file.filename, - knowledge_base_name=knowledge_base_name) - - if (os.path.exists(kb_file.filepath) - and not override - and os.path.getsize(kb_file.filepath) == len(file_content) - ): - # TODO: filesize 不同后的处理 - file_status = f"文件 {kb_file.filename} 已存在。" - return BaseResponse(code=404, msg=file_status) - try: + kb_file = KnowledgeFile(filename=file.filename, + knowledge_base_name=knowledge_base_name) + + if (os.path.exists(kb_file.filepath) + and not override + and os.path.getsize(kb_file.filepath) == len(file_content) + ): + # TODO: filesize 不同后的处理 + file_status = f"文件 {kb_file.filename} 已存在。" + return BaseResponse(code=404, msg=file_status) + with open(kb_file.filepath, "wb") as f: f.write(file_content) except Exception as e: + print(e) return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}") - kb.add_doc(kb_file) + try: + kb.add_doc(kb_file) + except Exception as e: + print(e) + return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}") + return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}") async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]), doc_name: str = Body(..., examples=["file_name.md"]), delete_content: bool = Body(False), - ): + ) -> BaseResponse: if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -92,17 +98,22 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]), if not kb.exist_doc(doc_name): return BaseResponse(code=404, msg=f"未找到文件 {doc_name}") - kb_file = KnowledgeFile(filename=doc_name, - knowledge_base_name=knowledge_base_name) - kb.delete_doc(kb_file, delete_content) + + try: + kb_file = KnowledgeFile(filename=doc_name, + knowledge_base_name=knowledge_base_name) + kb.delete_doc(kb_file, delete_content) + except Exception as e: + print(e) + return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}") + return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功") - # return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败") async def update_doc( knowledge_base_name: str = Body(..., examples=["samples"]), file_name: str = Body(..., examples=["file_name"]), - ): + ) -> BaseResponse: ''' 更新知识库文档 ''' @@ -113,14 +124,17 @@ async def update_doc( if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - kb_file = KnowledgeFile(filename=file_name, - knowledge_base_name=knowledge_base_name) + try: + kb_file = KnowledgeFile(filename=file_name, + knowledge_base_name=knowledge_base_name) + if os.path.exists(kb_file.filepath): + kb.update_doc(kb_file) + return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}") + except Exception as e: + print(e) + return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败,错误信息是:{e}") - if os.path.exists(kb_file.filepath): - kb.update_doc(kb_file) - return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}") - else: - return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败") + return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败") async def download_doc( @@ -137,18 +151,20 @@ async def download_doc( if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - kb_file = KnowledgeFile(filename=file_name, - knowledge_base_name=knowledge_base_name) - - if os.path.exists(kb_file.filepath): - return FileResponse( - path=kb_file.filepath, - filename=kb_file.filename, - media_type="multipart/form-data") - else: - return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败") + try: + kb_file = KnowledgeFile(filename=file_name, + knowledge_base_name=knowledge_base_name) + if os.path.exists(kb_file.filepath): + return FileResponse( + path=kb_file.filepath, + filename=kb_file.filename, + media_type="multipart/form-data") + except Exception as e: + print(e) + return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败,错误信息是:{e}") + return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败") async def recreate_vector_store( @@ -163,24 +179,33 @@ async def recreate_vector_store( by default, get_service_by_name only return knowledge base in the info.db and having document files in it. set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents. ''' - kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model) - if not kb.exists() and not allow_empty_kb: - return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - async def output(kb): - kb.create_kb() - kb.clear_vs() - docs = list_docs_from_folder(knowledge_base_name) - for i, doc in enumerate(docs): - try: - kb_file = KnowledgeFile(doc, knowledge_base_name) - yield json.dumps({ - "total": len(docs), - "finished": i, - "doc": doc, - }, ensure_ascii=False) - kb.add_doc(kb_file) - except Exception as e: - print(e) + async def output(): + kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model) + if not kb.exists() and not allow_empty_kb: + yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"} + else: + kb.create_kb() + kb.clear_vs() + docs = list_docs_from_folder(knowledge_base_name) + for i, doc in enumerate(docs): + try: + kb_file = KnowledgeFile(doc, knowledge_base_name) + yield json.dumps({ + "code": 200, + "msg": f"({i + 1} / {len(docs)}): {doc}", + "total": len(docs), + "finished": i, + "doc": doc, + }, ensure_ascii=False) + kb.add_doc(kb_file) + except Exception as e: + print(e) + yield json.dumps({ + "code": 500, + "msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。", + }) + import asyncio + await asyncio.sleep(5) - return StreamingResponse(output(kb), media_type="text/event-stream") + return StreamingResponse(output(), media_type="text/event-stream") diff --git a/server/utils.py b/server/utils.py index c0f11a5..4a88722 100644 --- a/server/utils.py +++ b/server/utils.py @@ -9,8 +9,8 @@ from typing import Any, Optional class BaseResponse(BaseModel): - code: int = pydantic.Field(200, description="HTTP status code") - msg: str = pydantic.Field("success", description="HTTP status message") + code: int = pydantic.Field(200, description="API status code") + msg: str = pydantic.Field("success", description="API status message") class Config: schema_extra = { diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 89e274c..3bd531a 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -249,12 +249,14 @@ def knowledge_base_page(api: ApiRequest): use_container_width=True, type="primary", ): - with st.spinner("向量库重构中"): + with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"): empty = st.empty() empty.progress(0.0, "") for d in api.recreate_vector_store(kb): - print(d) - empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}") + if msg := check_error_msg(d): + st.toast(msg) + else: + empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}") st.experimental_rerun() if cols[2].button( diff --git a/webui_pages/utils.py b/webui_pages/utils.py index cc38ef5..18a24e4 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -229,7 +229,7 @@ class ApiRequest: elif chunk.strip(): yield chunk except httpx.ConnectError as e: - msg = f"无法连接API服务器,请确认已执行python server\\api.py" + msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。" logger.error(msg) logger.error(e) yield {"code": 500, "msg": msg}