一些细节优化 (#1891)

Co-authored-by: zR <zRzRzRzRzRzRzR>
This commit is contained in:
zR 2023-10-27 11:52:44 +08:00 committed by GitHub
parent ce8e341b9f
commit 24d1e28a07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 190 additions and 85 deletions

View File

@ -231,4 +231,15 @@ VLLM_MODEL_DICT = {
"agentlm-13b":"THUDM/agentlm-13b", "agentlm-13b":"THUDM/agentlm-13b",
"agentlm-70b":"THUDM/agentlm-70b", "agentlm-70b":"THUDM/agentlm-70b",
} }
## 你认为支持Agent能力的模型可以在这里添加添加后不会出现可视化界面的警告
SUPPORT_AGENT_MODEL = [
"Azure-OpenAI",
"OpenAI",
"Anthropic",
"Qwen",
"qwen-api",
"baichuan-api",
"agentlm"
]

View File

@ -47,6 +47,11 @@ PROMPT_TEMPLATES["knowledge_base_chat"] = {
<已知信息>{{ context }}</已知信息>、 <已知信息>{{ context }}</已知信息>、
<问题>{{ question }}</问题> <问题>{{ question }}</问题>
""", """,
"Empty": # 搜不到内容的时候调用此时没有已知信息这个Empty可以更改但不能删除会影响程序使用
"""
<指令>请根据用户的问题,进行简洁明了的回答</指令>
<问题>{{ question }}</问题>
""",
} }
PROMPT_TEMPLATES["search_engine_chat"] = { PROMPT_TEMPLATES["search_engine_chat"] = {
"default": "default":
@ -55,13 +60,17 @@ PROMPT_TEMPLATES["search_engine_chat"] = {
<已知信息>{{ context }}</已知信息>、 <已知信息>{{ context }}</已知信息>、
<问题>{{ question }}</问题> <问题>{{ question }}</问题>
""", """,
"search": "search":
""" """
<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 </指令> <指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 </指令>
<已知信息>{{ context }}</已知信息>、 <已知信息>{{ context }}</已知信息>、
<问题>{{ question }}</问题> <问题>{{ question }}</问题>
""", """,
"Empty": # 搜不到内容的时候调用此时没有已知信息这个Empty可以更改但不能删除会影响程序使用
"""
<指令>请根据用户的问题,进行简洁明了的回答</指令>
<问题>{{ question }}</问题>
""",
} }
PROMPT_TEMPLATES["agent_chat"] = { PROMPT_TEMPLATES["agent_chat"] = {
"default": "default":

View File

@ -0,0 +1,80 @@
## 指定制定列的csv文件加载器
from langchain.document_loaders import CSVLoader
import csv
from io import TextIOWrapper
from typing import Dict, List, Optional
from langchain.docstore.document import Document
from langchain.document_loaders.helpers import detect_file_encodings
class FilteredCSVLoader(CSVLoader):
def __init__(
self,
file_path: str,
columns_to_read: List[str],
source_column: Optional[str] = None,
metadata_columns: List[str] = [],
csv_args: Optional[Dict] = None,
encoding: Optional[str] = None,
autodetect_encoding: bool = False,
):
super().__init__(
file_path=file_path,
source_column=source_column,
metadata_columns=metadata_columns,
csv_args=csv_args,
encoding=encoding,
autodetect_encoding=autodetect_encoding,
)
self.columns_to_read = columns_to_read
def load(self) -> List[Document]:
"""Load data into document objects."""
docs = []
try:
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
docs = self.__read_file(csvfile)
except UnicodeDecodeError as e:
if self.autodetect_encoding:
detected_encodings = detect_file_encodings(self.file_path)
for encoding in detected_encodings:
try:
with open(
self.file_path, newline="", encoding=encoding.encoding
) as csvfile:
docs = self.__read_file(csvfile)
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(f"Error loading {self.file_path}") from e
except Exception as e:
raise RuntimeError(f"Error loading {self.file_path}") from e
return docs
def __read_file(self, csvfile: TextIOWrapper) -> List[Document]:
docs = []
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
for i, row in enumerate(csv_reader):
if self.columns_to_read[0] in row:
content = row[self.columns_to_read[0]]
# Extract the source if available
source = (
row.get(self.source_column, None)
if self.source_column is not None
else self.file_path
)
metadata = {"source": source, "row": i}
for col in self.metadata_columns:
if col in row:
metadata[col] = row[col]
doc = Document(page_content=content, metadata=metadata)
docs.append(doc)
else:
raise ValueError(f"Column '{self.columns_to_read[0]}' not found in CSV file.")
return docs

View File

@ -1,4 +1,4 @@
langchain>=0.0.319 langchain>=0.0.324
langchain-experimental>=0.0.30 langchain-experimental>=0.0.30
fschat[model_worker]==0.2.31 fschat[model_worker]==0.2.31
xformers>=0.0.22.post4 xformers>=0.0.22.post4

View File

@ -1,4 +1,4 @@
langchain>=0.0.319 langchain>=0.0.324
langchain-experimental>=0.0.30 langchain-experimental>=0.0.30
fschat[model_worker]==0.2.31 fschat[model_worker]==0.2.31
xformers>=0.0.22.post4 xformers>=0.0.22.post4

View File

@ -1,4 +1,4 @@
langchain>=0.0.319 langchain>=0.0.324
fschat>=0.2.31 fschat>=0.2.31
openai openai
# sentence_transformers # sentence_transformers

View File

@ -3,29 +3,22 @@ from langchain.agents import Tool, AgentOutputParser
from langchain.prompts import StringPromptTemplate from langchain.prompts import StringPromptTemplate
from typing import List from typing import List
from langchain.schema import AgentAction, AgentFinish from langchain.schema import AgentAction, AgentFinish
from configs import SUPPORT_AGENT_MODEL
from server.agent import model_container from server.agent import model_container
class CustomPromptTemplate(StringPromptTemplate): class CustomPromptTemplate(StringPromptTemplate):
# The template to use
template: str template: str
# The list of tools available
tools: List[Tool] tools: List[Tool]
def format(self, **kwargs) -> str: def format(self, **kwargs) -> str:
# Get the intermediate steps (AgentAction, Observation tuples)
# Format them in a particular way
intermediate_steps = kwargs.pop("intermediate_steps") intermediate_steps = kwargs.pop("intermediate_steps")
thoughts = "" thoughts = ""
for action, observation in intermediate_steps: for action, observation in intermediate_steps:
thoughts += action.log thoughts += action.log
thoughts += f"\nObservation: {observation}\nThought: " thoughts += f"\nObservation: {observation}\nThought: "
# Set the agent_scratchpad variable to that value
kwargs["agent_scratchpad"] = thoughts kwargs["agent_scratchpad"] = thoughts
# Create a tools variable from the list of tools provided
kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools]) kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
# Create a list of tool names for the tools provided
kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools]) kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
# Return the formatted templatepr
# print( self.template.format(**kwargs), end="\n\n")
return self.template.format(**kwargs) return self.template.format(**kwargs)
@ -36,9 +29,7 @@ class CustomOutputParser(AgentOutputParser):
self.begin = True self.begin = True
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction: def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
# Check if agent should finish if not any(agent in model_container.MODEL for agent in SUPPORT_AGENT_MODEL) and self.begin:
support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api","agentlm"] # 目前支持agent的模型
if not any(agent in model_container.MODEL for agent in support_agent) and self.begin:
self.begin = False self.begin = False
stop_words = ["Observation:"] stop_words = ["Observation:"]
min_index = len(llm_output) min_index = len(llm_output)
@ -54,8 +45,6 @@ class CustomOutputParser(AgentOutputParser):
return_values={"output": llm_output.split("Final Answer:", 1)[-1].strip()}, return_values={"output": llm_output.split("Final Answer:", 1)[-1].strip()},
log=llm_output, log=llm_output,
) )
# Parse out the action and action input
parts = llm_output.split("Action:") parts = llm_output.split("Action:")
if len(parts) < 2: if len(parts) < 2:
return AgentFinish( return AgentFinish(
@ -66,7 +55,7 @@ class CustomOutputParser(AgentOutputParser):
action = parts[1].split("Action Input:")[0].strip() action = parts[1].split("Action Input:")[0].strip()
action_input = parts[1].split("Action Input:")[1].strip() action_input = parts[1].split("Action Input:")[1].strip()
# 原来的正则化检查方式 # 原来的正则化检查方式,更严格,但是成功率更低
# regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)" # regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
# print("llm_output",llm_output) # print("llm_output",llm_output)
# match = re.search(regex, llm_output, re.DOTALL) # match = re.search(regex, llm_output, re.DOTALL)
@ -80,7 +69,6 @@ class CustomOutputParser(AgentOutputParser):
# action_input = match.group(2) # action_input = match.group(2)
# Return the action and action input # Return the action and action input
try: try:
ans = AgentAction( ans = AgentAction(
tool=action, tool=action,

View File

@ -55,8 +55,10 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
) )
docs = search_docs(query, knowledge_base_name, top_k, score_threshold) docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
context = "\n".join([doc.page_content for doc in docs]) context = "\n".join([doc.page_content for doc in docs])
if len(docs) == 0: ## 如果没有找到相关文档使用Empty模板
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) prompt_template = get_prompt_template("knowledge_base_chat", "Empty")
else:
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False) input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages( chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg]) [i.to_msg_template() for i in history] + [input_msg])
@ -76,6 +78,10 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
url = f"/knowledge_base/download_doc?" + parameters url = f"/knowledge_base/download_doc?" + parameters
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n""" text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
source_documents.append(text) source_documents.append(text)
if len(source_documents) == 0: # 没有找到相关文档
source_documents.append(f"""<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>""")
if stream: if stream:
async for token in callback.aiter(): async for token in callback.aiter():
# Use server-sent-events to stream the response # Use server-sent-events to stream the response
@ -88,7 +94,6 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
yield json.dumps({"answer": answer, yield json.dumps({"answer": answer,
"docs": source_documents}, "docs": source_documents},
ensure_ascii=False) ensure_ascii=False)
await task await task
return StreamingResponse(knowledge_base_chat_iterator(query=query, return StreamingResponse(knowledge_base_chat_iterator(query=query,

View File

@ -176,6 +176,9 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
for inum, doc in enumerate(docs) for inum, doc in enumerate(docs)
] ]
if len(source_documents) == 0: # 没有找到相关资料(不太可能)
source_documents.append(f"""<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>""")
if stream: if stream:
async for token in callback.aiter(): async for token in callback.aiter():
# Use server-sent-events to stream the response # Use server-sent-events to stream the response

View File

@ -31,7 +31,6 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
return [] return []
docs = kb.search_docs(query, top_k, score_threshold) docs = kb.search_docs(query, top_k, score_threshold)
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs] data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
return data return data

View File

@ -1,4 +1,7 @@
import os import os
import sys
sys.path.append("/home/congyin/Code/Project_Langchain_0814/Langchain-Chatchat")
from transformers import AutoTokenizer from transformers import AutoTokenizer
from configs import ( from configs import (
EMBEDDING_MODEL, EMBEDDING_MODEL,
@ -25,7 +28,6 @@ import io
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
import chardet import chardet
def validate_kb_name(knowledge_base_id: str) -> bool: def validate_kb_name(knowledge_base_id: str) -> bool:
# 检查是否包含预期外的字符或路径攻击关键字 # 检查是否包含预期外的字符或路径攻击关键字
if "../" in knowledge_base_id: if "../" in knowledge_base_id:
@ -72,6 +74,7 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
"UnstructuredMarkdownLoader": ['.md'], "UnstructuredMarkdownLoader": ['.md'],
"CustomJSONLoader": [".json"], "CustomJSONLoader": [".json"],
"CSVLoader": [".csv"], "CSVLoader": [".csv"],
# "FilteredCSVLoader": [".csv"], # 需要自己指定,目前还没有支持
"RapidOCRPDFLoader": [".pdf"], "RapidOCRPDFLoader": [".pdf"],
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'], "RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
"UnstructuredFileLoader": ['.eml', '.msg', '.rst', "UnstructuredFileLoader": ['.eml', '.msg', '.rst',
@ -88,12 +91,12 @@ class CustomJSONLoader(langchain.document_loaders.JSONLoader):
''' '''
def __init__( def __init__(
self, self,
file_path: Union[str, Path], file_path: Union[str, Path],
content_key: Optional[str] = None, content_key: Optional[str] = None,
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None, metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
text_content: bool = True, text_content: bool = True,
json_lines: bool = False, json_lines: bool = False,
): ):
"""Initialize the JSONLoader. """Initialize the JSONLoader.
@ -150,7 +153,7 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri
根据loader_name和文件路径或内容返回文档加载器 根据loader_name和文件路径或内容返回文档加载器
''' '''
try: try:
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]: if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader","FilteredCSVLoader"]:
document_loaders_module = importlib.import_module('document_loaders') document_loaders_module = importlib.import_module('document_loaders')
else: else:
document_loaders_module = importlib.import_module('langchain.document_loaders') document_loaders_module = importlib.import_module('langchain.document_loaders')
@ -168,10 +171,11 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri
# 自动识别文件编码类型避免langchain loader 加载文件报编码错误 # 自动识别文件编码类型避免langchain loader 加载文件报编码错误
with open(file_path_or_content, 'rb') as struct_file: with open(file_path_or_content, 'rb') as struct_file:
encode_detect = chardet.detect(struct_file.read()) encode_detect = chardet.detect(struct_file.read())
if encode_detect: if encode_detect is None:
loader = DocumentLoader(file_path_or_content, encoding=encode_detect["encoding"]) encode_detect = {"encoding": "utf-8"}
else:
loader = DocumentLoader(file_path_or_content, encoding="utf-8") loader = DocumentLoader(file_path_or_content, encoding=encode_detect["encoding"])
## TODO支持更多的自定义CSV读取逻辑
elif loader_name == "JSONLoader": elif loader_name == "JSONLoader":
loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False) loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False)
@ -187,10 +191,10 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri
def make_text_splitter( def make_text_splitter(
splitter_name: str = TEXT_SPLITTER_NAME, splitter_name: str = TEXT_SPLITTER_NAME,
chunk_size: int = CHUNK_SIZE, chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE, chunk_overlap: int = OVERLAP_SIZE,
llm_model: str = LLM_MODEL, llm_model: str = LLM_MODEL,
): ):
""" """
根据参数获取特定的分词器 根据参数获取特定的分词器
@ -262,6 +266,7 @@ def make_text_splitter(
text_splitter = TextSplitter(chunk_size=250, chunk_overlap=50) text_splitter = TextSplitter(chunk_size=250, chunk_overlap=50)
return text_splitter return text_splitter
class KnowledgeFile: class KnowledgeFile:
def __init__( def __init__(
self, self,
@ -282,7 +287,7 @@ class KnowledgeFile:
self.document_loader_name = get_LoaderClass(self.ext) self.document_loader_name = get_LoaderClass(self.ext)
self.text_splitter_name = TEXT_SPLITTER_NAME self.text_splitter_name = TEXT_SPLITTER_NAME
def file2docs(self, refresh: bool=False): def file2docs(self, refresh: bool = False):
if self.docs is None or refresh: if self.docs is None or refresh:
logger.info(f"{self.document_loader_name} used for {self.filepath}") logger.info(f"{self.document_loader_name} used for {self.filepath}")
loader = get_loader(self.document_loader_name, self.filepath) loader = get_loader(self.document_loader_name, self.filepath)
@ -290,20 +295,21 @@ class KnowledgeFile:
return self.docs return self.docs
def docs2texts( def docs2texts(
self, self,
docs: List[Document] = None, docs: List[Document] = None,
zh_title_enhance: bool = ZH_TITLE_ENHANCE, zh_title_enhance: bool = ZH_TITLE_ENHANCE,
refresh: bool = False, refresh: bool = False,
chunk_size: int = CHUNK_SIZE, chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE, chunk_overlap: int = OVERLAP_SIZE,
text_splitter: TextSplitter = None, text_splitter: TextSplitter = None,
): ):
docs = docs or self.file2docs(refresh=refresh) docs = docs or self.file2docs(refresh=refresh)
if not docs: if not docs:
return [] return []
if self.ext not in [".csv"]: if self.ext not in [".csv"]:
if text_splitter is None: if text_splitter is None:
text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap) text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size,
chunk_overlap=chunk_overlap)
if self.text_splitter_name == "MarkdownHeaderTextSplitter": if self.text_splitter_name == "MarkdownHeaderTextSplitter":
docs = text_splitter.split_text(docs[0].page_content) docs = text_splitter.split_text(docs[0].page_content)
for doc in docs: for doc in docs:
@ -320,12 +326,12 @@ class KnowledgeFile:
return self.splited_docs return self.splited_docs
def file2text( def file2text(
self, self,
zh_title_enhance: bool = ZH_TITLE_ENHANCE, zh_title_enhance: bool = ZH_TITLE_ENHANCE,
refresh: bool = False, refresh: bool = False,
chunk_size: int = CHUNK_SIZE, chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE, chunk_overlap: int = OVERLAP_SIZE,
text_splitter: TextSplitter = None, text_splitter: TextSplitter = None,
): ):
if self.splited_docs is None or refresh: if self.splited_docs is None or refresh:
docs = self.file2docs() docs = self.file2docs()
@ -359,6 +365,7 @@ def files2docs_in_thread(
如果传入参数是Tuple形式为(filename, kb_name) 如果传入参数是Tuple形式为(filename, kb_name)
生成器返回值为 status, (kb_name, file_name, docs | error) 生成器返回值为 status, (kb_name, file_name, docs | error)
''' '''
def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]: def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
try: try:
return True, (file.kb_name, file.filename, file.file2text(**kwargs)) return True, (file.kb_name, file.filename, file.file2text(**kwargs))
@ -373,8 +380,8 @@ def files2docs_in_thread(
kwargs = {} kwargs = {}
try: try:
if isinstance(file, tuple) and len(file) >= 2: if isinstance(file, tuple) and len(file) >= 2:
filename=file[0] filename = file[0]
kb_name=file[1] kb_name = file[1]
file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
elif isinstance(file, dict): elif isinstance(file, dict):
filename = file.pop("filename") filename = file.pop("filename")
@ -396,10 +403,9 @@ def files2docs_in_thread(
if __name__ == "__main__": if __name__ == "__main__":
from pprint import pprint from pprint import pprint
kb_file = KnowledgeFile(filename="test.txt", knowledge_base_name="samples") kb_file = KnowledgeFile(
filename="/home/congyin/Code/Project_Langchain_0814/Langchain-Chatchat/knowledge_base/csv1/content/gm.csv",
knowledge_base_name="samples")
# kb_file.text_splitter_name = "RecursiveCharacterTextSplitter" # kb_file.text_splitter_name = "RecursiveCharacterTextSplitter"
docs = kb_file.file2docs() docs = kb_file.file2docs()
pprint(docs[-1]) # pprint(docs[-1])
docs = kb_file.file2text()
pprint(docs[-1])

View File

@ -4,16 +4,17 @@ from streamlit_chatbox import *
from datetime import datetime from datetime import datetime
import os import os
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES, from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE) DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL)
from typing import List, Dict from typing import List, Dict
chat_box = ChatBox( chat_box = ChatBox(
assistant_avatar=os.path.join( assistant_avatar=os.path.join(
"img", "img",
"chatchat_icon_blue_square_v2.png" "chatchat_icon_blue_square_v2.png"
) )
) )
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]: def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
''' '''
返回消息历史 返回消息历史
@ -55,14 +56,14 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
if is_lite: if is_lite:
dialogue_modes = ["LLM 对话", dialogue_modes = ["LLM 对话",
"搜索引擎问答", "搜索引擎问答",
] ]
else: else:
dialogue_modes = ["LLM 对话", dialogue_modes = ["LLM 对话",
"知识库问答", "知识库问答",
"搜索引擎问答", "搜索引擎问答",
"自定义Agent问答", "自定义Agent问答",
] ]
dialogue_mode = st.selectbox("请选择对话模式:", dialogue_mode = st.selectbox("请选择对话模式:",
dialogue_modes, dialogue_modes,
index=0, index=0,
@ -102,10 +103,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
key="llm_model", key="llm_model",
) )
if (st.session_state.get("prev_llm_model") != llm_model if (st.session_state.get("prev_llm_model") != llm_model
and not is_lite and not is_lite
and not llm_model in config_models.get("online", {}) and not llm_model in config_models.get("online", {})
and not llm_model in config_models.get("langchain", {}) and not llm_model in config_models.get("langchain", {})
and llm_model not in running_models): and llm_model not in running_models):
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"): with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
prev_model = st.session_state.get("prev_llm_model") prev_model = st.session_state.get("prev_llm_model")
r = api.change_llm_model(prev_model, llm_model) r = api.change_llm_model(prev_model, llm_model)
@ -210,17 +211,20 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
elif dialogue_mode == "自定义Agent问答": elif dialogue_mode == "自定义Agent问答":
chat_box.ai_say([ if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
f"正在思考...", chat_box.ai_say([
Markdown("...", in_expander=True, title="思考过程", state="complete"), f"正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐请更换支持Agent的模型获得更好的体验</span>\n\n\n",
Markdown("...", in_expander=True, title="思考过程", state="complete"),
]) ])
else:
chat_box.ai_say([
f"正在思考...",
Markdown("...", in_expander=True, title="思考过程", state="complete"),
])
text = "" text = ""
ans = "" ans = ""
support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api","agentlm"] # 目前支持agent的模型
if not any(agent in llm_model for agent in support_agent):
ans += "正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐请更换支持Agent的模型获得更好的体验</span>\n\n\n"
chat_box.update_msg(ans, element_index=0, streaming=False)
for d in api.agent_chat(prompt, for d in api.agent_chat(prompt,
history=history, history=history,
model=llm_model, model=llm_model,
@ -278,7 +282,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
model=llm_model, model=llm_model,
prompt_name=prompt_template_name, prompt_name=prompt_template_name,
temperature=temperature, temperature=temperature,
split_result=se_top_k>1): split_result=se_top_k > 1):
if error_msg := check_error_msg(d): # check whether error occured if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg) st.error(error_msg)
elif chunk := d.get("answer"): elif chunk := d.get("answer"):
@ -305,4 +309,4 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
file_name=f"{now:%Y-%m-%d %H.%M}_对话记录.md", file_name=f"{now:%Y-%m-%d %H.%M}_对话记录.md",
mime="text/markdown", mime="text/markdown",
use_container_width=True, use_container_width=True,
) )