update bing_search.py
This commit is contained in:
parent
f986b756ff
commit
9c422cc6bc
|
|
@ -1,20 +1,19 @@
|
||||||
#coding=utf8
|
#coding=utf8
|
||||||
|
|
||||||
import os
|
|
||||||
from langchain.utilities import BingSearchAPIWrapper
|
from langchain.utilities import BingSearchAPIWrapper
|
||||||
|
from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY
|
||||||
|
|
||||||
|
|
||||||
env_bing_key = os.environ.get("BING_SUBSCRIPTION_KEY")
|
def bing_search(text, result_len=3):
|
||||||
env_bing_url = os.environ.get("BING_SEARCH_URL")
|
if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
|
||||||
|
|
||||||
|
|
||||||
def search(text, result_len=3):
|
|
||||||
if not (env_bing_key and env_bing_url):
|
|
||||||
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
|
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
|
||||||
"title": "env inof not fould", "link":"https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}]
|
"title": "env inof not fould",
|
||||||
search = BingSearchAPIWrapper()
|
"link": "https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}]
|
||||||
|
search = BingSearchAPIWrapper(bing_subscription_key=BING_SUBSCRIPTION_KEY,
|
||||||
|
bing_search_url=BING_SEARCH_URL)
|
||||||
return search.results(text, result_len)
|
return search.results(text, result_len)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
r = search('python')
|
r = bing_search('python')
|
||||||
|
print(r)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from langchain.document_loaders import UnstructuredFileLoader, TextLoader
|
||||||
from configs.model_config import *
|
from configs.model_config import *
|
||||||
import datetime
|
import datetime
|
||||||
from textsplitter import ChineseTextSplitter
|
from textsplitter import ChineseTextSplitter
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple, Dict
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from utils import torch_gc
|
from utils import torch_gc
|
||||||
|
|
@ -18,6 +18,8 @@ from models.base import (BaseAnswer,
|
||||||
from models.loader.args import parser
|
from models.loader.args import parser
|
||||||
from models.loader import LoaderCheckPoint
|
from models.loader import LoaderCheckPoint
|
||||||
import models.shared as shared
|
import models.shared as shared
|
||||||
|
from agent import bing_search
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
|
|
||||||
def load_file(filepath, sentence_size=SENTENCE_SIZE):
|
def load_file(filepath, sentence_size=SENTENCE_SIZE):
|
||||||
|
|
@ -58,8 +60,9 @@ def write_check_file(filepath, docs):
|
||||||
fout.close()
|
fout.close()
|
||||||
|
|
||||||
|
|
||||||
def generate_prompt(related_docs: List[str], query: str,
|
def generate_prompt(related_docs: List[str],
|
||||||
prompt_template=PROMPT_TEMPLATE) -> str:
|
query: str,
|
||||||
|
prompt_template: str = PROMPT_TEMPLATE, ) -> str:
|
||||||
context = "\n".join([doc.page_content for doc in related_docs])
|
context = "\n".join([doc.page_content for doc in related_docs])
|
||||||
prompt = prompt_template.replace("{question}", query).replace("{context}", context)
|
prompt = prompt_template.replace("{question}", query).replace("{context}", context)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
@ -137,6 +140,16 @@ def similarity_search_with_score_by_vector(
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
def search_result2docs(search_results):
|
||||||
|
docs = []
|
||||||
|
for result in search_results:
|
||||||
|
doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "",
|
||||||
|
metadata={"source": result["link"] if "link" in result.keys() else "",
|
||||||
|
"filename": result["title"] if "title" in result.keys() else ""})
|
||||||
|
docs.append(doc)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
class LocalDocQA:
|
class LocalDocQA:
|
||||||
llm: BaseAnswer = None
|
llm: BaseAnswer = None
|
||||||
embeddings: object = None
|
embeddings: object = None
|
||||||
|
|
@ -262,7 +275,6 @@ class LocalDocQA:
|
||||||
"source_documents": related_docs_with_score}
|
"source_documents": related_docs_with_score}
|
||||||
yield response, history
|
yield response, history
|
||||||
|
|
||||||
|
|
||||||
# query 查询内容
|
# query 查询内容
|
||||||
# vs_path 知识库路径
|
# vs_path 知识库路径
|
||||||
# chunk_conent 是否启用上下文关联
|
# chunk_conent 是否启用上下文关联
|
||||||
|
|
@ -288,6 +300,21 @@ class LocalDocQA:
|
||||||
"source_documents": related_docs_with_score}
|
"source_documents": related_docs_with_score}
|
||||||
return response, prompt
|
return response, prompt
|
||||||
|
|
||||||
|
def get_search_result_based_answer(self, query, chat_history=[], streaming: bool = STREAMING):
|
||||||
|
results = bing_search(query)
|
||||||
|
result_docs = search_result2docs(results)
|
||||||
|
prompt = generate_prompt(result_docs, query)
|
||||||
|
|
||||||
|
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
|
||||||
|
streaming=streaming):
|
||||||
|
resp = answer_result.llm_output["answer"]
|
||||||
|
history = answer_result.history
|
||||||
|
history[-1][0] = query
|
||||||
|
response = {"query": query,
|
||||||
|
"result": resp,
|
||||||
|
"source_documents": result_docs}
|
||||||
|
yield response, history
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 初始化消息
|
# 初始化消息
|
||||||
|
|
@ -304,13 +331,17 @@ if __name__ == "__main__":
|
||||||
query = "本项目使用的embedding模型是什么,消耗多少显存"
|
query = "本项目使用的embedding模型是什么,消耗多少显存"
|
||||||
vs_path = "/media/gpt4-pdf-chatbot-langchain/dev-langchain-ChatGLM/vector_store/test"
|
vs_path = "/media/gpt4-pdf-chatbot-langchain/dev-langchain-ChatGLM/vector_store/test"
|
||||||
last_print_len = 0
|
last_print_len = 0
|
||||||
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
|
# for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
|
||||||
vs_path=vs_path,
|
# vs_path=vs_path,
|
||||||
|
# chat_history=[],
|
||||||
|
# streaming=True):
|
||||||
|
for resp, history in local_doc_qa.get_search_result_based_answer(query=query,
|
||||||
chat_history=[],
|
chat_history=[],
|
||||||
streaming=True):
|
streaming=True):
|
||||||
logger.info(resp["result"][last_print_len:], end="", flush=True)
|
print(resp["result"][last_print_len:], end="", flush=True)
|
||||||
last_print_len = len(resp["result"])
|
last_print_len = len(resp["result"])
|
||||||
source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
|
source_text = [f"""出处 [{inum + 1}] {doc.metadata['source'] if doc.metadata['source'].startswith("http")
|
||||||
|
else os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
|
||||||
# f"""相关度:{doc.metadata['score']}\n\n"""
|
# f"""相关度:{doc.metadata['score']}\n\n"""
|
||||||
for inum, doc in
|
for inum, doc in
|
||||||
enumerate(resp["source_documents"])]
|
enumerate(resp["source_documents"])]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue