fix #1142: 在History中使用jinja2模板代替f-string,避免消息中含有{ }时出错

This commit is contained in:
liunux4odoo 2023-08-23 08:35:26 +08:00
parent 054ae4715d
commit 44edce6bcf
6 changed files with 41 additions and 118 deletions

View File

@ -142,12 +142,12 @@ SEARCH_ENGINE_TOP_K = 5
# nltk 模型存储路径 # nltk 模型存储路径
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data") NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
# 基于本地知识问答的提示词模版 # 基于本地知识问答的提示词模版使用Jinja2语法简单点就是用双大括号代替f-string的单大括号
PROMPT_TEMPLATE = """【指令】根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 PROMPT_TEMPLATE = """【指令】根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。
【已知信息】{context} 【已知信息】{{ context }}
【问题】{question}""" 【问题】{{ question }}"""
# API 是否开启跨域默认为False如果需要开启请设置为True # API 是否开启跨域默认为False如果需要开启请设置为True
# is open cross domain # is open cross domain

View File

@ -21,7 +21,7 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
), ),
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
): ):
history = [History(**h) if isinstance(h, dict) else h for h in history] history = [History.from_data(h) for h in history]
async def chat_iterator(query: str, async def chat_iterator(query: str,
history: List[History] = [], history: List[History] = [],
@ -37,8 +37,9 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
model_name=LLM_MODEL model_name=LLM_MODEL
) )
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages( chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_tuple() for i in history] + [("human", "{input}")]) [i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model) chain = LLMChain(prompt=chat_prompt, llm=model)
# Begin a task that runs in the background. # Begin a task that runs in the background.

View File

@ -1,109 +0,0 @@
from langchain.document_loaders.github import GitHubIssuesLoader
from fastapi import Body
from fastapi.responses import StreamingResponse
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE)
from server.chat.utils import wrap_done
from server.utils import BaseResponse
from langchain.chat_models import ChatOpenAI
from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional, Literal
from server.chat.utils import History
from langchain.docstore.document import Document
import json
import os
from functools import lru_cache
from datetime import datetime
GITHUB_PERSONAL_ACCESS_TOKEN = os.environ.get("GITHUB_PERSONAL_ACCESS_TOKEN")
@lru_cache(1)
def load_issues(tick: str):
'''
set tick to a periodic value to refresh cache
'''
loader = GitHubIssuesLoader(
repo="chatchat-space/langchain-chatglm",
access_token=GITHUB_PERSONAL_ACCESS_TOKEN,
include_prs=True,
state="all",
)
docs = loader.load()
return docs
def
def github_chat(query: str = Body(..., description="用户输入", examples=["本项目最新进展"]),
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
include_prs: bool = Body(True, description="是否包含PR"),
state: Literal['open', 'closed', 'all'] = Body(None, description="Issue/PR状态"),
creator: str = Body(None, description="创建者"),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
"content": "介绍一下本项目"},
{"role": "assistant",
"content": "LangChain-Chatchat (原 Langchain-ChatGLM): 基于 Langchain 与 ChatGLM 等大语言模型的本地知识库问答应用实现。"}]]
),
stream: bool = Body(False, description="流式输出"),
):
if GITHUB_PERSONAL_ACCESS_TOKEN is None:
return BaseResponse(code=404, msg=f"使用本功能需要 GITHUB_PERSONAL_ACCESS_TOKEN")
async def chat_iterator(query: str,
search_engine_name: str,
top_k: int,
history: Optional[List[History]],
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL
)
docs = lookup_search_engine(query, search_engine_name, top_k)
context = "\n".join([doc.page_content for doc in docs])
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
chain = LLMChain(prompt=chat_prompt, llm=model)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.acall({"context": context, "question": query}),
callback.done),
)
source_documents = [
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
for inum, doc in enumerate(docs)
]
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps({"answer": token,
"docs": source_documents},
ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps({"answer": token,
"docs": source_documents},
ensure_ascii=False)
await task
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history),
media_type="text/event-stream")

View File

@ -38,7 +38,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
if kb is None: if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
history = [History(**h) if isinstance(h, dict) else h for h in history] history = [History.from_data(h) for h in history]
async def knowledge_base_chat_iterator(query: str, async def knowledge_base_chat_iterator(query: str,
kb: KBService, kb: KBService,
@ -57,8 +57,9 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
docs = search_docs(query, knowledge_base_name, top_k, score_threshold) docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
context = "\n".join([doc.page_content for doc in docs]) context = "\n".join([doc.page_content for doc in docs])
input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages( chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)]) [i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model) chain = LLMChain(prompt=chat_prompt, llm=model)

View File

@ -73,6 +73,8 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
if search_engine_name not in SEARCH_ENGINES.keys(): if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
history = [History.from_data(h) for h in history]
async def search_engine_chat_iterator(query: str, async def search_engine_chat_iterator(query: str,
search_engine_name: str, search_engine_name: str,
top_k: int, top_k: int,
@ -91,8 +93,9 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
docs = lookup_search_engine(query, search_engine_name, top_k) docs = lookup_search_engine(query, search_engine_name, top_k)
context = "\n".join([doc.page_content for doc in docs]) context = "\n".join([doc.page_content for doc in docs])
input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages( chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)]) [i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model) chain = LLMChain(prompt=chat_prompt, llm=model)

View File

@ -1,6 +1,7 @@
import asyncio import asyncio
from typing import Awaitable from typing import Awaitable, List, Tuple, Dict, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from langchain.prompts.chat import ChatMessagePromptTemplate
async def wrap_done(fn: Awaitable, event: asyncio.Event): async def wrap_done(fn: Awaitable, event: asyncio.Event):
@ -28,3 +29,29 @@ class History(BaseModel):
def to_msg_tuple(self): def to_msg_tuple(self):
return "ai" if self.role=="assistant" else "human", self.content return "ai" if self.role=="assistant" else "human", self.content
def to_msg_template(self, is_raw=True) -> ChatMessagePromptTemplate:
role_maps = {
"ai": "assistant",
"human": "user",
}
role = role_maps.get(self.role, self.role)
if is_raw: # 当前默认历史消息都是没有input_variable的文本。
content = "{% raw %}" + self.content + "{% endraw %}"
else:
content = self.content
return ChatMessagePromptTemplate.from_template(
content,
"jinja2",
role=role,
)
@classmethod
def from_data(cls, h: Union[List, Tuple, Dict]) -> "History":
if isinstance(h, (list,tuple)) and len(h) >= 2:
h = cls(role=h[0], content=h[1])
elif isinstance(h, dict):
h = cls(**h)
return h