feat: 添加mmr相似度搜索,支持返回相似度分数
This commit is contained in:
parent
b3f83060c8
commit
059fe82887
|
|
@ -1,7 +1,9 @@
|
||||||
from langchain.chains import RetrievalQA
|
from langchain.chains import RetrievalQA
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
# from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||||
from langchain.vectorstores import FAISS
|
from chains.lib.embeddings import MyEmbeddings
|
||||||
|
# from langchain.vectorstores import FAISS
|
||||||
|
from chains.lib.vectorstores import FAISSVS
|
||||||
from langchain.document_loaders import UnstructuredFileLoader
|
from langchain.document_loaders import UnstructuredFileLoader
|
||||||
from models.chatglm_llm import ChatGLM
|
from models.chatglm_llm import ChatGLM
|
||||||
import sentence_transformers
|
import sentence_transformers
|
||||||
|
|
@ -50,7 +52,7 @@ class LocalDocQA:
|
||||||
use_ptuning_v2=use_ptuning_v2)
|
use_ptuning_v2=use_ptuning_v2)
|
||||||
self.llm.history_len = llm_history_len
|
self.llm.history_len = llm_history_len
|
||||||
|
|
||||||
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
|
self.embeddings = MyEmbeddings(model_name=embedding_model_dict[embedding_model],
|
||||||
model_kwargs={'device': embedding_device})
|
model_kwargs={'device': embedding_device})
|
||||||
# self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
|
# self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
|
||||||
# device=embedding_device)
|
# device=embedding_device)
|
||||||
|
|
@ -97,12 +99,12 @@ class LocalDocQA:
|
||||||
print(f"{file} 未能成功加载")
|
print(f"{file} 未能成功加载")
|
||||||
if len(docs) > 0:
|
if len(docs) > 0:
|
||||||
if vs_path and os.path.isdir(vs_path):
|
if vs_path and os.path.isdir(vs_path):
|
||||||
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
vector_store = FAISSVS.load_local(vs_path, self.embeddings)
|
||||||
vector_store.add_documents(docs)
|
vector_store.add_documents(docs)
|
||||||
else:
|
else:
|
||||||
if not vs_path:
|
if not vs_path:
|
||||||
vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
|
vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
|
||||||
vector_store = FAISS.from_documents(docs, self.embeddings)
|
vector_store = FAISSVS.from_documents(docs, self.embeddings)
|
||||||
|
|
||||||
vector_store.save_local(vs_path)
|
vector_store.save_local(vs_path)
|
||||||
return vs_path, loaded_files
|
return vs_path, loaded_files
|
||||||
|
|
@ -127,7 +129,7 @@ class LocalDocQA:
|
||||||
input_variables=["context", "question"]
|
input_variables=["context", "question"]
|
||||||
)
|
)
|
||||||
self.llm.history = chat_history
|
self.llm.history = chat_history
|
||||||
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
vector_store = FAISSVS.load_local(vs_path, self.embeddings)
|
||||||
knowledge_chain = RetrievalQA.from_llm(
|
knowledge_chain = RetrievalQA.from_llm(
|
||||||
llm=self.llm,
|
llm=self.llm,
|
||||||
retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),
|
retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,195 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chains.question_answering import load_qa_chain\n",
|
||||||
|
"from langchain.prompts import PromptTemplate\n",
|
||||||
|
"from lib.embeds import MyEmbeddings\n",
|
||||||
|
"from lib.faiss import FAISSVS\n",
|
||||||
|
"from langchain.chains.combine_documents.stuff import StuffDocumentsChain\n",
|
||||||
|
"from langchain.chains.llm import LLMChain\n",
|
||||||
|
"from lib.chatglm_llm import ChatGLM, AlpacaGLM\n",
|
||||||
|
"from lib.config import *\n",
|
||||||
|
"from lib.utils import get_docs\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"class LocalDocQA:\n",
|
||||||
|
" def __init__(self, \n",
|
||||||
|
" embedding_model=EMBEDDING_MODEL, \n",
|
||||||
|
" embedding_device=EMBEDDING_DEVICE, \n",
|
||||||
|
" llm_model=LLM_MODEL, \n",
|
||||||
|
" llm_device=LLM_DEVICE, \n",
|
||||||
|
" llm_history_len=LLM_HISTORY_LEN, \n",
|
||||||
|
" top_k=VECTOR_SEARCH_TOP_K,\n",
|
||||||
|
" vs_name = VS_NAME\n",
|
||||||
|
" ) -> None:\n",
|
||||||
|
" \n",
|
||||||
|
" torch.cuda.empty_cache()\n",
|
||||||
|
" torch.cuda.empty_cache()\n",
|
||||||
|
"\n",
|
||||||
|
" self.embedding_model = embedding_model\n",
|
||||||
|
" self.llm_model = llm_model\n",
|
||||||
|
" self.embedding_device = embedding_device\n",
|
||||||
|
" self.llm_device = llm_device\n",
|
||||||
|
" self.llm_history_len = llm_history_len\n",
|
||||||
|
" self.top_k = top_k\n",
|
||||||
|
" self.vs_name = vs_name\n",
|
||||||
|
"\n",
|
||||||
|
" self.llm = AlpacaGLM()\n",
|
||||||
|
" self.llm.load_model(model_name_or_path=llm_model_dict[llm_model], llm_device=llm_device)\n",
|
||||||
|
"\n",
|
||||||
|
" self.embeddings = MyEmbeddings(model_name=embedding_model_dict[embedding_model])\n",
|
||||||
|
" self.load_vector_store(vs_name)\n",
|
||||||
|
"\n",
|
||||||
|
" self.prompt = PromptTemplate(\n",
|
||||||
|
" template=PROMPT_TEMPLATE,\n",
|
||||||
|
" input_variables=[\"context\", \"question\"]\n",
|
||||||
|
" )\n",
|
||||||
|
" self.search_params = {\n",
|
||||||
|
" \"engine\": \"bing\",\n",
|
||||||
|
" \"gl\": \"us\",\n",
|
||||||
|
" \"hl\": \"en\",\n",
|
||||||
|
" \"serpapi_api_key\": \"\"\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" def init_knowledge_vector_store(self, vs_name: str):\n",
|
||||||
|
" \n",
|
||||||
|
" docs = get_docs(KNOWLEDGE_PATH)\n",
|
||||||
|
" vector_store = FAISSVS.from_documents(docs, self.embeddings)\n",
|
||||||
|
" vs_path = VECTORSTORE_PATH + vs_name\n",
|
||||||
|
" vector_store.save_local(vs_path)\n",
|
||||||
|
"\n",
|
||||||
|
" def add_knowledge_to_vector_store(self, vs_name: str):\n",
|
||||||
|
" docs = get_docs(ADD_KNOWLEDGE_PATH)\n",
|
||||||
|
" new_vector_store = FAISSVS.from_documents(docs, self.embeddings)\n",
|
||||||
|
" vector_store = FAISSVS.load_local(VECTORSTORE_PATH + vs_name, self.embeddings) \n",
|
||||||
|
" vector_store.merge_from(new_vector_store)\n",
|
||||||
|
" vector_store.save_local(VECTORSTORE_PATH + vs_name)\n",
|
||||||
|
"\n",
|
||||||
|
" def load_vector_store(self, vs_name: str):\n",
|
||||||
|
" self.vector_store = FAISSVS.load_local(VECTORSTORE_PATH + vs_name, self.embeddings)\n",
|
||||||
|
"\n",
|
||||||
|
" # def get_search_based_answer(self, query):\n",
|
||||||
|
" \n",
|
||||||
|
" # search = SerpAPIWrapper(params=self.search_params)\n",
|
||||||
|
" # docs = search.run(query)\n",
|
||||||
|
" # search_chain = load_qa_chain(self.llm, chain_type=\"stuff\")\n",
|
||||||
|
" # answer = search_chain.run(input_documents=docs, question=query)\n",
|
||||||
|
"\n",
|
||||||
|
" # return answer\n",
|
||||||
|
" \n",
|
||||||
|
" def get_knowledge_based_answer(self, query):\n",
|
||||||
|
" \n",
|
||||||
|
" docs = self.vector_store.max_marginal_relevance_search(query)\n",
|
||||||
|
" print(f'召回的文档和相似度分数:{docs}')\n",
|
||||||
|
" # 这里 doc[1] 就是对应的score \n",
|
||||||
|
" docs = [doc[0] for doc in docs]\n",
|
||||||
|
" \n",
|
||||||
|
" document_prompt = PromptTemplate(\n",
|
||||||
|
" input_variables=[\"page_content\"], template=\"Context:\\n{page_content}\"\n",
|
||||||
|
" )\n",
|
||||||
|
" llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)\n",
|
||||||
|
" combine_documents_chain = StuffDocumentsChain(\n",
|
||||||
|
" llm_chain=llm_chain,\n",
|
||||||
|
" document_variable_name=\"context\",\n",
|
||||||
|
" document_prompt=document_prompt,\n",
|
||||||
|
" )\n",
|
||||||
|
" answer = combine_documents_chain.run(\n",
|
||||||
|
" input_documents=docs, question=query\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" self.llm.history[-1][0] = query\n",
|
||||||
|
" self.llm.history[-1][-1] = answer\n",
|
||||||
|
" return answer, docs, self.llm.history"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "d4342213010c4ed2ad5b04694aa436d6",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"qa = LocalDocQA()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"召回的文档和相似度分数:[(Document(page_content='****** LOGI APT Group Intelligence Research Yearbook APT Knowledge Graph APT组织情报 研究年鉴', metadata={'source': './KnowledgeStore/APT group Intelligence Research handbook-2022.pdf', 'page': 0}), 0.45381865), (Document(page_content='9 MANDIANT APT42: Crooked Charms, Cons and Compromises FIGURE 8. APT42 impersonates University of Oxford vaccinologist. APT42 Credential harvesting page masquerading as a Yahoo login portal.', metadata={'source': './KnowledgeStore/APT42_Crooked_Charms_Cons_and_Compromises.pdf', 'page': 8}), 0.4535672), (Document(page_content='The origin story of APT32 macros T H R E A T R E S E A R C H R E P O R T R u n n i n g t h r o u g h a l l t h e S U O f i l e s t r u c t u r e s i s l a b o r i o u s a n d d i d n ’ t y i e l d m u c h m o r e t h a n a s t r i n g d u m p w o u l d h a v e d o n e a n y w a y . W e f i n d p a t h s t o s o u r c e c o d e f i l e s , p r o j e c t n a m e s , e t c . W e c a n i n f e r f r o m t h e m y r i a d o f r e f e r e n c e s i n XmlPackageOptions , O u t l i n i n g S t a t e D i r , e t c . , t h a t t h e HtaDotnet a n d ShellcodeLoader s o l u t i o n s w e r e o r i g i n a l l y u n d e r t h e f o l d e r p a t h G:\\\\WebBuilder\\\\Gift_HtaDotnet\\\\ . T h i s i s a l s o s u p p o r t e d b y t h e P D B p a t h s o f o l d e r b u i l t b i n a r i e s w i t h i n t h e b r o a d e r S t r i k e S u i t G i f t p a c k a g e . F r o m l o o k i n g a t D e b u g g e r W a t c h e s v a l u e s i n o t h e r p r o j e c t s , w e c a n s e e t h a t t h e m a l w a r e d e v e l o p e r w a s a c t i v e l y d e b u g g i n g t h e h i s t o r i c a l p r o g r a m s . S U O f i l e D e b u g g e r W a t c h e s WebBuilder/HtaDotNet/HtaDotnet.v11.suo result WebBuilder/ShellcodeLoader/.vs/L/v14/.suo (char)77 WebBuilder/ShellcodeLoader/L.suo (char)77 3 4 04/2022', metadata={'source': './KnowledgeStore/Stairwell-threat-report-The-origin-of-APT32-macros.pdf', 'page': 33}), 0.38091612), (Document(page_content='2 APTs and COVID-19: How advanced persistent threats use the coronavirus as a lureTable of contents Introduction: APT groups using COVID-19 .........................................................', metadata={'source': './KnowledgeStore/200407-MWB-COVID-White-Paper_Final.pdf', 'page': 1}), 0.44476452)]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"query = r\"\"\"make a brief introduction of APT?\"\"\"\n",
|
||||||
|
"ans, docs, _ = qa.get_knowledge_based_answer(query)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'\\nAnswer: APT stands for Advanced Persistent Threat, which is a type of malicious cyberattack that is carried out by a sophisticated hacker group or state-sponsored organization. APTs are designed to remain undetected for a long period of time and are often used to steal sensitive data or disrupt critical infrastructure.'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"ans"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "chatgpt",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.9"
|
||||||
|
},
|
||||||
|
"orig_nbformat": 4
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue