diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index e3a0743..ca8d1ae 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -87,8 +87,7 @@ class KBService(ABC): if docs: self.delete_doc(kb_file) - embeddings = self._load_embeddings() - self.do_add_doc(docs, embeddings=embeddings, **kwargs) + self.do_add_doc(docs, **kwargs) status = add_file_to_db(kb_file, custom_docs=custom_docs, docs_count=len(docs)) else: status = False @@ -181,7 +180,6 @@ class KBService(ABC): @abstractmethod def do_add_doc(self, docs: List[Document], - embeddings: Embeddings, ): """ 向知识库添加文档子类实自己逻辑 diff --git a/server/knowledge_base/kb_service/default_kb_service.py b/server/knowledge_base/kb_service/default_kb_service.py index 922e39b..a68d59c 100644 --- a/server/knowledge_base/kb_service/default_kb_service.py +++ b/server/knowledge_base/kb_service/default_kb_service.py @@ -13,7 +13,7 @@ class DefaultKBService(KBService): def do_drop_kb(self): pass - def do_add_doc(self, docs: List[Document], embeddings: Embeddings): + def do_add_doc(self, docs: List[Document]): pass def do_clear_vs(self): diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 936ffad..3601b57 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -109,7 +109,6 @@ class FaissKBService(KBService): def do_add_doc(self, docs: List[Document], - embeddings: Embeddings, **kwargs, ): vector_store = self.load_vector_store() diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index dc916c9..ac7712b 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -54,7 +54,7 @@ class MilvusKBService(KBService): self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings)) return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k)) - def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs): + def do_add_doc(self, docs: List[Document], **kwargs): self.milvus.add_documents(docs) def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index 906818f..4b52625 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -51,7 +51,7 @@ class PGKBService(KBService): return score_threshold_process(score_threshold, top_k, self.pg_vector.similarity_search_with_score(query, top_k)) - def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs): + def do_add_doc(self, docs: List[Document], **kwargs): self.pg_vector.add_documents(docs) def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): diff --git a/startup.py b/startup.py index 2ca62d9..46b886e 100644 --- a/startup.py +++ b/startup.py @@ -321,7 +321,7 @@ def parse_args() -> argparse.ArgumentParser: dest="webui", ) args = parser.parse_args() - return args + return args, parser def dump_server_info(after_start=False): @@ -330,7 +330,7 @@ def dump_server_info(after_start=False): import fastchat from configs.server_config import api_address, webui_address - print("\n\n") + print("\n") print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30) print(f"操作系统:{platform.platform()}.") print(f"python版本:{sys.version}") @@ -351,7 +351,7 @@ def dump_server_info(after_start=False): if args.webui: print(f" Chatchat WEBUI Server: {webui_address()}") print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30) - print("\n\n") + print("\n") if __name__ == "__main__": @@ -359,7 +359,7 @@ if __name__ == "__main__": mp.set_start_method("spawn") queue = Queue() - args = parse_args() + args, parser = parse_args() if args.all_webui: args.openai_api = True args.model_worker = True @@ -379,8 +379,10 @@ if __name__ == "__main__": args.webui = False dump_server_info() - logger.info(f"正在启动服务:") - logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}") + + if len(sys.argv) > 1: + logger.info(f"正在启动服务:") + logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}") processes = {} @@ -435,28 +437,32 @@ if __name__ == "__main__": process.start() processes["webui"] = process - try: - # log infors - while True: - no = queue.get() - if no == len(processes): - time.sleep(0.5) - dump_server_info(True) - break - else: - queue.put(no) + if len(processes) == 0: + parser.print_help() + else: + try: + # log infors + while True: + no = queue.get() + if no == len(processes): + time.sleep(0.5) + dump_server_info(True) + break + else: + queue.put(no) + + if model_worker_process := processes.get("model_worker"): + model_worker_process.join() + for name, process in processes.items(): + if name != "model_worker": + process.join() + except: + if model_worker_process := processes.get("model_worker"): + model_worker_process.terminate() + for name, process in processes.items(): + if name != "model_worker": + process.terminate() - if model_worker_process := processes.get("model_worker"): - model_worker_process.join() - for name, process in processes.items(): - if name != "model_worker": - process.join() - except: - if model_worker_process := processes.get("model_worker"): - model_worker_process.terminate() - for name, process in processes.items(): - if name != "model_worker": - process.terminate() # 服务启动后接口调用示例: # import openai