import os import urllib from fastapi import File, Form, Body, Query, UploadFile from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD) from server.utils import BaseResponse, ListResponse from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile from fastapi.responses import StreamingResponse, FileResponse import json from server.knowledge_base.kb_service.base import KBServiceFactory from typing import List, Dict from langchain.docstore.document import Document class DocumentWithScore(Document): score: float = None def search_docs(query: str = Body(..., description="用户输入", examples=["你好"]), knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1), ) -> List[DocumentWithScore]: kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: return [] docs = kb.search_docs(query, top_k, score_threshold) data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs] return data async def list_docs( knowledge_base_name: str ) -> ListResponse: if not validate_kb_name(knowledge_base_name): return ListResponse(code=403, msg="Don't attack me", data=[]) knowledge_base_name = urllib.parse.unquote(knowledge_base_name) kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[]) else: all_doc_names = kb.list_docs() return ListResponse(data=all_doc_names) async def upload_doc(file: UploadFile = File(..., description="上传文件"), knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]), override: bool = Form(False, description="覆盖已有文件"), not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"), ) -> BaseResponse: if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") file_content = await file.read() # 读取上传文件的内容 try: kb_file = KnowledgeFile(filename=file.filename, knowledge_base_name=knowledge_base_name) if (os.path.exists(kb_file.filepath) and not override and os.path.getsize(kb_file.filepath) == len(file_content) ): # TODO: filesize 不同后的处理 file_status = f"文件 {kb_file.filename} 已存在。" return BaseResponse(code=404, msg=file_status) with open(kb_file.filepath, "wb") as f: f.write(file_content) except Exception as e: print(e) return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}") try: kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) except Exception as e: print(e) return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}") return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}") async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]), doc_name: str = Body(..., examples=["file_name.md"]), delete_content: bool = Body(False), not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), ) -> BaseResponse: if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") knowledge_base_name = urllib.parse.unquote(knowledge_base_name) kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") if not kb.exist_doc(doc_name): return BaseResponse(code=404, msg=f"未找到文件 {doc_name}") try: kb_file = KnowledgeFile(filename=doc_name, knowledge_base_name=knowledge_base_name) kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=not_refresh_vs_cache) except Exception as e: print(e) return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}") return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功") async def update_doc( knowledge_base_name: str = Body(..., examples=["samples"]), file_name: str = Body(..., examples=["file_name"]), not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), ) -> BaseResponse: ''' 更新知识库文档 ''' if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") try: kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name) if os.path.exists(kb_file.filepath): kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}") except Exception as e: print(e) return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败,错误信息是:{e}") return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败") async def download_doc( knowledge_base_name: str = Query(..., examples=["samples"]), file_name: str = Query(..., examples=["test.txt"]), ): ''' 下载知识库文档 ''' if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") try: kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name) if os.path.exists(kb_file.filepath): return FileResponse( path=kb_file.filepath, filename=kb_file.filename, media_type="multipart/form-data") except Exception as e: print(e) return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败,错误信息是:{e}") return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败") async def recreate_vector_store( knowledge_base_name: str = Body(..., examples=["samples"]), allow_empty_kb: bool = Body(True), vs_type: str = Body(DEFAULT_VS_TYPE), embed_model: str = Body(EMBEDDING_MODEL), ): ''' recreate vector store from the content. this is usefull when user can copy files to content folder directly instead of upload through network. by default, get_service_by_name only return knowledge base in the info.db and having document files in it. set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents. ''' async def output(): kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model) if not kb.exists() and not allow_empty_kb: yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"} else: kb.create_kb() kb.clear_vs() docs = list_docs_from_folder(knowledge_base_name) for i, doc in enumerate(docs): try: kb_file = KnowledgeFile(doc, knowledge_base_name) yield json.dumps({ "code": 200, "msg": f"({i + 1} / {len(docs)}): {doc}", "total": len(docs), "finished": i, "doc": doc, }, ensure_ascii=False) if i == len(docs) - 1: not_refresh_vs_cache = False else: not_refresh_vs_cache = True kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) except Exception as e: print(e) yield json.dumps({ "code": 500, "msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。", }) return StreamingResponse(output(), media_type="text/event-stream")