parent
ce8e341b9f
commit
24d1e28a07
|
|
@ -231,4 +231,15 @@ VLLM_MODEL_DICT = {
|
||||||
"agentlm-13b":"THUDM/agentlm-13b",
|
"agentlm-13b":"THUDM/agentlm-13b",
|
||||||
"agentlm-70b":"THUDM/agentlm-70b",
|
"agentlm-70b":"THUDM/agentlm-70b",
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
## 你认为支持Agent能力的模型,可以在这里添加,添加后不会出现可视化界面的警告
|
||||||
|
SUPPORT_AGENT_MODEL = [
|
||||||
|
"Azure-OpenAI",
|
||||||
|
"OpenAI",
|
||||||
|
"Anthropic",
|
||||||
|
"Qwen",
|
||||||
|
"qwen-api",
|
||||||
|
"baichuan-api",
|
||||||
|
"agentlm"
|
||||||
|
]
|
||||||
|
|
@ -47,6 +47,11 @@ PROMPT_TEMPLATES["knowledge_base_chat"] = {
|
||||||
<已知信息>{{ context }}</已知信息>、
|
<已知信息>{{ context }}</已知信息>、
|
||||||
<问题>{{ question }}</问题>
|
<问题>{{ question }}</问题>
|
||||||
""",
|
""",
|
||||||
|
"Empty": # 搜不到内容的时候调用,此时没有已知信息,这个Empty可以更改,但不能删除,会影响程序使用
|
||||||
|
"""
|
||||||
|
<指令>请根据用户的问题,进行简洁明了的回答</指令>
|
||||||
|
<问题>{{ question }}</问题>
|
||||||
|
""",
|
||||||
}
|
}
|
||||||
PROMPT_TEMPLATES["search_engine_chat"] = {
|
PROMPT_TEMPLATES["search_engine_chat"] = {
|
||||||
"default":
|
"default":
|
||||||
|
|
@ -55,13 +60,17 @@ PROMPT_TEMPLATES["search_engine_chat"] = {
|
||||||
<已知信息>{{ context }}</已知信息>、
|
<已知信息>{{ context }}</已知信息>、
|
||||||
<问题>{{ question }}</问题>
|
<问题>{{ question }}</问题>
|
||||||
""",
|
""",
|
||||||
|
|
||||||
"search":
|
"search":
|
||||||
"""
|
"""
|
||||||
<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 </指令>
|
<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 </指令>
|
||||||
<已知信息>{{ context }}</已知信息>、
|
<已知信息>{{ context }}</已知信息>、
|
||||||
<问题>{{ question }}</问题>
|
<问题>{{ question }}</问题>
|
||||||
""",
|
""",
|
||||||
|
"Empty": # 搜不到内容的时候调用,此时没有已知信息,这个Empty可以更改,但不能删除,会影响程序使用
|
||||||
|
"""
|
||||||
|
<指令>请根据用户的问题,进行简洁明了的回答</指令>
|
||||||
|
<问题>{{ question }}</问题>
|
||||||
|
""",
|
||||||
}
|
}
|
||||||
PROMPT_TEMPLATES["agent_chat"] = {
|
PROMPT_TEMPLATES["agent_chat"] = {
|
||||||
"default":
|
"default":
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,80 @@
|
||||||
|
## 指定制定列的csv文件加载器
|
||||||
|
|
||||||
|
from langchain.document_loaders import CSVLoader
|
||||||
|
import csv
|
||||||
|
from io import TextIOWrapper
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.document_loaders.helpers import detect_file_encodings
|
||||||
|
|
||||||
|
|
||||||
|
class FilteredCSVLoader(CSVLoader):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
columns_to_read: List[str],
|
||||||
|
source_column: Optional[str] = None,
|
||||||
|
metadata_columns: List[str] = [],
|
||||||
|
csv_args: Optional[Dict] = None,
|
||||||
|
encoding: Optional[str] = None,
|
||||||
|
autodetect_encoding: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
file_path=file_path,
|
||||||
|
source_column=source_column,
|
||||||
|
metadata_columns=metadata_columns,
|
||||||
|
csv_args=csv_args,
|
||||||
|
encoding=encoding,
|
||||||
|
autodetect_encoding=autodetect_encoding,
|
||||||
|
)
|
||||||
|
self.columns_to_read = columns_to_read
|
||||||
|
|
||||||
|
def load(self) -> List[Document]:
|
||||||
|
"""Load data into document objects."""
|
||||||
|
|
||||||
|
docs = []
|
||||||
|
try:
|
||||||
|
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
|
||||||
|
docs = self.__read_file(csvfile)
|
||||||
|
except UnicodeDecodeError as e:
|
||||||
|
if self.autodetect_encoding:
|
||||||
|
detected_encodings = detect_file_encodings(self.file_path)
|
||||||
|
for encoding in detected_encodings:
|
||||||
|
try:
|
||||||
|
with open(
|
||||||
|
self.file_path, newline="", encoding=encoding.encoding
|
||||||
|
) as csvfile:
|
||||||
|
docs = self.__read_file(csvfile)
|
||||||
|
break
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Error loading {self.file_path}") from e
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Error loading {self.file_path}") from e
|
||||||
|
|
||||||
|
return docs
|
||||||
|
def __read_file(self, csvfile: TextIOWrapper) -> List[Document]:
|
||||||
|
docs = []
|
||||||
|
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
|
||||||
|
for i, row in enumerate(csv_reader):
|
||||||
|
if self.columns_to_read[0] in row:
|
||||||
|
content = row[self.columns_to_read[0]]
|
||||||
|
# Extract the source if available
|
||||||
|
source = (
|
||||||
|
row.get(self.source_column, None)
|
||||||
|
if self.source_column is not None
|
||||||
|
else self.file_path
|
||||||
|
)
|
||||||
|
metadata = {"source": source, "row": i}
|
||||||
|
|
||||||
|
for col in self.metadata_columns:
|
||||||
|
if col in row:
|
||||||
|
metadata[col] = row[col]
|
||||||
|
|
||||||
|
doc = Document(page_content=content, metadata=metadata)
|
||||||
|
docs.append(doc)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Column '{self.columns_to_read[0]}' not found in CSV file.")
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
langchain>=0.0.319
|
langchain>=0.0.324
|
||||||
langchain-experimental>=0.0.30
|
langchain-experimental>=0.0.30
|
||||||
fschat[model_worker]==0.2.31
|
fschat[model_worker]==0.2.31
|
||||||
xformers>=0.0.22.post4
|
xformers>=0.0.22.post4
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
langchain>=0.0.319
|
langchain>=0.0.324
|
||||||
langchain-experimental>=0.0.30
|
langchain-experimental>=0.0.30
|
||||||
fschat[model_worker]==0.2.31
|
fschat[model_worker]==0.2.31
|
||||||
xformers>=0.0.22.post4
|
xformers>=0.0.22.post4
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
langchain>=0.0.319
|
langchain>=0.0.324
|
||||||
fschat>=0.2.31
|
fschat>=0.2.31
|
||||||
openai
|
openai
|
||||||
# sentence_transformers
|
# sentence_transformers
|
||||||
|
|
|
||||||
|
|
@ -3,29 +3,22 @@ from langchain.agents import Tool, AgentOutputParser
|
||||||
from langchain.prompts import StringPromptTemplate
|
from langchain.prompts import StringPromptTemplate
|
||||||
from typing import List
|
from typing import List
|
||||||
from langchain.schema import AgentAction, AgentFinish
|
from langchain.schema import AgentAction, AgentFinish
|
||||||
|
|
||||||
|
from configs import SUPPORT_AGENT_MODEL
|
||||||
from server.agent import model_container
|
from server.agent import model_container
|
||||||
class CustomPromptTemplate(StringPromptTemplate):
|
class CustomPromptTemplate(StringPromptTemplate):
|
||||||
# The template to use
|
|
||||||
template: str
|
template: str
|
||||||
# The list of tools available
|
|
||||||
tools: List[Tool]
|
tools: List[Tool]
|
||||||
|
|
||||||
def format(self, **kwargs) -> str:
|
def format(self, **kwargs) -> str:
|
||||||
# Get the intermediate steps (AgentAction, Observation tuples)
|
|
||||||
# Format them in a particular way
|
|
||||||
intermediate_steps = kwargs.pop("intermediate_steps")
|
intermediate_steps = kwargs.pop("intermediate_steps")
|
||||||
thoughts = ""
|
thoughts = ""
|
||||||
for action, observation in intermediate_steps:
|
for action, observation in intermediate_steps:
|
||||||
thoughts += action.log
|
thoughts += action.log
|
||||||
thoughts += f"\nObservation: {observation}\nThought: "
|
thoughts += f"\nObservation: {observation}\nThought: "
|
||||||
# Set the agent_scratchpad variable to that value
|
|
||||||
kwargs["agent_scratchpad"] = thoughts
|
kwargs["agent_scratchpad"] = thoughts
|
||||||
# Create a tools variable from the list of tools provided
|
|
||||||
kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
|
kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
|
||||||
# Create a list of tool names for the tools provided
|
|
||||||
kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
|
kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
|
||||||
# Return the formatted templatepr
|
|
||||||
# print( self.template.format(**kwargs), end="\n\n")
|
|
||||||
return self.template.format(**kwargs)
|
return self.template.format(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -36,9 +29,7 @@ class CustomOutputParser(AgentOutputParser):
|
||||||
self.begin = True
|
self.begin = True
|
||||||
|
|
||||||
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
|
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
|
||||||
# Check if agent should finish
|
if not any(agent in model_container.MODEL for agent in SUPPORT_AGENT_MODEL) and self.begin:
|
||||||
support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api","agentlm"] # 目前支持agent的模型
|
|
||||||
if not any(agent in model_container.MODEL for agent in support_agent) and self.begin:
|
|
||||||
self.begin = False
|
self.begin = False
|
||||||
stop_words = ["Observation:"]
|
stop_words = ["Observation:"]
|
||||||
min_index = len(llm_output)
|
min_index = len(llm_output)
|
||||||
|
|
@ -54,8 +45,6 @@ class CustomOutputParser(AgentOutputParser):
|
||||||
return_values={"output": llm_output.split("Final Answer:", 1)[-1].strip()},
|
return_values={"output": llm_output.split("Final Answer:", 1)[-1].strip()},
|
||||||
log=llm_output,
|
log=llm_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse out the action and action input
|
|
||||||
parts = llm_output.split("Action:")
|
parts = llm_output.split("Action:")
|
||||||
if len(parts) < 2:
|
if len(parts) < 2:
|
||||||
return AgentFinish(
|
return AgentFinish(
|
||||||
|
|
@ -66,7 +55,7 @@ class CustomOutputParser(AgentOutputParser):
|
||||||
action = parts[1].split("Action Input:")[0].strip()
|
action = parts[1].split("Action Input:")[0].strip()
|
||||||
action_input = parts[1].split("Action Input:")[1].strip()
|
action_input = parts[1].split("Action Input:")[1].strip()
|
||||||
|
|
||||||
# 原来的正则化检查方式
|
# 原来的正则化检查方式,更严格,但是成功率更低
|
||||||
# regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
|
# regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
|
||||||
# print("llm_output",llm_output)
|
# print("llm_output",llm_output)
|
||||||
# match = re.search(regex, llm_output, re.DOTALL)
|
# match = re.search(regex, llm_output, re.DOTALL)
|
||||||
|
|
@ -80,7 +69,6 @@ class CustomOutputParser(AgentOutputParser):
|
||||||
# action_input = match.group(2)
|
# action_input = match.group(2)
|
||||||
|
|
||||||
# Return the action and action input
|
# Return the action and action input
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ans = AgentAction(
|
ans = AgentAction(
|
||||||
tool=action,
|
tool=action,
|
||||||
|
|
|
||||||
|
|
@ -55,8 +55,10 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
||||||
)
|
)
|
||||||
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
context = "\n".join([doc.page_content for doc in docs])
|
||||||
|
if len(docs) == 0: ## 如果没有找到相关文档,使用Empty模板
|
||||||
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
|
prompt_template = get_prompt_template("knowledge_base_chat", "Empty")
|
||||||
|
else:
|
||||||
|
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
|
||||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||||
chat_prompt = ChatPromptTemplate.from_messages(
|
chat_prompt = ChatPromptTemplate.from_messages(
|
||||||
[i.to_msg_template() for i in history] + [input_msg])
|
[i.to_msg_template() for i in history] + [input_msg])
|
||||||
|
|
@ -76,6 +78,10 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
||||||
url = f"/knowledge_base/download_doc?" + parameters
|
url = f"/knowledge_base/download_doc?" + parameters
|
||||||
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
||||||
source_documents.append(text)
|
source_documents.append(text)
|
||||||
|
|
||||||
|
if len(source_documents) == 0: # 没有找到相关文档
|
||||||
|
source_documents.append(f"""<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>""")
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
async for token in callback.aiter():
|
async for token in callback.aiter():
|
||||||
# Use server-sent-events to stream the response
|
# Use server-sent-events to stream the response
|
||||||
|
|
@ -88,7 +94,6 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
||||||
yield json.dumps({"answer": answer,
|
yield json.dumps({"answer": answer,
|
||||||
"docs": source_documents},
|
"docs": source_documents},
|
||||||
ensure_ascii=False)
|
ensure_ascii=False)
|
||||||
|
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return StreamingResponse(knowledge_base_chat_iterator(query=query,
|
return StreamingResponse(knowledge_base_chat_iterator(query=query,
|
||||||
|
|
|
||||||
|
|
@ -176,6 +176,9 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
||||||
for inum, doc in enumerate(docs)
|
for inum, doc in enumerate(docs)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if len(source_documents) == 0: # 没有找到相关资料(不太可能)
|
||||||
|
source_documents.append(f"""<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>""")
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
async for token in callback.aiter():
|
async for token in callback.aiter():
|
||||||
# Use server-sent-events to stream the response
|
# Use server-sent-events to stream the response
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,6 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
|
||||||
return []
|
return []
|
||||||
docs = kb.search_docs(query, top_k, score_threshold)
|
docs = kb.search_docs(query, top_k, score_threshold)
|
||||||
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
|
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,7 @@
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append("/home/congyin/Code/Project_Langchain_0814/Langchain-Chatchat")
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from configs import (
|
from configs import (
|
||||||
EMBEDDING_MODEL,
|
EMBEDDING_MODEL,
|
||||||
|
|
@ -25,7 +28,6 @@ import io
|
||||||
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
|
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
|
||||||
import chardet
|
import chardet
|
||||||
|
|
||||||
|
|
||||||
def validate_kb_name(knowledge_base_id: str) -> bool:
|
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||||
# 检查是否包含预期外的字符或路径攻击关键字
|
# 检查是否包含预期外的字符或路径攻击关键字
|
||||||
if "../" in knowledge_base_id:
|
if "../" in knowledge_base_id:
|
||||||
|
|
@ -72,6 +74,7 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
|
||||||
"UnstructuredMarkdownLoader": ['.md'],
|
"UnstructuredMarkdownLoader": ['.md'],
|
||||||
"CustomJSONLoader": [".json"],
|
"CustomJSONLoader": [".json"],
|
||||||
"CSVLoader": [".csv"],
|
"CSVLoader": [".csv"],
|
||||||
|
# "FilteredCSVLoader": [".csv"], # 需要自己指定,目前还没有支持
|
||||||
"RapidOCRPDFLoader": [".pdf"],
|
"RapidOCRPDFLoader": [".pdf"],
|
||||||
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
|
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
|
||||||
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
|
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
|
||||||
|
|
@ -88,12 +91,12 @@ class CustomJSONLoader(langchain.document_loaders.JSONLoader):
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
file_path: Union[str, Path],
|
file_path: Union[str, Path],
|
||||||
content_key: Optional[str] = None,
|
content_key: Optional[str] = None,
|
||||||
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
|
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
|
||||||
text_content: bool = True,
|
text_content: bool = True,
|
||||||
json_lines: bool = False,
|
json_lines: bool = False,
|
||||||
):
|
):
|
||||||
"""Initialize the JSONLoader.
|
"""Initialize the JSONLoader.
|
||||||
|
|
||||||
|
|
@ -150,7 +153,7 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri
|
||||||
根据loader_name和文件路径或内容返回文档加载器。
|
根据loader_name和文件路径或内容返回文档加载器。
|
||||||
'''
|
'''
|
||||||
try:
|
try:
|
||||||
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
|
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader","FilteredCSVLoader"]:
|
||||||
document_loaders_module = importlib.import_module('document_loaders')
|
document_loaders_module = importlib.import_module('document_loaders')
|
||||||
else:
|
else:
|
||||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||||
|
|
@ -168,10 +171,11 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri
|
||||||
# 自动识别文件编码类型,避免langchain loader 加载文件报编码错误
|
# 自动识别文件编码类型,避免langchain loader 加载文件报编码错误
|
||||||
with open(file_path_or_content, 'rb') as struct_file:
|
with open(file_path_or_content, 'rb') as struct_file:
|
||||||
encode_detect = chardet.detect(struct_file.read())
|
encode_detect = chardet.detect(struct_file.read())
|
||||||
if encode_detect:
|
if encode_detect is None:
|
||||||
loader = DocumentLoader(file_path_or_content, encoding=encode_detect["encoding"])
|
encode_detect = {"encoding": "utf-8"}
|
||||||
else:
|
|
||||||
loader = DocumentLoader(file_path_or_content, encoding="utf-8")
|
loader = DocumentLoader(file_path_or_content, encoding=encode_detect["encoding"])
|
||||||
|
## TODO:支持更多的自定义CSV读取逻辑
|
||||||
|
|
||||||
elif loader_name == "JSONLoader":
|
elif loader_name == "JSONLoader":
|
||||||
loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False)
|
loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False)
|
||||||
|
|
@ -187,10 +191,10 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri
|
||||||
|
|
||||||
|
|
||||||
def make_text_splitter(
|
def make_text_splitter(
|
||||||
splitter_name: str = TEXT_SPLITTER_NAME,
|
splitter_name: str = TEXT_SPLITTER_NAME,
|
||||||
chunk_size: int = CHUNK_SIZE,
|
chunk_size: int = CHUNK_SIZE,
|
||||||
chunk_overlap: int = OVERLAP_SIZE,
|
chunk_overlap: int = OVERLAP_SIZE,
|
||||||
llm_model: str = LLM_MODEL,
|
llm_model: str = LLM_MODEL,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
根据参数获取特定的分词器
|
根据参数获取特定的分词器
|
||||||
|
|
@ -262,6 +266,7 @@ def make_text_splitter(
|
||||||
text_splitter = TextSplitter(chunk_size=250, chunk_overlap=50)
|
text_splitter = TextSplitter(chunk_size=250, chunk_overlap=50)
|
||||||
return text_splitter
|
return text_splitter
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeFile:
|
class KnowledgeFile:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -282,7 +287,7 @@ class KnowledgeFile:
|
||||||
self.document_loader_name = get_LoaderClass(self.ext)
|
self.document_loader_name = get_LoaderClass(self.ext)
|
||||||
self.text_splitter_name = TEXT_SPLITTER_NAME
|
self.text_splitter_name = TEXT_SPLITTER_NAME
|
||||||
|
|
||||||
def file2docs(self, refresh: bool=False):
|
def file2docs(self, refresh: bool = False):
|
||||||
if self.docs is None or refresh:
|
if self.docs is None or refresh:
|
||||||
logger.info(f"{self.document_loader_name} used for {self.filepath}")
|
logger.info(f"{self.document_loader_name} used for {self.filepath}")
|
||||||
loader = get_loader(self.document_loader_name, self.filepath)
|
loader = get_loader(self.document_loader_name, self.filepath)
|
||||||
|
|
@ -290,20 +295,21 @@ class KnowledgeFile:
|
||||||
return self.docs
|
return self.docs
|
||||||
|
|
||||||
def docs2texts(
|
def docs2texts(
|
||||||
self,
|
self,
|
||||||
docs: List[Document] = None,
|
docs: List[Document] = None,
|
||||||
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
|
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
|
||||||
refresh: bool = False,
|
refresh: bool = False,
|
||||||
chunk_size: int = CHUNK_SIZE,
|
chunk_size: int = CHUNK_SIZE,
|
||||||
chunk_overlap: int = OVERLAP_SIZE,
|
chunk_overlap: int = OVERLAP_SIZE,
|
||||||
text_splitter: TextSplitter = None,
|
text_splitter: TextSplitter = None,
|
||||||
):
|
):
|
||||||
docs = docs or self.file2docs(refresh=refresh)
|
docs = docs or self.file2docs(refresh=refresh)
|
||||||
if not docs:
|
if not docs:
|
||||||
return []
|
return []
|
||||||
if self.ext not in [".csv"]:
|
if self.ext not in [".csv"]:
|
||||||
if text_splitter is None:
|
if text_splitter is None:
|
||||||
text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap)
|
||||||
if self.text_splitter_name == "MarkdownHeaderTextSplitter":
|
if self.text_splitter_name == "MarkdownHeaderTextSplitter":
|
||||||
docs = text_splitter.split_text(docs[0].page_content)
|
docs = text_splitter.split_text(docs[0].page_content)
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
|
|
@ -320,12 +326,12 @@ class KnowledgeFile:
|
||||||
return self.splited_docs
|
return self.splited_docs
|
||||||
|
|
||||||
def file2text(
|
def file2text(
|
||||||
self,
|
self,
|
||||||
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
|
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
|
||||||
refresh: bool = False,
|
refresh: bool = False,
|
||||||
chunk_size: int = CHUNK_SIZE,
|
chunk_size: int = CHUNK_SIZE,
|
||||||
chunk_overlap: int = OVERLAP_SIZE,
|
chunk_overlap: int = OVERLAP_SIZE,
|
||||||
text_splitter: TextSplitter = None,
|
text_splitter: TextSplitter = None,
|
||||||
):
|
):
|
||||||
if self.splited_docs is None or refresh:
|
if self.splited_docs is None or refresh:
|
||||||
docs = self.file2docs()
|
docs = self.file2docs()
|
||||||
|
|
@ -359,6 +365,7 @@ def files2docs_in_thread(
|
||||||
如果传入参数是Tuple,形式为(filename, kb_name)
|
如果传入参数是Tuple,形式为(filename, kb_name)
|
||||||
生成器返回值为 status, (kb_name, file_name, docs | error)
|
生成器返回值为 status, (kb_name, file_name, docs | error)
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
|
def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
|
||||||
try:
|
try:
|
||||||
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
|
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
|
||||||
|
|
@ -373,8 +380,8 @@ def files2docs_in_thread(
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
try:
|
try:
|
||||||
if isinstance(file, tuple) and len(file) >= 2:
|
if isinstance(file, tuple) and len(file) >= 2:
|
||||||
filename=file[0]
|
filename = file[0]
|
||||||
kb_name=file[1]
|
kb_name = file[1]
|
||||||
file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
|
file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
|
||||||
elif isinstance(file, dict):
|
elif isinstance(file, dict):
|
||||||
filename = file.pop("filename")
|
filename = file.pop("filename")
|
||||||
|
|
@ -396,10 +403,9 @@ def files2docs_in_thread(
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
||||||
kb_file = KnowledgeFile(filename="test.txt", knowledge_base_name="samples")
|
kb_file = KnowledgeFile(
|
||||||
|
filename="/home/congyin/Code/Project_Langchain_0814/Langchain-Chatchat/knowledge_base/csv1/content/gm.csv",
|
||||||
|
knowledge_base_name="samples")
|
||||||
# kb_file.text_splitter_name = "RecursiveCharacterTextSplitter"
|
# kb_file.text_splitter_name = "RecursiveCharacterTextSplitter"
|
||||||
docs = kb_file.file2docs()
|
docs = kb_file.file2docs()
|
||||||
pprint(docs[-1])
|
# pprint(docs[-1])
|
||||||
|
|
||||||
docs = kb_file.file2text()
|
|
||||||
pprint(docs[-1])
|
|
||||||
|
|
|
||||||
|
|
@ -4,16 +4,17 @@ from streamlit_chatbox import *
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import os
|
import os
|
||||||
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
|
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
|
||||||
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE)
|
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL)
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
|
|
||||||
chat_box = ChatBox(
|
chat_box = ChatBox(
|
||||||
assistant_avatar=os.path.join(
|
assistant_avatar=os.path.join(
|
||||||
"img",
|
"img",
|
||||||
"chatchat_icon_blue_square_v2.png"
|
"chatchat_icon_blue_square_v2.png"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
|
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
|
||||||
'''
|
'''
|
||||||
返回消息历史。
|
返回消息历史。
|
||||||
|
|
@ -55,14 +56,14 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||||
|
|
||||||
if is_lite:
|
if is_lite:
|
||||||
dialogue_modes = ["LLM 对话",
|
dialogue_modes = ["LLM 对话",
|
||||||
"搜索引擎问答",
|
"搜索引擎问答",
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
dialogue_modes = ["LLM 对话",
|
dialogue_modes = ["LLM 对话",
|
||||||
"知识库问答",
|
"知识库问答",
|
||||||
"搜索引擎问答",
|
"搜索引擎问答",
|
||||||
"自定义Agent问答",
|
"自定义Agent问答",
|
||||||
]
|
]
|
||||||
dialogue_mode = st.selectbox("请选择对话模式:",
|
dialogue_mode = st.selectbox("请选择对话模式:",
|
||||||
dialogue_modes,
|
dialogue_modes,
|
||||||
index=0,
|
index=0,
|
||||||
|
|
@ -102,10 +103,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||||
key="llm_model",
|
key="llm_model",
|
||||||
)
|
)
|
||||||
if (st.session_state.get("prev_llm_model") != llm_model
|
if (st.session_state.get("prev_llm_model") != llm_model
|
||||||
and not is_lite
|
and not is_lite
|
||||||
and not llm_model in config_models.get("online", {})
|
and not llm_model in config_models.get("online", {})
|
||||||
and not llm_model in config_models.get("langchain", {})
|
and not llm_model in config_models.get("langchain", {})
|
||||||
and llm_model not in running_models):
|
and llm_model not in running_models):
|
||||||
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
|
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
|
||||||
prev_model = st.session_state.get("prev_llm_model")
|
prev_model = st.session_state.get("prev_llm_model")
|
||||||
r = api.change_llm_model(prev_model, llm_model)
|
r = api.change_llm_model(prev_model, llm_model)
|
||||||
|
|
@ -210,17 +211,20 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||||
|
|
||||||
|
|
||||||
elif dialogue_mode == "自定义Agent问答":
|
elif dialogue_mode == "自定义Agent问答":
|
||||||
chat_box.ai_say([
|
if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
|
||||||
f"正在思考...",
|
chat_box.ai_say([
|
||||||
Markdown("...", in_expander=True, title="思考过程", state="complete"),
|
f"正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!</span>\n\n\n",
|
||||||
|
Markdown("...", in_expander=True, title="思考过程", state="complete"),
|
||||||
|
|
||||||
])
|
])
|
||||||
|
else:
|
||||||
|
chat_box.ai_say([
|
||||||
|
f"正在思考...",
|
||||||
|
Markdown("...", in_expander=True, title="思考过程", state="complete"),
|
||||||
|
|
||||||
|
])
|
||||||
text = ""
|
text = ""
|
||||||
ans = ""
|
ans = ""
|
||||||
support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api","agentlm"] # 目前支持agent的模型
|
|
||||||
if not any(agent in llm_model for agent in support_agent):
|
|
||||||
ans += "正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!</span>\n\n\n"
|
|
||||||
chat_box.update_msg(ans, element_index=0, streaming=False)
|
|
||||||
for d in api.agent_chat(prompt,
|
for d in api.agent_chat(prompt,
|
||||||
history=history,
|
history=history,
|
||||||
model=llm_model,
|
model=llm_model,
|
||||||
|
|
@ -278,7 +282,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||||
model=llm_model,
|
model=llm_model,
|
||||||
prompt_name=prompt_template_name,
|
prompt_name=prompt_template_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
split_result=se_top_k>1):
|
split_result=se_top_k > 1):
|
||||||
if error_msg := check_error_msg(d): # check whether error occured
|
if error_msg := check_error_msg(d): # check whether error occured
|
||||||
st.error(error_msg)
|
st.error(error_msg)
|
||||||
elif chunk := d.get("answer"):
|
elif chunk := d.get("answer"):
|
||||||
|
|
@ -305,4 +309,4 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||||
file_name=f"{now:%Y-%m-%d %H.%M}_对话记录.md",
|
file_name=f"{now:%Y-%m-%d %H.%M}_对话记录.md",
|
||||||
mime="text/markdown",
|
mime="text/markdown",
|
||||||
use_container_width=True,
|
use_container_width=True,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue