diff --git a/configs/__init__.py b/configs/__init__.py index f4c1866..b9c1fee 100644 --- a/configs/__init__.py +++ b/configs/__init__.py @@ -2,6 +2,7 @@ from .basic_config import * from .model_config import * from .kb_config import * from .server_config import * +from .prompt_config import * VERSION = "v0.2.5-preview" diff --git a/configs/kb_config.py.exmaple b/configs/kb_config.py.exmaple index 1c2deb9..af68a74 100644 --- a/configs/kb_config.py.exmaple +++ b/configs/kb_config.py.exmaple @@ -23,14 +23,6 @@ SCORE_THRESHOLD = 1 SEARCH_ENGINE_TOP_K = 3 -# 基于本地知识问答的提示词模版(使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号 -PROMPT_TEMPLATE = """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 - -<已知信息>{{ context }} - -<问题>{{ question }}""" - - # Bing 搜索必备变量 # 使用 Bing 搜索需要使用 Bing Subscription Key,需要在azure port中申请试用bing search # 具体申请方式请见 diff --git a/configs/prompt_config.py.example b/configs/prompt_config.py.example new file mode 100644 index 0000000..013f946 --- /dev/null +++ b/configs/prompt_config.py.example @@ -0,0 +1,23 @@ +# prompt模板使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号 +# 本配置文件支持热加载,修改prompt模板后无需重启服务。 + + +# LLM对话支持的变量: +# - input: 用户输入内容 + +# 知识库和搜索引擎对话支持的变量: +# - context: 从检索结果拼接的知识文本 +# - question: 用户提出的问题 + + +PROMPT_TEMPLATES = { + # LLM对话模板 + "llm_chat": "{{ input }}", + + # 基于本地知识问答的提示词模 + "knowledge_base_chat": """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 + +<已知信息>{{ context }} + +<问题>{{ question }}""", +} diff --git a/configs/server_config.py.example b/configs/server_config.py.example index b5040bf..7a4e318 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -76,6 +76,9 @@ FSCHAT_MODEL_WORKERS = { "qianfan-api": { "port": 21004, }, + "fangzhou-api": { + "port": 21005, + }, } # fastchat multi model worker server diff --git a/server/chat/chat.py b/server/chat/chat.py index e7564bc..8f4d343 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -1,6 +1,6 @@ from fastapi import Body from fastapi.responses import StreamingResponse -from configs.model_config import LLM_MODEL, TEMPERATURE +from configs import LLM_MODEL, TEMPERATURE from server.chat.utils import wrap_done, get_ChatOpenAI from langchain import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler @@ -9,6 +9,7 @@ import asyncio from langchain.prompts.chat import ChatPromptTemplate from typing import List from server.chat.utils import History +from server.utils import get_prompt_template async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), @@ -22,12 +23,14 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), + prompt_name: str = Body("llm_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ): history = [History.from_data(h) for h in history] async def chat_iterator(query: str, history: List[History] = [], model_name: str = LLM_MODEL, + prompt_name: str = prompt_name, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() model = get_ChatOpenAI( @@ -35,7 +38,9 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 temperature=temperature, callbacks=[callback], ) - input_msg = History(role="user", content="{{ input }}").to_msg_template(False) + + prompt_template = get_prompt_template(prompt_name) + input_msg = History(role="user", content=prompt_template).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages( [i.to_msg_template() for i in history] + [input_msg]) chain = LLMChain(prompt=chat_prompt, llm=model) @@ -58,5 +63,8 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 await task - return StreamingResponse(chat_iterator(query, history, model_name), + return StreamingResponse(chat_iterator(query=query, + history=history, + model_name=model_name, + prompt_name=prompt_name), media_type="text/event-stream") diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 6e8a982..7316e5e 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -1,10 +1,8 @@ from fastapi import Body, Request from fastapi.responses import StreamingResponse -from configs import (LLM_MODEL, PROMPT_TEMPLATE, - VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, - TEMPERATURE) +from configs import (LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE) from server.chat.utils import wrap_done, get_ChatOpenAI -from server.utils import BaseResponse +from server.utils import BaseResponse, get_prompt_template from langchain import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable, List, Optional @@ -33,6 +31,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0), + prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"), request: Request = None, ): @@ -43,10 +42,10 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", history = [History.from_data(h) for h in history] async def knowledge_base_chat_iterator(query: str, - kb: KBService, top_k: int, history: Optional[List[History]], model_name: str = LLM_MODEL, + prompt_name: str = prompt_name, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() model = get_ChatOpenAI( @@ -57,7 +56,8 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", docs = search_docs(query, knowledge_base_name, top_k, score_threshold) context = "\n".join([doc.page_content for doc in docs]) - input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False) + prompt_template = get_prompt_template(prompt_name) + input_msg = History(role="user", content=prompt_template).to_msg_template(False) # 用户最后一个问题会进入PROMPT_TEMPLATE,不用再作为history 了 if len(history) >= 1: history.pop() @@ -98,5 +98,9 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", await task - return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history, model_name), + return StreamingResponse(knowledge_base_chat_iterator(query=query, + top_k=top_k, + history=history, + model_name=model_name, + prompt_name=prompt_name), media_type="text/event-stream") diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index ffeff86..cfc1fa9 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -1,12 +1,11 @@ from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, - LLM_MODEL, SEARCH_ENGINE_TOP_K, - PROMPT_TEMPLATE, TEMPERATURE) + LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE) from fastapi import Body from fastapi.responses import StreamingResponse from fastapi.concurrency import run_in_threadpool from server.chat.utils import wrap_done, get_ChatOpenAI -from server.utils import BaseResponse +from server.utils import BaseResponse, get_prompt_template from langchain import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable @@ -73,6 +72,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0), + prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ): if search_engine_name not in SEARCH_ENGINES.keys(): return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") @@ -87,6 +87,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", top_k: int, history: Optional[List[History]], model_name: str = LLM_MODEL, + prompt_name: str = prompt_name, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() model = get_ChatOpenAI( @@ -98,7 +99,8 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", docs = await lookup_search_engine(query, search_engine_name, top_k) context = "\n".join([doc.page_content for doc in docs]) - input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False) + prompt_template = get_prompt_template(prompt_name) + input_msg = History(role="user", content=prompt_template).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages( [i.to_msg_template() for i in history] + [input_msg]) @@ -129,5 +131,10 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", ensure_ascii=False) await task - return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history, model_name), + return StreamingResponse(search_engine_chat_iterator(query=query, + search_engine_name=search_engine_name, + top_k=top_k, + history=history, + model_name=model_name, + prompt_name=prompt_name), media_type="text/event-stream") diff --git a/server/utils.py b/server/utils.py index 4dc29ae..e0a595d 100644 --- a/server/utils.py +++ b/server/utils.py @@ -327,6 +327,17 @@ def webui_address() -> str: return f"http://{host}:{port}" +def get_prompt_template(name: str) -> Optional[str]: + ''' + 从prompt_config中加载模板内容 + ''' + from configs import prompt_config + import importlib + importlib.reload(prompt_config) # TODO: 检查configs/prompt_config.py文件有修改再重新加载 + + return prompt_config.PROMPT_TEMPLATES.get(name) + + def set_httpx_timeout(timeout: float = None): ''' 设置httpx默认timeout。 diff --git a/tests/api/test_stream_chat_api.py b/tests/api/test_stream_chat_api.py index 7989499..47547f7 100644 --- a/tests/api/test_stream_chat_api.py +++ b/tests/api/test_stream_chat_api.py @@ -114,7 +114,7 @@ def test_search_engine_chat(api="/chat/search_engine_chat"): assert data["msg"] == f"要使用Bing搜索引擎,需要设置 `BING_SUBSCRIPTION_KEY`" print("\n") - print("=" * 30 + api + " by {se} output" + "="*30) + print("=" * 30 + api + f" by {se} output" + "="*30) for line in response.iter_content(None, decode_unicode=True): data = json.loads(line) if "answer" in data: diff --git a/webui_pages/utils.py b/webui_pages/utils.py index a94cb99..4095437 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -315,6 +315,7 @@ class ApiRequest: stream: bool = True, model: str = LLM_MODEL, temperature: float = TEMPERATURE, + prompt_name: str = "llm_chat", no_remote_api: bool = None, ): ''' @@ -329,6 +330,7 @@ class ApiRequest: "stream": stream, "model_name": model, "temperature": temperature, + "prompt_name": prompt_name, } print(f"received input message:") @@ -385,6 +387,7 @@ class ApiRequest: stream: bool = True, model: str = LLM_MODEL, temperature: float = TEMPERATURE, + prompt_name: str = "knowledge_base_chat", no_remote_api: bool = None, ): ''' @@ -403,6 +406,7 @@ class ApiRequest: "model_name": model, "temperature": temperature, "local_doc_url": no_remote_api, + "prompt_name": prompt_name, } print(f"received input message:") @@ -428,6 +432,7 @@ class ApiRequest: stream: bool = True, model: str = LLM_MODEL, temperature: float = TEMPERATURE, + prompt_name: str = "knowledge_base_chat", no_remote_api: bool = None, ): ''' @@ -443,6 +448,7 @@ class ApiRequest: "stream": stream, "model_name": model, "temperature": temperature, + "prompt_name": prompt_name, } print(f"received input message:")