{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from langchain.chains.question_answering import load_qa_chain\n", "from langchain.prompts import PromptTemplate\n", "from lib.embeds import MyEmbeddings\n", "from lib.faiss import FAISSVS\n", "from langchain.chains.combine_documents.stuff import StuffDocumentsChain\n", "from langchain.chains.llm import LLMChain\n", "from lib.chatglm_llm import ChatGLM, AlpacaGLM\n", "from lib.config import *\n", "from lib.utils import get_docs\n", "\n", "\n", "class LocalDocQA:\n", " def __init__(self, \n", " embedding_model=EMBEDDING_MODEL, \n", " embedding_device=EMBEDDING_DEVICE, \n", " llm_model=LLM_MODEL, \n", " llm_device=LLM_DEVICE, \n", " llm_history_len=LLM_HISTORY_LEN, \n", " top_k=VECTOR_SEARCH_TOP_K,\n", " vs_name = VS_NAME\n", " ) -> None:\n", " \n", " torch.cuda.empty_cache()\n", " torch.cuda.empty_cache()\n", "\n", " self.embedding_model = embedding_model\n", " self.llm_model = llm_model\n", " self.embedding_device = embedding_device\n", " self.llm_device = llm_device\n", " self.llm_history_len = llm_history_len\n", " self.top_k = top_k\n", " self.vs_name = vs_name\n", "\n", " self.llm = AlpacaGLM()\n", " self.llm.load_model(model_name_or_path=llm_model_dict[llm_model], llm_device=llm_device)\n", "\n", " self.embeddings = MyEmbeddings(model_name=embedding_model_dict[embedding_model])\n", " self.load_vector_store(vs_name)\n", "\n", " self.prompt = PromptTemplate(\n", " template=PROMPT_TEMPLATE,\n", " input_variables=[\"context\", \"question\"]\n", " )\n", " self.search_params = {\n", " \"engine\": \"bing\",\n", " \"gl\": \"us\",\n", " \"hl\": \"en\",\n", " \"serpapi_api_key\": \"\"\n", " }\n", "\n", " def init_knowledge_vector_store(self, vs_name: str):\n", " \n", " docs = get_docs(KNOWLEDGE_PATH)\n", " vector_store = FAISSVS.from_documents(docs, self.embeddings)\n", " vs_path = VECTORSTORE_PATH + vs_name\n", " vector_store.save_local(vs_path)\n", "\n", " def add_knowledge_to_vector_store(self, vs_name: str):\n", " docs = get_docs(ADD_KNOWLEDGE_PATH)\n", " new_vector_store = FAISSVS.from_documents(docs, self.embeddings)\n", " vector_store = FAISSVS.load_local(VECTORSTORE_PATH + vs_name, self.embeddings) \n", " vector_store.merge_from(new_vector_store)\n", " vector_store.save_local(VECTORSTORE_PATH + vs_name)\n", "\n", " def load_vector_store(self, vs_name: str):\n", " self.vector_store = FAISSVS.load_local(VECTORSTORE_PATH + vs_name, self.embeddings)\n", "\n", " # def get_search_based_answer(self, query):\n", " \n", " # search = SerpAPIWrapper(params=self.search_params)\n", " # docs = search.run(query)\n", " # search_chain = load_qa_chain(self.llm, chain_type=\"stuff\")\n", " # answer = search_chain.run(input_documents=docs, question=query)\n", "\n", " # return answer\n", " \n", " def get_knowledge_based_answer(self, query):\n", " \n", " docs = self.vector_store.max_marginal_relevance_search(query)\n", " print(f'召回的文档和相似度分数:{docs}')\n", " # 这里 doc[1] 就是对应的score \n", " docs = [doc[0] for doc in docs]\n", " \n", " document_prompt = PromptTemplate(\n", " input_variables=[\"page_content\"], template=\"Context:\\n{page_content}\"\n", " )\n", " llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)\n", " combine_documents_chain = StuffDocumentsChain(\n", " llm_chain=llm_chain,\n", " document_variable_name=\"context\",\n", " document_prompt=document_prompt,\n", " )\n", " answer = combine_documents_chain.run(\n", " input_documents=docs, question=query\n", " )\n", "\n", " self.llm.history[-1][0] = query\n", " self.llm.history[-1][-1] = answer\n", " return answer, docs, self.llm.history" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d4342213010c4ed2ad5b04694aa436d6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/3 [00:00