diff --git a/api.py b/api.py new file mode 100644 index 0000000..4c93492 --- /dev/null +++ b/api.py @@ -0,0 +1,102 @@ +from configs.model_config import * +from chains.local_doc_qa import LocalDocQA +import os +import nltk + +import uvicorn +from fastapi import FastAPI, File, UploadFile +from pydantic import BaseModel +from starlette.responses import RedirectResponse + +app = FastAPI() + +global local_doc_qa, vs_path + +nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path + +# return top-k text chunk from vector store +VECTOR_SEARCH_TOP_K = 10 + +# LLM input history length +LLM_HISTORY_LEN = 3 + +# Show reply with source text from input document +REPLY_WITH_SOURCE = False + +class Query(BaseModel): + query: str + +@app.get('/') +async def document(): + return RedirectResponse(url="/docs") + +@app.on_event("startup") +async def get_local_doc_qa(): + global local_doc_qa + local_doc_qa = LocalDocQA() + local_doc_qa.init_cfg(llm_model=LLM_MODEL, + embedding_model=EMBEDDING_MODEL, + embedding_device=EMBEDDING_DEVICE, + llm_history_len=LLM_HISTORY_LEN, + top_k=VECTOR_SEARCH_TOP_K) + + +@app.post("/file") +async def upload_file(UserFile: UploadFile=File(...)): + global vs_path + response = { + "msg": None, + "status": 0 + } + try: + filepath = './content/' + UserFile.filename + content = await UserFile.read() + # print(UserFile.filename) + with open(filepath, 'wb') as f: + f.write(content) + vs_path = local_doc_qa.init_knowledge_vector_store(filepath) + response = { + 'msg': 'seccessful', + 'status': 1 + } + + except Exception as err: + response["message"] = err + + return response + +@app.post("/qa") +async def get_answer(UserQuery: Query): + response = { + "status": 0, + "message": "", + "answer": None + } + global vs_path + history = [] + try: + resp, history = local_doc_qa.get_knowledge_based_answer(query=UserQuery.query, + vs_path=vs_path, + chat_history=history) + if REPLY_WITH_SOURCE: + response["answer"] = resp + else: + response['answer'] = resp["result"] + + response["message"] = 'successful' + response["status"] = 1 + + except Exception as err: + response["message"] = err + + return response + + +if __name__ == "__main__": + uvicorn.run( + app='api:app', + host='0.0.0.0', + port=8100, + reload = True, + ) + diff --git a/webui.py b/webui.py index 65c21d8..67e6612 100644 --- a/webui.py +++ b/webui.py @@ -9,6 +9,10 @@ VECTOR_SEARCH_TOP_K = 6 # LLM input history length LLM_HISTORY_LEN = 3 +<<<<<<< HEAD +======= + +>>>>>>> f87a5f5 (fix bug in webui.py) def get_file_list(): if not os.path.exists("content"): @@ -54,7 +58,8 @@ def init_model(): try: local_doc_qa.init_cfg() return """模型已成功加载,请选择文件后点击"加载文件"按钮""" - except: + except Exception as e: + print(e) return """模型未成功加载,请重新选择后点击"加载模型"按钮""" @@ -66,14 +71,15 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, to use_ptuning_v2=use_ptuning_v2, top_k=top_k) model_status = """模型已成功重新加载,请选择文件后点击"加载文件"按钮""" - except: + except Exception as e: + print(e) model_status = """模型未成功重新加载,请重新选择后点击"加载模型"按钮""" return history + [[None, model_status]] def get_vector_store(filepath, history): - if local_doc_qa.llm and local_doc_qa.llm: + if local_doc_qa.llm and local_doc_qa.embeddings: vs_path = local_doc_qa.init_knowledge_vector_store(["content/" + filepath]) if vs_path: file_status = "文件已成功加载,请开始提问"