使用orm操作数据库
This commit is contained in:
parent
41fd1acc9c
commit
b61e0772c9
|
|
@ -0,0 +1,12 @@
|
||||||
|
SQLALCHEMY_DATABASE_URI = "sqlite:///./langchain_chat_glm.db"
|
||||||
|
kbs_config = {
|
||||||
|
"faiss": {
|
||||||
|
},
|
||||||
|
"milvus": {
|
||||||
|
"host": "127.0.0.1",
|
||||||
|
"port": "19530",
|
||||||
|
"user": "",
|
||||||
|
"password": "",
|
||||||
|
"secure": False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -278,15 +278,3 @@ BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search"
|
||||||
# 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG
|
# 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG
|
||||||
BING_SUBSCRIPTION_KEY = ""
|
BING_SUBSCRIPTION_KEY = ""
|
||||||
|
|
||||||
kbs_config = {
|
|
||||||
"faiss": {
|
|
||||||
},
|
|
||||||
"milvus": {
|
|
||||||
"host": "127.0.0.1",
|
|
||||||
"port": "19530",
|
|
||||||
"user": "",
|
|
||||||
"password": "",
|
|
||||||
"secure": False,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
|
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from configs.config import SQLALCHEMY_DATABASE_URI
|
||||||
|
|
||||||
|
engine = create_engine(SQLALCHEMY_DATABASE_URI)
|
||||||
|
|
||||||
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
from datetime import datetime
|
||||||
|
from sqlalchemy import Column, DateTime, String, Integer
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModel:
|
||||||
|
"""
|
||||||
|
基础模型
|
||||||
|
"""
|
||||||
|
id = Column(Integer, primary_key=True, index=True, comment="主键ID")
|
||||||
|
create_time = Column(DateTime, default=datetime.utcnow, comment="创建时间")
|
||||||
|
update_time = Column(DateTime, default=None, onupdate=datetime.utcnow, comment="更新时间")
|
||||||
|
create_by = Column(String, default=None, comment="创建者")
|
||||||
|
update_by = Column(String, default=None, comment="更新者")
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
from sqlalchemy import Column, Integer, String, DateTime
|
||||||
|
|
||||||
|
from server.db.base import Base
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeBaseModel(Base):
|
||||||
|
"""
|
||||||
|
知识库模型
|
||||||
|
"""
|
||||||
|
__tablename__ = 'knowledge_base'
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True, comment='知识库ID')
|
||||||
|
kb_name = Column(String, comment='知识库名称')
|
||||||
|
vs_type = Column(String, comment='嵌入模型类型')
|
||||||
|
embed_model = Column(String, comment='嵌入模型名称')
|
||||||
|
file_count = Column(Integer, comment='文件数量')
|
||||||
|
create_time = Column(DateTime, comment='创建时间')
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<KnowledgeBase(id='{self.id}', kb_name='{self.kb_name}', vs_type='{self.vs_type}', embed_model='{self.embed_model}', file_count='{self.file_count}', create_time='{self.create_time}')>"
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
from sqlalchemy import Column, Integer, String, DateTime
|
||||||
|
|
||||||
|
from server.db.base import Base
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeFileModel(Base):
|
||||||
|
"""
|
||||||
|
知识文件模型
|
||||||
|
"""
|
||||||
|
__tablename__ = 'knowledge_file'
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True, comment='知识文件ID')
|
||||||
|
file_name = Column(String, comment='文件名')
|
||||||
|
file_ext = Column(String, comment='文件扩展名')
|
||||||
|
kb_name = Column(String, comment='所属知识库名称')
|
||||||
|
document_loader_name = Column(String, comment='文档加载器名称')
|
||||||
|
text_splitter_name = Column(String, comment='文本分割器名称')
|
||||||
|
file_version = Column(Integer, comment='文件版本')
|
||||||
|
create_time = Column(DateTime, comment='创建时间')
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<KnowledgeFile(id='{self.id}', file_name='{self.file_name}', file_ext='{self.file_ext}', kb_name='{self.kb_name}', document_loader_name='{self.document_loader_name}', text_splitter_name='{self.text_splitter_name}', file_version='{self.file_version}', create_time='{self.create_time}')>"
|
||||||
|
|
@ -0,0 +1,42 @@
|
||||||
|
from server.db.models.knowledge_base_model import KnowledgeBaseModel
|
||||||
|
from server.db.session import with_session
|
||||||
|
|
||||||
|
|
||||||
|
@with_session
|
||||||
|
def add_kb_to_db(session, kb_name, vs_type, embed_model):
|
||||||
|
# 创建知识库实例
|
||||||
|
kb = KnowledgeBaseModel(kb_name=kb_name, vs_type=vs_type, embed_model=embed_model)
|
||||||
|
session.add(kb)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@with_session
|
||||||
|
def list_kbs_from_db(session):
|
||||||
|
kbs = session.query(KnowledgeBaseModel.kb_name).filter(KnowledgeBaseModel.file_count > 0).all()
|
||||||
|
kbs = [kb[0] for kb in kbs]
|
||||||
|
return kbs
|
||||||
|
|
||||||
|
|
||||||
|
@with_session
|
||||||
|
def kb_exists(session, kb_name):
|
||||||
|
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first()
|
||||||
|
status = True if kb else False
|
||||||
|
return status
|
||||||
|
|
||||||
|
|
||||||
|
@with_session
|
||||||
|
def load_kb_from_db(session, kb_name):
|
||||||
|
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first()
|
||||||
|
if kb:
|
||||||
|
kb_name, vs_type, embed_model = kb.kb_name, kb.vs_type, kb.embed_model
|
||||||
|
else:
|
||||||
|
kb_name, vs_type, embed_model = None, None, None
|
||||||
|
return kb_name, vs_type, embed_model
|
||||||
|
|
||||||
|
|
||||||
|
@with_session
|
||||||
|
def delete_kb_from_db(session, kb_name):
|
||||||
|
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first()
|
||||||
|
if kb:
|
||||||
|
session.delete(kb)
|
||||||
|
return True
|
||||||
|
|
@ -0,0 +1,49 @@
|
||||||
|
from server.db.models.knowledge_base_model import KnowledgeBaseModel
|
||||||
|
from server.db.models.knowledge_file_model import KnowledgeFileModel
|
||||||
|
from server.db.session import with_session
|
||||||
|
from server.knowledge_base.knowledge_file import KnowledgeFile
|
||||||
|
|
||||||
|
|
||||||
|
@with_session
|
||||||
|
def list_docs_from_db(session, kb_name):
|
||||||
|
files = session.query(KnowledgeFileModel).filter_by(kb_name=kb_name).all()
|
||||||
|
docs = [f.file_name for f in files]
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
@with_session
|
||||||
|
def add_doc_to_db(session, kb_file: KnowledgeFile):
|
||||||
|
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
|
||||||
|
if kb:
|
||||||
|
# 如果已经存在该文件,则更新文件版本号
|
||||||
|
existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename,
|
||||||
|
kb_name=kb_file.kb_name).first()
|
||||||
|
if existing_file:
|
||||||
|
existing_file.file_version += 1
|
||||||
|
# 否则,添加新文件
|
||||||
|
else:
|
||||||
|
session.add(kb_file)
|
||||||
|
kb.file_count += 1
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@with_session
|
||||||
|
def delete_file_from_db(session, kb_file: KnowledgeFile):
|
||||||
|
existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename,
|
||||||
|
kb_name=kb_file.kb_name).first()
|
||||||
|
if existing_file:
|
||||||
|
session.delete(existing_file)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
|
||||||
|
if kb:
|
||||||
|
kb.file_count -= 1
|
||||||
|
session.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@with_session
|
||||||
|
def doc_exists(session, kb_file: KnowledgeFile):
|
||||||
|
existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename,
|
||||||
|
kb_name=kb_file.kb_name).first()
|
||||||
|
return True if existing_file else False
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from server.db.base import engine, SessionLocal
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def session_scope():
|
||||||
|
"""上下文管理器用于自动获取 Session, 避免错误"""
|
||||||
|
session = SessionLocal()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
session.commit()
|
||||||
|
except:
|
||||||
|
session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
|
def with_session(f):
|
||||||
|
@wraps(f)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
with session_scope() as session:
|
||||||
|
try:
|
||||||
|
result = f(session, *args, **kwargs)
|
||||||
|
session.commit()
|
||||||
|
return result
|
||||||
|
except:
|
||||||
|
session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def get_db() -> SessionLocal:
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def get_db0() -> SessionLocal:
|
||||||
|
db = SessionLocal()
|
||||||
|
return db
|
||||||
|
|
@ -8,9 +8,13 @@ from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
from configs.model_config import (DB_ROOT_PATH, kbs_config, VECTOR_SEARCH_TOP_K,
|
from configs.config import kbs_config
|
||||||
|
from configs.model_config import (VECTOR_SEARCH_TOP_K,
|
||||||
embedding_model_dict, EMBEDDING_DEVICE, EMBEDDING_MODEL)
|
embedding_model_dict, EMBEDDING_DEVICE, EMBEDDING_MODEL)
|
||||||
import datetime
|
|
||||||
|
from server.db.repository.knowledge_base_repository import add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists
|
||||||
|
from server.db.repository.knowledge_file_repository import add_doc_to_db, delete_file_from_db, doc_exists, \
|
||||||
|
list_docs_from_db
|
||||||
from server.knowledge_base.utils import (get_kb_path, get_doc_path)
|
from server.knowledge_base.utils import (get_kb_path, get_doc_path)
|
||||||
from server.knowledge_base.knowledge_file import KnowledgeFile
|
from server.knowledge_base.knowledge_file import KnowledgeFile
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
@ -22,171 +26,11 @@ class SupportedVSType:
|
||||||
DEFAULT = 'default'
|
DEFAULT = 'default'
|
||||||
|
|
||||||
|
|
||||||
def init_db():
|
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
|
||||||
c = conn.cursor()
|
|
||||||
c.execute('''CREATE TABLE if not exists knowledge_base
|
|
||||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
kb_name TEXT,
|
|
||||||
vs_type TEXT,
|
|
||||||
embed_model TEXT,
|
|
||||||
file_count INTEGER,
|
|
||||||
create_time DATETIME) ''')
|
|
||||||
c.execute('''CREATE TABLE if not exists knowledge_files
|
|
||||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
file_name TEXT,
|
|
||||||
file_ext TEXT,
|
|
||||||
kb_name TEXT,
|
|
||||||
document_loader_name TEXT,
|
|
||||||
text_splitter_name TEXT,
|
|
||||||
file_version INTEGER,
|
|
||||||
create_time DATETIME) ''')
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
|
|
||||||
def list_kbs_from_db():
|
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
|
||||||
c = conn.cursor()
|
|
||||||
c.execute(f'''SELECT kb_name
|
|
||||||
FROM knowledge_base
|
|
||||||
WHERE file_count>0 ''')
|
|
||||||
kbs = [i[0] for i in c.fetchall() if i]
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
return kbs
|
|
||||||
|
|
||||||
|
|
||||||
def add_kb_to_db(kb_name, vs_type, embed_model):
|
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
|
||||||
c = conn.cursor()
|
|
||||||
# Insert a row of data
|
|
||||||
c.execute(f"""INSERT INTO knowledge_base
|
|
||||||
(kb_name, vs_type, embed_model, file_count, create_time)
|
|
||||||
VALUES
|
|
||||||
('{kb_name}','{vs_type}','{embed_model}',
|
|
||||||
0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""")
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def kb_exists(kb_name):
|
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
|
||||||
c = conn.cursor()
|
|
||||||
c.execute(f'''SELECT COUNT(*)
|
|
||||||
FROM knowledge_base
|
|
||||||
WHERE kb_name="{kb_name}" ''')
|
|
||||||
status = True if c.fetchone()[0] else False
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
return status
|
|
||||||
|
|
||||||
|
|
||||||
def load_kb_from_db(kb_name):
|
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
|
||||||
c = conn.cursor()
|
|
||||||
c.execute(f'''SELECT kb_name, vs_type, embed_model
|
|
||||||
FROM knowledge_base
|
|
||||||
WHERE kb_name="{kb_name}" ''')
|
|
||||||
resp = c.fetchone()
|
|
||||||
if resp:
|
|
||||||
kb_name, vs_type, embed_model = resp
|
|
||||||
else:
|
|
||||||
kb_name, vs_type, embed_model = None, None, None
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
return kb_name, vs_type, embed_model
|
|
||||||
|
|
||||||
|
|
||||||
def delete_kb_from_db(kb_name):
|
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
|
||||||
c = conn.cursor()
|
|
||||||
c.execute(f'''DELETE
|
|
||||||
FROM knowledge_base
|
|
||||||
WHERE kb_name="{kb_name}" ''')
|
|
||||||
c.execute(f"""DELETE
|
|
||||||
FROM knowledge_files
|
|
||||||
WHERE kb_name="{kb_name}"
|
|
||||||
""")
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def list_docs_from_db(kb_name):
|
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
|
||||||
c = conn.cursor()
|
|
||||||
c.execute(f'''SELECT file_name
|
|
||||||
FROM knowledge_files
|
|
||||||
WHERE kb_name="{kb_name}" ''')
|
|
||||||
kbs = [i[0] for i in c.fetchall() if i]
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
return kbs
|
|
||||||
|
|
||||||
def list_docs_from_folder(kb_name: str):
|
def list_docs_from_folder(kb_name: str):
|
||||||
doc_path = get_doc_path(kb_name)
|
doc_path = get_doc_path(kb_name)
|
||||||
return [file for file in os.listdir(doc_path)
|
return [file for file in os.listdir(doc_path)
|
||||||
if os.path.isfile(os.path.join(doc_path, file))]
|
if os.path.isfile(os.path.join(doc_path, file))]
|
||||||
|
|
||||||
def add_doc_to_db(kb_file: KnowledgeFile):
|
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
|
||||||
c = conn.cursor()
|
|
||||||
# Insert a row of data
|
|
||||||
c.execute(
|
|
||||||
f"""SELECT 1 FROM knowledge_files WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}" """)
|
|
||||||
record_exist = c.fetchone()
|
|
||||||
if record_exist is not None:
|
|
||||||
c.execute(f"""UPDATE knowledge_files
|
|
||||||
SET file_version = file_version + 1
|
|
||||||
WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}"
|
|
||||||
""")
|
|
||||||
else:
|
|
||||||
c.execute(f"""INSERT INTO knowledge_files
|
|
||||||
(file_name, file_ext, kb_name, document_loader_name, text_splitter_name, file_version, create_time)
|
|
||||||
VALUES
|
|
||||||
('{kb_file.filename}','{kb_file.ext}','{kb_file.kb_name}', '{kb_file.document_loader_name}',
|
|
||||||
'{kb_file.text_splitter_name}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""")
|
|
||||||
c.execute(f"""UPDATE knowledge_base
|
|
||||||
SET file_count = file_count + 1
|
|
||||||
WHERE kb_name="{kb_file.kb_name}"
|
|
||||||
""")
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def delete_file_from_db(kb_file: KnowledgeFile):
|
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
|
||||||
c = conn.cursor()
|
|
||||||
# Insert a row of data
|
|
||||||
c.execute(f"""DELETE
|
|
||||||
FROM knowledge_files
|
|
||||||
WHERE file_name="{kb_file.filename}"
|
|
||||||
AND kb_name="{kb_file.kb_name}"
|
|
||||||
""")
|
|
||||||
c.execute(f"""UPDATE knowledge_base
|
|
||||||
SET file_count = file_count - 1
|
|
||||||
WHERE kb_name="{kb_file.kb_name}"
|
|
||||||
""")
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def doc_exists(kb_file: KnowledgeFile):
|
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
|
||||||
c = conn.cursor()
|
|
||||||
c.execute(f'''SELECT COUNT(*)
|
|
||||||
FROM knowledge_files
|
|
||||||
WHERE file_name="{kb_file.filename}"
|
|
||||||
AND kb_name="{kb_file.kb_name}" ''')
|
|
||||||
status = True if c.fetchone()[0] else False
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
return status
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(1)
|
@lru_cache(1)
|
||||||
def load_embeddings(model: str, device: str):
|
def load_embeddings(model: str, device: str):
|
||||||
|
|
@ -323,7 +167,7 @@ class KBService(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def do_delete_doc(self,
|
def do_delete_doc(self,
|
||||||
kb_file: KnowledgeFile):
|
kb_file: KnowledgeFile):
|
||||||
"""
|
"""
|
||||||
从知识库删除文档子类实自己逻辑
|
从知识库删除文档子类实自己逻辑
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,24 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.schema import Document
|
||||||
|
|
||||||
from server.knowledge_base.kb_service.base import KBService
|
from server.knowledge_base.kb_service.base import KBService
|
||||||
|
|
||||||
|
|
||||||
class DefaultKBService(KBService):
|
class DefaultKBService(KBService):
|
||||||
|
def do_create_kb(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def do_drop_kb(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def do_add_doc(self, docs: List[Document], embeddings: Embeddings):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def do_clear_vs(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def vs_type(self) -> str:
|
def vs_type(self) -> str:
|
||||||
return "default"
|
return "default"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,10 @@ from langchain.embeddings.base import Embeddings
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
from langchain.vectorstores import Milvus
|
from langchain.vectorstores import Milvus
|
||||||
|
|
||||||
from configs.model_config import EMBEDDING_DEVICE, kbs_config
|
from configs.config import kbs_config
|
||||||
|
from configs.model_config import EMBEDDING_DEVICE
|
||||||
from server.knowledge_base import KnowledgeFile
|
from server.knowledge_base import KnowledgeFile
|
||||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings, add_doc_to_db
|
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings
|
||||||
|
|
||||||
|
|
||||||
class MilvusKBService(KBService):
|
class MilvusKBService(KBService):
|
||||||
|
|
@ -55,6 +56,7 @@ class MilvusKBService(KBService):
|
||||||
"""
|
"""
|
||||||
docs = kb_file.file2text()
|
docs = kb_file.file2text()
|
||||||
self.milvus.add_documents(docs)
|
self.milvus.add_documents(docs)
|
||||||
|
from server.db.repository.knowledge_file_repository import add_doc_to_db
|
||||||
status = add_doc_to_db(kb_file)
|
status = add_doc_to_db(kb_file)
|
||||||
return status
|
return status
|
||||||
|
|
||||||
|
|
@ -72,8 +74,11 @@ class MilvusKBService(KBService):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
# 测试建表使用
|
||||||
|
from server.db.base import Base, engine
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
milvusService = MilvusKBService("test")
|
milvusService = MilvusKBService("test")
|
||||||
# milvusService.add_doc(KnowledgeFile("test.pdf", "test"))
|
milvusService.add_doc(KnowledgeFile("test.pdf", "test"))
|
||||||
# milvusService.delete_doc(KnowledgeFile("test.pdf", "test"))
|
milvusService.delete_doc(KnowledgeFile("test.pdf", "test"))
|
||||||
milvusService.do_drop_kb()
|
milvusService.do_drop_kb()
|
||||||
print(milvusService.search_docs("测试"))
|
print(milvusService.search_docs("测试"))
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, init_db, load_kb_from_db
|
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||||
|
from server.db.repository.knowledge_base_repository import load_kb_from_db
|
||||||
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
|
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
|
||||||
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
|
||||||
|
|
||||||
|
|
||||||
class KBServiceFactory:
|
class KBServiceFactory:
|
||||||
|
|
@ -21,7 +21,7 @@ class KBServiceFactory:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_service_by_name(kb_name: str
|
def get_service_by_name(kb_name: str
|
||||||
) -> KBService:
|
) -> KBService:
|
||||||
kb_name, vs_type = load_kb_from_db(kb_name)
|
kb_name, vs_type, _ = load_kb_from_db(kb_name)
|
||||||
return KBServiceFactory.get_service(kb_name, vs_type)
|
return KBServiceFactory.get_service(kb_name, vs_type)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -30,8 +30,10 @@ class KBServiceFactory:
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
# 测试建表使用
|
||||||
|
# from server.db.base import Base, engine
|
||||||
|
# Base.metadata.create_all(bind=engine)
|
||||||
KBService = KBServiceFactory.get_service("test", SupportedVSType.FAISS)
|
KBService = KBServiceFactory.get_service("test", SupportedVSType.FAISS)
|
||||||
init_db()
|
|
||||||
KBService.create_kb()
|
KBService.create_kb()
|
||||||
KBService = KBServiceFactory.get_default()
|
KBService = KBServiceFactory.get_default()
|
||||||
print(KBService.list_kbs())
|
print(KBService.list_kbs())
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue