添加文件对话模式 (#2071)
开发者: - 添加 /chat/file_chat, /knowledge_base/upload_temp_docs API 接口 - 添加 CACHED_MEMO_VS_NUM, BASE_TEMP_DIR 配置项
This commit is contained in:
parent
2adfa4277c
commit
3b3d948d27
|
|
@ -1,6 +1,8 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import langchain
|
import langchain
|
||||||
|
import tempfile
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
# 是否显示详细日志
|
# 是否显示详细日志
|
||||||
|
|
@ -23,3 +25,9 @@ logging.basicConfig(format=LOG_FORMAT)
|
||||||
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
|
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
|
||||||
if not os.path.exists(LOG_PATH):
|
if not os.path.exists(LOG_PATH):
|
||||||
os.mkdir(LOG_PATH)
|
os.mkdir(LOG_PATH)
|
||||||
|
|
||||||
|
# 临时文件目录,主要用于文件对话
|
||||||
|
BASE_TEMP_DIR = os.path.join(tempfile.gettempdir(), "chatchat")
|
||||||
|
if os.path.isdir(BASE_TEMP_DIR):
|
||||||
|
shutil.rmtree(BASE_TEMP_DIR)
|
||||||
|
os.makedirs(BASE_TEMP_DIR)
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,9 @@ DEFAULT_VS_TYPE = "faiss"
|
||||||
# 缓存向量库数量(针对FAISS)
|
# 缓存向量库数量(针对FAISS)
|
||||||
CACHED_VS_NUM = 1
|
CACHED_VS_NUM = 1
|
||||||
|
|
||||||
|
# 缓存临时向量库数量(针对FAISS),用于文件对话
|
||||||
|
CACHED_MEMO_VS_NUM = 10
|
||||||
|
|
||||||
# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
|
# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
|
||||||
CHUNK_SIZE = 250
|
CHUNK_SIZE = 250
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -142,6 +142,7 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
|
||||||
|
|
||||||
def mount_knowledge_routes(app: FastAPI):
|
def mount_knowledge_routes(app: FastAPI):
|
||||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||||
|
from server.chat.file_chat import upload_temp_docs, file_chat
|
||||||
from server.chat.agent_chat import agent_chat
|
from server.chat.agent_chat import agent_chat
|
||||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||||
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||||||
|
|
@ -152,6 +153,11 @@ def mount_knowledge_routes(app: FastAPI):
|
||||||
tags=["Chat"],
|
tags=["Chat"],
|
||||||
summary="与知识库对话")(knowledge_base_chat)
|
summary="与知识库对话")(knowledge_base_chat)
|
||||||
|
|
||||||
|
app.post("/chat/file_chat",
|
||||||
|
tags=["Knowledge Base Management"],
|
||||||
|
summary="文件对话"
|
||||||
|
)(file_chat)
|
||||||
|
|
||||||
app.post("/chat/agent_chat",
|
app.post("/chat/agent_chat",
|
||||||
tags=["Chat"],
|
tags=["Chat"],
|
||||||
summary="与agent对话")(agent_chat)
|
summary="与agent对话")(agent_chat)
|
||||||
|
|
@ -218,6 +224,11 @@ def mount_knowledge_routes(app: FastAPI):
|
||||||
summary="根据content中文档重建向量库,流式输出处理进度。"
|
summary="根据content中文档重建向量库,流式输出处理进度。"
|
||||||
)(recreate_vector_store)
|
)(recreate_vector_store)
|
||||||
|
|
||||||
|
app.post("/knowledge_base/upload_temp_docs",
|
||||||
|
tags=["Knowledge Base Management"],
|
||||||
|
summary="上传文件到临时目录,用于文件对话。"
|
||||||
|
)(upload_temp_docs)
|
||||||
|
|
||||||
|
|
||||||
def run_api(host, port, **kwargs):
|
def run_api(host, port, **kwargs):
|
||||||
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,167 @@
|
||||||
|
from fastapi import Body, File, Form, UploadFile
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE,
|
||||||
|
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
||||||
|
from server.utils import (wrap_done, get_ChatOpenAI,
|
||||||
|
BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool)
|
||||||
|
from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
|
||||||
|
from langchain.chains import LLMChain
|
||||||
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
|
from typing import AsyncIterable, List, Optional
|
||||||
|
import asyncio
|
||||||
|
from langchain.prompts.chat import ChatPromptTemplate
|
||||||
|
from server.chat.utils import History
|
||||||
|
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
|
||||||
|
from server.knowledge_base.utils import KnowledgeFile
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_files_in_thread(
|
||||||
|
files: List[UploadFile],
|
||||||
|
dir: str,
|
||||||
|
zh_title_enhance: bool,
|
||||||
|
chunk_size: int,
|
||||||
|
chunk_overlap: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
通过多线程将上传的文件保存到对应目录内。
|
||||||
|
生成器返回保存结果:[success or error, filename, msg, docs]
|
||||||
|
"""
|
||||||
|
def parse_file(file: UploadFile) -> dict:
|
||||||
|
'''
|
||||||
|
保存单个文件。
|
||||||
|
'''
|
||||||
|
try:
|
||||||
|
filename = file.filename
|
||||||
|
file_path = os.path.join(dir, filename)
|
||||||
|
file_content = file.file.read() # 读取上传文件的内容
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(file_content)
|
||||||
|
kb_file = KnowledgeFile(filename=filename, knowledge_base_name="temp")
|
||||||
|
kb_file.filepath = file_path
|
||||||
|
docs = kb_file.file2text(zh_title_enhance=zh_title_enhance,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap)
|
||||||
|
return True, filename, f"成功上传文件 {filename}", docs
|
||||||
|
except Exception as e:
|
||||||
|
msg = f"{filename} 文件上传失败,报错信息为: {e}"
|
||||||
|
return False, filename, msg, []
|
||||||
|
|
||||||
|
params = [{"file": file} for file in files]
|
||||||
|
for result in run_in_thread_pool(parse_file, params=params):
|
||||||
|
yield result
|
||||||
|
|
||||||
|
|
||||||
|
def upload_temp_docs(
|
||||||
|
files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||||||
|
prev_id: str = Form(None, description="前知识库ID"),
|
||||||
|
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
||||||
|
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
||||||
|
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
||||||
|
) -> BaseResponse:
|
||||||
|
'''
|
||||||
|
将文件保存到临时目录,并进行向量化。
|
||||||
|
返回临时目录名称作为ID,同时也是临时向量库的ID。
|
||||||
|
'''
|
||||||
|
if prev_id is not None:
|
||||||
|
memo_faiss_pool.pop(prev_id)
|
||||||
|
|
||||||
|
failed_files = []
|
||||||
|
documents = []
|
||||||
|
path, id = get_temp_dir(prev_id)
|
||||||
|
for success, file, msg, docs in _parse_files_in_thread(files=files,
|
||||||
|
dir=path,
|
||||||
|
zh_title_enhance=zh_title_enhance,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap):
|
||||||
|
if success:
|
||||||
|
documents += docs
|
||||||
|
else:
|
||||||
|
failed_files.append({file: msg})
|
||||||
|
|
||||||
|
with memo_faiss_pool.load_vector_store(id).acquire() as vs:
|
||||||
|
vs.add_documents(documents)
|
||||||
|
return BaseResponse(data={"id": id, "failed_files": failed_files})
|
||||||
|
|
||||||
|
|
||||||
|
async def file_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||||
|
knowledge_id: str = Body(..., description="临时知识库ID"),
|
||||||
|
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||||
|
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=2),
|
||||||
|
history: List[History] = Body([],
|
||||||
|
description="历史对话",
|
||||||
|
examples=[[
|
||||||
|
{"role": "user",
|
||||||
|
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||||
|
{"role": "assistant",
|
||||||
|
"content": "虎头虎脑"}]]
|
||||||
|
),
|
||||||
|
stream: bool = Body(False, description="流式输出"),
|
||||||
|
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||||
|
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
|
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
|
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||||
|
):
|
||||||
|
if knowledge_id not in memo_faiss_pool.keys():
|
||||||
|
return BaseResponse(code=404, msg=f"未找到临时知识库 {knowledge_id},请先上传文件")
|
||||||
|
|
||||||
|
history = [History.from_data(h) for h in history]
|
||||||
|
|
||||||
|
async def knowledge_base_chat_iterator() -> AsyncIterable[str]:
|
||||||
|
callback = AsyncIteratorCallbackHandler()
|
||||||
|
model = get_ChatOpenAI(
|
||||||
|
model_name=model_name,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
callbacks=[callback],
|
||||||
|
)
|
||||||
|
embed_func = EmbeddingsFunAdapter()
|
||||||
|
embeddings = embed_func.embed_query(query)
|
||||||
|
with memo_faiss_pool.acquire(knowledge_id) as vs:
|
||||||
|
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
||||||
|
docs = [x[0] for x in docs]
|
||||||
|
|
||||||
|
context = "\n".join([doc.page_content for doc in docs])
|
||||||
|
if len(docs) == 0: ## 如果没有找到相关文档,使用Empty模板
|
||||||
|
prompt_template = get_prompt_template("knowledge_base_chat", "Empty")
|
||||||
|
else:
|
||||||
|
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
|
||||||
|
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||||
|
chat_prompt = ChatPromptTemplate.from_messages(
|
||||||
|
[i.to_msg_template() for i in history] + [input_msg])
|
||||||
|
|
||||||
|
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||||
|
|
||||||
|
# Begin a task that runs in the background.
|
||||||
|
task = asyncio.create_task(wrap_done(
|
||||||
|
chain.acall({"context": context, "question": query}),
|
||||||
|
callback.done),
|
||||||
|
)
|
||||||
|
|
||||||
|
source_documents = []
|
||||||
|
doc_path = get_temp_dir(knowledge_id)[0]
|
||||||
|
for inum, doc in enumerate(docs):
|
||||||
|
filename = Path(doc.metadata["source"]).resolve().relative_to(doc_path)
|
||||||
|
text = f"""出处 [{inum + 1}] [{filename}] \n\n{doc.page_content}\n\n"""
|
||||||
|
source_documents.append(text)
|
||||||
|
|
||||||
|
if len(source_documents) == 0: # 没有找到相关文档
|
||||||
|
source_documents.append(f"""<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>""")
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
async for token in callback.aiter():
|
||||||
|
# Use server-sent-events to stream the response
|
||||||
|
yield json.dumps({"answer": token}, ensure_ascii=False)
|
||||||
|
yield json.dumps({"docs": source_documents}, ensure_ascii=False)
|
||||||
|
else:
|
||||||
|
answer = ""
|
||||||
|
async for token in callback.aiter():
|
||||||
|
answer += token
|
||||||
|
yield json.dumps({"answer": answer,
|
||||||
|
"docs": source_documents},
|
||||||
|
ensure_ascii=False)
|
||||||
|
await task
|
||||||
|
|
||||||
|
return StreamingResponse(knowledge_base_chat_iterator(), media_type="text/event-stream")
|
||||||
|
|
@ -17,7 +17,7 @@ def list_docs_from_db(session,
|
||||||
'''
|
'''
|
||||||
docs = session.query(FileDocModel).filter_by(kb_name=kb_name)
|
docs = session.query(FileDocModel).filter_by(kb_name=kb_name)
|
||||||
if file_name:
|
if file_name:
|
||||||
docs = docs.filter_by(file_name=file_name)
|
docs = docs.filter(FileDocModel.file_name.ilike(file_name))
|
||||||
for k, v in metadata.items():
|
for k, v in metadata.items():
|
||||||
docs = docs.filter(FileDocModel.meta_data[k].as_string()==str(v))
|
docs = docs.filter(FileDocModel.meta_data[k].as_string()==str(v))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.vectorstores.faiss import FAISS
|
||||||
import threading
|
import threading
|
||||||
from configs import (EMBEDDING_MODEL, CHUNK_SIZE,
|
from configs import (EMBEDDING_MODEL, CHUNK_SIZE,
|
||||||
logger, log_verbose)
|
logger, log_verbose)
|
||||||
|
|
@ -25,7 +26,7 @@ class ThreadSafeObject:
|
||||||
return self._key
|
return self._key
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def acquire(self, owner: str = "", msg: str = ""):
|
def acquire(self, owner: str = "", msg: str = "") -> FAISS:
|
||||||
owner = owner or f"thread {threading.get_native_id()}"
|
owner = owner or f"thread {threading.get_native_id()}"
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from configs import CACHED_VS_NUM
|
from configs import CACHED_VS_NUM, CACHED_MEMO_VS_NUM
|
||||||
from server.knowledge_base.kb_cache.base import *
|
from server.knowledge_base.kb_cache.base import *
|
||||||
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
|
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
|
||||||
from server.utils import load_local_embeddings
|
from server.utils import load_local_embeddings
|
||||||
|
|
@ -123,7 +123,7 @@ class MemoFaissPool(_FaissPool):
|
||||||
|
|
||||||
|
|
||||||
kb_faiss_pool = KBFaissPool(cache_num=CACHED_VS_NUM)
|
kb_faiss_pool = KBFaissPool(cache_num=CACHED_VS_NUM)
|
||||||
memo_faiss_pool = MemoFaissPool()
|
memo_faiss_pool = MemoFaissPool(cache_num=CACHED_MEMO_VS_NUM)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -13,8 +13,8 @@ from pydantic import Json
|
||||||
import json
|
import json
|
||||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||||
from server.db.repository.knowledge_file_repository import get_file_detail
|
from server.db.repository.knowledge_file_repository import get_file_detail
|
||||||
from typing import List
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
class DocumentWithScore(Document):
|
class DocumentWithScore(Document):
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
from langchain.llms import OpenAI, AzureOpenAI, Anthropic
|
from langchain.llms import OpenAI, AzureOpenAI, Anthropic
|
||||||
import httpx
|
import httpx
|
||||||
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
|
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union, Tuple
|
||||||
|
|
||||||
|
|
||||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||||
|
|
@ -700,3 +700,19 @@ def load_local_embeddings(model: str = None, device: str = embedding_device()):
|
||||||
|
|
||||||
model = model or EMBEDDING_MODEL
|
model = model or EMBEDDING_MODEL
|
||||||
return embeddings_pool.load_embeddings(model=model, device=device)
|
return embeddings_pool.load_embeddings(model=model, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def get_temp_dir(id: str = None) -> Tuple[str, str]:
|
||||||
|
'''
|
||||||
|
创建一个临时目录,返回(路径,文件夹名称)
|
||||||
|
'''
|
||||||
|
from configs.basic_config import BASE_TEMP_DIR
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
if id is not None: # 如果指定的临时目录已存在,直接返回
|
||||||
|
path = os.path.join(BASE_TEMP_DIR, id)
|
||||||
|
if os.path.isdir(path):
|
||||||
|
return path, id
|
||||||
|
|
||||||
|
path = tempfile.mkdtemp(dir=BASE_TEMP_DIR)
|
||||||
|
return path, os.path.basename(path)
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from datetime import datetime
|
||||||
import os
|
import os
|
||||||
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
|
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
|
||||||
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL)
|
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL)
|
||||||
|
from server.knowledge_base.utils import LOADER_DICT
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -36,7 +37,17 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) ->
|
||||||
return chat_box.filter_history(history_len=history_len, filter=filter)
|
return chat_box.filter_history(history_len=history_len, filter=filter)
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_data
|
||||||
|
def upload_temp_docs(files, _api: ApiRequest) -> str:
|
||||||
|
'''
|
||||||
|
将文件上传到临时目录,用于文件对话
|
||||||
|
返回临时向量库ID
|
||||||
|
'''
|
||||||
|
return _api.upload_temp_docs(files).get("data", {}).get("id")
|
||||||
|
|
||||||
|
|
||||||
def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||||
|
st.session_state.setdefault("file_chat_id", None)
|
||||||
default_model = api.get_default_llm_model()[0]
|
default_model = api.get_default_llm_model()[0]
|
||||||
if not chat_box.chat_inited:
|
if not chat_box.chat_inited:
|
||||||
st.toast(
|
st.toast(
|
||||||
|
|
@ -58,6 +69,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||||
|
|
||||||
dialogue_modes = ["LLM 对话",
|
dialogue_modes = ["LLM 对话",
|
||||||
"知识库问答",
|
"知识库问答",
|
||||||
|
"文件对话",
|
||||||
"搜索引擎问答",
|
"搜索引擎问答",
|
||||||
"自定义Agent问答",
|
"自定义Agent问答",
|
||||||
]
|
]
|
||||||
|
|
@ -122,6 +134,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||||
"自定义Agent问答": "agent_chat",
|
"自定义Agent问答": "agent_chat",
|
||||||
"搜索引擎问答": "search_engine_chat",
|
"搜索引擎问答": "search_engine_chat",
|
||||||
"知识库问答": "knowledge_base_chat",
|
"知识库问答": "knowledge_base_chat",
|
||||||
|
"文件对话": "knowledge_base_chat",
|
||||||
}
|
}
|
||||||
prompt_templates_kb_list = list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys())
|
prompt_templates_kb_list = list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys())
|
||||||
prompt_template_name = prompt_templates_kb_list[0]
|
prompt_template_name = prompt_templates_kb_list[0]
|
||||||
|
|
@ -163,7 +176,18 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||||
|
|
||||||
## Bge 模型会超过1
|
## Bge 模型会超过1
|
||||||
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
|
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
|
||||||
|
elif dialogue_mode == "文件对话":
|
||||||
|
with st.expander("文件对话配置", True):
|
||||||
|
files = st.file_uploader("上传知识文件:",
|
||||||
|
[i for ls in LOADER_DICT.values() for i in ls],
|
||||||
|
accept_multiple_files=True,
|
||||||
|
)
|
||||||
|
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
|
||||||
|
|
||||||
|
## Bge 模型会超过1
|
||||||
|
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
|
||||||
|
if st.button("开始上传", disabled=len(files)==0):
|
||||||
|
st.session_state["file_chat_id"] = upload_temp_docs(files, api)
|
||||||
elif dialogue_mode == "搜索引擎问答":
|
elif dialogue_mode == "搜索引擎问答":
|
||||||
search_engine_list = api.list_search_engines()
|
search_engine_list = api.list_search_engines()
|
||||||
if DEFAULT_SEARCH_ENGINE in search_engine_list:
|
if DEFAULT_SEARCH_ENGINE in search_engine_list:
|
||||||
|
|
@ -288,6 +312,30 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||||
chat_box.update_msg(text, element_index=0)
|
chat_box.update_msg(text, element_index=0)
|
||||||
chat_box.update_msg(text, element_index=0, streaming=False)
|
chat_box.update_msg(text, element_index=0, streaming=False)
|
||||||
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
|
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
|
||||||
|
elif dialogue_mode == "文件对话":
|
||||||
|
if st.session_state["file_chat_id"] is None:
|
||||||
|
st.error("请先上传文件再进行对话")
|
||||||
|
st.stop()
|
||||||
|
chat_box.ai_say([
|
||||||
|
f"正在查询文件 `{st.session_state['file_chat_id']}` ...",
|
||||||
|
Markdown("...", in_expander=True, title="文件匹配结果", state="complete"),
|
||||||
|
])
|
||||||
|
text = ""
|
||||||
|
for d in api.file_chat(prompt,
|
||||||
|
knowledge_id=st.session_state["file_chat_id"],
|
||||||
|
top_k=kb_top_k,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
history=history,
|
||||||
|
model=llm_model,
|
||||||
|
prompt_name=prompt_template_name,
|
||||||
|
temperature=temperature):
|
||||||
|
if error_msg := check_error_msg(d): # check whether error occured
|
||||||
|
st.error(error_msg)
|
||||||
|
elif chunk := d.get("answer"):
|
||||||
|
text += chunk
|
||||||
|
chat_box.update_msg(text, element_index=0)
|
||||||
|
chat_box.update_msg(text, element_index=0, streaming=False)
|
||||||
|
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
|
||||||
elif dialogue_mode == "搜索引擎问答":
|
elif dialogue_mode == "搜索引擎问答":
|
||||||
chat_box.ai_say([
|
chat_box.ai_say([
|
||||||
f"正在执行 `{search_engine}` 搜索...",
|
f"正在执行 `{search_engine}` 搜索...",
|
||||||
|
|
|
||||||
|
|
@ -384,6 +384,81 @@ class ApiRequest:
|
||||||
)
|
)
|
||||||
return self._httpx_stream2generator(response, as_json=True)
|
return self._httpx_stream2generator(response, as_json=True)
|
||||||
|
|
||||||
|
def upload_temp_docs(
|
||||||
|
self,
|
||||||
|
files: List[Union[str, Path, bytes]],
|
||||||
|
knowledge_id: str = None,
|
||||||
|
chunk_size=CHUNK_SIZE,
|
||||||
|
chunk_overlap=OVERLAP_SIZE,
|
||||||
|
zh_title_enhance=ZH_TITLE_ENHANCE,
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
对应api.py/knowledge_base/upload_tmep_docs接口
|
||||||
|
'''
|
||||||
|
def convert_file(file, filename=None):
|
||||||
|
if isinstance(file, bytes): # raw bytes
|
||||||
|
file = BytesIO(file)
|
||||||
|
elif hasattr(file, "read"): # a file io like object
|
||||||
|
filename = filename or file.name
|
||||||
|
else: # a local path
|
||||||
|
file = Path(file).absolute().open("rb")
|
||||||
|
filename = filename or os.path.split(file.name)[-1]
|
||||||
|
return filename, file
|
||||||
|
|
||||||
|
files = [convert_file(file) for file in files]
|
||||||
|
data={
|
||||||
|
"knowledge_id": knowledge_id,
|
||||||
|
"chunk_size": chunk_size,
|
||||||
|
"chunk_overlap": chunk_overlap,
|
||||||
|
"zh_title_enhance": zh_title_enhance,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = self.post(
|
||||||
|
"/knowledge_base/upload_temp_docs",
|
||||||
|
data=data,
|
||||||
|
files=[("files", (filename, file)) for filename, file in files],
|
||||||
|
)
|
||||||
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
|
def file_chat(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
knowledge_id: str,
|
||||||
|
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||||
|
score_threshold: float = SCORE_THRESHOLD,
|
||||||
|
history: List[Dict] = [],
|
||||||
|
stream: bool = True,
|
||||||
|
model: str = LLM_MODELS[0],
|
||||||
|
temperature: float = TEMPERATURE,
|
||||||
|
max_tokens: int = None,
|
||||||
|
prompt_name: str = "default",
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
对应api.py/chat/file_chat接口
|
||||||
|
'''
|
||||||
|
data = {
|
||||||
|
"query": query,
|
||||||
|
"knowledge_id": knowledge_id,
|
||||||
|
"top_k": top_k,
|
||||||
|
"score_threshold": score_threshold,
|
||||||
|
"history": history,
|
||||||
|
"stream": stream,
|
||||||
|
"model_name": model,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"prompt_name": prompt_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"received input message:")
|
||||||
|
pprint(data)
|
||||||
|
|
||||||
|
response = self.post(
|
||||||
|
"/chat/file_chat",
|
||||||
|
json=data,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
return self._httpx_stream2generator(response, as_json=True)
|
||||||
|
|
||||||
def search_engine_chat(
|
def search_engine_chat(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue