From ca0ae29fef0eee36f427dbb92d9a16da2f901c57 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Mon, 28 Aug 2023 16:03:53 +0800 Subject: [PATCH] =?UTF-8?q?=E8=BF=90=E8=A1=8Cstartup.py=E6=97=B6=EF=BC=8C?= =?UTF-8?q?=E5=A6=82=E6=9E=9C=E4=B8=8D=E5=8A=A0=E5=8F=82=E6=95=B0=E7=9B=B4?= =?UTF-8?q?=E6=8E=A5=E6=98=BE=E7=A4=BA=E9=85=8D=E7=BD=AE=E5=92=8C=E5=B8=AE?= =?UTF-8?q?=E5=8A=A9=E4=BF=A1=E6=81=AF=E5=90=8E=E9=80=80=E5=87=BA=20(#1284?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 统一XX_kb_service.add_doc/do_add_doc接口,不再需要embeddings参数 * 运行startup.py时,如果不加参数直接显示配置和帮助信息后退出 --------- Co-authored-by: liunux4odoo --- server/knowledge_base/kb_service/base.py | 4 +- .../kb_service/default_kb_service.py | 2 +- .../kb_service/faiss_kb_service.py | 1 - .../kb_service/milvus_kb_service.py | 2 +- .../kb_service/pg_kb_service.py | 2 +- startup.py | 60 ++++++++++--------- 6 files changed, 37 insertions(+), 34 deletions(-) diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index e3a0743..ca8d1ae 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -87,8 +87,7 @@ class KBService(ABC): if docs: self.delete_doc(kb_file) - embeddings = self._load_embeddings() - self.do_add_doc(docs, embeddings=embeddings, **kwargs) + self.do_add_doc(docs, **kwargs) status = add_file_to_db(kb_file, custom_docs=custom_docs, docs_count=len(docs)) else: status = False @@ -181,7 +180,6 @@ class KBService(ABC): @abstractmethod def do_add_doc(self, docs: List[Document], - embeddings: Embeddings, ): """ 向知识库添加文档子类实自己逻辑 diff --git a/server/knowledge_base/kb_service/default_kb_service.py b/server/knowledge_base/kb_service/default_kb_service.py index 922e39b..a68d59c 100644 --- a/server/knowledge_base/kb_service/default_kb_service.py +++ b/server/knowledge_base/kb_service/default_kb_service.py @@ -13,7 +13,7 @@ class DefaultKBService(KBService): def do_drop_kb(self): pass - def do_add_doc(self, docs: List[Document], embeddings: Embeddings): + def do_add_doc(self, docs: List[Document]): pass def do_clear_vs(self): diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 936ffad..3601b57 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -109,7 +109,6 @@ class FaissKBService(KBService): def do_add_doc(self, docs: List[Document], - embeddings: Embeddings, **kwargs, ): vector_store = self.load_vector_store() diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index dc916c9..ac7712b 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -54,7 +54,7 @@ class MilvusKBService(KBService): self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings)) return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k)) - def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs): + def do_add_doc(self, docs: List[Document], **kwargs): self.milvus.add_documents(docs) def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index 906818f..4b52625 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -51,7 +51,7 @@ class PGKBService(KBService): return score_threshold_process(score_threshold, top_k, self.pg_vector.similarity_search_with_score(query, top_k)) - def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs): + def do_add_doc(self, docs: List[Document], **kwargs): self.pg_vector.add_documents(docs) def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): diff --git a/startup.py b/startup.py index 2ca62d9..46b886e 100644 --- a/startup.py +++ b/startup.py @@ -321,7 +321,7 @@ def parse_args() -> argparse.ArgumentParser: dest="webui", ) args = parser.parse_args() - return args + return args, parser def dump_server_info(after_start=False): @@ -330,7 +330,7 @@ def dump_server_info(after_start=False): import fastchat from configs.server_config import api_address, webui_address - print("\n\n") + print("\n") print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30) print(f"操作系统:{platform.platform()}.") print(f"python版本:{sys.version}") @@ -351,7 +351,7 @@ def dump_server_info(after_start=False): if args.webui: print(f" Chatchat WEBUI Server: {webui_address()}") print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30) - print("\n\n") + print("\n") if __name__ == "__main__": @@ -359,7 +359,7 @@ if __name__ == "__main__": mp.set_start_method("spawn") queue = Queue() - args = parse_args() + args, parser = parse_args() if args.all_webui: args.openai_api = True args.model_worker = True @@ -379,8 +379,10 @@ if __name__ == "__main__": args.webui = False dump_server_info() - logger.info(f"正在启动服务:") - logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}") + + if len(sys.argv) > 1: + logger.info(f"正在启动服务:") + logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}") processes = {} @@ -435,28 +437,32 @@ if __name__ == "__main__": process.start() processes["webui"] = process - try: - # log infors - while True: - no = queue.get() - if no == len(processes): - time.sleep(0.5) - dump_server_info(True) - break - else: - queue.put(no) + if len(processes) == 0: + parser.print_help() + else: + try: + # log infors + while True: + no = queue.get() + if no == len(processes): + time.sleep(0.5) + dump_server_info(True) + break + else: + queue.put(no) + + if model_worker_process := processes.get("model_worker"): + model_worker_process.join() + for name, process in processes.items(): + if name != "model_worker": + process.join() + except: + if model_worker_process := processes.get("model_worker"): + model_worker_process.terminate() + for name, process in processes.items(): + if name != "model_worker": + process.terminate() - if model_worker_process := processes.get("model_worker"): - model_worker_process.join() - for name, process in processes.items(): - if name != "model_worker": - process.join() - except: - if model_worker_process := processes.get("model_worker"): - model_worker_process.terminate() - for name, process in processes.items(): - if name != "model_worker": - process.terminate() # 服务启动后接口调用示例: # import openai