添加configs/prompt_config.py,允许用户自定义prompt模板: (#1504)
1、 默认包含2个模板,分别用于LLM对话,知识库和搜索引擎对话 2、 server/utils.py提供函数get_prompt_template,获取指定的prompt模板内容(支持热加载) 3、 api.py中chat/knowledge_base_chat/search_engine_chat接口支持prompt_name参数
This commit is contained in:
parent
598eb298df
commit
a65bc4a63c
|
|
@ -2,6 +2,7 @@ from .basic_config import *
|
||||||
from .model_config import *
|
from .model_config import *
|
||||||
from .kb_config import *
|
from .kb_config import *
|
||||||
from .server_config import *
|
from .server_config import *
|
||||||
|
from .prompt_config import *
|
||||||
|
|
||||||
|
|
||||||
VERSION = "v0.2.5-preview"
|
VERSION = "v0.2.5-preview"
|
||||||
|
|
|
||||||
|
|
@ -23,14 +23,6 @@ SCORE_THRESHOLD = 1
|
||||||
SEARCH_ENGINE_TOP_K = 3
|
SEARCH_ENGINE_TOP_K = 3
|
||||||
|
|
||||||
|
|
||||||
# 基于本地知识问答的提示词模版(使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号
|
|
||||||
PROMPT_TEMPLATE = """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
|
|
||||||
|
|
||||||
<已知信息>{{ context }}</已知信息>
|
|
||||||
|
|
||||||
<问题>{{ question }}</问题>"""
|
|
||||||
|
|
||||||
|
|
||||||
# Bing 搜索必备变量
|
# Bing 搜索必备变量
|
||||||
# 使用 Bing 搜索需要使用 Bing Subscription Key,需要在azure port中申请试用bing search
|
# 使用 Bing 搜索需要使用 Bing Subscription Key,需要在azure port中申请试用bing search
|
||||||
# 具体申请方式请见
|
# 具体申请方式请见
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,23 @@
|
||||||
|
# prompt模板使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号
|
||||||
|
# 本配置文件支持热加载,修改prompt模板后无需重启服务。
|
||||||
|
|
||||||
|
|
||||||
|
# LLM对话支持的变量:
|
||||||
|
# - input: 用户输入内容
|
||||||
|
|
||||||
|
# 知识库和搜索引擎对话支持的变量:
|
||||||
|
# - context: 从检索结果拼接的知识文本
|
||||||
|
# - question: 用户提出的问题
|
||||||
|
|
||||||
|
|
||||||
|
PROMPT_TEMPLATES = {
|
||||||
|
# LLM对话模板
|
||||||
|
"llm_chat": "{{ input }}",
|
||||||
|
|
||||||
|
# 基于本地知识问答的提示词模
|
||||||
|
"knowledge_base_chat": """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
|
||||||
|
|
||||||
|
<已知信息>{{ context }}</已知信息>
|
||||||
|
|
||||||
|
<问题>{{ question }}</问题>""",
|
||||||
|
}
|
||||||
|
|
@ -76,6 +76,9 @@ FSCHAT_MODEL_WORKERS = {
|
||||||
"qianfan-api": {
|
"qianfan-api": {
|
||||||
"port": 21004,
|
"port": 21004,
|
||||||
},
|
},
|
||||||
|
"fangzhou-api": {
|
||||||
|
"port": 21005,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# fastchat multi model worker server
|
# fastchat multi model worker server
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from configs.model_config import LLM_MODEL, TEMPERATURE
|
from configs import LLM_MODEL, TEMPERATURE
|
||||||
from server.chat.utils import wrap_done, get_ChatOpenAI
|
from server.chat.utils import wrap_done, get_ChatOpenAI
|
||||||
from langchain import LLMChain
|
from langchain import LLMChain
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
|
|
@ -9,6 +9,7 @@ import asyncio
|
||||||
from langchain.prompts.chat import ChatPromptTemplate
|
from langchain.prompts.chat import ChatPromptTemplate
|
||||||
from typing import List
|
from typing import List
|
||||||
from server.chat.utils import History
|
from server.chat.utils import History
|
||||||
|
from server.utils import get_prompt_template
|
||||||
|
|
||||||
|
|
||||||
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||||
|
|
@ -22,12 +23,14 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
|
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
|
||||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||||||
|
prompt_name: str = Body("llm_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||||
):
|
):
|
||||||
history = [History.from_data(h) for h in history]
|
history = [History.from_data(h) for h in history]
|
||||||
|
|
||||||
async def chat_iterator(query: str,
|
async def chat_iterator(query: str,
|
||||||
history: List[History] = [],
|
history: List[History] = [],
|
||||||
model_name: str = LLM_MODEL,
|
model_name: str = LLM_MODEL,
|
||||||
|
prompt_name: str = prompt_name,
|
||||||
) -> AsyncIterable[str]:
|
) -> AsyncIterable[str]:
|
||||||
callback = AsyncIteratorCallbackHandler()
|
callback = AsyncIteratorCallbackHandler()
|
||||||
model = get_ChatOpenAI(
|
model = get_ChatOpenAI(
|
||||||
|
|
@ -35,7 +38,9 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
callbacks=[callback],
|
callbacks=[callback],
|
||||||
)
|
)
|
||||||
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
|
|
||||||
|
prompt_template = get_prompt_template(prompt_name)
|
||||||
|
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||||
chat_prompt = ChatPromptTemplate.from_messages(
|
chat_prompt = ChatPromptTemplate.from_messages(
|
||||||
[i.to_msg_template() for i in history] + [input_msg])
|
[i.to_msg_template() for i in history] + [input_msg])
|
||||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||||
|
|
@ -58,5 +63,8 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||||
|
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return StreamingResponse(chat_iterator(query, history, model_name),
|
return StreamingResponse(chat_iterator(query=query,
|
||||||
|
history=history,
|
||||||
|
model_name=model_name,
|
||||||
|
prompt_name=prompt_name),
|
||||||
media_type="text/event-stream")
|
media_type="text/event-stream")
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
from fastapi import Body, Request
|
from fastapi import Body, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from configs import (LLM_MODEL, PROMPT_TEMPLATE,
|
from configs import (LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE)
|
||||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
|
||||||
TEMPERATURE)
|
|
||||||
from server.chat.utils import wrap_done, get_ChatOpenAI
|
from server.chat.utils import wrap_done, get_ChatOpenAI
|
||||||
from server.utils import BaseResponse
|
from server.utils import BaseResponse, get_prompt_template
|
||||||
from langchain import LLMChain
|
from langchain import LLMChain
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
from typing import AsyncIterable, List, Optional
|
from typing import AsyncIterable, List, Optional
|
||||||
|
|
@ -33,6 +31,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
||||||
stream: bool = Body(False, description="流式输出"),
|
stream: bool = Body(False, description="流式输出"),
|
||||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
|
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
|
||||||
|
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||||
request: Request = None,
|
request: Request = None,
|
||||||
):
|
):
|
||||||
|
|
@ -43,10 +42,10 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
||||||
history = [History.from_data(h) for h in history]
|
history = [History.from_data(h) for h in history]
|
||||||
|
|
||||||
async def knowledge_base_chat_iterator(query: str,
|
async def knowledge_base_chat_iterator(query: str,
|
||||||
kb: KBService,
|
|
||||||
top_k: int,
|
top_k: int,
|
||||||
history: Optional[List[History]],
|
history: Optional[List[History]],
|
||||||
model_name: str = LLM_MODEL,
|
model_name: str = LLM_MODEL,
|
||||||
|
prompt_name: str = prompt_name,
|
||||||
) -> AsyncIterable[str]:
|
) -> AsyncIterable[str]:
|
||||||
callback = AsyncIteratorCallbackHandler()
|
callback = AsyncIteratorCallbackHandler()
|
||||||
model = get_ChatOpenAI(
|
model = get_ChatOpenAI(
|
||||||
|
|
@ -57,7 +56,8 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
||||||
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
context = "\n".join([doc.page_content for doc in docs])
|
||||||
|
|
||||||
input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False)
|
prompt_template = get_prompt_template(prompt_name)
|
||||||
|
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||||
# 用户最后一个问题会进入PROMPT_TEMPLATE,不用再作为history 了
|
# 用户最后一个问题会进入PROMPT_TEMPLATE,不用再作为history 了
|
||||||
if len(history) >= 1:
|
if len(history) >= 1:
|
||||||
history.pop()
|
history.pop()
|
||||||
|
|
@ -98,5 +98,9 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
||||||
|
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history, model_name),
|
return StreamingResponse(knowledge_base_chat_iterator(query=query,
|
||||||
|
top_k=top_k,
|
||||||
|
history=history,
|
||||||
|
model_name=model_name,
|
||||||
|
prompt_name=prompt_name),
|
||||||
media_type="text/event-stream")
|
media_type="text/event-stream")
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,11 @@
|
||||||
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
|
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
|
||||||
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY,
|
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY,
|
||||||
LLM_MODEL, SEARCH_ENGINE_TOP_K,
|
LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE)
|
||||||
PROMPT_TEMPLATE, TEMPERATURE)
|
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from fastapi.concurrency import run_in_threadpool
|
from fastapi.concurrency import run_in_threadpool
|
||||||
from server.chat.utils import wrap_done, get_ChatOpenAI
|
from server.chat.utils import wrap_done, get_ChatOpenAI
|
||||||
from server.utils import BaseResponse
|
from server.utils import BaseResponse, get_prompt_template
|
||||||
from langchain import LLMChain
|
from langchain import LLMChain
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
from typing import AsyncIterable
|
from typing import AsyncIterable
|
||||||
|
|
@ -73,6 +72,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
||||||
stream: bool = Body(False, description="流式输出"),
|
stream: bool = Body(False, description="流式输出"),
|
||||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
|
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
|
||||||
|
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||||
):
|
):
|
||||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||||||
|
|
@ -87,6 +87,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
||||||
top_k: int,
|
top_k: int,
|
||||||
history: Optional[List[History]],
|
history: Optional[List[History]],
|
||||||
model_name: str = LLM_MODEL,
|
model_name: str = LLM_MODEL,
|
||||||
|
prompt_name: str = prompt_name,
|
||||||
) -> AsyncIterable[str]:
|
) -> AsyncIterable[str]:
|
||||||
callback = AsyncIteratorCallbackHandler()
|
callback = AsyncIteratorCallbackHandler()
|
||||||
model = get_ChatOpenAI(
|
model = get_ChatOpenAI(
|
||||||
|
|
@ -98,7 +99,8 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
||||||
docs = await lookup_search_engine(query, search_engine_name, top_k)
|
docs = await lookup_search_engine(query, search_engine_name, top_k)
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
context = "\n".join([doc.page_content for doc in docs])
|
||||||
|
|
||||||
input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False)
|
prompt_template = get_prompt_template(prompt_name)
|
||||||
|
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||||
chat_prompt = ChatPromptTemplate.from_messages(
|
chat_prompt = ChatPromptTemplate.from_messages(
|
||||||
[i.to_msg_template() for i in history] + [input_msg])
|
[i.to_msg_template() for i in history] + [input_msg])
|
||||||
|
|
||||||
|
|
@ -129,5 +131,10 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
||||||
ensure_ascii=False)
|
ensure_ascii=False)
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history, model_name),
|
return StreamingResponse(search_engine_chat_iterator(query=query,
|
||||||
|
search_engine_name=search_engine_name,
|
||||||
|
top_k=top_k,
|
||||||
|
history=history,
|
||||||
|
model_name=model_name,
|
||||||
|
prompt_name=prompt_name),
|
||||||
media_type="text/event-stream")
|
media_type="text/event-stream")
|
||||||
|
|
|
||||||
|
|
@ -327,6 +327,17 @@ def webui_address() -> str:
|
||||||
return f"http://{host}:{port}"
|
return f"http://{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt_template(name: str) -> Optional[str]:
|
||||||
|
'''
|
||||||
|
从prompt_config中加载模板内容
|
||||||
|
'''
|
||||||
|
from configs import prompt_config
|
||||||
|
import importlib
|
||||||
|
importlib.reload(prompt_config) # TODO: 检查configs/prompt_config.py文件有修改再重新加载
|
||||||
|
|
||||||
|
return prompt_config.PROMPT_TEMPLATES.get(name)
|
||||||
|
|
||||||
|
|
||||||
def set_httpx_timeout(timeout: float = None):
|
def set_httpx_timeout(timeout: float = None):
|
||||||
'''
|
'''
|
||||||
设置httpx默认timeout。
|
设置httpx默认timeout。
|
||||||
|
|
|
||||||
|
|
@ -114,7 +114,7 @@ def test_search_engine_chat(api="/chat/search_engine_chat"):
|
||||||
assert data["msg"] == f"要使用Bing搜索引擎,需要设置 `BING_SUBSCRIPTION_KEY`"
|
assert data["msg"] == f"要使用Bing搜索引擎,需要设置 `BING_SUBSCRIPTION_KEY`"
|
||||||
|
|
||||||
print("\n")
|
print("\n")
|
||||||
print("=" * 30 + api + " by {se} output" + "="*30)
|
print("=" * 30 + api + f" by {se} output" + "="*30)
|
||||||
for line in response.iter_content(None, decode_unicode=True):
|
for line in response.iter_content(None, decode_unicode=True):
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
if "answer" in data:
|
if "answer" in data:
|
||||||
|
|
|
||||||
|
|
@ -315,6 +315,7 @@ class ApiRequest:
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
model: str = LLM_MODEL,
|
model: str = LLM_MODEL,
|
||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
|
prompt_name: str = "llm_chat",
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
|
|
@ -329,6 +330,7 @@ class ApiRequest:
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
"model_name": model,
|
"model_name": model,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
|
"prompt_name": prompt_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
print(f"received input message:")
|
print(f"received input message:")
|
||||||
|
|
@ -385,6 +387,7 @@ class ApiRequest:
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
model: str = LLM_MODEL,
|
model: str = LLM_MODEL,
|
||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
|
prompt_name: str = "knowledge_base_chat",
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
|
|
@ -403,6 +406,7 @@ class ApiRequest:
|
||||||
"model_name": model,
|
"model_name": model,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"local_doc_url": no_remote_api,
|
"local_doc_url": no_remote_api,
|
||||||
|
"prompt_name": prompt_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
print(f"received input message:")
|
print(f"received input message:")
|
||||||
|
|
@ -428,6 +432,7 @@ class ApiRequest:
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
model: str = LLM_MODEL,
|
model: str = LLM_MODEL,
|
||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
|
prompt_name: str = "knowledge_base_chat",
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
|
|
@ -443,6 +448,7 @@ class ApiRequest:
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
"model_name": model,
|
"model_name": model,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
|
"prompt_name": prompt_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
print(f"received input message:")
|
print(f"received input message:")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue