添加configs/prompt_config.py,允许用户自定义prompt模板: (#1504)

1、 默认包含2个模板,分别用于LLM对话,知识库和搜索引擎对话
2、 server/utils.py提供函数get_prompt_template,获取指定的prompt模板内容(支持热加载)
3、 api.py中chat/knowledge_base_chat/search_engine_chat接口支持prompt_name参数
This commit is contained in:
liunux4odoo 2023-09-17 13:27:11 +08:00 committed by GitHub
parent 598eb298df
commit a65bc4a63c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 79 additions and 24 deletions

View File

@ -2,6 +2,7 @@ from .basic_config import *
from .model_config import * from .model_config import *
from .kb_config import * from .kb_config import *
from .server_config import * from .server_config import *
from .prompt_config import *
VERSION = "v0.2.5-preview" VERSION = "v0.2.5-preview"

View File

@ -23,14 +23,6 @@ SCORE_THRESHOLD = 1
SEARCH_ENGINE_TOP_K = 3 SEARCH_ENGINE_TOP_K = 3
# 基于本地知识问答的提示词模版使用Jinja2语法简单点就是用双大括号代替f-string的单大括号
PROMPT_TEMPLATE = """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
<已知信息>{{ context }}</已知信息>
<问题>{{ question }}</问题>"""
# Bing 搜索必备变量 # Bing 搜索必备变量
# 使用 Bing 搜索需要使用 Bing Subscription Key,需要在azure port中申请试用bing search # 使用 Bing 搜索需要使用 Bing Subscription Key,需要在azure port中申请试用bing search
# 具体申请方式请见 # 具体申请方式请见

View File

@ -0,0 +1,23 @@
# prompt模板使用Jinja2语法简单点就是用双大括号代替f-string的单大括号
# 本配置文件支持热加载修改prompt模板后无需重启服务。
# LLM对话支持的变量
# - input: 用户输入内容
# 知识库和搜索引擎对话支持的变量:
# - context: 从检索结果拼接的知识文本
# - question: 用户提出的问题
PROMPT_TEMPLATES = {
# LLM对话模板
"llm_chat": "{{ input }}",
# 基于本地知识问答的提示词模
"knowledge_base_chat": """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
<已知信息>{{ context }}</已知信息>
<问题>{{ question }}</问题>""",
}

View File

@ -76,6 +76,9 @@ FSCHAT_MODEL_WORKERS = {
"qianfan-api": { "qianfan-api": {
"port": 21004, "port": 21004,
}, },
"fangzhou-api": {
"port": 21005,
},
} }
# fastchat multi model worker server # fastchat multi model worker server

View File

@ -1,6 +1,6 @@
from fastapi import Body from fastapi import Body
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from configs.model_config import LLM_MODEL, TEMPERATURE from configs import LLM_MODEL, TEMPERATURE
from server.chat.utils import wrap_done, get_ChatOpenAI from server.chat.utils import wrap_done, get_ChatOpenAI
from langchain import LLMChain from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.callbacks import AsyncIteratorCallbackHandler
@ -9,6 +9,7 @@ import asyncio
from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts.chat import ChatPromptTemplate
from typing import List from typing import List
from server.chat.utils import History from server.chat.utils import History
from server.utils import get_prompt_template
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
@ -22,12 +23,14 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("llm_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
): ):
history = [History.from_data(h) for h in history] history = [History.from_data(h) for h in history]
async def chat_iterator(query: str, async def chat_iterator(query: str,
history: List[History] = [], history: List[History] = [],
model_name: str = LLM_MODEL, model_name: str = LLM_MODEL,
prompt_name: str = prompt_name,
) -> AsyncIterable[str]: ) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler() callback = AsyncIteratorCallbackHandler()
model = get_ChatOpenAI( model = get_ChatOpenAI(
@ -35,7 +38,9 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
temperature=temperature, temperature=temperature,
callbacks=[callback], callbacks=[callback],
) )
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
prompt_template = get_prompt_template(prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages( chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg]) [i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model) chain = LLMChain(prompt=chat_prompt, llm=model)
@ -58,5 +63,8 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
await task await task
return StreamingResponse(chat_iterator(query, history, model_name), return StreamingResponse(chat_iterator(query=query,
history=history,
model_name=model_name,
prompt_name=prompt_name),
media_type="text/event-stream") media_type="text/event-stream")

View File

@ -1,10 +1,8 @@
from fastapi import Body, Request from fastapi import Body, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from configs import (LLM_MODEL, PROMPT_TEMPLATE, from configs import (LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE)
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
TEMPERATURE)
from server.chat.utils import wrap_done, get_ChatOpenAI from server.chat.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse from server.utils import BaseResponse, get_prompt_template
from langchain import LLMChain from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable, List, Optional from typing import AsyncIterable, List, Optional
@ -33,6 +31,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"), local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
request: Request = None, request: Request = None,
): ):
@ -43,10 +42,10 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
history = [History.from_data(h) for h in history] history = [History.from_data(h) for h in history]
async def knowledge_base_chat_iterator(query: str, async def knowledge_base_chat_iterator(query: str,
kb: KBService,
top_k: int, top_k: int,
history: Optional[List[History]], history: Optional[List[History]],
model_name: str = LLM_MODEL, model_name: str = LLM_MODEL,
prompt_name: str = prompt_name,
) -> AsyncIterable[str]: ) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler() callback = AsyncIteratorCallbackHandler()
model = get_ChatOpenAI( model = get_ChatOpenAI(
@ -57,7 +56,8 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
docs = search_docs(query, knowledge_base_name, top_k, score_threshold) docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
context = "\n".join([doc.page_content for doc in docs]) context = "\n".join([doc.page_content for doc in docs])
input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False) prompt_template = get_prompt_template(prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
# 用户最后一个问题会进入PROMPT_TEMPLATE不用再作为history 了 # 用户最后一个问题会进入PROMPT_TEMPLATE不用再作为history 了
if len(history) >= 1: if len(history) >= 1:
history.pop() history.pop()
@ -98,5 +98,9 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
await task await task
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history, model_name), return StreamingResponse(knowledge_base_chat_iterator(query=query,
top_k=top_k,
history=history,
model_name=model_name,
prompt_name=prompt_name),
media_type="text/event-stream") media_type="text/event-stream")

