Langchain-Chatchat/chains/test.ipynb

196 lines
8.8 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains.question_answering import load_qa_chain\n",
"from langchain.prompts import PromptTemplate\n",
"from lib.embeds import MyEmbeddings\n",
"from lib.faiss import FAISSVS\n",
"from langchain.chains.combine_documents.stuff import StuffDocumentsChain\n",
"from langchain.chains.llm import LLMChain\n",
"from lib.chatglm_llm import ChatGLM, AlpacaGLM\n",
"from lib.config import *\n",
"from lib.utils import get_docs\n",
"\n",
"\n",
"class LocalDocQA:\n",
" def __init__(self, \n",
" embedding_model=EMBEDDING_MODEL, \n",
" embedding_device=EMBEDDING_DEVICE, \n",
" llm_model=LLM_MODEL, \n",
" llm_device=LLM_DEVICE, \n",
" llm_history_len=LLM_HISTORY_LEN, \n",
" top_k=VECTOR_SEARCH_TOP_K,\n",
" vs_name = VS_NAME\n",
" ) -> None:\n",
" \n",
" torch.cuda.empty_cache()\n",
" torch.cuda.empty_cache()\n",
"\n",
" self.embedding_model = embedding_model\n",
" self.llm_model = llm_model\n",
" self.embedding_device = embedding_device\n",
" self.llm_device = llm_device\n",
" self.llm_history_len = llm_history_len\n",
" self.top_k = top_k\n",
" self.vs_name = vs_name\n",
"\n",
" self.llm = AlpacaGLM()\n",
" self.llm.load_model(model_name_or_path=llm_model_dict[llm_model], llm_device=llm_device)\n",
"\n",
" self.embeddings = MyEmbeddings(model_name=embedding_model_dict[embedding_model])\n",
" self.load_vector_store(vs_name)\n",
"\n",
" self.prompt = PromptTemplate(\n",
" template=PROMPT_TEMPLATE,\n",
" input_variables=[\"context\", \"question\"]\n",
" )\n",
" self.search_params = {\n",
" \"engine\": \"bing\",\n",
" \"gl\": \"us\",\n",
" \"hl\": \"en\",\n",
" \"serpapi_api_key\": \"\"\n",
" }\n",
"\n",
" def init_knowledge_vector_store(self, vs_name: str):\n",
" \n",
" docs = get_docs(KNOWLEDGE_PATH)\n",
" vector_store = FAISSVS.from_documents(docs, self.embeddings)\n",
" vs_path = VECTORSTORE_PATH + vs_name\n",
" vector_store.save_local(vs_path)\n",
"\n",
" def add_knowledge_to_vector_store(self, vs_name: str):\n",
" docs = get_docs(ADD_KNOWLEDGE_PATH)\n",
" new_vector_store = FAISSVS.from_documents(docs, self.embeddings)\n",
" vector_store = FAISSVS.load_local(VECTORSTORE_PATH + vs_name, self.embeddings) \n",
" vector_store.merge_from(new_vector_store)\n",
" vector_store.save_local(VECTORSTORE_PATH + vs_name)\n",
"\n",
" def load_vector_store(self, vs_name: str):\n",
" self.vector_store = FAISSVS.load_local(VECTORSTORE_PATH + vs_name, self.embeddings)\n",
"\n",
" # def get_search_based_answer(self, query):\n",
" \n",
" # search = SerpAPIWrapper(params=self.search_params)\n",
" # docs = search.run(query)\n",
" # search_chain = load_qa_chain(self.llm, chain_type=\"stuff\")\n",
" # answer = search_chain.run(input_documents=docs, question=query)\n",
"\n",
" # return answer\n",
" \n",
" def get_knowledge_based_answer(self, query):\n",
" \n",
" docs = self.vector_store.max_marginal_relevance_search(query)\n",
" print(f'召回的文档和相似度分数:{docs}')\n",
" # 这里 doc[1] 就是对应的score \n",
" docs = [doc[0] for doc in docs]\n",
" \n",
" document_prompt = PromptTemplate(\n",
" input_variables=[\"page_content\"], template=\"Context:\\n{page_content}\"\n",
" )\n",
" llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)\n",
" combine_documents_chain = StuffDocumentsChain(\n",
" llm_chain=llm_chain,\n",
" document_variable_name=\"context\",\n",
" document_prompt=document_prompt,\n",
" )\n",
" answer = combine_documents_chain.run(\n",
" input_documents=docs, question=query\n",
" )\n",
"\n",
" self.llm.history[-1][0] = query\n",
" self.llm.history[-1][-1] = answer\n",
" return answer, docs, self.llm.history"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d4342213010c4ed2ad5b04694aa436d6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"qa = LocalDocQA()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"召回的文档和相似度分数:[(Document(page_content='****** LOGI APT Group Intelligence Research Yearbook APT Knowledge Graph APT组织情报 研究年鉴', metadata={'source': './KnowledgeStore/APT group Intelligence Research handbook-2022.pdf', 'page': 0}), 0.45381865), (Document(page_content='9 MANDIANT APT42: Crooked Charms, Cons and Compromises FIGURE 8. APT42 impersonates University of Oxford vaccinologist. APT42 Credential harvesting page masquerading as a Yahoo login portal.', metadata={'source': './KnowledgeStore/APT42_Crooked_Charms_Cons_and_Compromises.pdf', 'page': 8}), 0.4535672), (Document(page_content='The origin story of APT32 macros T H R E A T R E S E A R C H R E P O R T R u n n i n g t h r o u g h a l l t h e S U O f i l e s t r u c t u r e s i s l a b o r i o u s a n d d i d n t y i e l d m u c h m o r e t h a n a s t r i n g d u m p w o u l d h a v e d o n e a n y w a y . W e f i n d p a t h s t o s o u r c e c o d e f i l e s , p r o j e c t n a m e s , e t c . W e c a n i n f e r f r o m t h e m y r i a d o f r e f e r e n c e s i n XmlPackageOptions , O u t l i n i n g S t a t e D i r , e t c . , t h a t t h e HtaDotnet a n d ShellcodeLoader s o l u t i o n s w e r e o r i g i n a l l y u n d e r t h e f o l d e r p a t h G:\\\\WebBuilder\\\\Gift_HtaDotnet\\\\ . T h i s i s a l s o s u p p o r t e d b y t h e P D B p a t h s o f o l d e r b u i l t b i n a r i e s w i t h i n t h e b r o a d e r S t r i k e S u i t G i f t p a c k a g e . F r o m l o o k i n g a t D e b u g g e r W a t c h e s v a l u e s i n o t h e r p r o j e c t s , w e c a n s e e t h a t t h e m a l w a r e d e v e l o p e r w a s a c t i v e l y d e b u g g i n g t h e h i s t o r i c a l p r o g r a m s . S U O f i l e D e b u g g e r W a t c h e s WebBuilder/HtaDotNet/HtaDotnet.v11.suo result WebBuilder/ShellcodeLoader/.vs/L/v14/.suo (char)77 WebBuilder/ShellcodeLoader/L.suo (char)77 3 4 04/2022', metadata={'source': './KnowledgeStore/Stairwell-threat-report-The-origin-of-APT32-macros.pdf', 'page': 33}), 0.38091612), (Document(page_content='2 APTs and COVID-19: How advanced persistent threats use the coronavirus as a lureTable of contents Introduction: APT groups using COVID-19 .........................................................', metadata={'source': './KnowledgeStore/200407-MWB-COVID-White-Paper_Final.pdf', 'page': 1}), 0.44476452)]\n"
]
}
],
"source": [
"query = r\"\"\"make a brief introduction of APT?\"\"\"\n",
"ans, docs, _ = qa.get_knowledge_based_answer(query)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'\\nAnswer: APT stands for Advanced Persistent Threat, which is a type of malicious cyberattack that is carried out by a sophisticated hacker group or state-sponsored organization. APTs are designed to remain undetected for a long period of time and are often used to steal sensitive data or disrupt critical infrastructure.'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ans"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "chatgpt",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}