diff --git a/configs/model_config.py.example b/configs/model_config.py.example
index efbedf9..74e9023 100644
--- a/configs/model_config.py.example
+++ b/configs/model_config.py.example
@@ -231,4 +231,15 @@ VLLM_MODEL_DICT = {
"agentlm-13b":"THUDM/agentlm-13b",
"agentlm-70b":"THUDM/agentlm-70b",
-}
\ No newline at end of file
+}
+
+## 你认为支持Agent能力的模型,可以在这里添加,添加后不会出现可视化界面的警告
+SUPPORT_AGENT_MODEL = [
+ "Azure-OpenAI",
+ "OpenAI",
+ "Anthropic",
+ "Qwen",
+ "qwen-api",
+ "baichuan-api",
+ "agentlm"
+]
\ No newline at end of file
diff --git a/configs/prompt_config.py.example b/configs/prompt_config.py.example
index e2d7ceb..6080ace 100644
--- a/configs/prompt_config.py.example
+++ b/configs/prompt_config.py.example
@@ -47,6 +47,11 @@ PROMPT_TEMPLATES["knowledge_base_chat"] = {
<已知信息>{{ context }}已知信息>、
<问题>{{ question }}问题>
""",
+ "Empty": # 搜不到内容的时候调用,此时没有已知信息,这个Empty可以更改,但不能删除,会影响程序使用
+ """
+ <指令>请根据用户的问题,进行简洁明了的回答指令>
+ <问题>{{ question }}问题>
+ """,
}
PROMPT_TEMPLATES["search_engine_chat"] = {
"default":
@@ -55,13 +60,17 @@ PROMPT_TEMPLATES["search_engine_chat"] = {
<已知信息>{{ context }}已知信息>、
<问题>{{ question }}问题>
""",
-
"search":
"""
<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 指令>
<已知信息>{{ context }}已知信息>、
<问题>{{ question }}问题>
""",
+ "Empty": # 搜不到内容的时候调用,此时没有已知信息,这个Empty可以更改,但不能删除,会影响程序使用
+ """
+ <指令>请根据用户的问题,进行简洁明了的回答指令>
+ <问题>{{ question }}问题>
+ """,
}
PROMPT_TEMPLATES["agent_chat"] = {
"default":
diff --git a/document_loaders/FilteredCSVloader.py b/document_loaders/FilteredCSVloader.py
new file mode 100644
index 0000000..0f8148d
--- /dev/null
+++ b/document_loaders/FilteredCSVloader.py
@@ -0,0 +1,80 @@
+## 指定制定列的csv文件加载器
+
+from langchain.document_loaders import CSVLoader
+import csv
+from io import TextIOWrapper
+from typing import Dict, List, Optional
+from langchain.docstore.document import Document
+from langchain.document_loaders.helpers import detect_file_encodings
+
+
+class FilteredCSVLoader(CSVLoader):
+ def __init__(
+ self,
+ file_path: str,
+ columns_to_read: List[str],
+ source_column: Optional[str] = None,
+ metadata_columns: List[str] = [],
+ csv_args: Optional[Dict] = None,
+ encoding: Optional[str] = None,
+ autodetect_encoding: bool = False,
+ ):
+ super().__init__(
+ file_path=file_path,
+ source_column=source_column,
+ metadata_columns=metadata_columns,
+ csv_args=csv_args,
+ encoding=encoding,
+ autodetect_encoding=autodetect_encoding,
+ )
+ self.columns_to_read = columns_to_read
+
+ def load(self) -> List[Document]:
+ """Load data into document objects."""
+
+ docs = []
+ try:
+ with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
+ docs = self.__read_file(csvfile)
+ except UnicodeDecodeError as e:
+ if self.autodetect_encoding:
+ detected_encodings = detect_file_encodings(self.file_path)
+ for encoding in detected_encodings:
+ try:
+ with open(
+ self.file_path, newline="", encoding=encoding.encoding
+ ) as csvfile:
+ docs = self.__read_file(csvfile)
+ break
+ except UnicodeDecodeError:
+ continue
+ else:
+ raise RuntimeError(f"Error loading {self.file_path}") from e
+ except Exception as e:
+ raise RuntimeError(f"Error loading {self.file_path}") from e
+
+ return docs
+ def __read_file(self, csvfile: TextIOWrapper) -> List[Document]:
+ docs = []
+ csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
+ for i, row in enumerate(csv_reader):
+ if self.columns_to_read[0] in row:
+ content = row[self.columns_to_read[0]]
+ # Extract the source if available
+ source = (
+ row.get(self.source_column, None)
+ if self.source_column is not None
+ else self.file_path
+ )
+ metadata = {"source": source, "row": i}
+
+ for col in self.metadata_columns:
+ if col in row:
+ metadata[col] = row[col]
+
+ doc = Document(page_content=content, metadata=metadata)
+ docs.append(doc)
+ else:
+ raise ValueError(f"Column '{self.columns_to_read[0]}' not found in CSV file.")
+
+ return docs
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 364cc34..419c650 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-langchain>=0.0.319
+langchain>=0.0.324
langchain-experimental>=0.0.30
fschat[model_worker]==0.2.31
xformers>=0.0.22.post4
diff --git a/requirements_api.txt b/requirements_api.txt
index 4ef692b..126e65f 100644
--- a/requirements_api.txt
+++ b/requirements_api.txt
@@ -1,4 +1,4 @@
-langchain>=0.0.319
+langchain>=0.0.324
langchain-experimental>=0.0.30
fschat[model_worker]==0.2.31
xformers>=0.0.22.post4
diff --git a/requirements_lite.txt b/requirements_lite.txt
index d453451..4ff659e 100644
--- a/requirements_lite.txt
+++ b/requirements_lite.txt
@@ -1,4 +1,4 @@
-langchain>=0.0.319
+langchain>=0.0.324
fschat>=0.2.31
openai
# sentence_transformers
diff --git a/server/agent/custom_template.py b/server/agent/custom_template.py
index a08ba8e..2744b87 100644
--- a/server/agent/custom_template.py
+++ b/server/agent/custom_template.py
@@ -3,29 +3,22 @@ from langchain.agents import Tool, AgentOutputParser
from langchain.prompts import StringPromptTemplate
from typing import List
from langchain.schema import AgentAction, AgentFinish
+
+from configs import SUPPORT_AGENT_MODEL
from server.agent import model_container
class CustomPromptTemplate(StringPromptTemplate):
- # The template to use
template: str
- # The list of tools available
tools: List[Tool]
def format(self, **kwargs) -> str:
- # Get the intermediate steps (AgentAction, Observation tuples)
- # Format them in a particular way
intermediate_steps = kwargs.pop("intermediate_steps")
thoughts = ""
for action, observation in intermediate_steps:
thoughts += action.log
thoughts += f"\nObservation: {observation}\nThought: "
- # Set the agent_scratchpad variable to that value
kwargs["agent_scratchpad"] = thoughts
- # Create a tools variable from the list of tools provided
kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
- # Create a list of tool names for the tools provided
kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
- # Return the formatted templatepr
- # print( self.template.format(**kwargs), end="\n\n")
return self.template.format(**kwargs)
@@ -36,9 +29,7 @@ class CustomOutputParser(AgentOutputParser):
self.begin = True
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
- # Check if agent should finish
- support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api","agentlm"] # 目前支持agent的模型
- if not any(agent in model_container.MODEL for agent in support_agent) and self.begin:
+ if not any(agent in model_container.MODEL for agent in SUPPORT_AGENT_MODEL) and self.begin:
self.begin = False
stop_words = ["Observation:"]
min_index = len(llm_output)
@@ -54,8 +45,6 @@ class CustomOutputParser(AgentOutputParser):
return_values={"output": llm_output.split("Final Answer:", 1)[-1].strip()},
log=llm_output,
)
-
- # Parse out the action and action input
parts = llm_output.split("Action:")
if len(parts) < 2:
return AgentFinish(
@@ -66,7 +55,7 @@ class CustomOutputParser(AgentOutputParser):
action = parts[1].split("Action Input:")[0].strip()
action_input = parts[1].split("Action Input:")[1].strip()
- # 原来的正则化检查方式
+ # 原来的正则化检查方式,更严格,但是成功率更低
# regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
# print("llm_output",llm_output)
# match = re.search(regex, llm_output, re.DOTALL)
@@ -80,7 +69,6 @@ class CustomOutputParser(AgentOutputParser):
# action_input = match.group(2)
# Return the action and action input
-
try:
ans = AgentAction(
tool=action,
diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py
index b7fe85c..01744c1 100644
--- a/server/chat/knowledge_base_chat.py
+++ b/server/chat/knowledge_base_chat.py
@@ -55,8 +55,10 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
)
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
context = "\n".join([doc.page_content for doc in docs])
-
- prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
+ if len(docs) == 0: ## 如果没有找到相关文档,使用Empty模板
+ prompt_template = get_prompt_template("knowledge_base_chat", "Empty")
+ else:
+ prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
@@ -76,6 +78,10 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
url = f"/knowledge_base/download_doc?" + parameters
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
source_documents.append(text)
+
+ if len(source_documents) == 0: # 没有找到相关文档
+ source_documents.append(f"""未找到相关文档,该回答为大模型自身能力解答!""")
+
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
@@ -88,7 +94,6 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
yield json.dumps({"answer": answer,
"docs": source_documents},
ensure_ascii=False)
-
await task
return StreamingResponse(knowledge_base_chat_iterator(query=query,
diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py
index 3d2100e..fcda1fa 100644
--- a/server/chat/search_engine_chat.py
+++ b/server/chat/search_engine_chat.py
@@ -176,6 +176,9 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
for inum, doc in enumerate(docs)
]
+ if len(source_documents) == 0: # 没有找到相关资料(不太可能)
+ source_documents.append(f"""未找到相关文档,该回答为大模型自身能力解答!""")
+
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py
index 3d01f9e..11b8f45 100644
--- a/server/knowledge_base/kb_doc_api.py
+++ b/server/knowledge_base/kb_doc_api.py
@@ -31,7 +31,6 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
return []
docs = kb.search_docs(query, top_k, score_threshold)
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
-
return data
diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py
index c73d021..3981e81 100644
--- a/server/knowledge_base/utils.py
+++ b/server/knowledge_base/utils.py
@@ -1,4 +1,7 @@
import os
+import sys
+
+sys.path.append("/home/congyin/Code/Project_Langchain_0814/Langchain-Chatchat")
from transformers import AutoTokenizer
from configs import (
EMBEDDING_MODEL,
@@ -25,7 +28,6 @@ import io
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
import chardet
-
def validate_kb_name(knowledge_base_id: str) -> bool:
# 检查是否包含预期外的字符或路径攻击关键字
if "../" in knowledge_base_id:
@@ -72,6 +74,7 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
"UnstructuredMarkdownLoader": ['.md'],
"CustomJSONLoader": [".json"],
"CSVLoader": [".csv"],
+ # "FilteredCSVLoader": [".csv"], # 需要自己指定,目前还没有支持
"RapidOCRPDFLoader": [".pdf"],
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
@@ -88,12 +91,12 @@ class CustomJSONLoader(langchain.document_loaders.JSONLoader):
'''
def __init__(
- self,
- file_path: Union[str, Path],
- content_key: Optional[str] = None,
- metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
- text_content: bool = True,
- json_lines: bool = False,
+ self,
+ file_path: Union[str, Path],
+ content_key: Optional[str] = None,
+ metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
+ text_content: bool = True,
+ json_lines: bool = False,
):
"""Initialize the JSONLoader.
@@ -150,7 +153,7 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri
根据loader_name和文件路径或内容返回文档加载器。
'''
try:
- if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
+ if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader","FilteredCSVLoader"]:
document_loaders_module = importlib.import_module('document_loaders')
else:
document_loaders_module = importlib.import_module('langchain.document_loaders')
@@ -168,10 +171,11 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri
# 自动识别文件编码类型,避免langchain loader 加载文件报编码错误
with open(file_path_or_content, 'rb') as struct_file:
encode_detect = chardet.detect(struct_file.read())
- if encode_detect:
- loader = DocumentLoader(file_path_or_content, encoding=encode_detect["encoding"])
- else:
- loader = DocumentLoader(file_path_or_content, encoding="utf-8")
+ if encode_detect is None:
+ encode_detect = {"encoding": "utf-8"}
+
+ loader = DocumentLoader(file_path_or_content, encoding=encode_detect["encoding"])
+ ## TODO:支持更多的自定义CSV读取逻辑
elif loader_name == "JSONLoader":
loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False)
@@ -187,10 +191,10 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri
def make_text_splitter(
- splitter_name: str = TEXT_SPLITTER_NAME,
- chunk_size: int = CHUNK_SIZE,
- chunk_overlap: int = OVERLAP_SIZE,
- llm_model: str = LLM_MODEL,
+ splitter_name: str = TEXT_SPLITTER_NAME,
+ chunk_size: int = CHUNK_SIZE,
+ chunk_overlap: int = OVERLAP_SIZE,
+ llm_model: str = LLM_MODEL,
):
"""
根据参数获取特定的分词器
@@ -262,6 +266,7 @@ def make_text_splitter(
text_splitter = TextSplitter(chunk_size=250, chunk_overlap=50)
return text_splitter
+
class KnowledgeFile:
def __init__(
self,
@@ -282,7 +287,7 @@ class KnowledgeFile:
self.document_loader_name = get_LoaderClass(self.ext)
self.text_splitter_name = TEXT_SPLITTER_NAME
- def file2docs(self, refresh: bool=False):
+ def file2docs(self, refresh: bool = False):
if self.docs is None or refresh:
logger.info(f"{self.document_loader_name} used for {self.filepath}")
loader = get_loader(self.document_loader_name, self.filepath)
@@ -290,20 +295,21 @@ class KnowledgeFile:
return self.docs
def docs2texts(
- self,
- docs: List[Document] = None,
- zh_title_enhance: bool = ZH_TITLE_ENHANCE,
- refresh: bool = False,
- chunk_size: int = CHUNK_SIZE,
- chunk_overlap: int = OVERLAP_SIZE,
- text_splitter: TextSplitter = None,
+ self,
+ docs: List[Document] = None,
+ zh_title_enhance: bool = ZH_TITLE_ENHANCE,
+ refresh: bool = False,
+ chunk_size: int = CHUNK_SIZE,
+ chunk_overlap: int = OVERLAP_SIZE,
+ text_splitter: TextSplitter = None,
):
docs = docs or self.file2docs(refresh=refresh)
if not docs:
return []
if self.ext not in [".csv"]:
if text_splitter is None:
- text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
+ text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size,
+ chunk_overlap=chunk_overlap)
if self.text_splitter_name == "MarkdownHeaderTextSplitter":
docs = text_splitter.split_text(docs[0].page_content)
for doc in docs:
@@ -320,12 +326,12 @@ class KnowledgeFile:
return self.splited_docs
def file2text(
- self,
- zh_title_enhance: bool = ZH_TITLE_ENHANCE,
- refresh: bool = False,
- chunk_size: int = CHUNK_SIZE,
- chunk_overlap: int = OVERLAP_SIZE,
- text_splitter: TextSplitter = None,
+ self,
+ zh_title_enhance: bool = ZH_TITLE_ENHANCE,
+ refresh: bool = False,
+ chunk_size: int = CHUNK_SIZE,
+ chunk_overlap: int = OVERLAP_SIZE,
+ text_splitter: TextSplitter = None,
):
if self.splited_docs is None or refresh:
docs = self.file2docs()
@@ -359,6 +365,7 @@ def files2docs_in_thread(
如果传入参数是Tuple,形式为(filename, kb_name)
生成器返回值为 status, (kb_name, file_name, docs | error)
'''
+
def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
try:
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
@@ -373,8 +380,8 @@ def files2docs_in_thread(
kwargs = {}
try:
if isinstance(file, tuple) and len(file) >= 2:
- filename=file[0]
- kb_name=file[1]
+ filename = file[0]
+ kb_name = file[1]
file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
elif isinstance(file, dict):
filename = file.pop("filename")
@@ -396,10 +403,9 @@ def files2docs_in_thread(
if __name__ == "__main__":
from pprint import pprint
- kb_file = KnowledgeFile(filename="test.txt", knowledge_base_name="samples")
+ kb_file = KnowledgeFile(
+ filename="/home/congyin/Code/Project_Langchain_0814/Langchain-Chatchat/knowledge_base/csv1/content/gm.csv",
+ knowledge_base_name="samples")
# kb_file.text_splitter_name = "RecursiveCharacterTextSplitter"
docs = kb_file.file2docs()
- pprint(docs[-1])
-
- docs = kb_file.file2text()
- pprint(docs[-1])
+ # pprint(docs[-1])
diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py
index f3fb4f6..5ff5a78 100644
--- a/webui_pages/dialogue/dialogue.py
+++ b/webui_pages/dialogue/dialogue.py
@@ -4,16 +4,17 @@ from streamlit_chatbox import *
from datetime import datetime
import os
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
- DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE)
+ DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL)
from typing import List, Dict
-
chat_box = ChatBox(
assistant_avatar=os.path.join(
"img",
"chatchat_icon_blue_square_v2.png"
)
)
+
+
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
'''
返回消息历史。
@@ -55,14 +56,14 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
if is_lite:
dialogue_modes = ["LLM 对话",
- "搜索引擎问答",
- ]
+ "搜索引擎问答",
+ ]
else:
dialogue_modes = ["LLM 对话",
- "知识库问答",
- "搜索引擎问答",
- "自定义Agent问答",
- ]
+ "知识库问答",
+ "搜索引擎问答",
+ "自定义Agent问答",
+ ]
dialogue_mode = st.selectbox("请选择对话模式:",
dialogue_modes,
index=0,
@@ -102,10 +103,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
key="llm_model",
)
if (st.session_state.get("prev_llm_model") != llm_model
- and not is_lite
- and not llm_model in config_models.get("online", {})
- and not llm_model in config_models.get("langchain", {})
- and llm_model not in running_models):
+ and not is_lite
+ and not llm_model in config_models.get("online", {})
+ and not llm_model in config_models.get("langchain", {})
+ and llm_model not in running_models):
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
prev_model = st.session_state.get("prev_llm_model")
r = api.change_llm_model(prev_model, llm_model)
@@ -210,17 +211,20 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
elif dialogue_mode == "自定义Agent问答":
- chat_box.ai_say([
- f"正在思考...",
- Markdown("...", in_expander=True, title="思考过程", state="complete"),
+ if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
+ chat_box.ai_say([
+ f"正在思考... \n\n 该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!\n\n\n",
+ Markdown("...", in_expander=True, title="思考过程", state="complete"),
- ])
+ ])
+ else:
+ chat_box.ai_say([
+ f"正在思考...",
+ Markdown("...", in_expander=True, title="思考过程", state="complete"),
+
+ ])
text = ""
ans = ""
- support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api","agentlm"] # 目前支持agent的模型
- if not any(agent in llm_model for agent in support_agent):
- ans += "正在思考... \n\n 该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!\n\n\n"
- chat_box.update_msg(ans, element_index=0, streaming=False)
for d in api.agent_chat(prompt,
history=history,
model=llm_model,
@@ -278,7 +282,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature,
- split_result=se_top_k>1):
+ split_result=se_top_k > 1):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
elif chunk := d.get("answer"):
@@ -305,4 +309,4 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
file_name=f"{now:%Y-%m-%d %H.%M}_对话记录.md",
mime="text/markdown",
use_container_width=True,
- )
\ No newline at end of file
+ )