Merge branch 'dev' into chatglm2cpp
This commit is contained in:
commit
4235270a32
|
|
@ -87,6 +87,10 @@ $ conda create -p /your_path/env_name python=3.8
|
||||||
# Activate the environment
|
# Activate the environment
|
||||||
$ source activate /your_path/env_name
|
$ source activate /your_path/env_name
|
||||||
|
|
||||||
|
# or, do not specify an env path, note that /your_path/env_name is to be replaced with env_name below
|
||||||
|
$ conda create -n env_name python=3.8
|
||||||
|
$ conda activate env_name # Activate the environment
|
||||||
|
|
||||||
# Deactivate the environment
|
# Deactivate the environment
|
||||||
$ source deactivate /your_path/env_name
|
$ source deactivate /your_path/env_name
|
||||||
|
|
||||||
|
|
|
||||||
81
api.py
81
api.py
|
|
@ -9,8 +9,9 @@ import asyncio
|
||||||
import nltk
|
import nltk
|
||||||
import pydantic
|
import pydantic
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket
|
from fastapi import Body, Request, FastAPI, File, Form, Query, UploadFile, WebSocket
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
from starlette.responses import RedirectResponse
|
from starlette.responses import RedirectResponse
|
||||||
|
|
@ -55,7 +56,7 @@ class ListDocsResponse(BaseResponse):
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
question: str = pydantic.Field(..., description="Question text")
|
question: str = pydantic.Field(..., description="Question text")
|
||||||
response: str = pydantic.Field(..., description="Response text")
|
response: str = pydantic.Field(..., description="Response text")
|
||||||
history: List[List[str]] = pydantic.Field(..., description="History text")
|
history: List[List[Optional[str]]] = pydantic.Field(..., description="History text")
|
||||||
source_documents: List[str] = pydantic.Field(
|
source_documents: List[str] = pydantic.Field(
|
||||||
..., description="List of source documents and their scores"
|
..., description="List of source documents and their scores"
|
||||||
)
|
)
|
||||||
|
|
@ -303,7 +304,8 @@ async def update_doc(
|
||||||
async def local_doc_chat(
|
async def local_doc_chat(
|
||||||
knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
|
knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
|
||||||
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
||||||
history: List[List[str]] = Body(
|
streaming: bool = Body(False, description="是否开启流式输出,默认false,有些模型可能不支持。"),
|
||||||
|
history: List[List[Optional[str]]] = Body(
|
||||||
[],
|
[],
|
||||||
description="History of previous questions and answers",
|
description="History of previous questions and answers",
|
||||||
example=[
|
example=[
|
||||||
|
|
@ -324,27 +326,39 @@ async def local_doc_chat(
|
||||||
source_documents=[],
|
source_documents=[],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
if (streaming):
|
||||||
query=question, vs_path=vs_path, chat_history=history, streaming=True
|
def generate_answer ():
|
||||||
):
|
last_print_len = 0
|
||||||
pass
|
for resp, next_history in local_doc_qa.get_knowledge_based_answer(
|
||||||
source_documents = [
|
query=question, vs_path=vs_path, chat_history=history, streaming=True
|
||||||
f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
|
):
|
||||||
f"""相关度:{doc.metadata['score']}\n\n"""
|
yield resp["result"][last_print_len:]
|
||||||
for inum, doc in enumerate(resp["source_documents"])
|
last_print_len=len(resp["result"])
|
||||||
]
|
|
||||||
|
|
||||||
return ChatMessage(
|
return StreamingResponse(generate_answer())
|
||||||
question=question,
|
else:
|
||||||
response=resp["result"],
|
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
||||||
history=history,
|
query=question, vs_path=vs_path, chat_history=history, streaming=True
|
||||||
source_documents=source_documents,
|
):
|
||||||
)
|
pass
|
||||||
|
|
||||||
|
source_documents = [
|
||||||
|
f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
|
||||||
|
f"""相关度:{doc.metadata['score']}\n\n"""
|
||||||
|
for inum, doc in enumerate(resp["source_documents"])
|
||||||
|
]
|
||||||
|
|
||||||
|
return ChatMessage(
|
||||||
|
question=question,
|
||||||
|
response=resp["result"],
|
||||||
|
history=history,
|
||||||
|
source_documents=source_documents,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def bing_search_chat(
|
async def bing_search_chat(
|
||||||
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
||||||
history: Optional[List[List[str]]] = Body(
|
history: Optional[List[List[Optional[str]]]] = Body(
|
||||||
[],
|
[],
|
||||||
description="History of previous questions and answers",
|
description="History of previous questions and answers",
|
||||||
example=[
|
example=[
|
||||||
|
|
@ -374,7 +388,8 @@ async def bing_search_chat(
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
||||||
history: Optional[List[List[str]]] = Body(
|
streaming: bool = Body(False, description="是否开启流式输出,默认false,有些模型可能不支持。"),
|
||||||
|
history: List[List[Optional[str]]] = Body(
|
||||||
[],
|
[],
|
||||||
description="History of previous questions and answers",
|
description="History of previous questions and answers",
|
||||||
example=[
|
example=[
|
||||||
|
|
@ -385,6 +400,30 @@ async def chat(
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
|
if (streaming):
|
||||||
|
def generate_answer ():
|
||||||
|
last_print_len = 0
|
||||||
|
answer_result_stream_result = local_doc_qa.llm_model_chain(
|
||||||
|
{"prompt": question, "history": history, "streaming": True})
|
||||||
|
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||||
|
yield answer_result.llm_output["answer"][last_print_len:]
|
||||||
|
last_print_len = len(answer_result.llm_output["answer"])
|
||||||
|
|
||||||
|
return StreamingResponse(generate_answer())
|
||||||
|
else:
|
||||||
|
answer_result_stream_result = local_doc_qa.llm_model_chain(
|
||||||
|
{"prompt": question, "history": history, "streaming": True})
|
||||||
|
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||||
|
resp = answer_result.llm_output["answer"]
|
||||||
|
history = answer_result.history
|
||||||
|
pass
|
||||||
|
|
||||||
|
return ChatMessage(
|
||||||
|
question=question,
|
||||||
|
response=resp,
|
||||||
|
history=history,
|
||||||
|
source_documents=[],
|
||||||
|
)
|
||||||
answer_result_stream_result = local_doc_qa.llm_model_chain(
|
answer_result_stream_result = local_doc_qa.llm_model_chain(
|
||||||
{"prompt": question, "history": history, "streaming": True})
|
{"prompt": question, "history": history, "streaming": True})
|
||||||
|
|
||||||
|
|
@ -544,7 +583,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--ssl_keyfile", type=str)
|
parser.add_argument("--ssl_keyfile", type=str)
|
||||||
parser.add_argument("--ssl_certfile", type=str)
|
parser.add_argument("--ssl_certfile", type=str)
|
||||||
# 初始化消息
|
# 初始化消息
|
||||||
args = None
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args_dict = vars(args)
|
args_dict = vars(args)
|
||||||
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backe
|
||||||
# llm_model_dict 处理了loader的一些预设行为,如加载位置,模型名称,模型处理器实例
|
# llm_model_dict 处理了loader的一些预设行为,如加载位置,模型名称,模型处理器实例
|
||||||
# 在以下字典中修改属性值,以指定本地 LLM 模型存储位置
|
# 在以下字典中修改属性值,以指定本地 LLM 模型存储位置
|
||||||
# 如将 "chatglm-6b" 的 "local_model_path" 由 None 修改为 "User/Downloads/chatglm-6b"
|
# 如将 "chatglm-6b" 的 "local_model_path" 由 None 修改为 "User/Downloads/chatglm-6b"
|
||||||
# 此处请写绝对路径
|
# 此处请写绝对路径,且路径中必须包含repo-id的模型名称,因为FastChat是以模型名匹配的
|
||||||
llm_model_dict = {
|
llm_model_dict = {
|
||||||
"chatglm-6b-int4-qe": {
|
"chatglm-6b-int4-qe": {
|
||||||
"name": "chatglm-6b-int4-qe",
|
"name": "chatglm-6b-int4-qe",
|
||||||
|
|
@ -218,9 +218,10 @@ BF16 = False
|
||||||
# 本地lora存放的位置
|
# 本地lora存放的位置
|
||||||
LORA_DIR = "loras/"
|
LORA_DIR = "loras/"
|
||||||
|
|
||||||
# LLM lora path,默认为空,如果有请直接指定文件夹路径
|
# LORA的名称,如有请指定为列表
|
||||||
LLM_LORA_PATH = ""
|
|
||||||
USE_LORA = True if LLM_LORA_PATH else False
|
LORA_NAME = ""
|
||||||
|
USE_LORA = True if LORA_NAME else False
|
||||||
|
|
||||||
# LLM streaming reponse
|
# LLM streaming reponse
|
||||||
STREAMING = True
|
STREAMING = True
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,12 @@ $ conda create -p /your_path/env_name python=3.8
|
||||||
|
|
||||||
# 激活环境
|
# 激活环境
|
||||||
$ source activate /your_path/env_name
|
$ source activate /your_path/env_name
|
||||||
|
|
||||||
|
# 或,conda安装,不指定路径, 注意以下,都将/your_path/env_name替换为env_name
|
||||||
|
$ conda create -n env_name python=3.8
|
||||||
|
$ conda activate env_name # Activate the environment
|
||||||
|
|
||||||
|
# 更新py库
|
||||||
$ pip3 install --upgrade pip
|
$ pip3 install --upgrade pip
|
||||||
|
|
||||||
# 关闭环境
|
# 关闭环境
|
||||||
|
|
|
||||||
|
|
@ -42,9 +42,10 @@ parser.add_argument('--no-remote-model', action='store_true', help='remote in th
|
||||||
'model to add the ` '
|
'model to add the ` '
|
||||||
'--no-remote-model`')
|
'--no-remote-model`')
|
||||||
parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.')
|
parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.')
|
||||||
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
|
parser.add_argument("--use-lora",type=bool,default=USE_LORA,help="use lora or not")
|
||||||
|
parser.add_argument('--lora', type=str, default=LORA_NAME,help='Name of the LoRA to apply to the model by default.')
|
||||||
parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras")
|
parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras")
|
||||||
parser.add_argument('--use-ptuning-v2',action='store_true',help="whether use ptuning-v2 checkpoint")
|
parser.add_argument('--use-ptuning-v2',default=USE_PTUNING_V2,help="whether use ptuning-v2 checkpoint")
|
||||||
parser.add_argument("--ptuning-dir",type=str,default=PTUNING_DIR,help="the dir of ptuning-v2 checkpoint")
|
parser.add_argument("--ptuning-dir",type=str,default=PTUNING_DIR,help="the dir of ptuning-v2 checkpoint")
|
||||||
# Accelerate/transformers
|
# Accelerate/transformers
|
||||||
parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT,
|
parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT,
|
||||||
|
|
|
||||||
|
|
@ -203,6 +203,12 @@ class LoaderCheckPoint:
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
elif self.is_llamacpp:
|
elif self.is_llamacpp:
|
||||||
|
# 要调用llama-cpp模型,如vicuma-13b量化模型需要安装llama-cpp-python库
|
||||||
|
# but!!! 实测pip install 不好使,需要手动从ttps://github.com/abetlen/llama-cpp-python/releases/下载
|
||||||
|
# 而且注意不同时期的ggml格式并不!兼!容!!!因此需要安装的llama-cpp-python版本也不一致,需要手动测试才能确定
|
||||||
|
# 实测ggml-vicuna-13b-1.1在llama-cpp-python 0.1.63上可正常兼容
|
||||||
|
# 不过!!!本项目模型加载的方式控制的比较严格,与llama-cpp-python的兼容性较差,很多参数设定不能使用,
|
||||||
|
# 建议如非必要还是不要使用llama-cpp
|
||||||
try:
|
try:
|
||||||
from llama_cpp import Llama
|
from llama_cpp import Llama
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
|
|
|
||||||
|
|
@ -23,13 +23,6 @@ openai
|
||||||
#accelerate~=0.18.0
|
#accelerate~=0.18.0
|
||||||
#peft~=0.3.0
|
#peft~=0.3.0
|
||||||
#bitsandbytes; platform_system != "Windows"
|
#bitsandbytes; platform_system != "Windows"
|
||||||
|
|
||||||
# 要调用llama-cpp模型,如vicuma-13b量化模型需要安装llama-cpp-python库
|
|
||||||
# but!!! 实测pip install 不好使,需要手动从ttps://github.com/abetlen/llama-cpp-python/releases/下载
|
|
||||||
# 而且注意不同时期的ggml格式并不!兼!容!!!因此需要安装的llama-cpp-python版本也不一致,需要手动测试才能确定
|
|
||||||
# 实测ggml-vicuna-13b-1.1在llama-cpp-python 0.1.63上可正常兼容
|
|
||||||
# 不过!!!本项目模型加载的方式控制的比较严格,与llama-cpp-python的兼容性较差,很多参数设定不能使用,
|
|
||||||
# 建议如非必要还是不要使用llama-cpp
|
|
||||||
torch~=2.0.0
|
torch~=2.0.0
|
||||||
pydantic~=1.10.7
|
pydantic~=1.10.7
|
||||||
starlette~=0.26.1
|
starlette~=0.26.1
|
||||||
|
|
|
||||||
29
webui_st.py
29
webui_st.py
|
|
@ -1,6 +1,8 @@
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from streamlit_chatbox import st_chatbox
|
from streamlit_chatbox import st_chatbox
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
###### 从webui借用的代码 #####
|
###### 从webui借用的代码 #####
|
||||||
###### 做了少量修改 #####
|
###### 做了少量修改 #####
|
||||||
import os
|
import os
|
||||||
|
|
@ -101,23 +103,23 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
|
||||||
|
|
||||||
|
|
||||||
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
|
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
|
||||||
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
vs_path = Path(KB_ROOT_PATH) / vs_id / "vector_store"
|
||||||
filelist = []
|
con_path = Path(KB_ROOT_PATH) / vs_id / "content"
|
||||||
if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")):
|
con_path.mkdir(parents=True, exist_ok=True)
|
||||||
os.makedirs(os.path.join(KB_ROOT_PATH, vs_id, "content"))
|
|
||||||
qa = st.session_state.local_doc_qa
|
qa = st.session_state.local_doc_qa
|
||||||
if qa.llm_model_chain and qa.embeddings:
|
if qa.llm_model_chain and qa.embeddings:
|
||||||
|
filelist = []
|
||||||
if isinstance(files, list):
|
if isinstance(files, list):
|
||||||
for file in files:
|
for file in files:
|
||||||
filename = os.path.split(file.name)[-1]
|
filename = os.path.split(file.name)[-1]
|
||||||
shutil.move(file.name, os.path.join(
|
target = con_path / filename
|
||||||
KB_ROOT_PATH, vs_id, "content", filename))
|
shutil.move(file.name, target)
|
||||||
filelist.append(os.path.join(
|
filelist.append(str(target))
|
||||||
KB_ROOT_PATH, vs_id, "content", filename))
|
|
||||||
vs_path, loaded_files = qa.init_knowledge_vector_store(
|
vs_path, loaded_files = qa.init_knowledge_vector_store(
|
||||||
filelist, vs_path, sentence_size)
|
filelist, str(vs_path), sentence_size)
|
||||||
else:
|
else:
|
||||||
vs_path, loaded_files = qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation,
|
vs_path, loaded_files = qa.one_knowledge_add(str(vs_path), files, one_conent, one_content_segmentation,
|
||||||
sentence_size)
|
sentence_size)
|
||||||
if len(loaded_files):
|
if len(loaded_files):
|
||||||
file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问"
|
file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问"
|
||||||
|
|
@ -322,7 +324,8 @@ with st.sidebar:
|
||||||
sentence_size = st.slider('文本入库分句长度限制', 1, 1000, SENTENCE_SIZE)
|
sentence_size = st.slider('文本入库分句长度限制', 1, 1000, SENTENCE_SIZE)
|
||||||
files = st.file_uploader('上传知识文件',
|
files = st.file_uploader('上传知识文件',
|
||||||
['docx', 'txt', 'md', 'csv', 'xlsx', 'pdf'],
|
['docx', 'txt', 'md', 'csv', 'xlsx', 'pdf'],
|
||||||
accept_multiple_files=True)
|
accept_multiple_files=True,
|
||||||
|
)
|
||||||
if st.button('添加文件到知识库'):
|
if st.button('添加文件到知识库'):
|
||||||
temp_dir = tempfile.mkdtemp()
|
temp_dir = tempfile.mkdtemp()
|
||||||
file_list = []
|
file_list = []
|
||||||
|
|
@ -331,8 +334,8 @@ with st.sidebar:
|
||||||
with open(file, 'wb') as fp:
|
with open(file, 'wb') as fp:
|
||||||
fp.write(f.getvalue())
|
fp.write(f.getvalue())
|
||||||
file_list.append(TempFile(file))
|
file_list.append(TempFile(file))
|
||||||
_, _, history = get_vector_store(
|
_, _, history = get_vector_store(
|
||||||
vs_path, file_list, sentence_size, [], None, None)
|
vs_path, file_list, sentence_size, [], None, None)
|
||||||
st.session_state.files = []
|
st.session_state.files = []
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue