update bing_search.py

This commit is contained in:
imClumsyPanda 2023-05-21 22:08:38 +08:00
parent f986b756ff
commit 9c422cc6bc
2 changed files with 52 additions and 22 deletions

View File

@ -1,20 +1,19 @@
#coding=utf8 #coding=utf8
import os
from langchain.utilities import BingSearchAPIWrapper from langchain.utilities import BingSearchAPIWrapper
from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY
env_bing_key = os.environ.get("BING_SUBSCRIPTION_KEY") def bing_search(text, result_len=3):
env_bing_url = os.environ.get("BING_SEARCH_URL") if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
"title": "env inof not fould",
def search(text, result_len=3): "link": "https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}]
if not (env_bing_key and env_bing_url): search = BingSearchAPIWrapper(bing_subscription_key=BING_SUBSCRIPTION_KEY,
return [{"snippet":"please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV", bing_search_url=BING_SEARCH_URL)
"title": "env inof not fould", "link":"https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}]
search = BingSearchAPIWrapper()
return search.results(text, result_len) return search.results(text, result_len)
if __name__ == "__main__": if __name__ == "__main__":
r = search('python') r = bing_search('python')
print(r)

View File

@ -4,7 +4,7 @@ from langchain.document_loaders import UnstructuredFileLoader, TextLoader
from configs.model_config import * from configs.model_config import *
import datetime import datetime
from textsplitter import ChineseTextSplitter from textsplitter import ChineseTextSplitter
from typing import List, Tuple from typing import List, Tuple, Dict
from langchain.docstore.document import Document from langchain.docstore.document import Document
import numpy as np import numpy as np
from utils import torch_gc from utils import torch_gc
@ -18,6 +18,8 @@ from models.base import (BaseAnswer,
from models.loader.args import parser from models.loader.args import parser
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
import models.shared as shared import models.shared as shared
from agent import bing_search
from langchain.docstore.document import Document
def load_file(filepath, sentence_size=SENTENCE_SIZE): def load_file(filepath, sentence_size=SENTENCE_SIZE):
@ -58,8 +60,9 @@ def write_check_file(filepath, docs):
fout.close() fout.close()
def generate_prompt(related_docs: List[str], query: str, def generate_prompt(related_docs: List[str],
prompt_template=PROMPT_TEMPLATE) -> str: query: str,
prompt_template: str = PROMPT_TEMPLATE, ) -> str:
context = "\n".join([doc.page_content for doc in related_docs]) context = "\n".join([doc.page_content for doc in related_docs])
prompt = prompt_template.replace("{question}", query).replace("{context}", context) prompt = prompt_template.replace("{question}", query).replace("{context}", context)
return prompt return prompt
@ -137,6 +140,16 @@ def similarity_search_with_score_by_vector(
return docs return docs
def search_result2docs(search_results):
docs = []
for result in search_results:
doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "",
metadata={"source": result["link"] if "link" in result.keys() else "",
"filename": result["title"] if "title" in result.keys() else ""})
docs.append(doc)
return docs
class LocalDocQA: class LocalDocQA:
llm: BaseAnswer = None llm: BaseAnswer = None
embeddings: object = None embeddings: object = None
@ -262,7 +275,6 @@ class LocalDocQA:
"source_documents": related_docs_with_score} "source_documents": related_docs_with_score}
yield response, history yield response, history
# query 查询内容 # query 查询内容
# vs_path 知识库路径 # vs_path 知识库路径
# chunk_conent 是否启用上下文关联 # chunk_conent 是否启用上下文关联
@ -288,6 +300,21 @@ class LocalDocQA:
"source_documents": related_docs_with_score} "source_documents": related_docs_with_score}
return response, prompt return response, prompt
def get_search_result_based_answer(self, query, chat_history=[], streaming: bool = STREAMING):
results = bing_search(query)
result_docs = search_result2docs(results)
prompt = generate_prompt(result_docs, query)
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
streaming=streaming):
resp = answer_result.llm_output["answer"]
history = answer_result.history
history[-1][0] = query
response = {"query": query,
"result": resp,
"source_documents": result_docs}
yield response, history
if __name__ == "__main__": if __name__ == "__main__":
# 初始化消息 # 初始化消息
@ -304,13 +331,17 @@ if __name__ == "__main__":
query = "本项目使用的embedding模型是什么消耗多少显存" query = "本项目使用的embedding模型是什么消耗多少显存"
vs_path = "/media/gpt4-pdf-chatbot-langchain/dev-langchain-ChatGLM/vector_store/test" vs_path = "/media/gpt4-pdf-chatbot-langchain/dev-langchain-ChatGLM/vector_store/test"
last_print_len = 0 last_print_len = 0
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query, # for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
vs_path=vs_path, # vs_path=vs_path,
# chat_history=[],
# streaming=True):
for resp, history in local_doc_qa.get_search_result_based_answer(query=query,
chat_history=[], chat_history=[],
streaming=True): streaming=True):
logger.info(resp["result"][last_print_len:], end="", flush=True) print(resp["result"][last_print_len:], end="", flush=True)
last_print_len = len(resp["result"]) last_print_len = len(resp["result"])
source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}\n\n{doc.page_content}\n\n""" source_text = [f"""出处 [{inum + 1}] {doc.metadata['source'] if doc.metadata['source'].startswith("http")
else os.path.split(doc.metadata['source'])[-1]}\n\n{doc.page_content}\n\n"""
# f"""相关度:{doc.metadata['score']}\n\n""" # f"""相关度:{doc.metadata['score']}\n\n"""
for inum, doc in for inum, doc in
enumerate(resp["source_documents"])] enumerate(resp["source_documents"])]