View File

@ -1,12 +1,11 @@
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY,
LLM_MODEL, SEARCH_ENGINE_TOP_K, LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE)
PROMPT_TEMPLATE, TEMPERATURE)
from fastapi import Body from fastapi import Body
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool from fastapi.concurrency import run_in_threadpool
from server.chat.utils import wrap_done, get_ChatOpenAI from server.chat.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse from server.utils import BaseResponse, get_prompt_template
from langchain import LLMChain from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable from typing import AsyncIterable
@ -73,6 +72,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
): ):
if search_engine_name not in SEARCH_ENGINES.keys(): if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
@ -87,6 +87,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
top_k: int, top_k: int,
history: Optional[List[History]], history: Optional[List[History]],
model_name: str = LLM_MODEL, model_name: str = LLM_MODEL,
prompt_name: str = prompt_name,
) -> AsyncIterable[str]: ) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler() callback = AsyncIteratorCallbackHandler()
model = get_ChatOpenAI( model = get_ChatOpenAI(
@ -98,7 +99,8 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
docs = await lookup_search_engine(query, search_engine_name, top_k) docs = await lookup_search_engine(query, search_engine_name, top_k)
context = "\n".join([doc.page_content for doc in docs]) context = "\n".join([doc.page_content for doc in docs])
input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False) prompt_template = get_prompt_template(prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages( chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg]) [i.to_msg_template() for i in history] + [input_msg])
@ -129,5 +131,10 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
ensure_ascii=False) ensure_ascii=False)
await task await task
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history, model_name), return StreamingResponse(search_engine_chat_iterator(query=query,
search_engine_name=search_engine_name,
top_k=top_k,
history=history,
model_name=model_name,
prompt_name=prompt_name),
media_type="text/event-stream") media_type="text/event-stream")

View File

@ -327,6 +327,17 @@ def webui_address() -> str:
return f"http://{host}:{port}" return f"http://{host}:{port}"
def get_prompt_template(name: str) -> Optional[str]:
'''
从prompt_config中加载模板内容
'''
from configs import prompt_config
import importlib
importlib.reload(prompt_config) # TODO: 检查configs/prompt_config.py文件有修改再重新加载
return prompt_config.PROMPT_TEMPLATES.get(name)
def set_httpx_timeout(timeout: float = None): def set_httpx_timeout(timeout: float = None):
''' '''
设置httpx默认timeout 设置httpx默认timeout

View File

@ -114,7 +114,7 @@ def test_search_engine_chat(api="/chat/search_engine_chat"):
assert data["msg"] == f"要使用Bing搜索引擎需要设置 `BING_SUBSCRIPTION_KEY`" assert data["msg"] == f"要使用Bing搜索引擎需要设置 `BING_SUBSCRIPTION_KEY`"
print("\n") print("\n")
print("=" * 30 + api + " by {se} output" + "="*30) print("=" * 30 + api + f" by {se} output" + "="*30)
for line in response.iter_content(None, decode_unicode=True): for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line) data = json.loads(line)
if "answer" in data: if "answer" in data:

View File

@ -315,6 +315,7 @@ class ApiRequest:
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL, model: str = LLM_MODEL,
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
prompt_name: str = "llm_chat",
no_remote_api: bool = None, no_remote_api: bool = None,
): ):
''' '''
@ -329,6 +330,7 @@ class ApiRequest:
"stream": stream, "stream": stream,
"model_name": model, "model_name": model,
"temperature": temperature, "temperature": temperature,
"prompt_name": prompt_name,
} }
print(f"received input message:") print(f"received input message:")
@ -385,6 +387,7 @@ class ApiRequest:
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL, model: str = LLM_MODEL,
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
prompt_name: str = "knowledge_base_chat",
no_remote_api: bool = None, no_remote_api: bool = None,
): ):
''' '''
@ -403,6 +406,7 @@ class ApiRequest:
"model_name": model, "model_name": model,
"temperature": temperature, "temperature": temperature,
"local_doc_url": no_remote_api, "local_doc_url": no_remote_api,
"prompt_name": prompt_name,
} }
print(f"received input message:") print(f"received input message:")
@ -428,6 +432,7 @@ class ApiRequest:
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL, model: str = LLM_MODEL,
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
prompt_name: str = "knowledge_base_chat",
no_remote_api: bool = None, no_remote_api: bool = None,
): ):
''' '''
@ -443,6 +448,7 @@ class ApiRequest:
"stream": stream, "stream": stream,
"model_name": model, "model_name": model,
"temperature": temperature, "temperature": temperature,
"prompt_name": prompt_name,
} }
print(f"received input message:") print(f"received input message:")