fix #1142: 在History中使用jinja2模板代替f-string,避免消息中含有{ }时出错
This commit is contained in:
parent
054ae4715d
commit
44edce6bcf
|
|
@ -142,12 +142,12 @@ SEARCH_ENGINE_TOP_K = 5
|
||||||
# nltk 模型存储路径
|
# nltk 模型存储路径
|
||||||
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
|
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
|
||||||
|
|
||||||
# 基于本地知识问答的提示词模版
|
# 基于本地知识问答的提示词模版(使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号
|
||||||
PROMPT_TEMPLATE = """【指令】根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。
|
PROMPT_TEMPLATE = """【指令】根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。
|
||||||
|
|
||||||
【已知信息】{context}
|
【已知信息】{{ context }}
|
||||||
|
|
||||||
【问题】{question}"""
|
【问题】{{ question }}"""
|
||||||
|
|
||||||
# API 是否开启跨域,默认为False,如果需要开启,请设置为True
|
# API 是否开启跨域,默认为False,如果需要开启,请设置为True
|
||||||
# is open cross domain
|
# is open cross domain
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
|
||||||
),
|
),
|
||||||
stream: bool = Body(False, description="流式输出"),
|
stream: bool = Body(False, description="流式输出"),
|
||||||
):
|
):
|
||||||
history = [History(**h) if isinstance(h, dict) else h for h in history]
|
history = [History.from_data(h) for h in history]
|
||||||
|
|
||||||
async def chat_iterator(query: str,
|
async def chat_iterator(query: str,
|
||||||
history: List[History] = [],
|
history: List[History] = [],
|
||||||
|
|
@ -37,8 +37,9 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
|
||||||
model_name=LLM_MODEL
|
model_name=LLM_MODEL
|
||||||
)
|
)
|
||||||
|
|
||||||
|
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
|
||||||
chat_prompt = ChatPromptTemplate.from_messages(
|
chat_prompt = ChatPromptTemplate.from_messages(
|
||||||
[i.to_msg_tuple() for i in history] + [("human", "{input}")])
|
[i.to_msg_template() for i in history] + [input_msg])
|
||||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||||
|
|
||||||
# Begin a task that runs in the background.
|
# Begin a task that runs in the background.
|
||||||
|
|
|
||||||
|
|
@ -1,109 +0,0 @@
|
||||||
from langchain.document_loaders.github import GitHubIssuesLoader
|
|
||||||
from fastapi import Body
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE)
|
|
||||||
from server.chat.utils import wrap_done
|
|
||||||
from server.utils import BaseResponse
|
|
||||||
from langchain.chat_models import ChatOpenAI
|
|
||||||
from langchain import LLMChain
|
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
|
||||||
from typing import AsyncIterable
|
|
||||||
import asyncio
|
|
||||||
from langchain.prompts.chat import ChatPromptTemplate
|
|
||||||
from typing import List, Optional, Literal
|
|
||||||
from server.chat.utils import History
|
|
||||||
from langchain.docstore.document import Document
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from functools import lru_cache
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
|
|
||||||
GITHUB_PERSONAL_ACCESS_TOKEN = os.environ.get("GITHUB_PERSONAL_ACCESS_TOKEN")
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(1)
|
|
||||||
def load_issues(tick: str):
|
|
||||||
'''
|
|
||||||
set tick to a periodic value to refresh cache
|
|
||||||
'''
|
|
||||||
loader = GitHubIssuesLoader(
|
|
||||||
repo="chatchat-space/langchain-chatglm",
|
|
||||||
access_token=GITHUB_PERSONAL_ACCESS_TOKEN,
|
|
||||||
include_prs=True,
|
|
||||||
state="all",
|
|
||||||
)
|
|
||||||
docs = loader.load()
|
|
||||||
return docs
|
|
||||||
|
|
||||||
|
|
||||||
def
|
|
||||||
def github_chat(query: str = Body(..., description="用户输入", examples=["本项目最新进展"]),
|
|
||||||
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
|
||||||
include_prs: bool = Body(True, description="是否包含PR"),
|
|
||||||
state: Literal['open', 'closed', 'all'] = Body(None, description="Issue/PR状态"),
|
|
||||||
creator: str = Body(None, description="创建者"),
|
|
||||||
history: List[History] = Body([],
|
|
||||||
description="历史对话",
|
|
||||||
examples=[[
|
|
||||||
{"role": "user",
|
|
||||||
"content": "介绍一下本项目"},
|
|
||||||
{"role": "assistant",
|
|
||||||
"content": "LangChain-Chatchat (原 Langchain-ChatGLM): 基于 Langchain 与 ChatGLM 等大语言模型的本地知识库问答应用实现。"}]]
|
|
||||||
),
|
|
||||||
stream: bool = Body(False, description="流式输出"),
|
|
||||||
):
|
|
||||||
if GITHUB_PERSONAL_ACCESS_TOKEN is None:
|
|
||||||
return BaseResponse(code=404, msg=f"使用本功能需要 GITHUB_PERSONAL_ACCESS_TOKEN")
|
|
||||||
|
|
||||||
async def chat_iterator(query: str,
|
|
||||||
search_engine_name: str,
|
|
||||||
top_k: int,
|
|
||||||
history: Optional[List[History]],
|
|
||||||
) -> AsyncIterable[str]:
|
|
||||||
callback = AsyncIteratorCallbackHandler()
|
|
||||||
model = ChatOpenAI(
|
|
||||||
streaming=True,
|
|
||||||
verbose=True,
|
|
||||||
callbacks=[callback],
|
|
||||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
|
||||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
|
||||||
model_name=LLM_MODEL
|
|
||||||
)
|
|
||||||
|
|
||||||
docs = lookup_search_engine(query, search_engine_name, top_k)
|
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
|
||||||
|
|
||||||
chat_prompt = ChatPromptTemplate.from_messages(
|
|
||||||
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
|
|
||||||
|
|
||||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
|
||||||
|
|
||||||
# Begin a task that runs in the background.
|
|
||||||
task = asyncio.create_task(wrap_done(
|
|
||||||
chain.acall({"context": context, "question": query}),
|
|
||||||
callback.done),
|
|
||||||
)
|
|
||||||
|
|
||||||
source_documents = [
|
|
||||||
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
|
|
||||||
for inum, doc in enumerate(docs)
|
|
||||||
]
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
async for token in callback.aiter():
|
|
||||||
# Use server-sent-events to stream the response
|
|
||||||
yield json.dumps({"answer": token,
|
|
||||||
"docs": source_documents},
|
|
||||||
ensure_ascii=False)
|
|
||||||
else:
|
|
||||||
answer = ""
|
|
||||||
async for token in callback.aiter():
|
|
||||||
answer += token
|
|
||||||
yield json.dumps({"answer": token,
|
|
||||||
"docs": source_documents},
|
|
||||||
ensure_ascii=False)
|
|
||||||
await task
|
|
||||||
|
|
||||||
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history),
|
|
||||||
media_type="text/event-stream")
|
|
||||||
|
|
@ -38,7 +38,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
||||||
if kb is None:
|
if kb is None:
|
||||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
history = [History(**h) if isinstance(h, dict) else h for h in history]
|
history = [History.from_data(h) for h in history]
|
||||||
|
|
||||||
async def knowledge_base_chat_iterator(query: str,
|
async def knowledge_base_chat_iterator(query: str,
|
||||||
kb: KBService,
|
kb: KBService,
|
||||||
|
|
@ -57,8 +57,9 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
||||||
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
context = "\n".join([doc.page_content for doc in docs])
|
||||||
|
|
||||||
|
input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False)
|
||||||
chat_prompt = ChatPromptTemplate.from_messages(
|
chat_prompt = ChatPromptTemplate.from_messages(
|
||||||
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
|
[i.to_msg_template() for i in history] + [input_msg])
|
||||||
|
|
||||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -73,6 +73,8 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
||||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||||||
|
|
||||||
|
history = [History.from_data(h) for h in history]
|
||||||
|
|
||||||
async def search_engine_chat_iterator(query: str,
|
async def search_engine_chat_iterator(query: str,
|
||||||
search_engine_name: str,
|
search_engine_name: str,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
|
|
@ -91,8 +93,9 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
||||||
docs = lookup_search_engine(query, search_engine_name, top_k)
|
docs = lookup_search_engine(query, search_engine_name, top_k)
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
context = "\n".join([doc.page_content for doc in docs])
|
||||||
|
|
||||||
|
input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False)
|
||||||
chat_prompt = ChatPromptTemplate.from_messages(
|
chat_prompt = ChatPromptTemplate.from_messages(
|
||||||
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
|
[i.to_msg_template() for i in history] + [input_msg])
|
||||||
|
|
||||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Awaitable
|
from typing import Awaitable, List, Tuple, Dict, Union
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from langchain.prompts.chat import ChatMessagePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||||
|
|
@ -28,3 +29,29 @@ class History(BaseModel):
|
||||||
|
|
||||||
def to_msg_tuple(self):
|
def to_msg_tuple(self):
|
||||||
return "ai" if self.role=="assistant" else "human", self.content
|
return "ai" if self.role=="assistant" else "human", self.content
|
||||||
|
|
||||||
|
def to_msg_template(self, is_raw=True) -> ChatMessagePromptTemplate:
|
||||||
|
role_maps = {
|
||||||
|
"ai": "assistant",
|
||||||
|
"human": "user",
|
||||||
|
}
|
||||||
|
role = role_maps.get(self.role, self.role)
|
||||||
|
if is_raw: # 当前默认历史消息都是没有input_variable的文本。
|
||||||
|
content = "{% raw %}" + self.content + "{% endraw %}"
|
||||||
|
else:
|
||||||
|
content = self.content
|
||||||
|
|
||||||
|
return ChatMessagePromptTemplate.from_template(
|
||||||
|
content,
|
||||||
|
"jinja2",
|
||||||
|
role=role,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_data(cls, h: Union[List, Tuple, Dict]) -> "History":
|
||||||
|
if isinstance(h, (list,tuple)) and len(h) >= 2:
|
||||||
|
h = cls(role=h[0], content=h[1])
|
||||||
|
elif isinstance(h, dict):
|
||||||
|
h = cls(**h)
|
||||||
|
|
||||||
|
return h
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue