diff --git a/api.py b/api.py index 6f09ae5..90ba1dc 100644 --- a/api.py +++ b/api.py @@ -170,32 +170,36 @@ async def delete_docs( async def chat( - knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"), - question: str = Body(..., description="Question", example="工伤保险是什么?"), + knowledge_base_id: str = Body(..., description="知识库名字", example="kb1"), + question: str = Body(..., description="问题", example="工伤保险是什么?"), history: List[List[str]] = Body( [], - description="History of previous questions and answers", + description="问题及答案的历史记录", example=[ [ - "工伤保险是什么?", - "工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。", + "这里是问题,如:工伤保险是什么?", + "答案:工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。", ] ], ), ): vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id) - if not os.path.exists(vs_path): - raise ValueError(f"Knowledge base {knowledge_base_id} not found") - - for resp, history in local_doc_qa.get_knowledge_based_answer( - query=question, vs_path=vs_path, chat_history=history, streaming=True - ): - pass - source_documents = [ - f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" - f"""相关度:{doc.metadata['score']}\n\n""" - for inum, doc in enumerate(resp["source_documents"]) - ] + resp = {} + if os.path.exists(vs_path) and knowledge_base_id: + for resp, history in local_doc_qa.get_knowledge_based_answer( + query=question, vs_path=vs_path, chat_history=history, streaming=False + ): + pass + source_documents = [ + f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" + f"""相关度:{doc.metadata['score']}\n\n""" + for inum, doc in enumerate(resp["source_documents"]) + ] + else: + for resp_s, history in local_doc_qa.llm._call(prompt=question, history=history, streaming=False): + pass + resp["result"] = resp_s + source_documents =[("当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。")] return ChatMessage( question=question, diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py index 60abf86..e0b1ae5 100644 --- a/models/chatglm_llm.py +++ b/models/chatglm_llm.py @@ -43,7 +43,7 @@ def auto_configure_device_map(num_gpus: int) -> Dict[str, int]: class ChatGLM(LLM): max_token: int = 10000 - temperature: float = 0.01 + temperature: float = 0.8 top_p = 0.9 # history = [] tokenizer: object = None @@ -68,6 +68,7 @@ class ChatGLM(LLM): history=history[-self.history_len:-1] if self.history_len > 0 else [], max_length=self.max_token, temperature=self.temperature, + top_p=self.top_p, )): torch_gc() if inum == 0: @@ -83,6 +84,7 @@ class ChatGLM(LLM): history=history[-self.history_len:] if self.history_len > 0 else [], max_length=self.max_token, temperature=self.temperature, + top_p=self.top_p, ) torch_gc() history += [[prompt, response]] @@ -141,7 +143,7 @@ class ChatGLM(LLM): from accelerate import dispatch_model model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True, - config=model_config, **kwargs) + config=model_config, **kwargs) if LLM_LORA_PATH and use_lora: from peft import PeftModel model = PeftModel.from_pretrained(model, LLM_LORA_PATH)