修复:切换embed_model时,FAISS向量库未正确释放,导致`d == self.d assert (#1766)

修复:
- 切换embed_model时,FAISS向量库未正确释放,导致d == self.d assert error (close [求助] 初始化知识库发生错误 #1657 )
- ApiRequest中chat接口增加max_tokens参数
- FileDocModel模型字段错误(字段错误 #1691)
This commit is contained in:
liunux4odoo 2023-10-16 21:02:07 +08:00 committed by GitHub
parent cd748128c3
commit 94977c7ab1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 22 additions and 12 deletions

View File

@ -37,4 +37,4 @@ class FileDocModel(Base):
meta_data = Column(JSON, default={}) meta_data = Column(JSON, default={})
def __repr__(self): def __repr__(self):
return f"<FileDoc(id='{self.id}', kb_name='{self.kb_name}', file_name='{self.file_name}', doc_id='{self.doc_id}', metadata='{self.metadata}')>" return f"<FileDoc(id='{self.id}', kb_name='{self.kb_name}', file_name='{self.file_name}', doc_id='{self.doc_id}', metadata='{self.meta_data}')>"

View File

@ -337,8 +337,9 @@ def recreate_vector_store(
if not kb.exists() and not allow_empty_kb: if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"} yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else: else:
if kb.exists():
kb.clear_vs()
kb.create_kb() kb.create_kb()
kb.clear_vs()
files = list_files_from_folder(knowledge_base_name) files = list_files_from_folder(knowledge_base_name)
kb_files = [(file, knowledge_base_name) for file in files] kb_files = [(file, knowledge_base_name) for file in files]
i = 0 i = 0

View File

@ -89,7 +89,7 @@ class FaissKBService(KBService):
def do_clear_vs(self): def do_clear_vs(self):
with kb_faiss_pool.atomic: with kb_faiss_pool.atomic:
kb_faiss_pool.pop(self.kb_name) kb_faiss_pool.pop((self.kb_name, self.vector_name))
shutil.rmtree(self.vs_path) shutil.rmtree(self.vs_path)
os.makedirs(self.vs_path) os.makedirs(self.vs_path)

View File

@ -72,6 +72,7 @@ def folder2db(
# 清除向量库,从本地文件重建 # 清除向量库,从本地文件重建
if mode == "recreate_vs": if mode == "recreate_vs":
kb.clear_vs() kb.clear_vs()
kb.create_kb()
kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name)) kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name))
files2vs(kb_name, kb_files) files2vs(kb_name, kb_files)
kb.save_vector_store() kb.save_vector_store()

View File

@ -275,7 +275,7 @@ class ApiRequest:
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL, model: str = LLM_MODEL,
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = 1024, # todo:根据message内容自动计算max_tokens max_tokens: int = 1024, # TODO:根据message内容自动计算max_tokens
no_remote_api: bool = None, no_remote_api: bool = None,
**kwargs: Any, **kwargs: Any,
): ):
@ -316,6 +316,7 @@ class ApiRequest:
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL, model: str = LLM_MODEL,
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = 1024,
prompt_name: str = "llm_chat", prompt_name: str = "llm_chat",
no_remote_api: bool = None, no_remote_api: bool = None,
): ):
@ -331,6 +332,7 @@ class ApiRequest:
"stream": stream, "stream": stream,
"model_name": model, "model_name": model,
"temperature": temperature, "temperature": temperature,
"max_tokens": max_tokens,
"prompt_name": prompt_name, "prompt_name": prompt_name,
} }
@ -346,13 +348,14 @@ class ApiRequest:
return self._httpx_stream2generator(response) return self._httpx_stream2generator(response)
def agent_chat( def agent_chat(
self, self,
query: str, query: str,
history: List[Dict] = [], history: List[Dict] = [],
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL, model: str = LLM_MODEL,
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
no_remote_api: bool = None, max_tokens: int = 1024,
no_remote_api: bool = None,
): ):
''' '''
对应api.py/chat/agent_chat 接口 对应api.py/chat/agent_chat 接口
@ -366,6 +369,7 @@ class ApiRequest:
"stream": stream, "stream": stream,
"model_name": model, "model_name": model,
"temperature": temperature, "temperature": temperature,
"max_tokens": max_tokens,
} }
print(f"received input message:") print(f"received input message:")
@ -389,6 +393,7 @@ class ApiRequest:
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL, model: str = LLM_MODEL,
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = 1024,
prompt_name: str = "knowledge_base_chat", prompt_name: str = "knowledge_base_chat",
no_remote_api: bool = None, no_remote_api: bool = None,
): ):
@ -407,6 +412,7 @@ class ApiRequest:
"stream": stream, "stream": stream,
"model_name": model, "model_name": model,
"temperature": temperature, "temperature": temperature,
"max_tokens": max_tokens,
"local_doc_url": no_remote_api, "local_doc_url": no_remote_api,
"prompt_name": prompt_name, "prompt_name": prompt_name,
} }
@ -435,6 +441,7 @@ class ApiRequest:
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL, model: str = LLM_MODEL,
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = 1024,
prompt_name: str = "knowledge_base_chat", prompt_name: str = "knowledge_base_chat",
no_remote_api: bool = None, no_remote_api: bool = None,
): ):
@ -452,6 +459,7 @@ class ApiRequest:
"stream": stream, "stream": stream,
"model_name": model, "model_name": model,
"temperature": temperature, "temperature": temperature,
"max_tokens": max_tokens,
"prompt_name": prompt_name, "prompt_name": prompt_name,
} }
@ -475,7 +483,7 @@ class ApiRequest:
def _check_httpx_json_response( def _check_httpx_json_response(
self, self,
response: httpx.Response, response: httpx.Response,
errorMsg: str = f"无法连接API服务器请确认已执行python server\\api.py", errorMsg: str = f"无法连接API服务器请确认 api.py 已正常启动。",
) -> Dict: ) -> Dict:
''' '''
check whether httpx returns correct data with normal Response. check whether httpx returns correct data with normal Response.