From 44edce6bcf715ae1ec66fb5f1ca78afd22b8f177 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Wed, 23 Aug 2023 08:35:26 +0800 Subject: [PATCH] =?UTF-8?q?fix=20#1142:=20=E5=9C=A8History=E4=B8=AD?= =?UTF-8?q?=E4=BD=BF=E7=94=A8jinja2=E6=A8=A1=E6=9D=BF=E4=BB=A3=E6=9B=BFf-s?= =?UTF-8?q?tring=EF=BC=8C=E9=81=BF=E5=85=8D=E6=B6=88=E6=81=AF=E4=B8=AD?= =?UTF-8?q?=E5=90=AB=E6=9C=89{=20}=E6=97=B6=E5=87=BA=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/model_config.py.example | 6 +- server/chat/chat.py | 5 +- server/chat/github_chat.py | 109 ----------------------------- server/chat/knowledge_base_chat.py | 5 +- server/chat/search_engine_chat.py | 5 +- server/chat/utils.py | 29 +++++++- 6 files changed, 41 insertions(+), 118 deletions(-) delete mode 100644 server/chat/github_chat.py diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 49c4c52..3e84b8b 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -142,12 +142,12 @@ SEARCH_ENGINE_TOP_K = 5 # nltk 模型存储路径 NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data") -# 基于本地知识问答的提示词模版 +# 基于本地知识问答的提示词模版(使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号 PROMPT_TEMPLATE = """【指令】根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 -【已知信息】{context} +【已知信息】{{ context }} -【问题】{question}""" +【问题】{{ question }}""" # API 是否开启跨域,默认为False,如果需要开启,请设置为True # is open cross domain diff --git a/server/chat/chat.py b/server/chat/chat.py index 2bc21db..3c4ebcf 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -21,7 +21,7 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成 ), stream: bool = Body(False, description="流式输出"), ): - history = [History(**h) if isinstance(h, dict) else h for h in history] + history = [History.from_data(h) for h in history] async def chat_iterator(query: str, history: List[History] = [], @@ -37,8 +37,9 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成 model_name=LLM_MODEL ) + input_msg = History(role="user", content="{{ input }}").to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages( - [i.to_msg_tuple() for i in history] + [("human", "{input}")]) + [i.to_msg_template() for i in history] + [input_msg]) chain = LLMChain(prompt=chat_prompt, llm=model) # Begin a task that runs in the background. diff --git a/server/chat/github_chat.py b/server/chat/github_chat.py deleted file mode 100644 index b161548..0000000 --- a/server/chat/github_chat.py +++ /dev/null @@ -1,109 +0,0 @@ -from langchain.document_loaders.github import GitHubIssuesLoader -from fastapi import Body -from fastapi.responses import StreamingResponse -from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE) -from server.chat.utils import wrap_done -from server.utils import BaseResponse -from langchain.chat_models import ChatOpenAI -from langchain import LLMChain -from langchain.callbacks import AsyncIteratorCallbackHandler -from typing import AsyncIterable -import asyncio -from langchain.prompts.chat import ChatPromptTemplate -from typing import List, Optional, Literal -from server.chat.utils import History -from langchain.docstore.document import Document -import json -import os -from functools import lru_cache -from datetime import datetime - - -GITHUB_PERSONAL_ACCESS_TOKEN = os.environ.get("GITHUB_PERSONAL_ACCESS_TOKEN") - - -@lru_cache(1) -def load_issues(tick: str): - ''' - set tick to a periodic value to refresh cache - ''' - loader = GitHubIssuesLoader( - repo="chatchat-space/langchain-chatglm", - access_token=GITHUB_PERSONAL_ACCESS_TOKEN, - include_prs=True, - state="all", - ) - docs = loader.load() - return docs - - -def -def github_chat(query: str = Body(..., description="用户输入", examples=["本项目最新进展"]), - top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"), - include_prs: bool = Body(True, description="是否包含PR"), - state: Literal['open', 'closed', 'all'] = Body(None, description="Issue/PR状态"), - creator: str = Body(None, description="创建者"), - history: List[History] = Body([], - description="历史对话", - examples=[[ - {"role": "user", - "content": "介绍一下本项目"}, - {"role": "assistant", - "content": "LangChain-Chatchat (原 Langchain-ChatGLM): 基于 Langchain 与 ChatGLM 等大语言模型的本地知识库问答应用实现。"}]] - ), - stream: bool = Body(False, description="流式输出"), - ): - if GITHUB_PERSONAL_ACCESS_TOKEN is None: - return BaseResponse(code=404, msg=f"使用本功能需要 GITHUB_PERSONAL_ACCESS_TOKEN") - - async def chat_iterator(query: str, - search_engine_name: str, - top_k: int, - history: Optional[List[History]], - ) -> AsyncIterable[str]: - callback = AsyncIteratorCallbackHandler() - model = ChatOpenAI( - streaming=True, - verbose=True, - callbacks=[callback], - openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], - openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], - model_name=LLM_MODEL - ) - - docs = lookup_search_engine(query, search_engine_name, top_k) - context = "\n".join([doc.page_content for doc in docs]) - - chat_prompt = ChatPromptTemplate.from_messages( - [i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)]) - - chain = LLMChain(prompt=chat_prompt, llm=model) - - # Begin a task that runs in the background. - task = asyncio.create_task(wrap_done( - chain.acall({"context": context, "question": query}), - callback.done), - ) - - source_documents = [ - f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n""" - for inum, doc in enumerate(docs) - ] - - if stream: - async for token in callback.aiter(): - # Use server-sent-events to stream the response - yield json.dumps({"answer": token, - "docs": source_documents}, - ensure_ascii=False) - else: - answer = "" - async for token in callback.aiter(): - answer += token - yield json.dumps({"answer": token, - "docs": source_documents}, - ensure_ascii=False) - await task - - return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history), - media_type="text/event-stream") diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 84c62f0..4914beb 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -38,7 +38,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - history = [History(**h) if isinstance(h, dict) else h for h in history] + history = [History.from_data(h) for h in history] async def knowledge_base_chat_iterator(query: str, kb: KBService, @@ -57,8 +57,9 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp docs = search_docs(query, knowledge_base_name, top_k, score_threshold) context = "\n".join([doc.page_content for doc in docs]) + input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages( - [i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)]) + [i.to_msg_template() for i in history] + [input_msg]) chain = LLMChain(prompt=chat_prompt, llm=model) diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index 15834d0..638c3be 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -73,6 +73,8 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl if search_engine_name not in SEARCH_ENGINES.keys(): return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") + history = [History.from_data(h) for h in history] + async def search_engine_chat_iterator(query: str, search_engine_name: str, top_k: int, @@ -91,8 +93,9 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl docs = lookup_search_engine(query, search_engine_name, top_k) context = "\n".join([doc.page_content for doc in docs]) + input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages( - [i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)]) + [i.to_msg_template() for i in history] + [input_msg]) chain = LLMChain(prompt=chat_prompt, llm=model) diff --git a/server/chat/utils.py b/server/chat/utils.py index f8afb10..2167f10 100644 --- a/server/chat/utils.py +++ b/server/chat/utils.py @@ -1,6 +1,7 @@ import asyncio -from typing import Awaitable +from typing import Awaitable, List, Tuple, Dict, Union from pydantic import BaseModel, Field +from langchain.prompts.chat import ChatMessagePromptTemplate async def wrap_done(fn: Awaitable, event: asyncio.Event): @@ -28,3 +29,29 @@ class History(BaseModel): def to_msg_tuple(self): return "ai" if self.role=="assistant" else "human", self.content + + def to_msg_template(self, is_raw=True) -> ChatMessagePromptTemplate: + role_maps = { + "ai": "assistant", + "human": "user", + } + role = role_maps.get(self.role, self.role) + if is_raw: # 当前默认历史消息都是没有input_variable的文本。 + content = "{% raw %}" + self.content + "{% endraw %}" + else: + content = self.content + + return ChatMessagePromptTemplate.from_template( + content, + "jinja2", + role=role, + ) + + @classmethod + def from_data(cls, h: Union[List, Tuple, Dict]) -> "History": + if isinstance(h, (list,tuple)) and len(h) >= 2: + h = cls(role=h[0], content=h[1]) + elif isinstance(h, dict): + h = cls(**h) + + return h