update api.py
This commit is contained in:
parent
2707f58aa1
commit
5f38645fa1
53
api.py
53
api.py
|
|
@ -18,7 +18,7 @@ from chains.local_doc_qa import LocalDocQA
|
|||
from configs.model_config import (VS_ROOT_PATH, UPLOAD_ROOT_PATH, EMBEDDING_DEVICE,
|
||||
EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH,
|
||||
VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN, OPEN_CROSS_DOMAIN)
|
||||
from agent import bing_search as agent_bing_search
|
||||
from agent import bing_search
|
||||
import models.shared as shared
|
||||
from models.loader.args import parser
|
||||
from models.loader import LoaderCheckPoint
|
||||
|
|
@ -248,6 +248,35 @@ async def local_doc_chat(
|
|||
)
|
||||
|
||||
|
||||
async def bing_search_chat(
|
||||
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
||||
history: List[List[str]] = Body(
|
||||
[],
|
||||
description="History of previous questions and answers",
|
||||
example=[
|
||||
[
|
||||
"工伤保险是什么?",
|
||||
"工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
|
||||
]
|
||||
],
|
||||
),
|
||||
):
|
||||
for resp, history in local_doc_qa.get_search_result_based_answer(
|
||||
query=question, chat_history=history, streaming=True
|
||||
):
|
||||
pass
|
||||
source_documents = [
|
||||
f"""出处 [{inum + 1}] {doc.metadata['source']}:\n\n{doc.page_content}\n\n"""
|
||||
for inum, doc in enumerate(resp["source_documents"])
|
||||
]
|
||||
|
||||
return ChatMessage(
|
||||
question=question,
|
||||
response=resp["result"],
|
||||
history=history,
|
||||
source_documents=source_documents,
|
||||
)
|
||||
|
||||
async def chat(
|
||||
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
||||
history: List[List[str]] = Body(
|
||||
|
|
@ -261,10 +290,8 @@ async def chat(
|
|||
],
|
||||
),
|
||||
):
|
||||
|
||||
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
|
||||
streaming=True):
|
||||
|
||||
resp = answer_result.llm_output["answer"]
|
||||
history = answer_result.history
|
||||
pass
|
||||
|
|
@ -323,22 +350,7 @@ async def document():
|
|||
return RedirectResponse(url="/docs")
|
||||
|
||||
|
||||
async def bing_search(
|
||||
search_text: str = Query(default=None, description="text you want to search", example="langchain")
|
||||
):
|
||||
results = agent_bing_search(search_text)
|
||||
result_str = ''
|
||||
for result in results:
|
||||
for k, v in result.items():
|
||||
result_str += "%s: %s\n" % (k, v)
|
||||
result_str += '\n'
|
||||
|
||||
return ChatMessage(
|
||||
question=search_text,
|
||||
response=result_str,
|
||||
history=[],
|
||||
source_documents=[],
|
||||
)
|
||||
|
||||
|
||||
def api_start(host, port):
|
||||
|
|
@ -369,11 +381,10 @@ def api_start(host, port):
|
|||
app.post("/local_doc_qa/upload_file", response_model=BaseResponse)(upload_file)
|
||||
app.post("/local_doc_qa/upload_files", response_model=BaseResponse)(upload_files)
|
||||
app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage)(local_doc_chat)
|
||||
app.post("/local_doc_qa/bing_search_chat", response_model=ChatMessage)(bing_search_chat)
|
||||
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs)
|
||||
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_docs)
|
||||
|
||||
app.get("/bing_search", response_model=ChatMessage)(bing_search)
|
||||
|
||||
local_doc_qa = LocalDocQA()
|
||||
local_doc_qa.init_cfg(
|
||||
llm_model=llm_model_ins,
|
||||
|
|
@ -384,9 +395,7 @@ def api_start(host, port):
|
|||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=7861)
|
||||
# 初始化消息
|
||||
|
|
|
|||
|
|
@ -1,24 +1,33 @@
|
|||
pymupdf
|
||||
paddlepaddle==2.4.2
|
||||
paddleocr
|
||||
paddleocr~=2.6.1.3
|
||||
langchain==0.0.174
|
||||
transformers==4.29.1
|
||||
unstructured[local-inference]
|
||||
layoutparser[layoutmodels,tesseract]
|
||||
nltk
|
||||
nltk~=3.8.1
|
||||
sentence-transformers
|
||||
beautifulsoup4
|
||||
icetk
|
||||
cpm_kernels
|
||||
faiss-cpu
|
||||
accelerate
|
||||
accelerate~=0.18.0
|
||||
gradio==3.28.3
|
||||
fastapi
|
||||
uvicorn
|
||||
peft
|
||||
pypinyin
|
||||
fastapi~=0.95.0
|
||||
uvicorn~=0.21.1
|
||||
peft~=0.3.0
|
||||
pypinyin~=0.48.0
|
||||
click~=8.1.3
|
||||
tabulate
|
||||
azure-core
|
||||
bitsandbytes; platform_system != "Windows"
|
||||
llama-cpp-python==0.1.34; platform_system != "Windows"
|
||||
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.34/llama_cpp_python-0.1.34-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
||||
|
||||
torch~=2.0.0
|
||||
pydantic~=1.10.7
|
||||
starlette~=0.26.1
|
||||
numpy~=1.23.5
|
||||
tqdm~=4.65.0
|
||||
requests~=2.28.2
|
||||
tenacity~=8.2.2
|
||||
Loading…
Reference in New Issue