运行startup.py时,如果不加参数直接显示配置和帮助信息后退出 (#1284)

* 统一XX_kb_service.add_doc/do_add_doc接口,不再需要embeddings参数

* 运行startup.py时,如果不加参数直接显示配置和帮助信息后退出

---------

Co-authored-by: liunux4odoo <liunu@qq.com>
This commit is contained in:
liunux4odoo 2023-08-28 16:03:53 +08:00 committed by GitHub
parent 3acbf4d5d1
commit ca0ae29fef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 37 additions and 34 deletions

View File

@ -87,8 +87,7 @@ class KBService(ABC):
if docs: if docs:
self.delete_doc(kb_file) self.delete_doc(kb_file)
embeddings = self._load_embeddings() self.do_add_doc(docs, **kwargs)
self.do_add_doc(docs, embeddings=embeddings, **kwargs)
status = add_file_to_db(kb_file, custom_docs=custom_docs, docs_count=len(docs)) status = add_file_to_db(kb_file, custom_docs=custom_docs, docs_count=len(docs))
else: else:
status = False status = False
@ -181,7 +180,6 @@ class KBService(ABC):
@abstractmethod @abstractmethod
def do_add_doc(self, def do_add_doc(self,
docs: List[Document], docs: List[Document],
embeddings: Embeddings,
): ):
""" """
向知识库添加文档子类实自己逻辑 向知识库添加文档子类实自己逻辑

View File

@ -13,7 +13,7 @@ class DefaultKBService(KBService):
def do_drop_kb(self): def do_drop_kb(self):
pass pass
def do_add_doc(self, docs: List[Document], embeddings: Embeddings): def do_add_doc(self, docs: List[Document]):
pass pass
def do_clear_vs(self): def do_clear_vs(self):

View File

@ -109,7 +109,6 @@ class FaissKBService(KBService):
def do_add_doc(self, def do_add_doc(self,
docs: List[Document], docs: List[Document],
embeddings: Embeddings,
**kwargs, **kwargs,
): ):
vector_store = self.load_vector_store() vector_store = self.load_vector_store()

View File

@ -54,7 +54,7 @@ class MilvusKBService(KBService):
self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings)) self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings))
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k)) return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs): def do_add_doc(self, docs: List[Document], **kwargs):
self.milvus.add_documents(docs) self.milvus.add_documents(docs)
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):

View File

@ -51,7 +51,7 @@ class PGKBService(KBService):
return score_threshold_process(score_threshold, top_k, return score_threshold_process(score_threshold, top_k,
self.pg_vector.similarity_search_with_score(query, top_k)) self.pg_vector.similarity_search_with_score(query, top_k))
def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs): def do_add_doc(self, docs: List[Document], **kwargs):
self.pg_vector.add_documents(docs) self.pg_vector.add_documents(docs)
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):

View File

@ -321,7 +321,7 @@ def parse_args() -> argparse.ArgumentParser:
dest="webui", dest="webui",
) )
args = parser.parse_args() args = parser.parse_args()
return args return args, parser
def dump_server_info(after_start=False): def dump_server_info(after_start=False):
@ -330,7 +330,7 @@ def dump_server_info(after_start=False):
import fastchat import fastchat
from configs.server_config import api_address, webui_address from configs.server_config import api_address, webui_address
print("\n\n") print("\n")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30) print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print(f"操作系统:{platform.platform()}.") print(f"操作系统:{platform.platform()}.")
print(f"python版本{sys.version}") print(f"python版本{sys.version}")
@ -351,7 +351,7 @@ def dump_server_info(after_start=False):
if args.webui: if args.webui:
print(f" Chatchat WEBUI Server: {webui_address()}") print(f" Chatchat WEBUI Server: {webui_address()}")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30) print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print("\n\n") print("\n")
if __name__ == "__main__": if __name__ == "__main__":
@ -359,7 +359,7 @@ if __name__ == "__main__":
mp.set_start_method("spawn") mp.set_start_method("spawn")
queue = Queue() queue = Queue()
args = parse_args() args, parser = parse_args()
if args.all_webui: if args.all_webui:
args.openai_api = True args.openai_api = True
args.model_worker = True args.model_worker = True
@ -379,8 +379,10 @@ if __name__ == "__main__":
args.webui = False args.webui = False
dump_server_info() dump_server_info()
logger.info(f"正在启动服务:")
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}") if len(sys.argv) > 1:
logger.info(f"正在启动服务:")
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
processes = {} processes = {}
@ -435,28 +437,32 @@ if __name__ == "__main__":
process.start() process.start()
processes["webui"] = process processes["webui"] = process
try: if len(processes) == 0:
# log infors parser.print_help()
while True: else:
no = queue.get() try:
if no == len(processes): # log infors
time.sleep(0.5) while True:
dump_server_info(True) no = queue.get()
break if no == len(processes):
else: time.sleep(0.5)
queue.put(no) dump_server_info(True)
break
else:
queue.put(no)
if model_worker_process := processes.get("model_worker"):
model_worker_process.join()
for name, process in processes.items():
if name != "model_worker":
process.join()
except:
if model_worker_process := processes.get("model_worker"):
model_worker_process.terminate()
for name, process in processes.items():
if name != "model_worker":
process.terminate()
if model_worker_process := processes.get("model_worker"):
model_worker_process.join()
for name, process in processes.items():
if name != "model_worker":
process.join()
except:
if model_worker_process := processes.get("model_worker"):
model_worker_process.terminate()
for name, process in processes.items():
if name != "model_worker":
process.terminate()
# 服务启动后接口调用示例: # 服务启动后接口调用示例:
# import openai # import openai