From 0d9db37f45bbc4931898a71c564a3c9887cff249 Mon Sep 17 00:00:00 2001 From: shrimp <411161555@qq.com> Date: Fri, 5 May 2023 22:31:41 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84API=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=EF=BC=8C=E5=AE=8C=E5=96=84=E6=A8=A1=E5=9E=8B=E5=8A=A0=E8=BD=BD?= =?UTF-8?q?=20(#247)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 完善知识库路径问题,完善api接口 统一webui、API接口知识库路径,后续路径如下: 知识库路经就是:/项目代码文件夹/vector_store/'知识库名字' 文件存放路经:/项目代码文件夹/content/'知识库名字' 修复通过api接口创建知识库的BUG,完善API接口功能。 * Update model_config.py * 完善知识库路径问题,完善api接口 (#245) (#246) * Fix 知识库无法上载,NLTK_DATA_PATH路径错误 (#236) * Update chatglm_llm.py (#242) * 完善知识库路径问题,完善api接口 统一webui、API接口知识库路径,后续路径如下: 知识库路经就是:/项目代码文件夹/vector_store/'知识库名字' 文件存放路经:/项目代码文件夹/content/'知识库名字' 修复通过api接口创建知识库的BUG,完善API接口功能。 * Update model_config.py --------- Co-authored-by: shrimp <411161555@qq.com> Co-authored-by: Bob Chang * 优化API接口,完善模型top_p参数 优化API接口,知识库非必须选项。 完善模型top_p参数 * 完善API接口,完善模型加载 API接口知识库非必须加载项 完善模型top_p参数。 --------- Co-authored-by: imClumsyPanda Co-authored-by: Bob Chang --- api.py | 38 +++++++++++++++++++++----------------- models/chatglm_llm.py | 6 ++++-- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/api.py b/api.py index 6f09ae5..90ba1dc 100644 --- a/api.py +++ b/api.py @@ -170,32 +170,36 @@ async def delete_docs( async def chat( - knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"), - question: str = Body(..., description="Question", example="工伤保险是什么?"), + knowledge_base_id: str = Body(..., description="知识库名字", example="kb1"), + question: str = Body(..., description="问题", example="工伤保险是什么?"), history: List[List[str]] = Body( [], - description="History of previous questions and answers", + description="问题及答案的历史记录", example=[ [ - "工伤保险是什么?", - "工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。", + "这里是问题,如:工伤保险是什么?", + "答案:工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。", ] ], ), ): vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id) - if not os.path.exists(vs_path): - raise ValueError(f"Knowledge base {knowledge_base_id} not found") - - for resp, history in local_doc_qa.get_knowledge_based_answer( - query=question, vs_path=vs_path, chat_history=history, streaming=True - ): - pass - source_documents = [ - f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" - f"""相关度:{doc.metadata['score']}\n\n""" - for inum, doc in enumerate(resp["source_documents"]) - ] + resp = {} + if os.path.exists(vs_path) and knowledge_base_id: + for resp, history in local_doc_qa.get_knowledge_based_answer( + query=question, vs_path=vs_path, chat_history=history, streaming=False + ): + pass + source_documents = [ + f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" + f"""相关度:{doc.metadata['score']}\n\n""" + for inum, doc in enumerate(resp["source_documents"]) + ] + else: + for resp_s, history in local_doc_qa.llm._call(prompt=question, history=history, streaming=False): + pass + resp["result"] = resp_s + source_documents =[("当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。")] return ChatMessage( question=question, diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py index 60abf86..e0b1ae5 100644 --- a/models/chatglm_llm.py +++ b/models/chatglm_llm.py @@ -43,7 +43,7 @@ def auto_configure_device_map(num_gpus: int) -> Dict[str, int]: class ChatGLM(LLM): max_token: int = 10000 - temperature: float = 0.01 + temperature: float = 0.8 top_p = 0.9 # history = [] tokenizer: object = None @@ -68,6 +68,7 @@ class ChatGLM(LLM): history=history[-self.history_len:-1] if self.history_len > 0 else [], max_length=self.max_token, temperature=self.temperature, + top_p=self.top_p, )): torch_gc() if inum == 0: @@ -83,6 +84,7 @@ class ChatGLM(LLM): history=history[-self.history_len:] if self.history_len > 0 else [], max_length=self.max_token, temperature=self.temperature, + top_p=self.top_p, ) torch_gc() history += [[prompt, response]] @@ -141,7 +143,7 @@ class ChatGLM(LLM): from accelerate import dispatch_model model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True, - config=model_config, **kwargs) + config=model_config, **kwargs) if LLM_LORA_PATH and use_lora: from peft import PeftModel model = PeftModel.from_pretrained(model, LLM_LORA_PATH)