From 94977c7ab1bea0c43351bba220d8dc7910fd9fe1 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Mon, 16 Oct 2023 21:02:07 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=EF=BC=9A=E5=88=87=E6=8D=A2em?= =?UTF-8?q?bed=5Fmodel=E6=97=B6=EF=BC=8CFAISS=E5=90=91=E9=87=8F=E5=BA=93?= =?UTF-8?q?=E6=9C=AA=E6=AD=A3=E7=A1=AE=E9=87=8A=E6=94=BE=EF=BC=8C=E5=AF=BC?= =?UTF-8?q?=E8=87=B4`d=20=3D=3D=20self.d=20assert=20(#1766)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复: - 切换embed_model时,FAISS向量库未正确释放,导致d == self.d assert error (close [求助] 初始化知识库发生错误 #1657 ) - ApiRequest中chat接口增加max_tokens参数 - FileDocModel模型字段错误(字段错误 #1691) --- server/db/models/knowledge_file_model.py | 2 +- server/knowledge_base/kb_doc_api.py | 3 ++- .../kb_service/faiss_kb_service.py | 2 +- server/knowledge_base/migrate.py | 1 + webui_pages/utils.py | 26 ++++++++++++------- 5 files changed, 22 insertions(+), 12 deletions(-) diff --git a/server/db/models/knowledge_file_model.py b/server/db/models/knowledge_file_model.py index c5784d1..2937dfa 100644 --- a/server/db/models/knowledge_file_model.py +++ b/server/db/models/knowledge_file_model.py @@ -37,4 +37,4 @@ class FileDocModel(Base): meta_data = Column(JSON, default={}) def __repr__(self): - return f"" + return f"" diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 75d9330..e158ad0 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -337,8 +337,9 @@ def recreate_vector_store( if not kb.exists() and not allow_empty_kb: yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"} else: + if kb.exists(): + kb.clear_vs() kb.create_kb() - kb.clear_vs() files = list_files_from_folder(knowledge_base_name) kb_files = [(file, knowledge_base_name) for file in files] i = 0 diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index e37c93e..a72fcf7 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -89,7 +89,7 @@ class FaissKBService(KBService): def do_clear_vs(self): with kb_faiss_pool.atomic: - kb_faiss_pool.pop(self.kb_name) + kb_faiss_pool.pop((self.kb_name, self.vector_name)) shutil.rmtree(self.vs_path) os.makedirs(self.vs_path) diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index 2ca144a..499e915 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -72,6 +72,7 @@ def folder2db( # 清除向量库,从本地文件重建 if mode == "recreate_vs": kb.clear_vs() + kb.create_kb() kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name)) files2vs(kb_name, kb_files) kb.save_vector_store() diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 4a26f20..6ba4661 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -275,7 +275,7 @@ class ApiRequest: stream: bool = True, model: str = LLM_MODEL, temperature: float = TEMPERATURE, - max_tokens: int = 1024, # todo:根据message内容自动计算max_tokens + max_tokens: int = 1024, # TODO:根据message内容自动计算max_tokens no_remote_api: bool = None, **kwargs: Any, ): @@ -316,6 +316,7 @@ class ApiRequest: stream: bool = True, model: str = LLM_MODEL, temperature: float = TEMPERATURE, + max_tokens: int = 1024, prompt_name: str = "llm_chat", no_remote_api: bool = None, ): @@ -331,6 +332,7 @@ class ApiRequest: "stream": stream, "model_name": model, "temperature": temperature, + "max_tokens": max_tokens, "prompt_name": prompt_name, } @@ -346,13 +348,14 @@ class ApiRequest: return self._httpx_stream2generator(response) def agent_chat( - self, - query: str, - history: List[Dict] = [], - stream: bool = True, - model: str = LLM_MODEL, - temperature: float = TEMPERATURE, - no_remote_api: bool = None, + self, + query: str, + history: List[Dict] = [], + stream: bool = True, + model: str = LLM_MODEL, + temperature: float = TEMPERATURE, + max_tokens: int = 1024, + no_remote_api: bool = None, ): ''' 对应api.py/chat/agent_chat 接口 @@ -366,6 +369,7 @@ class ApiRequest: "stream": stream, "model_name": model, "temperature": temperature, + "max_tokens": max_tokens, } print(f"received input message:") @@ -389,6 +393,7 @@ class ApiRequest: stream: bool = True, model: str = LLM_MODEL, temperature: float = TEMPERATURE, + max_tokens: int = 1024, prompt_name: str = "knowledge_base_chat", no_remote_api: bool = None, ): @@ -407,6 +412,7 @@ class ApiRequest: "stream": stream, "model_name": model, "temperature": temperature, + "max_tokens": max_tokens, "local_doc_url": no_remote_api, "prompt_name": prompt_name, } @@ -435,6 +441,7 @@ class ApiRequest: stream: bool = True, model: str = LLM_MODEL, temperature: float = TEMPERATURE, + max_tokens: int = 1024, prompt_name: str = "knowledge_base_chat", no_remote_api: bool = None, ): @@ -452,6 +459,7 @@ class ApiRequest: "stream": stream, "model_name": model, "temperature": temperature, + "max_tokens": max_tokens, "prompt_name": prompt_name, } @@ -475,7 +483,7 @@ class ApiRequest: def _check_httpx_json_response( self, response: httpx.Response, - errorMsg: str = f"无法连接API服务器,请确认已执行python server\\api.py", + errorMsg: str = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。", ) -> Dict: ''' check whether httpx returns correct data with normal Response.