update local_doc_qa.py
This commit is contained in:
parent
88941d3938
commit
8ae84c6c93
|
|
@ -9,6 +9,7 @@ import os
|
|||
from configs.model_config import *
|
||||
import datetime
|
||||
from typing import List
|
||||
from textsplitter import ChineseTextSplitter
|
||||
|
||||
# return top-k text chunk from vector store
|
||||
VECTOR_SEARCH_TOP_K = 6
|
||||
|
|
@ -17,6 +18,18 @@ VECTOR_SEARCH_TOP_K = 6
|
|||
LLM_HISTORY_LEN = 3
|
||||
|
||||
|
||||
def load_file(filepath):
|
||||
if filepath.lower().endswith(".pdf"):
|
||||
loader = UnstructuredFileLoader(filepath)
|
||||
textsplitter = ChineseTextSplitter(pdf=True)
|
||||
docs = loader.load_and_split(textsplitter)
|
||||
else:
|
||||
loader = UnstructuredFileLoader(filepath, mode="elements")
|
||||
textsplitter = ChineseTextSplitter(pdf=False)
|
||||
docs = loader.load_and_split(text_splitter=textsplitter)
|
||||
return docs
|
||||
|
||||
|
||||
class LocalDocQA:
|
||||
llm: object = None
|
||||
embeddings: object = None
|
||||
|
|
@ -48,10 +61,10 @@ class LocalDocQA:
|
|||
elif os.path.isfile(filepath):
|
||||
file = os.path.split(filepath)[-1]
|
||||
try:
|
||||
loader = UnstructuredFileLoader(filepath, mode="elements")
|
||||
docs = loader.load()
|
||||
docs = load_file(filepath)
|
||||
print(f"{file} 已成功加载")
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(f"{file} 未能成功加载")
|
||||
return None
|
||||
elif os.path.isdir(filepath):
|
||||
|
|
@ -59,25 +72,25 @@ class LocalDocQA:
|
|||
for file in os.listdir(filepath):
|
||||
fullfilepath = os.path.join(filepath, file)
|
||||
try:
|
||||
loader = UnstructuredFileLoader(fullfilepath, mode="elements")
|
||||
docs += loader.load()
|
||||
docs += load_file(fullfilepath)
|
||||
print(f"{file} 已成功加载")
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(f"{file} 未能成功加载")
|
||||
else:
|
||||
docs = []
|
||||
for file in filepath:
|
||||
try:
|
||||
loader = UnstructuredFileLoader(file, mode="elements")
|
||||
docs += loader.load()
|
||||
docs += load_file(file)
|
||||
print(f"{file} 已成功加载")
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(f"{file} 未能成功加载")
|
||||
|
||||
vector_store = FAISS.from_documents(docs, self.embeddings)
|
||||
vs_path = f"""./vector_store/{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
|
||||
vector_store.save_local(vs_path)
|
||||
return vs_path if len(docs)>0 else None
|
||||
return vs_path if len(docs) > 0 else None
|
||||
|
||||
def get_knowledge_based_answer(self,
|
||||
query,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from configs.model_config import *
|
|||
from chains.local_doc_qa import LocalDocQA
|
||||
|
||||
# return top-k text chunk from vector store
|
||||
VECTOR_SEARCH_TOP_K = 10
|
||||
VECTOR_SEARCH_TOP_K = 6
|
||||
|
||||
# LLM input history length
|
||||
LLM_HISTORY_LEN = 3
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -0,0 +1,2 @@
|
|||
|
||||
from .chinese_text_splitter import *
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
from langchain.text_splitter import CharacterTextSplitter
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
|
||||
class ChineseTextSplitter(CharacterTextSplitter):
|
||||
def __init__(self, pdf: bool = False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.pdf = pdf
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
if self.pdf:
|
||||
text = re.sub(r"\n{3,}", "\n", text)
|
||||
text = re.sub('\s', ' ', text)
|
||||
text = text.replace("\n\n", "")
|
||||
sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))') # del :;
|
||||
sent_list = []
|
||||
for ele in sent_sep_pattern.split(text):
|
||||
if sent_sep_pattern.match(ele) and sent_list:
|
||||
sent_list[-1] += ele
|
||||
elif ele:
|
||||
sent_list.append(ele)
|
||||
return sent_list
|
||||
|
||||
|
||||
Loading…
Reference in New Issue