一些细节优化 (#1891)

Co-authored-by: zR <zRzRzRzRzRzRzR>
This commit is contained in:
zR 2023-10-27 11:52:44 +08:00 committed by GitHub
parent ce8e341b9f
commit 24d1e28a07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 190 additions and 85 deletions

View File

@ -232,3 +232,14 @@ VLLM_MODEL_DICT = {
"agentlm-70b":"THUDM/agentlm-70b",
}
## 你认为支持Agent能力的模型可以在这里添加添加后不会出现可视化界面的警告
SUPPORT_AGENT_MODEL = [
"Azure-OpenAI",
"OpenAI",
"Anthropic",
"Qwen",
"qwen-api",
"baichuan-api",
"agentlm"
]

View File

@ -47,6 +47,11 @@ PROMPT_TEMPLATES["knowledge_base_chat"] = {
<已知信息>{{ context }}</已知信息>、
<问题>{{ question }}</问题>
""",
"Empty": # 搜不到内容的时候调用此时没有已知信息这个Empty可以更改但不能删除会影响程序使用
"""
<指令>请根据用户的问题,进行简洁明了的回答</指令>
<问题>{{ question }}</问题>
""",
}
PROMPT_TEMPLATES["search_engine_chat"] = {
"default":
@ -55,13 +60,17 @@ PROMPT_TEMPLATES["search_engine_chat"] = {
<已知信息>{{ context }}</已知信息>、
<问题>{{ question }}</问题>
""",
"search":
"""
<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 </指令>
<已知信息>{{ context }}</已知信息>、
<问题>{{ question }}</问题>
""",
"Empty": # 搜不到内容的时候调用此时没有已知信息这个Empty可以更改但不能删除会影响程序使用
"""
<指令>请根据用户的问题,进行简洁明了的回答</指令>
<问题>{{ question }}</问题>
""",
}
PROMPT_TEMPLATES["agent_chat"] = {
"default":

View File

@ -0,0 +1,80 @@
## 指定制定列的csv文件加载器
from langchain.document_loaders import CSVLoader
import csv
from io import TextIOWrapper
from typing import Dict, List, Optional
from langchain.docstore.document import Document
from langchain.document_loaders.helpers import detect_file_encodings
class FilteredCSVLoader(CSVLoader):
def __init__(
self,
file_path: str,
columns_to_read: List[str],
source_column: Optional[str] = None,
metadata_columns: List[str] = [],
csv_args: Optional[Dict] = None,
encoding: Optional[str] = None,
autodetect_encoding: bool = False,
):
super().__init__(
file_path=file_path,
source_column=source_column,
metadata_columns=metadata_columns,
csv_args=csv_args,
encoding=encoding,
autodetect_encoding=autodetect_encoding,
)
self.columns_to_read = columns_to_read
def load(self) -> List[Document]:
"""Load data into document objects."""
docs = []
try:
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
docs = self.__read_file(csvfile)
except UnicodeDecodeError as e:
if self.autodetect_encoding:
detected_encodings = detect_file_encodings(self.file_path)
for encoding in detected_encodings:
try:
with open(
self.file_path, newline="", encoding=encoding.encoding
) as csvfile:
docs = self.__read_file(csvfile)
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(f"Error loading {self.file_path}") from e
except Exception as e:
raise RuntimeError(f"Error loading {self.file_path}") from e
return docs
def __read_file(self, csvfile: TextIOWrapper) -> List[Document]:
docs = []
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
for i, row in enumerate(csv_reader):
if self.columns_to_read[0] in row:
content = row[self.columns_to_read[0]]
# Extract the source if available
source = (
row.get(self.source_column, None)
if self.source_column is not None
else self.file_path
)
metadata = {"source": source, "row": i}
for col in self.metadata_columns:
if col in row:
metadata[col] = row[col]
doc = Document(page_content=content, metadata=metadata)
docs.append(doc)
else:
raise ValueError(f"Column '{self.columns_to_read[0]}' not found in CSV file.")
return docs

View File

@ -1,4 +1,4 @@
langchain>=0.0.319
langchain>=0.0.324
langchain-experimental>=0.0.30
fschat[model_worker]==0.2.31
xformers>=0.0.22.post4

View File

@ -1,4 +1,4 @@
langchain>=0.0.319
langchain>=0.0.324
langchain-experimental>=0.0.30
fschat[model_worker]==0.2.31
xformers>=0.0.22.post4

View File

@ -1,4 +1,4 @@
langchain>=0.0.319
langchain>=0.0.324
fschat>=0.2.31
openai
# sentence_transformers

View File

@ -3,29 +3,22 @@ from langchain.agents import Tool, AgentOutputParser
from langchain.prompts import StringPromptTemplate
from typing import List
from langchain.schema import AgentAction, AgentFinish
from configs import SUPPORT_AGENT_MODEL
from server.agent import model_container
class CustomPromptTemplate(StringPromptTemplate):
# The template to use
template: str
# The list of tools available
tools: List[Tool]
def format(self, **kwargs) -> str:
# Get the intermediate steps (AgentAction, Observation tuples)
# Format them in a particular way
intermediate_steps = kwargs.pop("intermediate_steps")
thoughts = ""
for action, observation in intermediate_steps:
thoughts += action.log
thoughts += f"\nObservation: {observation}\nThought: "
# Set the agent_scratchpad variable to that value
kwargs["agent_scratchpad"] = thoughts
# Create a tools variable from the list of tools provided
kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
# Create a list of tool names for the tools provided
kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
# Return the formatted templatepr
# print( self.template.format(**kwargs), end="\n\n")
return self.template.format(**kwargs)
@ -36,9 +29,7 @@ class CustomOutputParser(AgentOutputParser):
self.begin = True
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
# Check if agent should finish
support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api","agentlm"] # 目前支持agent的模型
if not any(agent in model_container.MODEL for agent in support_agent) and self.begin:
if not any(agent in model_container.MODEL for agent in SUPPORT_AGENT_MODEL) and self.begin:
self.begin = False
stop_words = ["Observation:"]
min_index = len(llm_output)
@ -54,8 +45,6 @@ class CustomOutputParser(AgentOutputParser):
return_values={"output": llm_output.split("Final Answer:", 1)[-1].strip()},
log=llm_output,
)
# Parse out the action and action input
parts = llm_output.split("Action:")
if len(parts) < 2:
return AgentFinish(
@ -66,7 +55,7 @@ class CustomOutputParser(AgentOutputParser):
action = parts[1].split("Action Input:")[0].strip()
action_input = parts[1].split("Action Input:")[1].strip()
# 原来的正则化检查方式
# 原来的正则化检查方式,更严格,但是成功率更低
# regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
# print("llm_output",llm_output)
# match = re.search(regex, llm_output, re.DOTALL)
@ -80,7 +69,6 @@ class CustomOutputParser(AgentOutputParser):
# action_input = match.group(2)
# Return the action and action input
try:
ans = AgentAction(
tool=action,

View File

@ -55,8 +55,10 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
)
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
context = "\n".join([doc.page_content for doc in docs])
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
if len(docs) == 0: ## 如果没有找到相关文档使用Empty模板
prompt_template = get_prompt_template("knowledge_base_chat", "Empty")
else:
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
@ -76,6 +78,10 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
url = f"/knowledge_base/download_doc?" + parameters
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
source_documents.append(text)
if len(source_documents) == 0: # 没有找到相关文档
source_documents.append(f"""<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>""")
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
@ -88,7 +94,6 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
yield json.dumps({"answer": answer,
"docs": source_documents},
ensure_ascii=False)
await task
return StreamingResponse(knowledge_base_chat_iterator(query=query,

View File

@ -176,6 +176,9 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
for inum, doc in enumerate(docs)
]
if len(source_documents) == 0: # 没有找到相关资料(不太可能)
source_documents.append(f"""<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>""")
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response

View File

@ -31,7 +31,6 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
return []
docs = kb.search_docs(query, top_k, score_threshold)
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
return data

View File

@ -1,4 +1,7 @@
import os
import sys
sys.path.append("/home/congyin/Code/Project_Langchain_0814/Langchain-Chatchat")
from transformers import AutoTokenizer
from configs import (
EMBEDDING_MODEL,
@ -25,7 +28,6 @@ import io
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
import chardet
def validate_kb_name(knowledge_base_id: str) -> bool:
# 检查是否包含预期外的字符或路径攻击关键字
if "../" in knowledge_base_id:
@ -72,6 +74,7 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
"UnstructuredMarkdownLoader": ['.md'],
"CustomJSONLoader": [".json"],
"CSVLoader": [".csv"],
# "FilteredCSVLoader": [".csv"], # 需要自己指定,目前还没有支持
"RapidOCRPDFLoader": [".pdf"],
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
@ -88,12 +91,12 @@ class CustomJSONLoader(langchain.document_loaders.JSONLoader):
'''
def __init__(
self,
file_path: Union[str, Path],
content_key: Optional[str] = None,
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
text_content: bool = True,
json_lines: bool = False,
self,
file_path: Union[str, Path],
content_key: Optional[str] = None,
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
text_content: bool = True,
json_lines: bool = False,
):
"""Initialize the JSONLoader.
@ -150,7 +153,7 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri
根据loader_name和文件路径或内容返回文档加载器
'''
try:
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader","FilteredCSVLoader"]:
document_loaders_module = importlib.import_module('document_loaders')
else:
document_loaders_module = importlib.import_module('langchain.document_loaders')
@ -168,10 +171,11 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri
# 自动识别文件编码类型避免langchain loader 加载文件报编码错误
with open(file_path_or_content, 'rb') as struct_file:
encode_detect = chardet.detect(struct_file.read())
if encode_detect:
loader = DocumentLoader(file_path_or_content, encoding=encode_detect["encoding"])
else:
loader = DocumentLoader(file_path_or_content, encoding="utf-8")
if encode_detect is None:
encode_detect = {"encoding": "utf-8"}
loader = DocumentLoader(file_path_or_content, encoding=encode_detect["encoding"])
## TODO支持更多的自定义CSV读取逻辑
elif loader_name == "JSONLoader":
loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False)
@ -187,10 +191,10 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri
def make_text_splitter(
splitter_name: str = TEXT_SPLITTER_NAME,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
llm_model: str = LLM_MODEL,
splitter_name: str = TEXT_SPLITTER_NAME,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
llm_model: str = LLM_MODEL,
):
"""
根据参数获取特定的分词器
@ -262,6 +266,7 @@ def make_text_splitter(
text_splitter = TextSplitter(chunk_size=250, chunk_overlap=50)
return text_splitter
class KnowledgeFile:
def __init__(
self,
@ -282,7 +287,7 @@ class KnowledgeFile:
self.document_loader_name = get_LoaderClass(self.ext)
self.text_splitter_name = TEXT_SPLITTER_NAME
def file2docs(self, refresh: bool=False):
def file2docs(self, refresh: bool = False):
if self.docs is None or refresh:
logger.info(f"{self.document_loader_name} used for {self.filepath}")
loader = get_loader(self.document_loader_name, self.filepath)
@ -290,20 +295,21 @@ class KnowledgeFile:
return self.docs
def docs2texts(
self,
docs: List[Document] = None,
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
refresh: bool = False,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
text_splitter: TextSplitter = None,
self,
docs: List[Document] = None,
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
refresh: bool = False,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
text_splitter: TextSplitter = None,
):
docs = docs or self.file2docs(refresh=refresh)
if not docs:
return []
if self.ext not in [".csv"]:
if text_splitter is None:
text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size,
chunk_overlap=chunk_overlap)
if self.text_splitter_name == "MarkdownHeaderTextSplitter":
docs = text_splitter.split_text(docs[0].page_content)
for doc in docs:
@ -320,12 +326,12 @@ class KnowledgeFile:
return self.splited_docs
def file2text(
self,
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
refresh: bool = False,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
text_splitter: TextSplitter = None,
self,
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
refresh: bool = False,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
text_splitter: TextSplitter = None,
):
if self.splited_docs is None or refresh:
docs = self.file2docs()
@ -359,6 +365,7 @@ def files2docs_in_thread(
如果传入参数是Tuple形式为(filename, kb_name)
生成器返回值为 status, (kb_name, file_name, docs | error)
'''
def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
try:
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
@ -373,8 +380,8 @@ def files2docs_in_thread(
kwargs = {}
try:
if isinstance(file, tuple) and len(file) >= 2:
filename=file[0]
kb_name=file[1]
filename = file[0]
kb_name = file[1]
file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
elif isinstance(file, dict):
filename = file.pop("filename")
@ -396,10 +403,9 @@ def files2docs_in_thread(
if __name__ == "__main__":
from pprint import pprint
kb_file = KnowledgeFile(filename="test.txt", knowledge_base_name="samples")
kb_file = KnowledgeFile(
filename="/home/congyin/Code/Project_Langchain_0814/Langchain-Chatchat/knowledge_base/csv1/content/gm.csv",
knowledge_base_name="samples")
# kb_file.text_splitter_name = "RecursiveCharacterTextSplitter"
docs = kb_file.file2docs()
pprint(docs[-1])
docs = kb_file.file2text()
pprint(docs[-1])
# pprint(docs[-1])

View File

@ -4,16 +4,17 @@ from streamlit_chatbox import *
from datetime import datetime
import os
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE)
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL)
from typing import List, Dict
chat_box = ChatBox(
assistant_avatar=os.path.join(
"img",
"chatchat_icon_blue_square_v2.png"
)
)
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
'''
返回消息历史
@ -55,14 +56,14 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
if is_lite:
dialogue_modes = ["LLM 对话",
"搜索引擎问答",
]
"搜索引擎问答",
]
else:
dialogue_modes = ["LLM 对话",
"知识库问答",
"搜索引擎问答",
"自定义Agent问答",
]
"知识库问答",
"搜索引擎问答",
"自定义Agent问答",
]
dialogue_mode = st.selectbox("请选择对话模式:",
dialogue_modes,
index=0,
@ -102,10 +103,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
key="llm_model",
)
if (st.session_state.get("prev_llm_model") != llm_model
and not is_lite
and not llm_model in config_models.get("online", {})
and not llm_model in config_models.get("langchain", {})
and llm_model not in running_models):
and not is_lite
and not llm_model in config_models.get("online", {})
and not llm_model in config_models.get("langchain", {})
and llm_model not in running_models):
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
prev_model = st.session_state.get("prev_llm_model")
r = api.change_llm_model(prev_model, llm_model)
@ -210,17 +211,20 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
elif dialogue_mode == "自定义Agent问答":
chat_box.ai_say([
f"正在思考...",
Markdown("...", in_expander=True, title="思考过程", state="complete"),
if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
chat_box.ai_say([
f"正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐请更换支持Agent的模型获得更好的体验</span>\n\n\n",
Markdown("...", in_expander=True, title="思考过程", state="complete"),
])
])
else:
chat_box.ai_say([
f"正在思考...",
Markdown("...", in_expander=True, title="思考过程", state="complete"),
])
text = ""
ans = ""
support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api","agentlm"] # 目前支持agent的模型
if not any(agent in llm_model for agent in support_agent):
ans += "正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐请更换支持Agent的模型获得更好的体验</span>\n\n\n"
chat_box.update_msg(ans, element_index=0, streaming=False)
for d in api.agent_chat(prompt,
history=history,
model=llm_model,
@ -278,7 +282,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature,
split_result=se_top_k>1):
split_result=se_top_k > 1):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
elif chunk := d.get("answer"):