update api.py
This commit is contained in:
parent
2707f58aa1
commit
5f38645fa1
53
api.py
53
api.py
|
|
@ -18,7 +18,7 @@ from chains.local_doc_qa import LocalDocQA
|
||||||
from configs.model_config import (VS_ROOT_PATH, UPLOAD_ROOT_PATH, EMBEDDING_DEVICE,
|
from configs.model_config import (VS_ROOT_PATH, UPLOAD_ROOT_PATH, EMBEDDING_DEVICE,
|
||||||
EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH,
|
EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH,
|
||||||
VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN, OPEN_CROSS_DOMAIN)
|
VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN, OPEN_CROSS_DOMAIN)
|
||||||
from agent import bing_search as agent_bing_search
|
from agent import bing_search
|
||||||
import models.shared as shared
|
import models.shared as shared
|
||||||
from models.loader.args import parser
|
from models.loader.args import parser
|
||||||
from models.loader import LoaderCheckPoint
|
from models.loader import LoaderCheckPoint
|
||||||
|
|
@ -248,6 +248,35 @@ async def local_doc_chat(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def bing_search_chat(
|
||||||
|
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
||||||
|
history: List[List[str]] = Body(
|
||||||
|
[],
|
||||||
|
description="History of previous questions and answers",
|
||||||
|
example=[
|
||||||
|
[
|
||||||
|
"工伤保险是什么?",
|
||||||
|
"工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
|
||||||
|
]
|
||||||
|
],
|
||||||
|
),
|
||||||
|
):
|
||||||
|
for resp, history in local_doc_qa.get_search_result_based_answer(
|
||||||
|
query=question, chat_history=history, streaming=True
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
source_documents = [
|
||||||
|
f"""出处 [{inum + 1}] {doc.metadata['source']}:\n\n{doc.page_content}\n\n"""
|
||||||
|
for inum, doc in enumerate(resp["source_documents"])
|
||||||
|
]
|
||||||
|
|
||||||
|
return ChatMessage(
|
||||||
|
question=question,
|
||||||
|
response=resp["result"],
|
||||||
|
history=history,
|
||||||
|
source_documents=source_documents,
|
||||||
|
)
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
||||||
history: List[List[str]] = Body(
|
history: List[List[str]] = Body(
|
||||||
|
|
@ -261,10 +290,8 @@ async def chat(
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
|
|
||||||
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
|
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
|
||||||
streaming=True):
|
streaming=True):
|
||||||
|
|
||||||
resp = answer_result.llm_output["answer"]
|
resp = answer_result.llm_output["answer"]
|
||||||
history = answer_result.history
|
history = answer_result.history
|
||||||
pass
|
pass
|
||||||
|
|
@ -323,22 +350,7 @@ async def document():
|
||||||
return RedirectResponse(url="/docs")
|
return RedirectResponse(url="/docs")
|
||||||
|
|
||||||
|
|
||||||
async def bing_search(
|
|
||||||
search_text: str = Query(default=None, description="text you want to search", example="langchain")
|
|
||||||
):
|
|
||||||
results = agent_bing_search(search_text)
|
|
||||||
result_str = ''
|
|
||||||
for result in results:
|
|
||||||
for k, v in result.items():
|
|
||||||
result_str += "%s: %s\n" % (k, v)
|
|
||||||
result_str += '\n'
|
|
||||||
|
|
||||||
return ChatMessage(
|
|
||||||
question=search_text,
|
|
||||||
response=result_str,
|
|
||||||
history=[],
|
|
||||||
source_documents=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def api_start(host, port):
|
def api_start(host, port):
|
||||||
|
|
@ -369,11 +381,10 @@ def api_start(host, port):
|
||||||
app.post("/local_doc_qa/upload_file", response_model=BaseResponse)(upload_file)
|
app.post("/local_doc_qa/upload_file", response_model=BaseResponse)(upload_file)
|
||||||
app.post("/local_doc_qa/upload_files", response_model=BaseResponse)(upload_files)
|
app.post("/local_doc_qa/upload_files", response_model=BaseResponse)(upload_files)
|
||||||
app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage)(local_doc_chat)
|
app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage)(local_doc_chat)
|
||||||
|
app.post("/local_doc_qa/bing_search_chat", response_model=ChatMessage)(bing_search_chat)
|
||||||
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs)
|
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs)
|
||||||
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_docs)
|
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_docs)
|
||||||
|
|
||||||
app.get("/bing_search", response_model=ChatMessage)(bing_search)
|
|
||||||
|
|
||||||
local_doc_qa = LocalDocQA()
|
local_doc_qa = LocalDocQA()
|
||||||
local_doc_qa.init_cfg(
|
local_doc_qa.init_cfg(
|
||||||
llm_model=llm_model_ins,
|
llm_model=llm_model_ins,
|
||||||
|
|
@ -384,9 +395,7 @@ def api_start(host, port):
|
||||||
uvicorn.run(app, host=host, port=port)
|
uvicorn.run(app, host=host, port=port)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||||
parser.add_argument("--port", type=int, default=7861)
|
parser.add_argument("--port", type=int, default=7861)
|
||||||
# 初始化消息
|
# 初始化消息
|
||||||
|
|
|
||||||
|
|
@ -1,24 +1,33 @@
|
||||||
pymupdf
|
pymupdf
|
||||||
paddlepaddle==2.4.2
|
paddlepaddle==2.4.2
|
||||||
paddleocr
|
paddleocr~=2.6.1.3
|
||||||
langchain==0.0.174
|
langchain==0.0.174
|
||||||
transformers==4.29.1
|
transformers==4.29.1
|
||||||
unstructured[local-inference]
|
unstructured[local-inference]
|
||||||
layoutparser[layoutmodels,tesseract]
|
layoutparser[layoutmodels,tesseract]
|
||||||
nltk
|
nltk~=3.8.1
|
||||||
sentence-transformers
|
sentence-transformers
|
||||||
beautifulsoup4
|
beautifulsoup4
|
||||||
icetk
|
icetk
|
||||||
cpm_kernels
|
cpm_kernels
|
||||||
faiss-cpu
|
faiss-cpu
|
||||||
accelerate
|
accelerate~=0.18.0
|
||||||
gradio==3.28.3
|
gradio==3.28.3
|
||||||
fastapi
|
fastapi~=0.95.0
|
||||||
uvicorn
|
uvicorn~=0.21.1
|
||||||
peft
|
peft~=0.3.0
|
||||||
pypinyin
|
pypinyin~=0.48.0
|
||||||
click~=8.1.3
|
click~=8.1.3
|
||||||
tabulate
|
tabulate
|
||||||
|
azure-core
|
||||||
bitsandbytes; platform_system != "Windows"
|
bitsandbytes; platform_system != "Windows"
|
||||||
llama-cpp-python==0.1.34; platform_system != "Windows"
|
llama-cpp-python==0.1.34; platform_system != "Windows"
|
||||||
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.34/llama_cpp_python-0.1.34-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.34/llama_cpp_python-0.1.34-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
||||||
|
|
||||||
|
torch~=2.0.0
|
||||||
|
pydantic~=1.10.7
|
||||||
|
starlette~=0.26.1
|
||||||
|
numpy~=1.23.5
|
||||||
|
tqdm~=4.65.0
|
||||||
|
requests~=2.28.2
|
||||||
|
tenacity~=8.2.2
|
||||||
Loading…
Reference in New Issue