diff --git a/agent/bing_search.py b/agent/bing_search.py index d5ba766..2ff7749 100644 --- a/agent/bing_search.py +++ b/agent/bing_search.py @@ -1,20 +1,19 @@ #coding=utf8 -import os from langchain.utilities import BingSearchAPIWrapper +from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY -env_bing_key = os.environ.get("BING_SUBSCRIPTION_KEY") -env_bing_url = os.environ.get("BING_SEARCH_URL") - - -def search(text, result_len=3): - if not (env_bing_key and env_bing_url): - return [{"snippet":"please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV", - "title": "env inof not fould", "link":"https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}] - search = BingSearchAPIWrapper() +def bing_search(text, result_len=3): + if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY): + return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV", + "title": "env inof not fould", + "link": "https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}] + search = BingSearchAPIWrapper(bing_subscription_key=BING_SUBSCRIPTION_KEY, + bing_search_url=BING_SEARCH_URL) return search.results(text, result_len) if __name__ == "__main__": - r = search('python') + r = bing_search('python') + print(r) diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index bc31429..8b69e69 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -4,7 +4,7 @@ from langchain.document_loaders import UnstructuredFileLoader, TextLoader from configs.model_config import * import datetime from textsplitter import ChineseTextSplitter -from typing import List, Tuple +from typing import List, Tuple, Dict from langchain.docstore.document import Document import numpy as np from utils import torch_gc @@ -18,6 +18,8 @@ from models.base import (BaseAnswer, from models.loader.args import parser from models.loader import LoaderCheckPoint import models.shared as shared +from agent import bing_search +from langchain.docstore.document import Document def load_file(filepath, sentence_size=SENTENCE_SIZE): @@ -58,8 +60,9 @@ def write_check_file(filepath, docs): fout.close() -def generate_prompt(related_docs: List[str], query: str, - prompt_template=PROMPT_TEMPLATE) -> str: +def generate_prompt(related_docs: List[str], + query: str, + prompt_template: str = PROMPT_TEMPLATE, ) -> str: context = "\n".join([doc.page_content for doc in related_docs]) prompt = prompt_template.replace("{question}", query).replace("{context}", context) return prompt @@ -137,6 +140,16 @@ def similarity_search_with_score_by_vector( return docs +def search_result2docs(search_results): + docs = [] + for result in search_results: + doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "", + metadata={"source": result["link"] if "link" in result.keys() else "", + "filename": result["title"] if "title" in result.keys() else ""}) + docs.append(doc) + return docs + + class LocalDocQA: llm: BaseAnswer = None embeddings: object = None @@ -262,7 +275,6 @@ class LocalDocQA: "source_documents": related_docs_with_score} yield response, history - # query 查询内容 # vs_path 知识库路径 # chunk_conent 是否启用上下文关联 @@ -288,11 +300,26 @@ class LocalDocQA: "source_documents": related_docs_with_score} return response, prompt + def get_search_result_based_answer(self, query, chat_history=[], streaming: bool = STREAMING): + results = bing_search(query) + result_docs = search_result2docs(results) + prompt = generate_prompt(result_docs, query) + + for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history, + streaming=streaming): + resp = answer_result.llm_output["answer"] + history = answer_result.history + history[-1][0] = query + response = {"query": query, + "result": resp, + "source_documents": result_docs} + yield response, history + if __name__ == "__main__": # 初始化消息 args = None - args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'chatglm-6b', '--no-remote-model']) + args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'chatglm-6b', '--no-remote-model']) args_dict = vars(args) shared.loaderCheckPoint = LoaderCheckPoint(args_dict) @@ -304,13 +331,17 @@ if __name__ == "__main__": query = "本项目使用的embedding模型是什么,消耗多少显存" vs_path = "/media/gpt4-pdf-chatbot-langchain/dev-langchain-ChatGLM/vector_store/test" last_print_len = 0 - for resp, history in local_doc_qa.get_knowledge_based_answer(query=query, - vs_path=vs_path, - chat_history=[], - streaming=True): - logger.info(resp["result"][last_print_len:], end="", flush=True) + # for resp, history in local_doc_qa.get_knowledge_based_answer(query=query, + # vs_path=vs_path, + # chat_history=[], + # streaming=True): + for resp, history in local_doc_qa.get_search_result_based_answer(query=query, + chat_history=[], + streaming=True): + print(resp["result"][last_print_len:], end="", flush=True) last_print_len = len(resp["result"]) - source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" + source_text = [f"""出处 [{inum + 1}] {doc.metadata['source'] if doc.metadata['source'].startswith("http") + else os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" # f"""相关度:{doc.metadata['score']}\n\n""" for inum, doc in enumerate(resp["source_documents"])]