使用paddleocr实现 (#342)
* jpg and png ocr * fix * write docs to tmp file * fix * [BUGFIX] local_doc_qa.py line 172: logging have no end args. (#323) * image loader * fix * fix * update api.py * update api.py * update api.py * update README.md * update api.py * add pdf_loader * fix --------- Co-authored-by: RainGather <3255329+RainGather@users.noreply.github.com> Co-authored-by: imClumsyPanda <littlepanda0716@gmail.com>
This commit is contained in:
parent
dcf6e4ffeb
commit
dd93837343
|
|
@ -207,6 +207,6 @@ Web UI 可以实现如下功能:
|
||||||
- [ ] 实现调用 API 的 Web UI Demo
|
- [ ] 实现调用 API 的 Web UI Demo
|
||||||
|
|
||||||
## 项目交流群
|
## 项目交流群
|
||||||

|

|
||||||
|
|
||||||
🎉 langchain-ChatGLM 项目交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
|
🎉 langchain-ChatGLM 项目交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
|
||||||
|
|
|
||||||
61
api.py
61
api.py
|
|
@ -22,6 +22,7 @@ from configs.model_config import (VS_ROOT_PATH, UPLOAD_ROOT_PATH, EMBEDDING_DEVI
|
||||||
|
|
||||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||||
|
|
||||||
|
|
||||||
class BaseResponse(BaseModel):
|
class BaseResponse(BaseModel):
|
||||||
code: int = pydantic.Field(200, description="HTTP status code")
|
code: int = pydantic.Field(200, description="HTTP status code")
|
||||||
msg: str = pydantic.Field("success", description="HTTP status message")
|
msg: str = pydantic.Field("success", description="HTTP status message")
|
||||||
|
|
@ -87,7 +88,7 @@ def get_vs_path(local_doc_id: str):
|
||||||
def get_file_path(local_doc_id: str, doc_name: str):
|
def get_file_path(local_doc_id: str, doc_name: str):
|
||||||
return os.path.join(UPLOAD_ROOT_PATH, local_doc_id, doc_name)
|
return os.path.join(UPLOAD_ROOT_PATH, local_doc_id, doc_name)
|
||||||
|
|
||||||
async def single_upload_file(
|
async def upload_file(
|
||||||
file: UploadFile = File(description="A single binary file"),
|
file: UploadFile = File(description="A single binary file"),
|
||||||
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
|
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
|
||||||
):
|
):
|
||||||
|
|
@ -106,21 +107,15 @@ async def single_upload_file(
|
||||||
f.write(file_content)
|
f.write(file_content)
|
||||||
|
|
||||||
vs_path = get_vs_path(knowledge_base_id)
|
vs_path = get_vs_path(knowledge_base_id)
|
||||||
if os.path.exists(vs_path):
|
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store([file_path], vs_path)
|
||||||
added_files = await local_doc_qa.add_files_to_knowledge_vector_store(vs_path, [file_path])
|
if len(loaded_files) > 0:
|
||||||
if len(added_files) > 0:
|
file_status = f"文件 {file.filename} 已上传至新的知识库,并已加载知识库,请开始提问。"
|
||||||
file_status = f"文件 {file.filename} 已上传并已加载知识库,请开始提问。"
|
return BaseResponse(code=200, msg=file_status)
|
||||||
return BaseResponse(code=200, msg=file_status)
|
|
||||||
else:
|
else:
|
||||||
vs_path, loaded_files = await local_doc_qa.init_knowledge_vector_store([file_path], vs_path)
|
file_status = "文件上传失败,请重新上传"
|
||||||
if len(loaded_files) > 0:
|
return BaseResponse(code=500, msg=file_status)
|
||||||
file_status = f"文件 {file.filename} 已上传至新的知识库,并已加载知识库,请开始提问。"
|
|
||||||
return BaseResponse(code=200, msg=file_status)
|
|
||||||
|
|
||||||
file_status = "文件上传失败,请重新上传"
|
async def upload_files(
|
||||||
return BaseResponse(code=500, msg=file_status)
|
|
||||||
|
|
||||||
async def upload_file(
|
|
||||||
files: Annotated[
|
files: Annotated[
|
||||||
List[UploadFile], File(description="Multiple files as UploadFile")
|
List[UploadFile], File(description="Multiple files as UploadFile")
|
||||||
],
|
],
|
||||||
|
|
@ -203,7 +198,7 @@ async def delete_docs(
|
||||||
return BaseResponse()
|
return BaseResponse()
|
||||||
|
|
||||||
|
|
||||||
async def chat(
|
async def local_doc_chat(
|
||||||
knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
|
knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
|
||||||
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
||||||
history: List[List[str]] = Body(
|
history: List[List[str]] = Body(
|
||||||
|
|
@ -238,7 +233,8 @@ async def chat(
|
||||||
source_documents=source_documents,
|
source_documents=source_documents,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def no_knowledge_chat(
|
|
||||||
|
async def chat(
|
||||||
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
||||||
history: List[List[str]] = Body(
|
history: List[List[str]] = Body(
|
||||||
[],
|
[],
|
||||||
|
|
@ -251,12 +247,19 @@ async def no_knowledge_chat(
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
|
for resp, history in local_doc_qa.llm._call(
|
||||||
for resp, history in local_doc_qa._call(
|
prompt=question, history=history, streaming=True
|
||||||
query=question, chat_history=history, streaming=True
|
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
return ChatMessage(
|
||||||
|
question=question,
|
||||||
|
response=resp,
|
||||||
|
history=history,
|
||||||
|
source_documents=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
|
async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id)
|
vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id)
|
||||||
|
|
@ -322,16 +325,20 @@ def main():
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
app.websocket("/chat-docs/stream-chat/{knowledge_base_id}")(stream_chat)
|
app.websocket("/local_doc_qa/stream-chat/{knowledge_base_id}")(stream_chat)
|
||||||
app.post("/chat-docs/chat", response_model=ChatMessage)(chat)
|
|
||||||
app.post("/chat-docs/chatno", response_model=ChatMessage)(no_knowledge_chat)
|
|
||||||
app.post("/chat-docs/upload", response_model=BaseResponse)(upload_file)
|
|
||||||
app.post("/chat-docs/uploadone", response_model=BaseResponse)(single_upload_file)
|
|
||||||
app.get("/chat-docs/list", response_model=ListDocsResponse)(list_docs)
|
|
||||||
app.delete("/chat-docs/delete", response_model=BaseResponse)(delete_docs)
|
|
||||||
app.get("/", response_model=BaseResponse)(document)
|
app.get("/", response_model=BaseResponse)(document)
|
||||||
|
|
||||||
|
app.post("/chat", response_model=ChatMessage)(chat)
|
||||||
|
|
||||||
|
app.post("/local_doc_qa/upload_file", response_model=BaseResponse)(upload_file)
|
||||||
|
app.post("/local_doc_qa/upload_files", response_model=BaseResponse)(upload_files)
|
||||||
|
app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage)(local_doc_chat)
|
||||||
|
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs)
|
||||||
|
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_docs)
|
||||||
|
|
||||||
|
|
||||||
local_doc_qa = LocalDocQA()
|
local_doc_qa = LocalDocQA()
|
||||||
local_doc_qa.init_cfg(
|
local_doc_qa.init_cfg(
|
||||||
llm_model=LLM_MODEL,
|
llm_model=LLM_MODEL,
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,8 @@ import numpy as np
|
||||||
from utils import torch_gc
|
from utils import torch_gc
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from pypinyin import lazy_pinyin
|
from pypinyin import lazy_pinyin
|
||||||
|
from loader import UnstructuredPaddleImageLoader
|
||||||
|
from loader import UnstructuredPaddlePDFLoader
|
||||||
|
|
||||||
DEVICE_ = EMBEDDING_DEVICE
|
DEVICE_ = EMBEDDING_DEVICE
|
||||||
DEVICE_ID = "0" if torch.cuda.is_available() else None
|
DEVICE_ID = "0" if torch.cuda.is_available() else None
|
||||||
|
|
@ -21,16 +23,31 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE):
|
||||||
loader = UnstructuredFileLoader(filepath, mode="elements")
|
loader = UnstructuredFileLoader(filepath, mode="elements")
|
||||||
docs = loader.load()
|
docs = loader.load()
|
||||||
elif filepath.lower().endswith(".pdf"):
|
elif filepath.lower().endswith(".pdf"):
|
||||||
loader = UnstructuredFileLoader(filepath, strategy="fast")
|
loader = UnstructuredPaddlePDFLoader(filepath)
|
||||||
textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size)
|
textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size)
|
||||||
docs = loader.load_and_split(textsplitter)
|
docs = loader.load_and_split(textsplitter)
|
||||||
|
elif filepath.lower().endswith(".jpg") or filepath.lower().endswith(".png"):
|
||||||
|
loader = UnstructuredPaddleImageLoader(filepath, mode="elements")
|
||||||
|
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
|
||||||
|
docs = loader.load_and_split(text_splitter=textsplitter)
|
||||||
else:
|
else:
|
||||||
loader = UnstructuredFileLoader(filepath, mode="elements")
|
loader = UnstructuredFileLoader(filepath, mode="elements")
|
||||||
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
|
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
|
||||||
docs = loader.load_and_split(text_splitter=textsplitter)
|
docs = loader.load_and_split(text_splitter=textsplitter)
|
||||||
|
write_check_file(filepath, docs)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
def write_check_file(filepath, docs):
|
||||||
|
fout = open('load_file.txt', 'a')
|
||||||
|
fout.write("filepath=%s,len=%s" % (filepath, len(docs)))
|
||||||
|
fout.write('\n')
|
||||||
|
for i in docs:
|
||||||
|
fout.write(str(i))
|
||||||
|
fout.write('\n')
|
||||||
|
fout.close()
|
||||||
|
|
||||||
|
|
||||||
def generate_prompt(related_docs: List[str], query: str,
|
def generate_prompt(related_docs: List[str], query: str,
|
||||||
prompt_template=PROMPT_TEMPLATE) -> str:
|
prompt_template=PROMPT_TEMPLATE) -> str:
|
||||||
context = "\n".join([doc.page_content for doc in related_docs])
|
context = "\n".join([doc.page_content for doc in related_docs])
|
||||||
|
|
@ -176,7 +193,7 @@ class LocalDocQA:
|
||||||
if len(failed_files) > 0:
|
if len(failed_files) > 0:
|
||||||
logger.info("以下文件未能成功加载:")
|
logger.info("以下文件未能成功加载:")
|
||||||
for file in failed_files:
|
for file in failed_files:
|
||||||
logger.info(file, end="\n")
|
logger.info(f"{file}\n")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
docs = []
|
docs = []
|
||||||
|
|
@ -212,7 +229,7 @@ class LocalDocQA:
|
||||||
if not vs_path or not one_title or not one_conent:
|
if not vs_path or not one_title or not one_conent:
|
||||||
logger.info("知识库添加错误,请确认知识库名字、标题、内容是否正确!")
|
logger.info("知识库添加错误,请确认知识库名字、标题、内容是否正确!")
|
||||||
return None, [one_title]
|
return None, [one_title]
|
||||||
docs = [Document(page_content=one_conent+"\n", metadata={"source": one_title})]
|
docs = [Document(page_content=one_conent + "\n", metadata={"source": one_title})]
|
||||||
if not one_content_segmentation:
|
if not one_content_segmentation:
|
||||||
text_splitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
|
text_splitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
|
||||||
docs = text_splitter.split_documents(docs)
|
docs = text_splitter.split_documents(docs)
|
||||||
|
|
|
||||||
Binary file not shown.
Binary file not shown.
|
Before Width: | Height: | Size: 269 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 276 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 7.9 KiB |
|
|
@ -0,0 +1,2 @@
|
||||||
|
from .image_loader import UnstructuredPaddleImageLoader
|
||||||
|
from .pdf_loader import UnstructuredPaddlePDFLoader
|
||||||
|
|
@ -0,0 +1,28 @@
|
||||||
|
"""Loader that loads image files."""
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||||
|
from paddleocr import PaddleOCR
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class UnstructuredPaddleImageLoader(UnstructuredFileLoader):
|
||||||
|
"""Loader that uses unstructured to load image files, such as PNGs and JPGs."""
|
||||||
|
|
||||||
|
def _get_elements(self) -> List:
|
||||||
|
def image_ocr_txt(filepath, dir_path="tmp_files"):
|
||||||
|
if not os.path.exists(dir_path):
|
||||||
|
os.makedirs(dir_path)
|
||||||
|
filename = os.path.split(filepath)[-1]
|
||||||
|
ocr = PaddleOCR(lang="ch", use_gpu=False, show_log=False)
|
||||||
|
result = ocr.ocr(img=filepath)
|
||||||
|
|
||||||
|
ocr_result = [i[1][0] for line in result for i in line]
|
||||||
|
txt_file_path = os.path.join(dir_path, "%s.txt" % (filename))
|
||||||
|
with open(txt_file_path, 'w', encoding='utf-8') as fout:
|
||||||
|
fout.write("\n".join(ocr_result))
|
||||||
|
return txt_file_path
|
||||||
|
|
||||||
|
txt_file_path = image_ocr_txt(self.file_path)
|
||||||
|
from unstructured.partition.text import partition_text
|
||||||
|
return partition_text(filename=txt_file_path, **self.unstructured_kwargs)
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
"""Loader that loads image files."""
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||||
|
from paddleocr import PaddleOCR
|
||||||
|
import os
|
||||||
|
import fitz
|
||||||
|
|
||||||
|
|
||||||
|
class UnstructuredPaddlePDFLoader(UnstructuredFileLoader):
|
||||||
|
"""Loader that uses unstructured to load image files, such as PNGs and JPGs."""
|
||||||
|
|
||||||
|
def _get_elements(self) -> List:
|
||||||
|
def pdf_ocr_txt(filepath, dir_path="tmp_files"):
|
||||||
|
if not os.path.exists(dir_path):
|
||||||
|
os.makedirs(dir_path)
|
||||||
|
filename = os.path.split(filepath)[-1]
|
||||||
|
ocr = PaddleOCR(lang="ch", use_gpu=False, show_log=False)
|
||||||
|
doc = fitz.open(filepath)
|
||||||
|
txt_file_path = os.path.join(dir_path, "%s.txt" % (filename))
|
||||||
|
img_name = './img/.tmp.png'
|
||||||
|
with open(txt_file_path, 'w', encoding='utf-8') as fout:
|
||||||
|
|
||||||
|
for i in range(doc.page_count):
|
||||||
|
page = doc[i]
|
||||||
|
text = page.get_text("")
|
||||||
|
fout.write(text)
|
||||||
|
fout.write("\n")
|
||||||
|
|
||||||
|
img_list = page.get_images()
|
||||||
|
for img in img_list:
|
||||||
|
pix = fitz.Pixmap(doc, img[0])
|
||||||
|
|
||||||
|
pix.save(img_name)
|
||||||
|
|
||||||
|
result = ocr.ocr(img_name)
|
||||||
|
ocr_result = [i[1][0] for line in result for i in line]
|
||||||
|
fout.write("\n".join(ocr_result))
|
||||||
|
os.remove(img_name)
|
||||||
|
return txt_file_path
|
||||||
|
|
||||||
|
txt_file_path = pdf_ocr_txt(self.file_path)
|
||||||
|
from unstructured.partition.text import partition_text
|
||||||
|
return partition_text(filename=txt_file_path, **self.unstructured_kwargs)
|
||||||
|
|
@ -1,3 +1,6 @@
|
||||||
|
pymupdf
|
||||||
|
paddlepaddle==2.4.2
|
||||||
|
paddleocr
|
||||||
langchain==0.0.146
|
langchain==0.0.146
|
||||||
transformers==4.27.1
|
transformers==4.27.1
|
||||||
unstructured[local-inference]
|
unstructured[local-inference]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
from configs.model_config import *
|
||||||
|
import nltk
|
||||||
|
|
||||||
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||||
|
|
||||||
|
filepath = "./img/test.jpg"
|
||||||
|
from loader import UnstructuredPaddleImageLoader
|
||||||
|
|
||||||
|
loader = UnstructuredPaddleImageLoader(filepath, mode="elements")
|
||||||
|
docs = loader.load()
|
||||||
|
for doc in docs:
|
||||||
|
print(doc)
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
from configs.model_config import *
|
||||||
|
import nltk
|
||||||
|
|
||||||
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||||
|
|
||||||
|
filepath = "docs/test.pdf"
|
||||||
|
from loader import UnstructuredPaddlePDFLoader
|
||||||
|
|
||||||
|
loader = UnstructuredPaddlePDFLoader(filepath, mode="elements")
|
||||||
|
docs = loader.load()
|
||||||
|
for doc in docs:
|
||||||
|
print(doc)
|
||||||
Loading…
Reference in New Issue