update kb_doc_api.py
This commit is contained in:
parent
8773149a3e
commit
a447529c2e
|
|
@ -31,32 +31,31 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
|
|||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
saved_path = get_doc_path(knowledge_base_name)
|
||||
if not os.path.exists(saved_path):
|
||||
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
|
||||
|
||||
file_content = await file.read() # 读取上传文件的内容
|
||||
|
||||
file_path = os.path.join(saved_path, file.filename)
|
||||
if (os.path.exists(file_path)
|
||||
kb_file = KnowledgeFile(filename=file.filename,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
|
||||
if (os.path.exists(kb_file.filepath)
|
||||
and not override
|
||||
and os.path.getsize(file_path) == len(file_content)
|
||||
and os.path.getsize(kb_file.filepath) == len(file_content)
|
||||
):
|
||||
file_status = f"文件 {file.filename} 已存在。"
|
||||
file_status = f"文件 {kb_file.filename} 已存在。"
|
||||
return BaseResponse(code=404, msg=file_status)
|
||||
|
||||
try:
|
||||
with open(file_path, "wb") as f:
|
||||
with open(kb_file.filepath, "wb") as f:
|
||||
f.write(file_content)
|
||||
except Exception as e:
|
||||
return BaseResponse(code=500, msg=f"{file.filename} 文件上传失败,报错信息为: {e}")
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}")
|
||||
|
||||
kb_file = KnowledgeFile(filename=file.filename,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
|
||||
kb.add_doc(kb_file)
|
||||
|
||||
return BaseResponse(code=200, msg=f"成功上传文件 {file.filename}")
|
||||
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
|
||||
|
||||
|
||||
async def delete_doc(knowledge_base_name: str,
|
||||
|
|
@ -66,25 +65,17 @@ async def delete_doc(knowledge_base_name: str,
|
|||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
||||
if not os.path.exists(get_kb_path(knowledge_base_name)):
|
||||
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
doc_path = get_file_path(knowledge_base_name, doc_name)
|
||||
if os.path.exists(doc_path):
|
||||
os.remove(doc_path)
|
||||
remain_docs = await list_docs(knowledge_base_name)
|
||||
if len(remain_docs.data) == 0:
|
||||
shutil.rmtree(get_kb_path(knowledge_base_name), ignore_errors=True)
|
||||
return BaseResponse(code=200, msg=f"{doc_name} 文件删除成功")
|
||||
else:
|
||||
# TODO: 重写从向量库中删除文件
|
||||
status = "" # local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_name))
|
||||
if "success" in status:
|
||||
refresh_vs_cache(knowledge_base_name)
|
||||
return BaseResponse(code=200, msg=f"{doc_name} 文件删除成功")
|
||||
else:
|
||||
return BaseResponse(code=500, msg=f"{doc_name} 文件删除失败")
|
||||
else:
|
||||
|
||||
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
|
||||
if not kb.exist_doc(doc_name):
|
||||
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
|
||||
kb_file = KnowledgeFile(filename=doc_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
kb.delete_doc(kb_file)
|
||||
return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功")
|
||||
# return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败")
|
||||
|
||||
|
||||
async def update_doc():
|
||||
|
|
|
|||
|
|
@ -172,6 +172,50 @@ def add_doc_to_db(kb_file: KnowledgeFile):
|
|||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def delete_file_from_db(kb_file: KnowledgeFile):
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
# delete files in kb from table knowledge_files
|
||||
c.execute('''CREATE TABLE if not exists knowledge_files
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_name TEXT,
|
||||
file_ext TEXT,
|
||||
kb_name TEXT,
|
||||
document_loader_name TEXT,
|
||||
text_splitter_name TEXT,
|
||||
file_version INTEGER,
|
||||
create_time DATETIME) ''')
|
||||
# Insert a row of data
|
||||
c.execute(f"""DELETE
|
||||
FROM knowledge_files
|
||||
WHERE file_name="{kb_file.filename}"
|
||||
AND kb_name="{kb_file.kb_name}"
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return True
|
||||
|
||||
def doc_exists(kb_file: KnowledgeFile):
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
c.execute('''CREATE TABLE if not exists knowledge_files
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_name TEXT,
|
||||
file_ext TEXT,
|
||||
kb_name TEXT,
|
||||
document_loader_name TEXT,
|
||||
text_splitter_name TEXT,
|
||||
file_version INTEGER,
|
||||
create_time DATETIME) ''')
|
||||
c.execute(f'''SELECT COUNT(*)
|
||||
FROM knowledge_files
|
||||
WHERE file_name="{kb_file.filename}"
|
||||
AND kb_name="{kb_file.kb_name}" ''')
|
||||
status = True if c.fetchone()[0] else False
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return status
|
||||
|
||||
|
||||
class KnowledgeBase:
|
||||
def __init__(self,
|
||||
|
|
@ -226,6 +270,17 @@ class KnowledgeBase:
|
|||
# TODO: 向milvus库中增加文件
|
||||
pass
|
||||
|
||||
def delete_doc(self, kb_file: KnowledgeFile):
|
||||
if os.path.exists(kb_file.filepath):
|
||||
os.remove(kb_file.filepath)
|
||||
if self.vs_type in ["faiss"]:
|
||||
# TODO: 从FAISS向量库中删除文档
|
||||
delete_file_from_db(kb_file)
|
||||
|
||||
def exist_doc(self, file_name: str):
|
||||
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,
|
||||
filename=file_name))
|
||||
|
||||
def list_docs(self):
|
||||
return list_docs_from_db(self.kb_name)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue