diff --git a/README.md b/README.md index b384bd4..9a7afb4 100644 --- a/README.md +++ b/README.md @@ -207,6 +207,6 @@ Web UI 可以实现如下功能: - [ ] 实现调用 API 的 Web UI Demo ## 项目交流群 -![二维码](img/qr_code_16.jpg) +![二维码](img/qr_code_17.jpg) 🎉 langchain-ChatGLM 项目交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。 diff --git a/api.py b/api.py index be6ea52..9c93c0e 100644 --- a/api.py +++ b/api.py @@ -22,6 +22,7 @@ from configs.model_config import (VS_ROOT_PATH, UPLOAD_ROOT_PATH, EMBEDDING_DEVI nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path + class BaseResponse(BaseModel): code: int = pydantic.Field(200, description="HTTP status code") msg: str = pydantic.Field("success", description="HTTP status message") @@ -87,7 +88,7 @@ def get_vs_path(local_doc_id: str): def get_file_path(local_doc_id: str, doc_name: str): return os.path.join(UPLOAD_ROOT_PATH, local_doc_id, doc_name) -async def single_upload_file( +async def upload_file( file: UploadFile = File(description="A single binary file"), knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"), ): @@ -106,21 +107,15 @@ async def single_upload_file( f.write(file_content) vs_path = get_vs_path(knowledge_base_id) - if os.path.exists(vs_path): - added_files = await local_doc_qa.add_files_to_knowledge_vector_store(vs_path, [file_path]) - if len(added_files) > 0: - file_status = f"文件 {file.filename} 已上传并已加载知识库,请开始提问。" - return BaseResponse(code=200, msg=file_status) + vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store([file_path], vs_path) + if len(loaded_files) > 0: + file_status = f"文件 {file.filename} 已上传至新的知识库,并已加载知识库,请开始提问。" + return BaseResponse(code=200, msg=file_status) else: - vs_path, loaded_files = await local_doc_qa.init_knowledge_vector_store([file_path], vs_path) - if len(loaded_files) > 0: - file_status = f"文件 {file.filename} 已上传至新的知识库,并已加载知识库,请开始提问。" - return BaseResponse(code=200, msg=file_status) + file_status = "文件上传失败,请重新上传" + return BaseResponse(code=500, msg=file_status) - file_status = "文件上传失败,请重新上传" - return BaseResponse(code=500, msg=file_status) - -async def upload_file( +async def upload_files( files: Annotated[ List[UploadFile], File(description="Multiple files as UploadFile") ], @@ -203,7 +198,7 @@ async def delete_docs( return BaseResponse() -async def chat( +async def local_doc_chat( knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"), question: str = Body(..., description="Question", example="工伤保险是什么?"), history: List[List[str]] = Body( @@ -238,7 +233,8 @@ async def chat( source_documents=source_documents, ) -async def no_knowledge_chat( + +async def chat( question: str = Body(..., description="Question", example="工伤保险是什么?"), history: List[List[str]] = Body( [], @@ -251,12 +247,19 @@ async def no_knowledge_chat( ], ), ): - - for resp, history in local_doc_qa._call( - query=question, chat_history=history, streaming=True + for resp, history in local_doc_qa.llm._call( + prompt=question, history=history, streaming=True ): pass + return ChatMessage( + question=question, + response=resp, + history=history, + source_documents=[], + ) + + async def stream_chat(websocket: WebSocket, knowledge_base_id: str): await websocket.accept() vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id) @@ -322,16 +325,20 @@ def main(): allow_credentials=True, allow_methods=["*"], allow_headers=["*"], - ) - app.websocket("/chat-docs/stream-chat/{knowledge_base_id}")(stream_chat) - app.post("/chat-docs/chat", response_model=ChatMessage)(chat) - app.post("/chat-docs/chatno", response_model=ChatMessage)(no_knowledge_chat) - app.post("/chat-docs/upload", response_model=BaseResponse)(upload_file) - app.post("/chat-docs/uploadone", response_model=BaseResponse)(single_upload_file) - app.get("/chat-docs/list", response_model=ListDocsResponse)(list_docs) - app.delete("/chat-docs/delete", response_model=BaseResponse)(delete_docs) + ) + app.websocket("/local_doc_qa/stream-chat/{knowledge_base_id}")(stream_chat) + app.get("/", response_model=BaseResponse)(document) + app.post("/chat", response_model=ChatMessage)(chat) + + app.post("/local_doc_qa/upload_file", response_model=BaseResponse)(upload_file) + app.post("/local_doc_qa/upload_files", response_model=BaseResponse)(upload_files) + app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage)(local_doc_chat) + app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs) + app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_docs) + + local_doc_qa = LocalDocQA() local_doc_qa.init_cfg( llm_model=LLM_MODEL, diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index e0a9132..a454931 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -10,6 +10,8 @@ import numpy as np from utils import torch_gc from tqdm import tqdm from pypinyin import lazy_pinyin +from loader import UnstructuredPaddleImageLoader +from loader import UnstructuredPaddlePDFLoader DEVICE_ = EMBEDDING_DEVICE DEVICE_ID = "0" if torch.cuda.is_available() else None @@ -21,16 +23,31 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE): loader = UnstructuredFileLoader(filepath, mode="elements") docs = loader.load() elif filepath.lower().endswith(".pdf"): - loader = UnstructuredFileLoader(filepath, strategy="fast") + loader = UnstructuredPaddlePDFLoader(filepath) textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size) docs = loader.load_and_split(textsplitter) + elif filepath.lower().endswith(".jpg") or filepath.lower().endswith(".png"): + loader = UnstructuredPaddleImageLoader(filepath, mode="elements") + textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) + docs = loader.load_and_split(text_splitter=textsplitter) else: loader = UnstructuredFileLoader(filepath, mode="elements") textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) docs = loader.load_and_split(text_splitter=textsplitter) + write_check_file(filepath, docs) return docs +def write_check_file(filepath, docs): + fout = open('load_file.txt', 'a') + fout.write("filepath=%s,len=%s" % (filepath, len(docs))) + fout.write('\n') + for i in docs: + fout.write(str(i)) + fout.write('\n') + fout.close() + + def generate_prompt(related_docs: List[str], query: str, prompt_template=PROMPT_TEMPLATE) -> str: context = "\n".join([doc.page_content for doc in related_docs]) @@ -176,7 +193,7 @@ class LocalDocQA: if len(failed_files) > 0: logger.info("以下文件未能成功加载:") for file in failed_files: - logger.info(file, end="\n") + logger.info(f"{file}\n") else: docs = [] @@ -212,7 +229,7 @@ class LocalDocQA: if not vs_path or not one_title or not one_conent: logger.info("知识库添加错误,请确认知识库名字、标题、内容是否正确!") return None, [one_title] - docs = [Document(page_content=one_conent+"\n", metadata={"source": one_title})] + docs = [Document(page_content=one_conent + "\n", metadata={"source": one_title})] if not one_content_segmentation: text_splitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) docs = text_splitter.split_documents(docs) diff --git a/docs/test.pdf b/docs/test.pdf new file mode 100644 index 0000000..3a137ad Binary files /dev/null and b/docs/test.pdf differ diff --git a/img/qr_code_16.jpg b/img/qr_code_16.jpg deleted file mode 100644 index febb78b..0000000 Binary files a/img/qr_code_16.jpg and /dev/null differ diff --git a/img/qr_code_17.jpg b/img/qr_code_17.jpg new file mode 100644 index 0000000..1fe4e43 Binary files /dev/null and b/img/qr_code_17.jpg differ diff --git a/img/test.jpg b/img/test.jpg new file mode 100644 index 0000000..70c199b Binary files /dev/null and b/img/test.jpg differ diff --git a/loader/__init__.py b/loader/__init__.py new file mode 100644 index 0000000..e9a7aea --- /dev/null +++ b/loader/__init__.py @@ -0,0 +1,2 @@ +from .image_loader import UnstructuredPaddleImageLoader +from .pdf_loader import UnstructuredPaddlePDFLoader diff --git a/loader/image_loader.py b/loader/image_loader.py new file mode 100644 index 0000000..5a1552a --- /dev/null +++ b/loader/image_loader.py @@ -0,0 +1,28 @@ +"""Loader that loads image files.""" +from typing import List + +from langchain.document_loaders.unstructured import UnstructuredFileLoader +from paddleocr import PaddleOCR +import os + + +class UnstructuredPaddleImageLoader(UnstructuredFileLoader): + """Loader that uses unstructured to load image files, such as PNGs and JPGs.""" + + def _get_elements(self) -> List: + def image_ocr_txt(filepath, dir_path="tmp_files"): + if not os.path.exists(dir_path): + os.makedirs(dir_path) + filename = os.path.split(filepath)[-1] + ocr = PaddleOCR(lang="ch", use_gpu=False, show_log=False) + result = ocr.ocr(img=filepath) + + ocr_result = [i[1][0] for line in result for i in line] + txt_file_path = os.path.join(dir_path, "%s.txt" % (filename)) + with open(txt_file_path, 'w', encoding='utf-8') as fout: + fout.write("\n".join(ocr_result)) + return txt_file_path + + txt_file_path = image_ocr_txt(self.file_path) + from unstructured.partition.text import partition_text + return partition_text(filename=txt_file_path, **self.unstructured_kwargs) diff --git a/loader/pdf_loader.py b/loader/pdf_loader.py new file mode 100644 index 0000000..cc623d9 --- /dev/null +++ b/loader/pdf_loader.py @@ -0,0 +1,44 @@ +"""Loader that loads image files.""" +from typing import List + +from langchain.document_loaders.unstructured import UnstructuredFileLoader +from paddleocr import PaddleOCR +import os +import fitz + + +class UnstructuredPaddlePDFLoader(UnstructuredFileLoader): + """Loader that uses unstructured to load image files, such as PNGs and JPGs.""" + + def _get_elements(self) -> List: + def pdf_ocr_txt(filepath, dir_path="tmp_files"): + if not os.path.exists(dir_path): + os.makedirs(dir_path) + filename = os.path.split(filepath)[-1] + ocr = PaddleOCR(lang="ch", use_gpu=False, show_log=False) + doc = fitz.open(filepath) + txt_file_path = os.path.join(dir_path, "%s.txt" % (filename)) + img_name = './img/.tmp.png' + with open(txt_file_path, 'w', encoding='utf-8') as fout: + + for i in range(doc.page_count): + page = doc[i] + text = page.get_text("") + fout.write(text) + fout.write("\n") + + img_list = page.get_images() + for img in img_list: + pix = fitz.Pixmap(doc, img[0]) + + pix.save(img_name) + + result = ocr.ocr(img_name) + ocr_result = [i[1][0] for line in result for i in line] + fout.write("\n".join(ocr_result)) + os.remove(img_name) + return txt_file_path + + txt_file_path = pdf_ocr_txt(self.file_path) + from unstructured.partition.text import partition_text + return partition_text(filename=txt_file_path, **self.unstructured_kwargs) diff --git a/requirements.txt b/requirements.txt index d7b2e4e..31d44e0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,6 @@ +pymupdf +paddlepaddle==2.4.2 +paddleocr langchain==0.0.146 transformers==4.27.1 unstructured[local-inference] diff --git a/test_image.py b/test_image.py new file mode 100644 index 0000000..ed60890 --- /dev/null +++ b/test_image.py @@ -0,0 +1,12 @@ +from configs.model_config import * +import nltk + +nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path + +filepath = "./img/test.jpg" +from loader import UnstructuredPaddleImageLoader + +loader = UnstructuredPaddleImageLoader(filepath, mode="elements") +docs = loader.load() +for doc in docs: + print(doc) diff --git a/test_pdf.py b/test_pdf.py new file mode 100644 index 0000000..32dcb34 --- /dev/null +++ b/test_pdf.py @@ -0,0 +1,12 @@ +from configs.model_config import * +import nltk + +nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path + +filepath = "docs/test.pdf" +from loader import UnstructuredPaddlePDFLoader + +loader = UnstructuredPaddlePDFLoader(filepath, mode="elements") +docs = loader.load() +for doc in docs: + print(doc)