fix #1142: 在History中使用jinja2模板代替f-string,避免消息中含有{ }时出错
This commit is contained in:
parent
054ae4715d
commit
44edce6bcf
|
|
@ -142,12 +142,12 @@ SEARCH_ENGINE_TOP_K = 5
|
|||
# nltk 模型存储路径
|
||||
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
|
||||
|
||||
# 基于本地知识问答的提示词模版
|
||||
# 基于本地知识问答的提示词模版(使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号
|
||||
PROMPT_TEMPLATE = """【指令】根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。
|
||||
|
||||
【已知信息】{context}
|
||||
【已知信息】{{ context }}
|
||||
|
||||
【问题】{question}"""
|
||||
【问题】{{ question }}"""
|
||||
|
||||
# API 是否开启跨域,默认为False,如果需要开启,请设置为True
|
||||
# is open cross domain
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
|
|||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
):
|
||||
history = [History(**h) if isinstance(h, dict) else h for h in history]
|
||||
history = [History.from_data(h) for h in history]
|
||||
|
||||
async def chat_iterator(query: str,
|
||||
history: List[History] = [],
|
||||
|
|
@ -37,8 +37,9 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
|
|||
model_name=LLM_MODEL
|
||||
)
|
||||
|
||||
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", "{input}")])
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
|
||||
# Begin a task that runs in the background.
|
||||
|
|
|
|||
|
|
@ -1,109 +0,0 @@
|
|||
from langchain.document_loaders.github import GitHubIssuesLoader
|
||||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE)
|
||||
from server.chat.utils import wrap_done
|
||||
from server.utils import BaseResponse
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable
|
||||
import asyncio
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from typing import List, Optional, Literal
|
||||
from server.chat.utils import History
|
||||
from langchain.docstore.document import Document
|
||||
import json
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
GITHUB_PERSONAL_ACCESS_TOKEN = os.environ.get("GITHUB_PERSONAL_ACCESS_TOKEN")
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def load_issues(tick: str):
|
||||
'''
|
||||
set tick to a periodic value to refresh cache
|
||||
'''
|
||||
loader = GitHubIssuesLoader(
|
||||
repo="chatchat-space/langchain-chatglm",
|
||||
access_token=GITHUB_PERSONAL_ACCESS_TOKEN,
|
||||
include_prs=True,
|
||||
state="all",
|
||||
)
|
||||
docs = loader.load()
|
||||
return docs
|
||||
|
||||
|
||||
def
|
||||
def github_chat(query: str = Body(..., description="用户输入", examples=["本项目最新进展"]),
|
||||
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
||||
include_prs: bool = Body(True, description="是否包含PR"),
|
||||
state: Literal['open', 'closed', 'all'] = Body(None, description="Issue/PR状态"),
|
||||
creator: str = Body(None, description="创建者"),
|
||||
history: List[History] = Body([],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user",
|
||||
"content": "介绍一下本项目"},
|
||||
{"role": "assistant",
|
||||
"content": "LangChain-Chatchat (原 Langchain-ChatGLM): 基于 Langchain 与 ChatGLM 等大语言模型的本地知识库问答应用实现。"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
):
|
||||
if GITHUB_PERSONAL_ACCESS_TOKEN is None:
|
||||
return BaseResponse(code=404, msg=f"使用本功能需要 GITHUB_PERSONAL_ACCESS_TOKEN")
|
||||
|
||||
async def chat_iterator(query: str,
|
||||
search_engine_name: str,
|
||||
top_k: int,
|
||||
history: Optional[List[History]],
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
callbacks=[callback],
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL
|
||||
)
|
||||
|
||||
docs = lookup_search_engine(query, search_engine_name, top_k)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
|
||||
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
|
||||
# Begin a task that runs in the background.
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({"context": context, "question": query}),
|
||||
callback.done),
|
||||
)
|
||||
|
||||
source_documents = [
|
||||
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
|
||||
for inum, doc in enumerate(docs)
|
||||
]
|
||||
|
||||
if stream:
|
||||
async for token in callback.aiter():
|
||||
# Use server-sent-events to stream the response
|
||||
yield json.dumps({"answer": token,
|
||||
"docs": source_documents},
|
||||
ensure_ascii=False)
|
||||
else:
|
||||
answer = ""
|
||||
async for token in callback.aiter():
|
||||
answer += token
|
||||
yield json.dumps({"answer": token,
|
||||
"docs": source_documents},
|
||||
ensure_ascii=False)
|
||||
await task
|
||||
|
||||
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history),
|
||||
media_type="text/event-stream")
|
||||
|
|
@ -38,7 +38,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
|||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
history = [History(**h) if isinstance(h, dict) else h for h in history]
|
||||
history = [History.from_data(h) for h in history]
|
||||
|
||||
async def knowledge_base_chat_iterator(query: str,
|
||||
kb: KBService,
|
||||
|
|
@ -57,8 +57,9 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
|||
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
|
||||
input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
|
||||
|
|
|
|||
|
|
@ -73,6 +73,8 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
|||
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||||
|
||||
history = [History.from_data(h) for h in history]
|
||||
|
||||
async def search_engine_chat_iterator(query: str,
|
||||
search_engine_name: str,
|
||||
top_k: int,
|
||||
|
|
@ -91,8 +93,9 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
|||
docs = lookup_search_engine(query, search_engine_name, top_k)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
|
||||
input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
from typing import Awaitable
|
||||
from typing import Awaitable, List, Tuple, Dict, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.prompts.chat import ChatMessagePromptTemplate
|
||||
|
||||
|
||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||
|
|
@ -28,3 +29,29 @@ class History(BaseModel):
|
|||
|
||||
def to_msg_tuple(self):
|
||||
return "ai" if self.role=="assistant" else "human", self.content
|
||||
|
||||
def to_msg_template(self, is_raw=True) -> ChatMessagePromptTemplate:
|
||||
role_maps = {
|
||||
"ai": "assistant",
|
||||
"human": "user",
|
||||
}
|
||||
role = role_maps.get(self.role, self.role)
|
||||
if is_raw: # 当前默认历史消息都是没有input_variable的文本。
|
||||
content = "{% raw %}" + self.content + "{% endraw %}"
|
||||
else:
|
||||
content = self.content
|
||||
|
||||
return ChatMessagePromptTemplate.from_template(
|
||||
content,
|
||||
"jinja2",
|
||||
role=role,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_data(cls, h: Union[List, Tuple, Dict]) -> "History":
|
||||
if isinstance(h, (list,tuple)) and len(h) >= 2:
|
||||
h = cls(role=h[0], content=h[1])
|
||||
elif isinstance(h, dict):
|
||||
h = cls(**h)
|
||||
|
||||
return h
|
||||
|
|
|
|||
Loading…
Reference in New Issue