添加configs/prompt_config.py,允许用户自定义prompt模板: (#1504)
1、 默认包含2个模板,分别用于LLM对话,知识库和搜索引擎对话 2、 server/utils.py提供函数get_prompt_template,获取指定的prompt模板内容(支持热加载) 3、 api.py中chat/knowledge_base_chat/search_engine_chat接口支持prompt_name参数
This commit is contained in:
parent
598eb298df
commit
a65bc4a63c
|
|
@ -2,6 +2,7 @@ from .basic_config import *
|
|||
from .model_config import *
|
||||
from .kb_config import *
|
||||
from .server_config import *
|
||||
from .prompt_config import *
|
||||
|
||||
|
||||
VERSION = "v0.2.5-preview"
|
||||
|
|
|
|||
|
|
@ -23,14 +23,6 @@ SCORE_THRESHOLD = 1
|
|||
SEARCH_ENGINE_TOP_K = 3
|
||||
|
||||
|
||||
# 基于本地知识问答的提示词模版(使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号
|
||||
PROMPT_TEMPLATE = """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
|
||||
|
||||
<已知信息>{{ context }}</已知信息>
|
||||
|
||||
<问题>{{ question }}</问题>"""
|
||||
|
||||
|
||||
# Bing 搜索必备变量
|
||||
# 使用 Bing 搜索需要使用 Bing Subscription Key,需要在azure port中申请试用bing search
|
||||
# 具体申请方式请见
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
# prompt模板使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号
|
||||
# 本配置文件支持热加载,修改prompt模板后无需重启服务。
|
||||
|
||||
|
||||
# LLM对话支持的变量:
|
||||
# - input: 用户输入内容
|
||||
|
||||
# 知识库和搜索引擎对话支持的变量:
|
||||
# - context: 从检索结果拼接的知识文本
|
||||
# - question: 用户提出的问题
|
||||
|
||||
|
||||
PROMPT_TEMPLATES = {
|
||||
# LLM对话模板
|
||||
"llm_chat": "{{ input }}",
|
||||
|
||||
# 基于本地知识问答的提示词模
|
||||
"knowledge_base_chat": """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
|
||||
|
||||
<已知信息>{{ context }}</已知信息>
|
||||
|
||||
<问题>{{ question }}</问题>""",
|
||||
}
|
||||
|
|
@ -76,6 +76,9 @@ FSCHAT_MODEL_WORKERS = {
|
|||
"qianfan-api": {
|
||||
"port": 21004,
|
||||
},
|
||||
"fangzhou-api": {
|
||||
"port": 21005,
|
||||
},
|
||||
}
|
||||
|
||||
# fastchat multi model worker server
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs.model_config import LLM_MODEL, TEMPERATURE
|
||||
from configs import LLM_MODEL, TEMPERATURE
|
||||
from server.chat.utils import wrap_done, get_ChatOpenAI
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
|
|
@ -9,6 +9,7 @@ import asyncio
|
|||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from typing import List
|
||||
from server.chat.utils import History
|
||||
from server.utils import get_prompt_template
|
||||
|
||||
|
||||
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
|
|
@ -22,12 +23,14 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
|
||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||||
prompt_name: str = Body("llm_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
):
|
||||
history = [History.from_data(h) for h in history]
|
||||
|
||||
async def chat_iterator(query: str,
|
||||
history: List[History] = [],
|
||||
model_name: str = LLM_MODEL,
|
||||
prompt_name: str = prompt_name,
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
model = get_ChatOpenAI(
|
||||
|
|
@ -35,7 +38,9 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||
temperature=temperature,
|
||||
callbacks=[callback],
|
||||
)
|
||||
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
|
||||
|
||||
prompt_template = get_prompt_template(prompt_name)
|
||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
|
|
@ -58,5 +63,8 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||
|
||||
await task
|
||||
|
||||
return StreamingResponse(chat_iterator(query, history, model_name),
|
||||
return StreamingResponse(chat_iterator(query=query,
|
||||
history=history,
|
||||
model_name=model_name,
|
||||
prompt_name=prompt_name),
|
||||
media_type="text/event-stream")
|
||||
|
|
|
|||
|
|
@ -1,10 +1,8 @@
|
|||
from fastapi import Body, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs import (LLM_MODEL, PROMPT_TEMPLATE,
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
TEMPERATURE)
|
||||
from configs import (LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE)
|
||||
from server.chat.utils import wrap_done, get_ChatOpenAI
|
||||
from server.utils import BaseResponse
|
||||
from server.utils import BaseResponse, get_prompt_template
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable, List, Optional
|
||||
|
|
@ -33,6 +31,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
|
||||
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||
request: Request = None,
|
||||
):
|
||||
|
|
@ -43,10 +42,10 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
history = [History.from_data(h) for h in history]
|
||||
|
||||
async def knowledge_base_chat_iterator(query: str,
|
||||
kb: KBService,
|
||||
top_k: int,
|
||||
history: Optional[List[History]],
|
||||
model_name: str = LLM_MODEL,
|
||||
prompt_name: str = prompt_name,
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
model = get_ChatOpenAI(
|
||||
|
|
@ -57,7 +56,8 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
|
||||
input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False)
|
||||
prompt_template = get_prompt_template(prompt_name)
|
||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
# 用户最后一个问题会进入PROMPT_TEMPLATE,不用再作为history 了
|
||||
if len(history) >= 1:
|
||||
history.pop()
|
||||
|
|
@ -98,5 +98,9 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
|
||||
await task
|
||||
|
||||
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history, model_name),
|
||||
return StreamingResponse(knowledge_base_chat_iterator(query=query,
|
||||
top_k=top_k,
|
||||
history=history,
|
||||
model_name=model_name,
|
||||
prompt_name=prompt_name),
|
||||
media_type="text/event-stream")
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
|
||||
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY,
|
||||
LLM_MODEL, SEARCH_ENGINE_TOP_K,
|
||||
PROMPT_TEMPLATE, TEMPERATURE)
|
||||
LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE)
|
||||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from server.chat.utils import wrap_done, get_ChatOpenAI
|
||||
from server.utils import BaseResponse
|
||||
from server.utils import BaseResponse, get_prompt_template
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable
|
||||
|
|
@ -73,6 +72,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
|||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
|
||||
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
):
|
||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||||
|
|
@ -87,6 +87,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
|||
top_k: int,
|
||||
history: Optional[List[History]],
|
||||
model_name: str = LLM_MODEL,
|
||||
prompt_name: str = prompt_name,
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
model = get_ChatOpenAI(
|
||||
|
|
@ -98,7 +99,8 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
|||
docs = await lookup_search_engine(query, search_engine_name, top_k)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
|
||||
input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False)
|
||||
prompt_template = get_prompt_template(prompt_name)
|
||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
|
||||
|
|
@ -129,5 +131,10 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
|||
ensure_ascii=False)
|
||||
await task
|
||||
|
||||
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history, model_name),
|
||||
return StreamingResponse(search_engine_chat_iterator(query=query,
|
||||
search_engine_name=search_engine_name,
|
||||
top_k=top_k,
|
||||
history=history,
|
||||
model_name=model_name,
|
||||
prompt_name=prompt_name),
|
||||
media_type="text/event-stream")
|
||||
|
|
|
|||
|
|
@ -327,6 +327,17 @@ def webui_address() -> str:
|
|||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def get_prompt_template(name: str) -> Optional[str]:
|
||||
'''
|
||||
从prompt_config中加载模板内容
|
||||
'''
|
||||
from configs import prompt_config
|
||||
import importlib
|
||||
importlib.reload(prompt_config) # TODO: 检查configs/prompt_config.py文件有修改再重新加载
|
||||
|
||||
return prompt_config.PROMPT_TEMPLATES.get(name)
|
||||
|
||||
|
||||
def set_httpx_timeout(timeout: float = None):
|
||||
'''
|
||||
设置httpx默认timeout。
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ def test_search_engine_chat(api="/chat/search_engine_chat"):
|
|||
assert data["msg"] == f"要使用Bing搜索引擎,需要设置 `BING_SUBSCRIPTION_KEY`"
|
||||
|
||||
print("\n")
|
||||
print("=" * 30 + api + " by {se} output" + "="*30)
|
||||
print("=" * 30 + api + f" by {se} output" + "="*30)
|
||||
for line in response.iter_content(None, decode_unicode=True):
|
||||
data = json.loads(line)
|
||||
if "answer" in data:
|
||||
|
|
|
|||
|
|
@ -315,6 +315,7 @@ class ApiRequest:
|
|||
stream: bool = True,
|
||||
model: str = LLM_MODEL,
|
||||
temperature: float = TEMPERATURE,
|
||||
prompt_name: str = "llm_chat",
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
|
|
@ -329,6 +330,7 @@ class ApiRequest:
|
|||
"stream": stream,
|
||||
"model_name": model,
|
||||
"temperature": temperature,
|
||||
"prompt_name": prompt_name,
|
||||
}
|
||||
|
||||
print(f"received input message:")
|
||||
|
|
@ -385,6 +387,7 @@ class ApiRequest:
|
|||
stream: bool = True,
|
||||
model: str = LLM_MODEL,
|
||||
temperature: float = TEMPERATURE,
|
||||
prompt_name: str = "knowledge_base_chat",
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
|
|
@ -403,6 +406,7 @@ class ApiRequest:
|
|||
"model_name": model,
|
||||
"temperature": temperature,
|
||||
"local_doc_url": no_remote_api,
|
||||
"prompt_name": prompt_name,
|
||||
}
|
||||
|
||||
print(f"received input message:")
|
||||
|
|
@ -428,6 +432,7 @@ class ApiRequest:
|
|||
stream: bool = True,
|
||||
model: str = LLM_MODEL,
|
||||
temperature: float = TEMPERATURE,
|
||||
prompt_name: str = "knowledge_base_chat",
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
|
|
@ -443,6 +448,7 @@ class ApiRequest:
|
|||
"stream": stream,
|
||||
"model_name": model,
|
||||
"temperature": temperature,
|
||||
"prompt_name": prompt_name,
|
||||
}
|
||||
|
||||
print(f"received input message:")
|
||||
|
|
|
|||
Loading…
Reference in New Issue