fix bug in webui.py
This commit is contained in:
parent
8ae84c6c93
commit
048c71f893
|
|
@ -0,0 +1,102 @@
|
||||||
|
from configs.model_config import *
|
||||||
|
from chains.local_doc_qa import LocalDocQA
|
||||||
|
import os
|
||||||
|
import nltk
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI, File, UploadFile
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from starlette.responses import RedirectResponse
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
global local_doc_qa, vs_path
|
||||||
|
|
||||||
|
nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path
|
||||||
|
|
||||||
|
# return top-k text chunk from vector store
|
||||||
|
VECTOR_SEARCH_TOP_K = 10
|
||||||
|
|
||||||
|
# LLM input history length
|
||||||
|
LLM_HISTORY_LEN = 3
|
||||||
|
|
||||||
|
# Show reply with source text from input document
|
||||||
|
REPLY_WITH_SOURCE = False
|
||||||
|
|
||||||
|
class Query(BaseModel):
|
||||||
|
query: str
|
||||||
|
|
||||||
|
@app.get('/')
|
||||||
|
async def document():
|
||||||
|
return RedirectResponse(url="/docs")
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def get_local_doc_qa():
|
||||||
|
global local_doc_qa
|
||||||
|
local_doc_qa = LocalDocQA()
|
||||||
|
local_doc_qa.init_cfg(llm_model=LLM_MODEL,
|
||||||
|
embedding_model=EMBEDDING_MODEL,
|
||||||
|
embedding_device=EMBEDDING_DEVICE,
|
||||||
|
llm_history_len=LLM_HISTORY_LEN,
|
||||||
|
top_k=VECTOR_SEARCH_TOP_K)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/file")
|
||||||
|
async def upload_file(UserFile: UploadFile=File(...)):
|
||||||
|
global vs_path
|
||||||
|
response = {
|
||||||
|
"msg": None,
|
||||||
|
"status": 0
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
filepath = './content/' + UserFile.filename
|
||||||
|
content = await UserFile.read()
|
||||||
|
# print(UserFile.filename)
|
||||||
|
with open(filepath, 'wb') as f:
|
||||||
|
f.write(content)
|
||||||
|
vs_path = local_doc_qa.init_knowledge_vector_store(filepath)
|
||||||
|
response = {
|
||||||
|
'msg': 'seccessful',
|
||||||
|
'status': 1
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as err:
|
||||||
|
response["message"] = err
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
@app.post("/qa")
|
||||||
|
async def get_answer(UserQuery: Query):
|
||||||
|
response = {
|
||||||
|
"status": 0,
|
||||||
|
"message": "",
|
||||||
|
"answer": None
|
||||||
|
}
|
||||||
|
global vs_path
|
||||||
|
history = []
|
||||||
|
try:
|
||||||
|
resp, history = local_doc_qa.get_knowledge_based_answer(query=UserQuery.query,
|
||||||
|
vs_path=vs_path,
|
||||||
|
chat_history=history)
|
||||||
|
if REPLY_WITH_SOURCE:
|
||||||
|
response["answer"] = resp
|
||||||
|
else:
|
||||||
|
response['answer'] = resp["result"]
|
||||||
|
|
||||||
|
response["message"] = 'successful'
|
||||||
|
response["status"] = 1
|
||||||
|
|
||||||
|
except Exception as err:
|
||||||
|
response["message"] = err
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
uvicorn.run(
|
||||||
|
app='api:app',
|
||||||
|
host='0.0.0.0',
|
||||||
|
port=8100,
|
||||||
|
reload = True,
|
||||||
|
)
|
||||||
|
|
||||||
12
webui.py
12
webui.py
|
|
@ -9,6 +9,10 @@ VECTOR_SEARCH_TOP_K = 6
|
||||||
|
|
||||||
# LLM input history length
|
# LLM input history length
|
||||||
LLM_HISTORY_LEN = 3
|
LLM_HISTORY_LEN = 3
|
||||||
|
<<<<<<< HEAD
|
||||||
|
=======
|
||||||
|
|
||||||
|
>>>>>>> f87a5f5 (fix bug in webui.py)
|
||||||
|
|
||||||
def get_file_list():
|
def get_file_list():
|
||||||
if not os.path.exists("content"):
|
if not os.path.exists("content"):
|
||||||
|
|
@ -54,7 +58,8 @@ def init_model():
|
||||||
try:
|
try:
|
||||||
local_doc_qa.init_cfg()
|
local_doc_qa.init_cfg()
|
||||||
return """模型已成功加载,请选择文件后点击"加载文件"按钮"""
|
return """模型已成功加载,请选择文件后点击"加载文件"按钮"""
|
||||||
except:
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
return """模型未成功加载,请重新选择后点击"加载模型"按钮"""
|
return """模型未成功加载,请重新选择后点击"加载模型"按钮"""
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -66,14 +71,15 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, to
|
||||||
use_ptuning_v2=use_ptuning_v2,
|
use_ptuning_v2=use_ptuning_v2,
|
||||||
top_k=top_k)
|
top_k=top_k)
|
||||||
model_status = """模型已成功重新加载,请选择文件后点击"加载文件"按钮"""
|
model_status = """模型已成功重新加载,请选择文件后点击"加载文件"按钮"""
|
||||||
except:
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
model_status = """模型未成功重新加载,请重新选择后点击"加载模型"按钮"""
|
model_status = """模型未成功重新加载,请重新选择后点击"加载模型"按钮"""
|
||||||
return history + [[None, model_status]]
|
return history + [[None, model_status]]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_vector_store(filepath, history):
|
def get_vector_store(filepath, history):
|
||||||
if local_doc_qa.llm and local_doc_qa.llm:
|
if local_doc_qa.llm and local_doc_qa.embeddings:
|
||||||
vs_path = local_doc_qa.init_knowledge_vector_store(["content/" + filepath])
|
vs_path = local_doc_qa.init_knowledge_vector_store(["content/" + filepath])
|
||||||
if vs_path:
|
if vs_path:
|
||||||
file_status = "文件已成功加载,请开始提问"
|
file_status = "文件已成功加载,请开始提问"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue