From 5f38645fa1c47b9cd648c250c46be99257f9c23b Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Sun, 21 May 2023 22:27:02 +0800 Subject: [PATCH] update api.py --- api.py | 53 ++++++++++++++++++++++++++++-------------------- requirements.txt | 23 ++++++++++++++------- 2 files changed, 47 insertions(+), 29 deletions(-) diff --git a/api.py b/api.py index da6d30a..6f8597a 100644 --- a/api.py +++ b/api.py @@ -18,7 +18,7 @@ from chains.local_doc_qa import LocalDocQA from configs.model_config import (VS_ROOT_PATH, UPLOAD_ROOT_PATH, EMBEDDING_DEVICE, EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH, VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN, OPEN_CROSS_DOMAIN) -from agent import bing_search as agent_bing_search +from agent import bing_search import models.shared as shared from models.loader.args import parser from models.loader import LoaderCheckPoint @@ -248,6 +248,35 @@ async def local_doc_chat( ) +async def bing_search_chat( + question: str = Body(..., description="Question", example="工伤保险是什么?"), + history: List[List[str]] = Body( + [], + description="History of previous questions and answers", + example=[ + [ + "工伤保险是什么?", + "工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。", + ] + ], + ), +): + for resp, history in local_doc_qa.get_search_result_based_answer( + query=question, chat_history=history, streaming=True + ): + pass + source_documents = [ + f"""出处 [{inum + 1}] {doc.metadata['source']}:\n\n{doc.page_content}\n\n""" + for inum, doc in enumerate(resp["source_documents"]) + ] + + return ChatMessage( + question=question, + response=resp["result"], + history=history, + source_documents=source_documents, + ) + async def chat( question: str = Body(..., description="Question", example="工伤保险是什么?"), history: List[List[str]] = Body( @@ -261,10 +290,8 @@ async def chat( ], ), ): - for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history, streaming=True): - resp = answer_result.llm_output["answer"] history = answer_result.history pass @@ -323,22 +350,7 @@ async def document(): return RedirectResponse(url="/docs") -async def bing_search( - search_text: str = Query(default=None, description="text you want to search", example="langchain") -): - results = agent_bing_search(search_text) - result_str = '' - for result in results: - for k, v in result.items(): - result_str += "%s: %s\n" % (k, v) - result_str += '\n' - return ChatMessage( - question=search_text, - response=result_str, - history=[], - source_documents=[], - ) def api_start(host, port): @@ -369,11 +381,10 @@ def api_start(host, port): app.post("/local_doc_qa/upload_file", response_model=BaseResponse)(upload_file) app.post("/local_doc_qa/upload_files", response_model=BaseResponse)(upload_files) app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage)(local_doc_chat) + app.post("/local_doc_qa/bing_search_chat", response_model=ChatMessage)(bing_search_chat) app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs) app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_docs) - app.get("/bing_search", response_model=ChatMessage)(bing_search) - local_doc_qa = LocalDocQA() local_doc_qa.init_cfg( llm_model=llm_model_ins, @@ -384,9 +395,7 @@ def api_start(host, port): uvicorn.run(app, host=host, port=port) - if __name__ == "__main__": - parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=7861) # 初始化消息 diff --git a/requirements.txt b/requirements.txt index e1a892e..74a373a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,24 +1,33 @@ pymupdf paddlepaddle==2.4.2 -paddleocr +paddleocr~=2.6.1.3 langchain==0.0.174 transformers==4.29.1 unstructured[local-inference] layoutparser[layoutmodels,tesseract] -nltk +nltk~=3.8.1 sentence-transformers beautifulsoup4 icetk cpm_kernels faiss-cpu -accelerate +accelerate~=0.18.0 gradio==3.28.3 -fastapi -uvicorn -peft -pypinyin +fastapi~=0.95.0 +uvicorn~=0.21.1 +peft~=0.3.0 +pypinyin~=0.48.0 click~=8.1.3 tabulate +azure-core bitsandbytes; platform_system != "Windows" llama-cpp-python==0.1.34; platform_system != "Windows" https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.34/llama_cpp_python-0.1.34-cp310-cp310-win_amd64.whl; platform_system == "Windows" + +torch~=2.0.0 +pydantic~=1.10.7 +starlette~=0.26.1 +numpy~=1.23.5 +tqdm~=4.65.0 +requests~=2.28.2 +tenacity~=8.2.2 \ No newline at end of file