update local_doc_qa.py
This commit is contained in:
parent
88941d3938
commit
8ae84c6c93
|
|
@ -9,6 +9,7 @@ import os
|
||||||
from configs.model_config import *
|
from configs.model_config import *
|
||||||
import datetime
|
import datetime
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from textsplitter import ChineseTextSplitter
|
||||||
|
|
||||||
# return top-k text chunk from vector store
|
# return top-k text chunk from vector store
|
||||||
VECTOR_SEARCH_TOP_K = 6
|
VECTOR_SEARCH_TOP_K = 6
|
||||||
|
|
@ -17,6 +18,18 @@ VECTOR_SEARCH_TOP_K = 6
|
||||||
LLM_HISTORY_LEN = 3
|
LLM_HISTORY_LEN = 3
|
||||||
|
|
||||||
|
|
||||||
|
def load_file(filepath):
|
||||||
|
if filepath.lower().endswith(".pdf"):
|
||||||
|
loader = UnstructuredFileLoader(filepath)
|
||||||
|
textsplitter = ChineseTextSplitter(pdf=True)
|
||||||
|
docs = loader.load_and_split(textsplitter)
|
||||||
|
else:
|
||||||
|
loader = UnstructuredFileLoader(filepath, mode="elements")
|
||||||
|
textsplitter = ChineseTextSplitter(pdf=False)
|
||||||
|
docs = loader.load_and_split(text_splitter=textsplitter)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
class LocalDocQA:
|
class LocalDocQA:
|
||||||
llm: object = None
|
llm: object = None
|
||||||
embeddings: object = None
|
embeddings: object = None
|
||||||
|
|
@ -48,10 +61,10 @@ class LocalDocQA:
|
||||||
elif os.path.isfile(filepath):
|
elif os.path.isfile(filepath):
|
||||||
file = os.path.split(filepath)[-1]
|
file = os.path.split(filepath)[-1]
|
||||||
try:
|
try:
|
||||||
loader = UnstructuredFileLoader(filepath, mode="elements")
|
docs = load_file(filepath)
|
||||||
docs = loader.load()
|
|
||||||
print(f"{file} 已成功加载")
|
print(f"{file} 已成功加载")
|
||||||
except:
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
print(f"{file} 未能成功加载")
|
print(f"{file} 未能成功加载")
|
||||||
return None
|
return None
|
||||||
elif os.path.isdir(filepath):
|
elif os.path.isdir(filepath):
|
||||||
|
|
@ -59,25 +72,25 @@ class LocalDocQA:
|
||||||
for file in os.listdir(filepath):
|
for file in os.listdir(filepath):
|
||||||
fullfilepath = os.path.join(filepath, file)
|
fullfilepath = os.path.join(filepath, file)
|
||||||
try:
|
try:
|
||||||
loader = UnstructuredFileLoader(fullfilepath, mode="elements")
|
docs += load_file(fullfilepath)
|
||||||
docs += loader.load()
|
|
||||||
print(f"{file} 已成功加载")
|
print(f"{file} 已成功加载")
|
||||||
except:
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
print(f"{file} 未能成功加载")
|
print(f"{file} 未能成功加载")
|
||||||
else:
|
else:
|
||||||
docs = []
|
docs = []
|
||||||
for file in filepath:
|
for file in filepath:
|
||||||
try:
|
try:
|
||||||
loader = UnstructuredFileLoader(file, mode="elements")
|
docs += load_file(file)
|
||||||
docs += loader.load()
|
|
||||||
print(f"{file} 已成功加载")
|
print(f"{file} 已成功加载")
|
||||||
except:
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
print(f"{file} 未能成功加载")
|
print(f"{file} 未能成功加载")
|
||||||
|
|
||||||
vector_store = FAISS.from_documents(docs, self.embeddings)
|
vector_store = FAISS.from_documents(docs, self.embeddings)
|
||||||
vs_path = f"""./vector_store/{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
|
vs_path = f"""./vector_store/{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
|
||||||
vector_store.save_local(vs_path)
|
vector_store.save_local(vs_path)
|
||||||
return vs_path if len(docs)>0 else None
|
return vs_path if len(docs) > 0 else None
|
||||||
|
|
||||||
def get_knowledge_based_answer(self,
|
def get_knowledge_based_answer(self,
|
||||||
query,
|
query,
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ from configs.model_config import *
|
||||||
from chains.local_doc_qa import LocalDocQA
|
from chains.local_doc_qa import LocalDocQA
|
||||||
|
|
||||||
# return top-k text chunk from vector store
|
# return top-k text chunk from vector store
|
||||||
VECTOR_SEARCH_TOP_K = 10
|
VECTOR_SEARCH_TOP_K = 6
|
||||||
|
|
||||||
# LLM input history length
|
# LLM input history length
|
||||||
LLM_HISTORY_LEN = 3
|
LLM_HISTORY_LEN = 3
|
||||||
|
|
|
||||||
Binary file not shown.
|
|
@ -0,0 +1,2 @@
|
||||||
|
|
||||||
|
from .chinese_text_splitter import *
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
from langchain.text_splitter import CharacterTextSplitter
|
||||||
|
import re
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
class ChineseTextSplitter(CharacterTextSplitter):
|
||||||
|
def __init__(self, pdf: bool = False, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.pdf = pdf
|
||||||
|
|
||||||
|
def split_text(self, text: str) -> List[str]:
|
||||||
|
if self.pdf:
|
||||||
|
text = re.sub(r"\n{3,}", "\n", text)
|
||||||
|
text = re.sub('\s', ' ', text)
|
||||||
|
text = text.replace("\n\n", "")
|
||||||
|
sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))') # del :;
|
||||||
|
sent_list = []
|
||||||
|
for ele in sent_sep_pattern.split(text):
|
||||||
|
if sent_sep_pattern.match(ele) and sent_list:
|
||||||
|
sent_list[-1] += ele
|
||||||
|
elif ele:
|
||||||
|
sent_list.append(ele)
|
||||||
|
return sent_list
|
||||||
|
|
||||||
|
|
||||||
Loading…
Reference in New Issue