Langchain-Chatchat/chains/local_doc_qa.py

56 lines
1.8 KiB
Python

from langchain.vectorstores import Chroma
from langchain.prompts import PromptTemplate
from langchain.chains.question_answering import load_qa_chain
from langchain import LLMChain
from langchain.llms import OpenAI
from configs.model_config import *
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.callbacks import StreamlitCallbackHandler
with open("../knowledge_base/samples/content/test.txt") as f:
state_of_the_union = f.read()
# TODO: define params
# text_splitter = MyTextSplitter()
text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=200)
texts = text_splitter.split_text(state_of_the_union)
# TODO: define params
# embeddings = MyEmbeddings()
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL],
model_kwargs={'device': EMBEDDING_DEVICE})
docsearch = Chroma.from_texts(
texts,
embeddings,
metadatas=[{"source": str(i)} for i in range(len(texts))]
).as_retriever()
# test
query = "什么是Prompt工程"
docs = docsearch.get_relevant_documents(query)
# print(docs)
# prompt_template = PROMPT_TEMPLATE
llm = OpenAI(model_name=LLM_MODEL,
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
streaming=True)
# print(PROMPT)
prompt = PromptTemplate(input_variables=["input"], template="{input}")
chain = LLMChain(prompt=prompt, llm=llm)
resp = chain("你好")
for x in resp:
print(x)
PROMPT = PromptTemplate(
template=PROMPT_TEMPLATE,
input_variables=["context", "question"]
)
chain = load_qa_chain(llm, chain_type="stuff", prompt=PROMPT)
response = chain({"input_documents": docs, "question": query}, return_only_outputs=False)
for x in response:
print(response["output_text"])