Use fastapi to implement API (#209)
* Use fastapi to implement API * Update model_config.py --------- Co-authored-by: imClumsyPanda <littlepanda0716@gmail.com>
This commit is contained in:
parent
4e9de7df41
commit
e0cf26019b
|
|
@ -165,6 +165,10 @@ output/*
|
|||
log/*
|
||||
.chroma
|
||||
vector_store/*
|
||||
content/*
|
||||
api_content/*
|
||||
|
||||
llm/*
|
||||
embedding/*
|
||||
embedding/*
|
||||
|
||||
pyrightconfig.json
|
||||
|
|
|
|||
|
|
@ -93,6 +93,12 @@ $ python cli_demo.py
|
|||
$ python webui.py
|
||||
```
|
||||
|
||||
或执行 [api.py](api.py) 利用 fastapi 部署 API
|
||||
```shell
|
||||
$ python api.py
|
||||
```
|
||||
|
||||
|
||||
注:如未将模型下载至本地,请执行前检查`$HOME/.cache/huggingface/`文件夹剩余空间,至少15G。
|
||||
|
||||
执行后效果如下图所示:
|
||||
|
|
@ -168,7 +174,7 @@ Web UI 可以实现如下功能:
|
|||
- [ ] 删除知识库中文件
|
||||
- [ ] 利用 streamlit 实现 Web UI Demo
|
||||
- [ ] 增加 API 支持
|
||||
- [ ] 利用 fastapi 实现 API 部署方式
|
||||
- [x] 利用 fastapi 实现 API 部署方式
|
||||
- [ ] 实现调用 API 的 Web UI Demo
|
||||
|
||||
## 项目交流群
|
||||
|
|
|
|||
|
|
@ -112,6 +112,11 @@ Note: When using langchain.document_loaders.UnstructuredFileLoader for unstructu
|
|||
Execute [webui.py](webui.py) script to experience **Web interaction** <img src="https://img.shields.io/badge/Version-0.1-brightgreen">
|
||||
```commandline
|
||||
python webui.py
|
||||
|
||||
```
|
||||
Or execute [api.py](api.py) script to deploy web api.
|
||||
```shell
|
||||
$ python api.py
|
||||
```
|
||||
Note: Before executing, check the remaining space in the `$HOME/.cache/huggingface/` folder, at least 15G.
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,313 @@
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from typing import List, Optional
|
||||
|
||||
import nltk
|
||||
import pydantic
|
||||
import uvicorn
|
||||
from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from chains.local_doc_qa import LocalDocQA
|
||||
from configs.model_config import (API_UPLOAD_ROOT_PATH, EMBEDDING_DEVICE,
|
||||
EMBEDDING_MODEL, LLM_MODEL)
|
||||
|
||||
nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path
|
||||
|
||||
# return top-k text chunk from vector store
|
||||
VECTOR_SEARCH_TOP_K = 6
|
||||
|
||||
# LLM input history length
|
||||
LLM_HISTORY_LEN = 3
|
||||
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
code: int = pydantic.Field(200, description="HTTP status code")
|
||||
msg: str = pydantic.Field("success", description="HTTP status message")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"code": 200,
|
||||
"msg": "success",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ListDocsResponse(BaseResponse):
|
||||
data: List[str] = pydantic.Field(..., description="List of document names")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"code": 200,
|
||||
"msg": "success",
|
||||
"data": ["doc1.docx", "doc2.pdf", "doc3.txt"],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
question: str = pydantic.Field(..., description="Question text")
|
||||
response: str = pydantic.Field(..., description="Response text")
|
||||
history: List[List[str]] = pydantic.Field(..., description="History text")
|
||||
source_documents: List[str] = pydantic.Field(
|
||||
..., description="List of source documents and their scores"
|
||||
)
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"question": "工伤保险如何办理?",
|
||||
"response": "根据已知信息,可以总结如下:\n\n1. 参保单位为员工缴纳工伤保险费,以保障员工在发生工伤时能够获得相应的待遇。\n2. 不同地区的工伤保险缴费规定可能有所不同,需要向当地社保部门咨询以了解具体的缴费标准和规定。\n3. 工伤从业人员及其近亲属需要申请工伤认定,确认享受的待遇资格,并按时缴纳工伤保险费。\n4. 工伤保险待遇包括工伤医疗、康复、辅助器具配置费用、伤残待遇、工亡待遇、一次性工亡补助金等。\n5. 工伤保险待遇领取资格认证包括长期待遇领取人员认证和一次性待遇领取人员认证。\n6. 工伤保险基金支付的待遇项目包括工伤医疗待遇、康复待遇、辅助器具配置费用、一次性工亡补助金、丧葬补助金等。",
|
||||
"history": [
|
||||
[
|
||||
"工伤保险是什么?",
|
||||
"工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
|
||||
]
|
||||
],
|
||||
"source_documents": [
|
||||
"出处 [1] 广州市单位从业的特定人员参加工伤保险办事指引.docx:\n\n\t( 一) 从业单位 (组织) 按“自愿参保”原则, 为未建 立劳动关系的特定从业人员单项参加工伤保险 、缴纳工伤保 险费。",
|
||||
"出处 [2] ...",
|
||||
"出处 [3] ...",
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get_folder_path(local_doc_id: str):
|
||||
return os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id)
|
||||
|
||||
|
||||
def get_vs_path(local_doc_id: str):
|
||||
return os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id, "vector_store")
|
||||
|
||||
|
||||
def get_file_path(local_doc_id: str, doc_name: str):
|
||||
return os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id, doc_name)
|
||||
|
||||
|
||||
async def upload_file(
|
||||
files: Annotated[
|
||||
List[UploadFile], File(description="Multiple files as UploadFile")
|
||||
],
|
||||
local_doc_id: str = Form(..., description="Local document ID", example="doc_id_1"),
|
||||
):
|
||||
saved_path = get_folder_path(local_doc_id)
|
||||
if not os.path.exists(saved_path):
|
||||
os.makedirs(saved_path)
|
||||
for file in files:
|
||||
file_path = os.path.join(saved_path, file.filename)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file.file.read())
|
||||
|
||||
local_doc_qa.init_knowledge_vector_store(saved_path, get_vs_path(local_doc_id))
|
||||
return BaseResponse()
|
||||
|
||||
|
||||
async def list_docs(
|
||||
local_doc_id: Optional[str] = Query(description="Document ID", example="doc_id1")
|
||||
):
|
||||
if local_doc_id:
|
||||
local_doc_folder = get_folder_path(local_doc_id)
|
||||
if not os.path.exists(local_doc_folder):
|
||||
return {"code": 1, "msg": f"document {local_doc_id} not found"}
|
||||
all_doc_names = [
|
||||
doc
|
||||
for doc in os.listdir(local_doc_folder)
|
||||
if os.path.isfile(os.path.join(local_doc_folder, doc))
|
||||
]
|
||||
return ListDocsResponse(data=all_doc_names)
|
||||
else:
|
||||
if not os.path.exists(API_UPLOAD_ROOT_PATH):
|
||||
all_doc_ids = []
|
||||
else:
|
||||
all_doc_ids = [
|
||||
folder
|
||||
for folder in os.listdir(API_UPLOAD_ROOT_PATH)
|
||||
if os.path.isdir(os.path.join(API_UPLOAD_ROOT_PATH, folder))
|
||||
]
|
||||
|
||||
return ListDocsResponse(data=all_doc_ids)
|
||||
|
||||
|
||||
async def delete_docs(
|
||||
local_doc_id: str = Form(..., description="local doc id", example="doc_id_1"),
|
||||
doc_name: Optional[str] = Form(
|
||||
None, description="doc name", example="doc_name_1.pdf"
|
||||
),
|
||||
):
|
||||
if not os.path.exists(os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id)):
|
||||
return {"code": 1, "msg": f"document {local_doc_id} not found"}
|
||||
if doc_name:
|
||||
doc_path = get_file_path(local_doc_id, doc_name)
|
||||
if os.path.exists(doc_path):
|
||||
os.remove(doc_path)
|
||||
else:
|
||||
return {"code": 1, "msg": f"document {doc_name} not found"}
|
||||
|
||||
remain_docs = await list_docs(local_doc_id)
|
||||
if remain_docs["code"] != 0 or len(remain_docs["data"]) == 0:
|
||||
shutil.rmtree(get_folder_path(local_doc_id), ignore_errors=True)
|
||||
else:
|
||||
local_doc_qa.init_knowledge_vector_store(
|
||||
get_folder_path(local_doc_id), get_vs_path(local_doc_id)
|
||||
)
|
||||
else:
|
||||
shutil.rmtree(get_folder_path(local_doc_id))
|
||||
return BaseResponse()
|
||||
|
||||
|
||||
async def chat(
|
||||
local_doc_id: str = Body(..., description="Document ID", example="doc_id1"),
|
||||
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
||||
history: List[List[str]] = Body(
|
||||
[],
|
||||
description="History of previous questions and answers",
|
||||
example=[
|
||||
[
|
||||
"工伤保险是什么?",
|
||||
"工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
|
||||
]
|
||||
],
|
||||
),
|
||||
):
|
||||
vs_path = os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id, "vector_store")
|
||||
if not os.path.exists(vs_path):
|
||||
raise ValueError(f"Document {local_doc_id} not found")
|
||||
|
||||
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
||||
query=question, vs_path=vs_path, chat_history=history, streaming=True
|
||||
):
|
||||
pass
|
||||
source_documents = [
|
||||
f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
|
||||
f"""相关度:{doc.metadata['score']}\n\n"""
|
||||
for inum, doc in enumerate(resp["source_documents"])
|
||||
]
|
||||
|
||||
return ChatMessage(
|
||||
question=question,
|
||||
response=resp["result"],
|
||||
history=history,
|
||||
source_documents=source_documents,
|
||||
)
|
||||
|
||||
|
||||
async def stream_chat(websocket: WebSocket, local_doc_id: str):
|
||||
await websocket.accept()
|
||||
vs_path = os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id, "vector_store")
|
||||
|
||||
if not os.path.exists(vs_path):
|
||||
await websocket.send_json({"error": f"document {local_doc_id} not found"})
|
||||
await websocket.close()
|
||||
return
|
||||
|
||||
history = []
|
||||
turn = 1
|
||||
while True:
|
||||
question = await websocket.receive_text()
|
||||
await websocket.send_json({"question": question, "turn": turn, "flag": "start"})
|
||||
|
||||
last_print_len = 0
|
||||
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
||||
query=question, vs_path=vs_path, chat_history=history, streaming=True
|
||||
):
|
||||
await websocket.send_text(resp["result"][last_print_len:])
|
||||
last_print_len = len(resp["result"])
|
||||
|
||||
source_documents = [
|
||||
f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
|
||||
f"""相关度:{doc.metadata['score']}\n\n"""
|
||||
for inum, doc in enumerate(resp["source_documents"])
|
||||
]
|
||||
|
||||
await websocket.send_text(
|
||||
json.dumps(
|
||||
{
|
||||
"question": question,
|
||||
"turn": turn,
|
||||
"flag": "end",
|
||||
"sources_documents": source_documents,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
turn += 1
|
||||
|
||||
|
||||
def gen_docs():
|
||||
global app
|
||||
with tempfile.NamedTemporaryFile("w", encoding="utf-8", suffix=".json") as f:
|
||||
json.dump(
|
||||
get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
openapi_version=app.openapi_version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
),
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
f.flush()
|
||||
# test whether widdershins is available
|
||||
try:
|
||||
subprocess.run(
|
||||
[
|
||||
"widdershins",
|
||||
f.name,
|
||||
"-o",
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"docs",
|
||||
"API.md",
|
||||
),
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
except Exception:
|
||||
raise RuntimeError(
|
||||
"Failed to generate docs. Please install widdershins first."
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
global app
|
||||
global local_doc_qa
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=7861)
|
||||
parser.add_argument("--gen-docs", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
app = FastAPI()
|
||||
app.websocket("/chat-docs/stream-chat/{local_doc_id}")(stream_chat)
|
||||
app.post("/chat-docs/chat", response_model=ChatMessage)(chat)
|
||||
app.post("/chat-docs/upload", response_model=BaseResponse)(upload_file)
|
||||
app.get("/chat-docs/list", response_model=ListDocsResponse)(list_docs)
|
||||
app.delete("/chat-docs/delete", response_model=BaseResponse)(delete_docs)
|
||||
|
||||
if args.gen_docs:
|
||||
gen_docs()
|
||||
return
|
||||
|
||||
local_doc_qa = LocalDocQA()
|
||||
local_doc_qa.init_cfg(
|
||||
llm_model=LLM_MODEL,
|
||||
embedding_model=EMBEDDING_MODEL,
|
||||
embedding_device=EMBEDDING_DEVICE,
|
||||
llm_history_len=LLM_HISTORY_LEN,
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
)
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -40,6 +40,8 @@ VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vector_
|
|||
|
||||
UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content", "")
|
||||
|
||||
API_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "api_content")
|
||||
|
||||
# 基于上下文的prompt模版,请务必保留"{question}"和"{context}"
|
||||
PROMPT_TEMPLATE = """已知信息:
|
||||
{context}
|
||||
|
|
@ -47,4 +49,4 @@ PROMPT_TEMPLATE = """已知信息:
|
|||
根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 问题是:{question}"""
|
||||
|
||||
# 匹配后单段上下文长度
|
||||
CHUNK_SIZE = 500
|
||||
CHUNK_SIZE = 500
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -10,4 +10,6 @@ cpm_kernels
|
|||
faiss-cpu
|
||||
accelerate
|
||||
gradio==3.24.1
|
||||
#detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2
|
||||
fastapi
|
||||
uvicorn
|
||||
#detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2
|
||||
|
|
|
|||
Loading…
Reference in New Issue