update bing_search.py
This commit is contained in:
parent
f986b756ff
commit
9c422cc6bc
|
|
@ -1,20 +1,19 @@
|
|||
#coding=utf8
|
||||
|
||||
import os
|
||||
from langchain.utilities import BingSearchAPIWrapper
|
||||
from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY
|
||||
|
||||
|
||||
env_bing_key = os.environ.get("BING_SUBSCRIPTION_KEY")
|
||||
env_bing_url = os.environ.get("BING_SEARCH_URL")
|
||||
|
||||
|
||||
def search(text, result_len=3):
|
||||
if not (env_bing_key and env_bing_url):
|
||||
return [{"snippet":"please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
|
||||
"title": "env inof not fould", "link":"https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}]
|
||||
search = BingSearchAPIWrapper()
|
||||
def bing_search(text, result_len=3):
|
||||
if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
|
||||
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
|
||||
"title": "env inof not fould",
|
||||
"link": "https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}]
|
||||
search = BingSearchAPIWrapper(bing_subscription_key=BING_SUBSCRIPTION_KEY,
|
||||
bing_search_url=BING_SEARCH_URL)
|
||||
return search.results(text, result_len)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
r = search('python')
|
||||
r = bing_search('python')
|
||||
print(r)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from langchain.document_loaders import UnstructuredFileLoader, TextLoader
|
|||
from configs.model_config import *
|
||||
import datetime
|
||||
from textsplitter import ChineseTextSplitter
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Dict
|
||||
from langchain.docstore.document import Document
|
||||
import numpy as np
|
||||
from utils import torch_gc
|
||||
|
|
@ -18,6 +18,8 @@ from models.base import (BaseAnswer,
|
|||
from models.loader.args import parser
|
||||
from models.loader import LoaderCheckPoint
|
||||
import models.shared as shared
|
||||
from agent import bing_search
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
||||
def load_file(filepath, sentence_size=SENTENCE_SIZE):
|
||||
|
|
@ -58,8 +60,9 @@ def write_check_file(filepath, docs):
|
|||
fout.close()
|
||||
|
||||
|
||||
def generate_prompt(related_docs: List[str], query: str,
|
||||
prompt_template=PROMPT_TEMPLATE) -> str:
|
||||
def generate_prompt(related_docs: List[str],
|
||||
query: str,
|
||||
prompt_template: str = PROMPT_TEMPLATE, ) -> str:
|
||||
context = "\n".join([doc.page_content for doc in related_docs])
|
||||
prompt = prompt_template.replace("{question}", query).replace("{context}", context)
|
||||
return prompt
|
||||
|
|
@ -137,6 +140,16 @@ def similarity_search_with_score_by_vector(
|
|||
return docs
|
||||
|
||||
|
||||
def search_result2docs(search_results):
|
||||
docs = []
|
||||
for result in search_results:
|
||||
doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "",
|
||||
metadata={"source": result["link"] if "link" in result.keys() else "",
|
||||
"filename": result["title"] if "title" in result.keys() else ""})
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
|
||||
class LocalDocQA:
|
||||
llm: BaseAnswer = None
|
||||
embeddings: object = None
|
||||
|
|
@ -262,7 +275,6 @@ class LocalDocQA:
|
|||
"source_documents": related_docs_with_score}
|
||||
yield response, history
|
||||
|
||||
|
||||
# query 查询内容
|
||||
# vs_path 知识库路径
|
||||
# chunk_conent 是否启用上下文关联
|
||||
|
|
@ -288,11 +300,26 @@ class LocalDocQA:
|
|||
"source_documents": related_docs_with_score}
|
||||
return response, prompt
|
||||
|
||||
def get_search_result_based_answer(self, query, chat_history=[], streaming: bool = STREAMING):
|
||||
results = bing_search(query)
|
||||
result_docs = search_result2docs(results)
|
||||
prompt = generate_prompt(result_docs, query)
|
||||
|
||||
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
|
||||
streaming=streaming):
|
||||
resp = answer_result.llm_output["answer"]
|
||||
history = answer_result.history
|
||||
history[-1][0] = query
|
||||
response = {"query": query,
|
||||
"result": resp,
|
||||
"source_documents": result_docs}
|
||||
yield response, history
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 初始化消息
|
||||
args = None
|
||||
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'chatglm-6b', '--no-remote-model'])
|
||||
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'chatglm-6b', '--no-remote-model'])
|
||||
|
||||
args_dict = vars(args)
|
||||
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||||
|
|
@ -304,13 +331,17 @@ if __name__ == "__main__":
|
|||
query = "本项目使用的embedding模型是什么,消耗多少显存"
|
||||
vs_path = "/media/gpt4-pdf-chatbot-langchain/dev-langchain-ChatGLM/vector_store/test"
|
||||
last_print_len = 0
|
||||
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
|
||||
vs_path=vs_path,
|
||||
chat_history=[],
|
||||
streaming=True):
|
||||
logger.info(resp["result"][last_print_len:], end="", flush=True)
|
||||
# for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
|
||||
# vs_path=vs_path,
|
||||
# chat_history=[],
|
||||
# streaming=True):
|
||||
for resp, history in local_doc_qa.get_search_result_based_answer(query=query,
|
||||
chat_history=[],
|
||||
streaming=True):
|
||||
print(resp["result"][last_print_len:], end="", flush=True)
|
||||
last_print_len = len(resp["result"])
|
||||
source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
|
||||
source_text = [f"""出处 [{inum + 1}] {doc.metadata['source'] if doc.metadata['source'].startswith("http")
|
||||
else os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
|
||||
# f"""相关度:{doc.metadata['score']}\n\n"""
|
||||
for inum, doc in
|
||||
enumerate(resp["source_documents"])]
|
||||
|
|
|
|||
Loading…
Reference in New Issue