parent
ce8e341b9f
commit
24d1e28a07
|
|
@ -232,3 +232,14 @@ VLLM_MODEL_DICT = {
|
|||
"agentlm-70b":"THUDM/agentlm-70b",
|
||||
|
||||
}
|
||||
|
||||
## 你认为支持Agent能力的模型,可以在这里添加,添加后不会出现可视化界面的警告
|
||||
SUPPORT_AGENT_MODEL = [
|
||||
"Azure-OpenAI",
|
||||
"OpenAI",
|
||||
"Anthropic",
|
||||
"Qwen",
|
||||
"qwen-api",
|
||||
"baichuan-api",
|
||||
"agentlm"
|
||||
]
|
||||
|
|
@ -47,6 +47,11 @@ PROMPT_TEMPLATES["knowledge_base_chat"] = {
|
|||
<已知信息>{{ context }}</已知信息>、
|
||||
<问题>{{ question }}</问题>
|
||||
""",
|
||||
"Empty": # 搜不到内容的时候调用,此时没有已知信息,这个Empty可以更改,但不能删除,会影响程序使用
|
||||
"""
|
||||
<指令>请根据用户的问题,进行简洁明了的回答</指令>
|
||||
<问题>{{ question }}</问题>
|
||||
""",
|
||||
}
|
||||
PROMPT_TEMPLATES["search_engine_chat"] = {
|
||||
"default":
|
||||
|
|
@ -55,13 +60,17 @@ PROMPT_TEMPLATES["search_engine_chat"] = {
|
|||
<已知信息>{{ context }}</已知信息>、
|
||||
<问题>{{ question }}</问题>
|
||||
""",
|
||||
|
||||
"search":
|
||||
"""
|
||||
<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 </指令>
|
||||
<已知信息>{{ context }}</已知信息>、
|
||||
<问题>{{ question }}</问题>
|
||||
""",
|
||||
"Empty": # 搜不到内容的时候调用,此时没有已知信息,这个Empty可以更改,但不能删除,会影响程序使用
|
||||
"""
|
||||
<指令>请根据用户的问题,进行简洁明了的回答</指令>
|
||||
<问题>{{ question }}</问题>
|
||||
""",
|
||||
}
|
||||
PROMPT_TEMPLATES["agent_chat"] = {
|
||||
"default":
|
||||
|
|
|
|||
|
|
@ -0,0 +1,80 @@
|
|||
## 指定制定列的csv文件加载器
|
||||
|
||||
from langchain.document_loaders import CSVLoader
|
||||
import csv
|
||||
from io import TextIOWrapper
|
||||
from typing import Dict, List, Optional
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
|
||||
|
||||
class FilteredCSVLoader(CSVLoader):
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
columns_to_read: List[str],
|
||||
source_column: Optional[str] = None,
|
||||
metadata_columns: List[str] = [],
|
||||
csv_args: Optional[Dict] = None,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
file_path=file_path,
|
||||
source_column=source_column,
|
||||
metadata_columns=metadata_columns,
|
||||
csv_args=csv_args,
|
||||
encoding=encoding,
|
||||
autodetect_encoding=autodetect_encoding,
|
||||
)
|
||||
self.columns_to_read = columns_to_read
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load data into document objects."""
|
||||
|
||||
docs = []
|
||||
try:
|
||||
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
|
||||
docs = self.__read_file(csvfile)
|
||||
except UnicodeDecodeError as e:
|
||||
if self.autodetect_encoding:
|
||||
detected_encodings = detect_file_encodings(self.file_path)
|
||||
for encoding in detected_encodings:
|
||||
try:
|
||||
with open(
|
||||
self.file_path, newline="", encoding=encoding.encoding
|
||||
) as csvfile:
|
||||
docs = self.__read_file(csvfile)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(f"Error loading {self.file_path}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error loading {self.file_path}") from e
|
||||
|
||||
return docs
|
||||
def __read_file(self, csvfile: TextIOWrapper) -> List[Document]:
|
||||
docs = []
|
||||
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
|
||||
for i, row in enumerate(csv_reader):
|
||||
if self.columns_to_read[0] in row:
|
||||
content = row[self.columns_to_read[0]]
|
||||
# Extract the source if available
|
||||
source = (
|
||||
row.get(self.source_column, None)
|
||||
if self.source_column is not None
|
||||
else self.file_path
|
||||
)
|
||||
metadata = {"source": source, "row": i}
|
||||
|
||||
for col in self.metadata_columns:
|
||||
if col in row:
|
||||
metadata[col] = row[col]
|
||||
|
||||
doc = Document(page_content=content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
else:
|
||||
raise ValueError(f"Column '{self.columns_to_read[0]}' not found in CSV file.")
|
||||
|
||||
return docs
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
langchain>=0.0.319
|
||||
langchain>=0.0.324
|
||||
langchain-experimental>=0.0.30
|
||||
fschat[model_worker]==0.2.31
|
||||
xformers>=0.0.22.post4
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
langchain>=0.0.319
|
||||
langchain>=0.0.324
|
||||
langchain-experimental>=0.0.30
|
||||
fschat[model_worker]==0.2.31
|
||||
xformers>=0.0.22.post4
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
langchain>=0.0.319
|
||||
langchain>=0.0.324
|
||||
fschat>=0.2.31
|
||||
openai
|
||||
# sentence_transformers
|
||||
|
|
|
|||
|
|
@ -3,29 +3,22 @@ from langchain.agents import Tool, AgentOutputParser
|
|||
from langchain.prompts import StringPromptTemplate
|
||||
from typing import List
|
||||
from langchain.schema import AgentAction, AgentFinish
|
||||
|
||||
from configs import SUPPORT_AGENT_MODEL
|
||||
from server.agent import model_container
|
||||
class CustomPromptTemplate(StringPromptTemplate):
|
||||
# The template to use
|
||||
template: str
|
||||
# The list of tools available
|
||||
tools: List[Tool]
|
||||
|
||||
def format(self, **kwargs) -> str:
|
||||
# Get the intermediate steps (AgentAction, Observation tuples)
|
||||
# Format them in a particular way
|
||||
intermediate_steps = kwargs.pop("intermediate_steps")
|
||||
thoughts = ""
|
||||
for action, observation in intermediate_steps:
|
||||
thoughts += action.log
|
||||
thoughts += f"\nObservation: {observation}\nThought: "
|
||||
# Set the agent_scratchpad variable to that value
|
||||
kwargs["agent_scratchpad"] = thoughts
|
||||
# Create a tools variable from the list of tools provided
|
||||
kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
|
||||
# Create a list of tool names for the tools provided
|
||||
kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
|
||||
# Return the formatted templatepr
|
||||
# print( self.template.format(**kwargs), end="\n\n")
|
||||
return self.template.format(**kwargs)
|
||||
|
||||
|
||||
|
|
@ -36,9 +29,7 @@ class CustomOutputParser(AgentOutputParser):
|
|||
self.begin = True
|
||||
|
||||
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
|
||||
# Check if agent should finish
|
||||
support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api","agentlm"] # 目前支持agent的模型
|
||||
if not any(agent in model_container.MODEL for agent in support_agent) and self.begin:
|
||||
if not any(agent in model_container.MODEL for agent in SUPPORT_AGENT_MODEL) and self.begin:
|
||||
self.begin = False
|
||||
stop_words = ["Observation:"]
|
||||
min_index = len(llm_output)
|
||||
|
|
@ -54,8 +45,6 @@ class CustomOutputParser(AgentOutputParser):
|
|||
return_values={"output": llm_output.split("Final Answer:", 1)[-1].strip()},
|
||||
log=llm_output,
|
||||
)
|
||||
|
||||
# Parse out the action and action input
|
||||
parts = llm_output.split("Action:")
|
||||
if len(parts) < 2:
|
||||
return AgentFinish(
|
||||
|
|
@ -66,7 +55,7 @@ class CustomOutputParser(AgentOutputParser):
|
|||
action = parts[1].split("Action Input:")[0].strip()
|
||||
action_input = parts[1].split("Action Input:")[1].strip()
|
||||
|
||||
# 原来的正则化检查方式
|
||||
# 原来的正则化检查方式,更严格,但是成功率更低
|
||||
# regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
|
||||
# print("llm_output",llm_output)
|
||||
# match = re.search(regex, llm_output, re.DOTALL)
|
||||
|
|
@ -80,7 +69,6 @@ class CustomOutputParser(AgentOutputParser):
|
|||
# action_input = match.group(2)
|
||||
|
||||
# Return the action and action input
|
||||
|
||||
try:
|
||||
ans = AgentAction(
|
||||
tool=action,
|
||||
|
|
|
|||
|
|
@ -55,7 +55,9 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
)
|
||||
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
|
||||
if len(docs) == 0: ## 如果没有找到相关文档,使用Empty模板
|
||||
prompt_template = get_prompt_template("knowledge_base_chat", "Empty")
|
||||
else:
|
||||
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
|
||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
|
|
@ -76,6 +78,10 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
url = f"/knowledge_base/download_doc?" + parameters
|
||||
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
||||
source_documents.append(text)
|
||||
|
||||
if len(source_documents) == 0: # 没有找到相关文档
|
||||
source_documents.append(f"""<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>""")
|
||||
|
||||
if stream:
|
||||
async for token in callback.aiter():
|
||||
# Use server-sent-events to stream the response
|
||||
|
|
@ -88,7 +94,6 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
yield json.dumps({"answer": answer,
|
||||
"docs": source_documents},
|
||||
ensure_ascii=False)
|
||||
|
||||
await task
|
||||
|
||||
return StreamingResponse(knowledge_base_chat_iterator(query=query,
|
||||
|
|
|
|||
|
|
@ -176,6 +176,9 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
|||
for inum, doc in enumerate(docs)
|
||||
]
|
||||
|
||||
if len(source_documents) == 0: # 没有找到相关资料(不太可能)
|
||||
source_documents.append(f"""<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>""")
|
||||
|
||||
if stream:
|
||||
async for token in callback.aiter():
|
||||
# Use server-sent-events to stream the response
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
|
|||
return []
|
||||
docs = kb.search_docs(query, top_k, score_threshold)
|
||||
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append("/home/congyin/Code/Project_Langchain_0814/Langchain-Chatchat")
|
||||
from transformers import AutoTokenizer
|
||||
from configs import (
|
||||
EMBEDDING_MODEL,
|
||||
|
|
@ -25,7 +28,6 @@ import io
|
|||
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
|
||||
import chardet
|
||||
|
||||
|
||||
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||
# 检查是否包含预期外的字符或路径攻击关键字
|
||||
if "../" in knowledge_base_id:
|
||||
|
|
@ -72,6 +74,7 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
|
|||
"UnstructuredMarkdownLoader": ['.md'],
|
||||
"CustomJSONLoader": [".json"],
|
||||
"CSVLoader": [".csv"],
|
||||
# "FilteredCSVLoader": [".csv"], # 需要自己指定,目前还没有支持
|
||||
"RapidOCRPDFLoader": [".pdf"],
|
||||
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
|
||||
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
|
||||
|
|
@ -150,7 +153,7 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri
|
|||
根据loader_name和文件路径或内容返回文档加载器。
|
||||
'''
|
||||
try:
|
||||
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
|
||||
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader","FilteredCSVLoader"]:
|
||||
document_loaders_module = importlib.import_module('document_loaders')
|
||||
else:
|
||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||
|
|
@ -168,10 +171,11 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri
|
|||
# 自动识别文件编码类型,避免langchain loader 加载文件报编码错误
|
||||
with open(file_path_or_content, 'rb') as struct_file:
|
||||
encode_detect = chardet.detect(struct_file.read())
|
||||
if encode_detect:
|
||||
if encode_detect is None:
|
||||
encode_detect = {"encoding": "utf-8"}
|
||||
|
||||
loader = DocumentLoader(file_path_or_content, encoding=encode_detect["encoding"])
|
||||
else:
|
||||
loader = DocumentLoader(file_path_or_content, encoding="utf-8")
|
||||
## TODO:支持更多的自定义CSV读取逻辑
|
||||
|
||||
elif loader_name == "JSONLoader":
|
||||
loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False)
|
||||
|
|
@ -262,6 +266,7 @@ def make_text_splitter(
|
|||
text_splitter = TextSplitter(chunk_size=250, chunk_overlap=50)
|
||||
return text_splitter
|
||||
|
||||
|
||||
class KnowledgeFile:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -282,7 +287,7 @@ class KnowledgeFile:
|
|||
self.document_loader_name = get_LoaderClass(self.ext)
|
||||
self.text_splitter_name = TEXT_SPLITTER_NAME
|
||||
|
||||
def file2docs(self, refresh: bool=False):
|
||||
def file2docs(self, refresh: bool = False):
|
||||
if self.docs is None or refresh:
|
||||
logger.info(f"{self.document_loader_name} used for {self.filepath}")
|
||||
loader = get_loader(self.document_loader_name, self.filepath)
|
||||
|
|
@ -303,7 +308,8 @@ class KnowledgeFile:
|
|||
return []
|
||||
if self.ext not in [".csv"]:
|
||||
if text_splitter is None:
|
||||
text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap)
|
||||
if self.text_splitter_name == "MarkdownHeaderTextSplitter":
|
||||
docs = text_splitter.split_text(docs[0].page_content)
|
||||
for doc in docs:
|
||||
|
|
@ -359,6 +365,7 @@ def files2docs_in_thread(
|
|||
如果传入参数是Tuple,形式为(filename, kb_name)
|
||||
生成器返回值为 status, (kb_name, file_name, docs | error)
|
||||
'''
|
||||
|
||||
def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
|
||||
try:
|
||||
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
|
||||
|
|
@ -373,8 +380,8 @@ def files2docs_in_thread(
|
|||
kwargs = {}
|
||||
try:
|
||||
if isinstance(file, tuple) and len(file) >= 2:
|
||||
filename=file[0]
|
||||
kb_name=file[1]
|
||||
filename = file[0]
|
||||
kb_name = file[1]
|
||||
file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
|
||||
elif isinstance(file, dict):
|
||||
filename = file.pop("filename")
|
||||
|
|
@ -396,10 +403,9 @@ def files2docs_in_thread(
|
|||
if __name__ == "__main__":
|
||||
from pprint import pprint
|
||||
|
||||
kb_file = KnowledgeFile(filename="test.txt", knowledge_base_name="samples")
|
||||
kb_file = KnowledgeFile(
|
||||
filename="/home/congyin/Code/Project_Langchain_0814/Langchain-Chatchat/knowledge_base/csv1/content/gm.csv",
|
||||
knowledge_base_name="samples")
|
||||
# kb_file.text_splitter_name = "RecursiveCharacterTextSplitter"
|
||||
docs = kb_file.file2docs()
|
||||
pprint(docs[-1])
|
||||
|
||||
docs = kb_file.file2text()
|
||||
pprint(docs[-1])
|
||||
# pprint(docs[-1])
|
||||
|
|
|
|||
|
|
@ -4,16 +4,17 @@ from streamlit_chatbox import *
|
|||
from datetime import datetime
|
||||
import os
|
||||
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
|
||||
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE)
|
||||
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL)
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
chat_box = ChatBox(
|
||||
assistant_avatar=os.path.join(
|
||||
"img",
|
||||
"chatchat_icon_blue_square_v2.png"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
|
||||
'''
|
||||
返回消息历史。
|
||||
|
|
@ -210,6 +211,13 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||
|
||||
|
||||
elif dialogue_mode == "自定义Agent问答":
|
||||
if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
|
||||
chat_box.ai_say([
|
||||
f"正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!</span>\n\n\n",
|
||||
Markdown("...", in_expander=True, title="思考过程", state="complete"),
|
||||
|
||||
])
|
||||
else:
|
||||
chat_box.ai_say([
|
||||
f"正在思考...",
|
||||
Markdown("...", in_expander=True, title="思考过程", state="complete"),
|
||||
|
|
@ -217,10 +225,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||
])
|
||||
text = ""
|
||||
ans = ""
|
||||
support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api","agentlm"] # 目前支持agent的模型
|
||||
if not any(agent in llm_model for agent in support_agent):
|
||||
ans += "正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!</span>\n\n\n"
|
||||
chat_box.update_msg(ans, element_index=0, streaming=False)
|
||||
for d in api.agent_chat(prompt,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
|
|
@ -278,7 +282,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature,
|
||||
split_result=se_top_k>1):
|
||||
split_result=se_top_k > 1):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
elif chunk := d.get("answer"):
|
||||
|
|
|
|||
Loading…
Reference in New Issue