diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index 0552db3..8a45bd2 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -1,7 +1,9 @@ from langchain.chains import RetrievalQA from langchain.prompts import PromptTemplate -from langchain.embeddings.huggingface import HuggingFaceEmbeddings -from langchain.vectorstores import FAISS +# from langchain.embeddings.huggingface import HuggingFaceEmbeddings +from chains.lib.embeddings import MyEmbeddings +# from langchain.vectorstores import FAISS +from chains.lib.vectorstores import FAISSVS from langchain.document_loaders import UnstructuredFileLoader from models.chatglm_llm import ChatGLM import sentence_transformers @@ -50,7 +52,7 @@ class LocalDocQA: use_ptuning_v2=use_ptuning_v2) self.llm.history_len = llm_history_len - self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], + self.embeddings = MyEmbeddings(model_name=embedding_model_dict[embedding_model], model_kwargs={'device': embedding_device}) # self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name, # device=embedding_device) @@ -97,12 +99,12 @@ class LocalDocQA: print(f"{file} 未能成功加载") if len(docs) > 0: if vs_path and os.path.isdir(vs_path): - vector_store = FAISS.load_local(vs_path, self.embeddings) + vector_store = FAISSVS.load_local(vs_path, self.embeddings) vector_store.add_documents(docs) else: if not vs_path: vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""" - vector_store = FAISS.from_documents(docs, self.embeddings) + vector_store = FAISSVS.from_documents(docs, self.embeddings) vector_store.save_local(vs_path) return vs_path, loaded_files @@ -127,7 +129,7 @@ class LocalDocQA: input_variables=["context", "question"] ) self.llm.history = chat_history - vector_store = FAISS.load_local(vs_path, self.embeddings) + vector_store = FAISSVS.load_local(vs_path, self.embeddings) knowledge_chain = RetrievalQA.from_llm( llm=self.llm, retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}), diff --git a/chains/test.ipynb b/chains/test.ipynb new file mode 100644 index 0000000..5183fa2 --- /dev/null +++ b/chains/test.ipynb @@ -0,0 +1,195 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chains.question_answering import load_qa_chain\n", + "from langchain.prompts import PromptTemplate\n", + "from lib.embeds import MyEmbeddings\n", + "from lib.faiss import FAISSVS\n", + "from langchain.chains.combine_documents.stuff import StuffDocumentsChain\n", + "from langchain.chains.llm import LLMChain\n", + "from lib.chatglm_llm import ChatGLM, AlpacaGLM\n", + "from lib.config import *\n", + "from lib.utils import get_docs\n", + "\n", + "\n", + "class LocalDocQA:\n", + " def __init__(self, \n", + " embedding_model=EMBEDDING_MODEL, \n", + " embedding_device=EMBEDDING_DEVICE, \n", + " llm_model=LLM_MODEL, \n", + " llm_device=LLM_DEVICE, \n", + " llm_history_len=LLM_HISTORY_LEN, \n", + " top_k=VECTOR_SEARCH_TOP_K,\n", + " vs_name = VS_NAME\n", + " ) -> None:\n", + " \n", + " torch.cuda.empty_cache()\n", + " torch.cuda.empty_cache()\n", + "\n", + " self.embedding_model = embedding_model\n", + " self.llm_model = llm_model\n", + " self.embedding_device = embedding_device\n", + " self.llm_device = llm_device\n", + " self.llm_history_len = llm_history_len\n", + " self.top_k = top_k\n", + " self.vs_name = vs_name\n", + "\n", + " self.llm = AlpacaGLM()\n", + " self.llm.load_model(model_name_or_path=llm_model_dict[llm_model], llm_device=llm_device)\n", + "\n", + " self.embeddings = MyEmbeddings(model_name=embedding_model_dict[embedding_model])\n", + " self.load_vector_store(vs_name)\n", + "\n", + " self.prompt = PromptTemplate(\n", + " template=PROMPT_TEMPLATE,\n", + " input_variables=[\"context\", \"question\"]\n", + " )\n", + " self.search_params = {\n", + " \"engine\": \"bing\",\n", + " \"gl\": \"us\",\n", + " \"hl\": \"en\",\n", + " \"serpapi_api_key\": \"\"\n", + " }\n", + "\n", + " def init_knowledge_vector_store(self, vs_name: str):\n", + " \n", + " docs = get_docs(KNOWLEDGE_PATH)\n", + " vector_store = FAISSVS.from_documents(docs, self.embeddings)\n", + " vs_path = VECTORSTORE_PATH + vs_name\n", + " vector_store.save_local(vs_path)\n", + "\n", + " def add_knowledge_to_vector_store(self, vs_name: str):\n", + " docs = get_docs(ADD_KNOWLEDGE_PATH)\n", + " new_vector_store = FAISSVS.from_documents(docs, self.embeddings)\n", + " vector_store = FAISSVS.load_local(VECTORSTORE_PATH + vs_name, self.embeddings) \n", + " vector_store.merge_from(new_vector_store)\n", + " vector_store.save_local(VECTORSTORE_PATH + vs_name)\n", + "\n", + " def load_vector_store(self, vs_name: str):\n", + " self.vector_store = FAISSVS.load_local(VECTORSTORE_PATH + vs_name, self.embeddings)\n", + "\n", + " # def get_search_based_answer(self, query):\n", + " \n", + " # search = SerpAPIWrapper(params=self.search_params)\n", + " # docs = search.run(query)\n", + " # search_chain = load_qa_chain(self.llm, chain_type=\"stuff\")\n", + " # answer = search_chain.run(input_documents=docs, question=query)\n", + "\n", + " # return answer\n", + " \n", + " def get_knowledge_based_answer(self, query):\n", + " \n", + " docs = self.vector_store.max_marginal_relevance_search(query)\n", + " print(f'召回的文档和相似度分数:{docs}')\n", + " # 这里 doc[1] 就是对应的score \n", + " docs = [doc[0] for doc in docs]\n", + " \n", + " document_prompt = PromptTemplate(\n", + " input_variables=[\"page_content\"], template=\"Context:\\n{page_content}\"\n", + " )\n", + " llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)\n", + " combine_documents_chain = StuffDocumentsChain(\n", + " llm_chain=llm_chain,\n", + " document_variable_name=\"context\",\n", + " document_prompt=document_prompt,\n", + " )\n", + " answer = combine_documents_chain.run(\n", + " input_documents=docs, question=query\n", + " )\n", + "\n", + " self.llm.history[-1][0] = query\n", + " self.llm.history[-1][-1] = answer\n", + " return answer, docs, self.llm.history" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d4342213010c4ed2ad5b04694aa436d6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/3 [00:00