diff --git a/configs/basic_config.py.example b/configs/basic_config.py.example index ba9e046..a407b2e 100644 --- a/configs/basic_config.py.example +++ b/configs/basic_config.py.example @@ -1,6 +1,8 @@ import logging import os import langchain +import tempfile +import shutil # 是否显示详细日志 @@ -23,3 +25,9 @@ logging.basicConfig(format=LOG_FORMAT) LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs") if not os.path.exists(LOG_PATH): os.mkdir(LOG_PATH) + +# 临时文件目录,主要用于文件对话 +BASE_TEMP_DIR = os.path.join(tempfile.gettempdir(), "chatchat") +if os.path.isdir(BASE_TEMP_DIR): + shutil.rmtree(BASE_TEMP_DIR) +os.makedirs(BASE_TEMP_DIR) diff --git a/configs/kb_config.py.example b/configs/kb_config.py.example index 42e08fa..04e09ec 100644 --- a/configs/kb_config.py.example +++ b/configs/kb_config.py.example @@ -9,6 +9,9 @@ DEFAULT_VS_TYPE = "faiss" # 缓存向量库数量(针对FAISS) CACHED_VS_NUM = 1 +# 缓存临时向量库数量(针对FAISS),用于文件对话 +CACHED_MEMO_VS_NUM = 10 + # 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter) CHUNK_SIZE = 250 diff --git a/server/api.py b/server/api.py index 16a37e3..5fc80c4 100644 --- a/server/api.py +++ b/server/api.py @@ -142,6 +142,7 @@ def mount_app_routes(app: FastAPI, run_mode: str = None): def mount_knowledge_routes(app: FastAPI): from server.chat.knowledge_base_chat import knowledge_base_chat + from server.chat.file_chat import upload_temp_docs, file_chat from server.chat.agent_chat import agent_chat from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, @@ -152,6 +153,11 @@ def mount_knowledge_routes(app: FastAPI): tags=["Chat"], summary="与知识库对话")(knowledge_base_chat) + app.post("/chat/file_chat", + tags=["Knowledge Base Management"], + summary="文件对话" + )(file_chat) + app.post("/chat/agent_chat", tags=["Chat"], summary="与agent对话")(agent_chat) @@ -218,6 +224,11 @@ def mount_knowledge_routes(app: FastAPI): summary="根据content中文档重建向量库,流式输出处理进度。" )(recreate_vector_store) + app.post("/knowledge_base/upload_temp_docs", + tags=["Knowledge Base Management"], + summary="上传文件到临时目录,用于文件对话。" + )(upload_temp_docs) + def run_api(host, port, **kwargs): if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"): diff --git a/server/chat/file_chat.py b/server/chat/file_chat.py new file mode 100644 index 0000000..efe89cf --- /dev/null +++ b/server/chat/file_chat.py @@ -0,0 +1,167 @@ +from fastapi import Body, File, Form, UploadFile +from fastapi.responses import StreamingResponse +from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE, + CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) +from server.utils import (wrap_done, get_ChatOpenAI, + BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool) +from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool +from langchain.chains import LLMChain +from langchain.callbacks import AsyncIteratorCallbackHandler +from typing import AsyncIterable, List, Optional +import asyncio +from langchain.prompts.chat import ChatPromptTemplate +from server.chat.utils import History +from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter +from server.knowledge_base.utils import KnowledgeFile +import json +import os +from pathlib import Path + + +def _parse_files_in_thread( + files: List[UploadFile], + dir: str, + zh_title_enhance: bool, + chunk_size: int, + chunk_overlap: int, +): + """ + 通过多线程将上传的文件保存到对应目录内。 + 生成器返回保存结果:[success or error, filename, msg, docs] + """ + def parse_file(file: UploadFile) -> dict: + ''' + 保存单个文件。 + ''' + try: + filename = file.filename + file_path = os.path.join(dir, filename) + file_content = file.file.read() # 读取上传文件的内容 + with open(file_path, "wb") as f: + f.write(file_content) + kb_file = KnowledgeFile(filename=filename, knowledge_base_name="temp") + kb_file.filepath = file_path + docs = kb_file.file2text(zh_title_enhance=zh_title_enhance, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap) + return True, filename, f"成功上传文件 {filename}", docs + except Exception as e: + msg = f"{filename} 文件上传失败,报错信息为: {e}" + return False, filename, msg, [] + + params = [{"file": file} for file in files] + for result in run_in_thread_pool(parse_file, params=params): + yield result + + +def upload_temp_docs( + files: List[UploadFile] = File(..., description="上传文件,支持多文件"), + prev_id: str = Form(None, description="前知识库ID"), + chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"), + chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), + zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), +) -> BaseResponse: + ''' + 将文件保存到临时目录,并进行向量化。 + 返回临时目录名称作为ID,同时也是临时向量库的ID。 + ''' + if prev_id is not None: + memo_faiss_pool.pop(prev_id) + + failed_files = [] + documents = [] + path, id = get_temp_dir(prev_id) + for success, file, msg, docs in _parse_files_in_thread(files=files, + dir=path, + zh_title_enhance=zh_title_enhance, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap): + if success: + documents += docs + else: + failed_files.append({file: msg}) + + with memo_faiss_pool.load_vector_store(id).acquire() as vs: + vs.add_documents(documents) + return BaseResponse(data={"id": id, "failed_files": failed_files}) + + +async def file_chat(query: str = Body(..., description="用户输入", examples=["你好"]), + knowledge_id: str = Body(..., description="临时知识库ID"), + top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), + score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=2), + history: List[History] = Body([], + description="历史对话", + examples=[[ + {"role": "user", + "content": "我们来玩成语接龙,我先来,生龙活虎"}, + {"role": "assistant", + "content": "虎头虎脑"}]] + ), + stream: bool = Body(False, description="流式输出"), + model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), + temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), + max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), + prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), + ): + if knowledge_id not in memo_faiss_pool.keys(): + return BaseResponse(code=404, msg=f"未找到临时知识库 {knowledge_id},请先上传文件") + + history = [History.from_data(h) for h in history] + + async def knowledge_base_chat_iterator() -> AsyncIterable[str]: + callback = AsyncIteratorCallbackHandler() + model = get_ChatOpenAI( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + callbacks=[callback], + ) + embed_func = EmbeddingsFunAdapter() + embeddings = embed_func.embed_query(query) + with memo_faiss_pool.acquire(knowledge_id) as vs: + docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold) + docs = [x[0] for x in docs] + + context = "\n".join([doc.page_content for doc in docs]) + if len(docs) == 0: ## 如果没有找到相关文档,使用Empty模板 + prompt_template = get_prompt_template("knowledge_base_chat", "Empty") + else: + prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) + input_msg = History(role="user", content=prompt_template).to_msg_template(False) + chat_prompt = ChatPromptTemplate.from_messages( + [i.to_msg_template() for i in history] + [input_msg]) + + chain = LLMChain(prompt=chat_prompt, llm=model) + + # Begin a task that runs in the background. + task = asyncio.create_task(wrap_done( + chain.acall({"context": context, "question": query}), + callback.done), + ) + + source_documents = [] + doc_path = get_temp_dir(knowledge_id)[0] + for inum, doc in enumerate(docs): + filename = Path(doc.metadata["source"]).resolve().relative_to(doc_path) + text = f"""出处 [{inum + 1}] [{filename}] \n\n{doc.page_content}\n\n""" + source_documents.append(text) + + if len(source_documents) == 0: # 没有找到相关文档 + source_documents.append(f"""未找到相关文档,该回答为大模型自身能力解答!""") + + if stream: + async for token in callback.aiter(): + # Use server-sent-events to stream the response + yield json.dumps({"answer": token}, ensure_ascii=False) + yield json.dumps({"docs": source_documents}, ensure_ascii=False) + else: + answer = "" + async for token in callback.aiter(): + answer += token + yield json.dumps({"answer": answer, + "docs": source_documents}, + ensure_ascii=False) + await task + + return StreamingResponse(knowledge_base_chat_iterator(), media_type="text/event-stream") diff --git a/server/db/repository/knowledge_file_repository.py b/server/db/repository/knowledge_file_repository.py index 3d77cc0..4a37441 100644 --- a/server/db/repository/knowledge_file_repository.py +++ b/server/db/repository/knowledge_file_repository.py @@ -17,7 +17,7 @@ def list_docs_from_db(session, ''' docs = session.query(FileDocModel).filter_by(kb_name=kb_name) if file_name: - docs = docs.filter_by(file_name=file_name) + docs = docs.filter(FileDocModel.file_name.ilike(file_name)) for k, v in metadata.items(): docs = docs.filter(FileDocModel.meta_data[k].as_string()==str(v)) diff --git a/server/knowledge_base/kb_cache/base.py b/server/knowledge_base/kb_cache/base.py index 99f854f..2b1b316 100644 --- a/server/knowledge_base/kb_cache/base.py +++ b/server/knowledge_base/kb_cache/base.py @@ -1,4 +1,5 @@ from langchain.embeddings.base import Embeddings +from langchain.vectorstores.faiss import FAISS import threading from configs import (EMBEDDING_MODEL, CHUNK_SIZE, logger, log_verbose) @@ -25,7 +26,7 @@ class ThreadSafeObject: return self._key @contextmanager - def acquire(self, owner: str = "", msg: str = ""): + def acquire(self, owner: str = "", msg: str = "") -> FAISS: owner = owner or f"thread {threading.get_native_id()}" try: self._lock.acquire() diff --git a/server/knowledge_base/kb_cache/faiss_cache.py b/server/knowledge_base/kb_cache/faiss_cache.py index bcd89a8..6823536 100644 --- a/server/knowledge_base/kb_cache/faiss_cache.py +++ b/server/knowledge_base/kb_cache/faiss_cache.py @@ -1,4 +1,4 @@ -from configs import CACHED_VS_NUM +from configs import CACHED_VS_NUM, CACHED_MEMO_VS_NUM from server.knowledge_base.kb_cache.base import * from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter from server.utils import load_local_embeddings @@ -123,7 +123,7 @@ class MemoFaissPool(_FaissPool): kb_faiss_pool = KBFaissPool(cache_num=CACHED_VS_NUM) -memo_faiss_pool = MemoFaissPool() +memo_faiss_pool = MemoFaissPool(cache_num=CACHED_MEMO_VS_NUM) if __name__ == "__main__": diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index b339b6f..c086d00 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -13,8 +13,8 @@ from pydantic import Json import json from server.knowledge_base.kb_service.base import KBServiceFactory from server.db.repository.knowledge_file_repository import get_file_detail -from typing import List from langchain.docstore.document import Document +from typing import List class DocumentWithScore(Document): diff --git a/server/utils.py b/server/utils.py index ab307fa..28c2a9b 100644 --- a/server/utils.py +++ b/server/utils.py @@ -12,7 +12,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from langchain.chat_models import ChatOpenAI from langchain.llms import OpenAI, AzureOpenAI, Anthropic import httpx -from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union +from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union, Tuple async def wrap_done(fn: Awaitable, event: asyncio.Event): @@ -700,3 +700,19 @@ def load_local_embeddings(model: str = None, device: str = embedding_device()): model = model or EMBEDDING_MODEL return embeddings_pool.load_embeddings(model=model, device=device) + + +def get_temp_dir(id: str = None) -> Tuple[str, str]: + ''' + 创建一个临时目录,返回(路径,文件夹名称) + ''' + from configs.basic_config import BASE_TEMP_DIR + import tempfile + + if id is not None: # 如果指定的临时目录已存在,直接返回 + path = os.path.join(BASE_TEMP_DIR, id) + if os.path.isdir(path): + return path, id + + path = tempfile.mkdtemp(dir=BASE_TEMP_DIR) + return path, os.path.basename(path) diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 7286cae..0a05be5 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -5,6 +5,7 @@ from datetime import datetime import os from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES, DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL) +from server.knowledge_base.utils import LOADER_DICT from typing import List, Dict @@ -36,7 +37,17 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) -> return chat_box.filter_history(history_len=history_len, filter=filter) +@st.cache_data +def upload_temp_docs(files, _api: ApiRequest) -> str: + ''' + 将文件上传到临时目录,用于文件对话 + 返回临时向量库ID + ''' + return _api.upload_temp_docs(files).get("data", {}).get("id") + + def dialogue_page(api: ApiRequest, is_lite: bool = False): + st.session_state.setdefault("file_chat_id", None) default_model = api.get_default_llm_model()[0] if not chat_box.chat_inited: st.toast( @@ -57,10 +68,11 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): st.toast(text) dialogue_modes = ["LLM 对话", - "知识库问答", - "搜索引擎问答", - "自定义Agent问答", - ] + "知识库问答", + "文件对话", + "搜索引擎问答", + "自定义Agent问答", + ] dialogue_mode = st.selectbox("请选择对话模式:", dialogue_modes, index=3, @@ -122,6 +134,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): "自定义Agent问答": "agent_chat", "搜索引擎问答": "search_engine_chat", "知识库问答": "knowledge_base_chat", + "文件对话": "knowledge_base_chat", } prompt_templates_kb_list = list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys()) prompt_template_name = prompt_templates_kb_list[0] @@ -163,7 +176,18 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): ## Bge 模型会超过1 score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01) + elif dialogue_mode == "文件对话": + with st.expander("文件对话配置", True): + files = st.file_uploader("上传知识文件:", + [i for ls in LOADER_DICT.values() for i in ls], + accept_multiple_files=True, + ) + kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K) + ## Bge 模型会超过1 + score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01) + if st.button("开始上传", disabled=len(files)==0): + st.session_state["file_chat_id"] = upload_temp_docs(files, api) elif dialogue_mode == "搜索引擎问答": search_engine_list = api.list_search_engines() if DEFAULT_SEARCH_ENGINE in search_engine_list: @@ -288,6 +312,30 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): chat_box.update_msg(text, element_index=0) chat_box.update_msg(text, element_index=0, streaming=False) chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False) + elif dialogue_mode == "文件对话": + if st.session_state["file_chat_id"] is None: + st.error("请先上传文件再进行对话") + st.stop() + chat_box.ai_say([ + f"正在查询文件 `{st.session_state['file_chat_id']}` ...", + Markdown("...", in_expander=True, title="文件匹配结果", state="complete"), + ]) + text = "" + for d in api.file_chat(prompt, + knowledge_id=st.session_state["file_chat_id"], + top_k=kb_top_k, + score_threshold=score_threshold, + history=history, + model=llm_model, + prompt_name=prompt_template_name, + temperature=temperature): + if error_msg := check_error_msg(d): # check whether error occured + st.error(error_msg) + elif chunk := d.get("answer"): + text += chunk + chat_box.update_msg(text, element_index=0) + chat_box.update_msg(text, element_index=0, streaming=False) + chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False) elif dialogue_mode == "搜索引擎问答": chat_box.ai_say([ f"正在执行 `{search_engine}` 搜索...", diff --git a/webui_pages/utils.py b/webui_pages/utils.py index bbae64e..fd84cd5 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -384,6 +384,81 @@ class ApiRequest: ) return self._httpx_stream2generator(response, as_json=True) + def upload_temp_docs( + self, + files: List[Union[str, Path, bytes]], + knowledge_id: str = None, + chunk_size=CHUNK_SIZE, + chunk_overlap=OVERLAP_SIZE, + zh_title_enhance=ZH_TITLE_ENHANCE, + ): + ''' + 对应api.py/knowledge_base/upload_tmep_docs接口 + ''' + def convert_file(file, filename=None): + if isinstance(file, bytes): # raw bytes + file = BytesIO(file) + elif hasattr(file, "read"): # a file io like object + filename = filename or file.name + else: # a local path + file = Path(file).absolute().open("rb") + filename = filename or os.path.split(file.name)[-1] + return filename, file + + files = [convert_file(file) for file in files] + data={ + "knowledge_id": knowledge_id, + "chunk_size": chunk_size, + "chunk_overlap": chunk_overlap, + "zh_title_enhance": zh_title_enhance, + } + + response = self.post( + "/knowledge_base/upload_temp_docs", + data=data, + files=[("files", (filename, file)) for filename, file in files], + ) + return self._get_response_value(response, as_json=True) + + def file_chat( + self, + query: str, + knowledge_id: str, + top_k: int = VECTOR_SEARCH_TOP_K, + score_threshold: float = SCORE_THRESHOLD, + history: List[Dict] = [], + stream: bool = True, + model: str = LLM_MODELS[0], + temperature: float = TEMPERATURE, + max_tokens: int = None, + prompt_name: str = "default", + ): + ''' + 对应api.py/chat/file_chat接口 + ''' + data = { + "query": query, + "knowledge_id": knowledge_id, + "top_k": top_k, + "score_threshold": score_threshold, + "history": history, + "stream": stream, + "model_name": model, + "temperature": temperature, + "max_tokens": max_tokens, + "prompt_name": prompt_name, + } + + print(f"received input message:") + pprint(data) + + response = self.post( + "/chat/file_chat", + json=data, + stream=True, + ) + return self._httpx_stream2generator(response, as_json=True) + def search_engine_chat( self, query: str,