fix the issue doc file can't be loaded

This commit is contained in:
wvivi2023 2024-01-26 14:18:57 +08:00
parent 99969ef1e3
commit 1d12f84310
5 changed files with 21 additions and 45 deletions

View File

@ -2,4 +2,4 @@ from .mypdfloader import RapidOCRPDFLoader
from .myimgloader import RapidOCRLoader from .myimgloader import RapidOCRLoader
from .customiedpdfloader import CustomizedPDFLoader from .customiedpdfloader import CustomizedPDFLoader
from .mywordload import RapidWordLoader from .mywordload import RapidWordLoader
#from .customercore import custom_group_broken_paragraphs from .customercore import custom_group_broken_paragraphs

View File

@ -5,11 +5,12 @@ from docx.document import Document as _Document
from docx.table import _Cell from docx.table import _Cell
from docx.oxml.text.paragraph import CT_P from docx.oxml.text.paragraph import CT_P
from docx.oxml.table import CT_Tbl from docx.oxml.table import CT_Tbl
from docx.oxml.table import CT_TblGrid
from docx.table import _Cell, Table from docx.table import _Cell, Table
from docx.text.paragraph import Paragraph from docx.text.paragraph import Paragraph
from unstructured.partition.text import partition_text from unstructured.partition.text import partition_text
import unstructured.cleaners.core import unstructured.cleaners.core
from .customercore import custom_group_broken_paragraphs from customercore import custom_group_broken_paragraphs
unstructured.cleaners.core.group_broken_paragraphs = custom_group_broken_paragraphs unstructured.cleaners.core.group_broken_paragraphs = custom_group_broken_paragraphs
class RapidWordLoader(UnstructuredFileLoader): class RapidWordLoader(UnstructuredFileLoader):
@ -32,6 +33,10 @@ class RapidWordLoader(UnstructuredFileLoader):
yield Paragraph(child, parent) yield Paragraph(child, parent)
elif isinstance(child, CT_Tbl): elif isinstance(child, CT_Tbl):
yield Table(child, parent) yield Table(child, parent)
elif isinstance(child, CT_TblGrid):
yield Table(child, parent)
else:
print(f"都不属于")
def read_table(table): def read_table(table):
# 获取表格列标题 # 获取表格列标题
@ -61,7 +66,7 @@ class RapidWordLoader(UnstructuredFileLoader):
doc = docxDocument(filepath) doc = docxDocument(filepath)
for block in iter_block_items(doc): for block in iter_block_items(doc):
if isinstance(block,Paragraph): if isinstance(block,Paragraph):
#print(f"Paragraph:{block.text}") print(f"Paragraph:{block.text}")
resp += (block.text + "\n\n") resp += (block.text + "\n\n")
elif isinstance(block, Table): elif isinstance(block, Table):
resp += read_table(block) + "\n" resp += read_table(block) + "\n"
@ -76,7 +81,7 @@ class RapidWordLoader(UnstructuredFileLoader):
return listText return listText
if __name__ == "__main__": if __name__ == "__main__":
loader = RapidWordLoader(file_path="/Users/wangvivi/Desktop/Work/思极GPT/数字化部/设备类all/sb389/10kV带电作业用绝缘斗臂车.docx") loader = RapidWordLoader(file_path="/Users/wangvivi/Downloads/输变电设备风险评估导则.docx")
#loader = Docx2txtLoader(file_path="/Users/wangvivi/Desktop/Work/思极GPT/数字化部/设备类all/sb389/10kV带电作业用绝缘斗臂车.docx") #loader = Docx2txtLoader(file_path="/Users/wangvivi/Desktop/Work/思极GPT/数字化部/设备类all/sb389/10kV带电作业用绝缘斗臂车.docx")
#loader = RapidWordLoader(file_path="/Users/wangvivi/Desktop/MySelf/AI/Test/这是一个测试文档_副本2.docx") #loader = RapidWordLoader(file_path="/Users/wangvivi/Desktop/MySelf/AI/Test/这是一个测试文档_副本2.docx")
docs = loader.load() docs = loader.load()

View File

@ -11,7 +11,7 @@ from langchain.schema.language_model import BaseLanguageModel
from typing import List, Any, Optional from typing import List, Any, Optional
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from server.chat.knowledge_base_chat import knowledge_base_chat from server.chat.knowledge_base_chat import knowledge_base_chat
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS from configs import FIRST_VECTOR_SEARCH_TOP_K, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
import asyncio import asyncio
from server.agent import model_container from server.agent import model_container
from pydantic import BaseModel, Field from pydantic import BaseModel, Field

View File

@ -123,7 +123,7 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
"UnstructuredPowerPointLoader": ['.ppt', '.pptx'], "UnstructuredPowerPointLoader": ['.ppt', '.pptx'],
"EverNoteLoader": ['.enex'], "EverNoteLoader": ['.enex'],
"UnstructuredFileLoader": ['.txt'], "UnstructuredFileLoader": ['.txt'],
"Docx2txtLoader":['.doc'], "UnstructuredWordDocumentLoader":['.doc'],
"RapidWordLoader":['.docx'] "RapidWordLoader":['.docx']
} }
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist] SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]

View File

@ -9,12 +9,12 @@ from configs import (LLM_MODELS, LLM_DEVICE, EMBEDDING_DEVICE,
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT) FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
import os import os
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain.chat_models import ChatOpenAI #from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI from langchain._api import ChatOpenAI
from langchain.llms import OpenAI, AzureOpenAI, Anthropic
import httpx import httpx
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union, Tuple from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union, Tuple
import logging import logging
import torch
async def wrap_done(fn: Awaitable, event: asyncio.Event): async def wrap_done(fn: Awaitable, event: asyncio.Event):
@ -59,7 +59,6 @@ def get_ChatOpenAI(
) )
return model return model
def get_OpenAI( def get_OpenAI(
model_name: str, model_name: str,
temperature: float, temperature: float,
@ -153,6 +152,7 @@ class ChatMessage(BaseModel):
def torch_gc(): def torch_gc():
try: try:
import torch
if torch.cuda.is_available(): if torch.cuda.is_available():
# with torch.cuda.device(DEVICE): # with torch.cuda.device(DEVICE):
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -500,58 +500,29 @@ def set_httpx_config(
# 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch # 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch
def is_mps_available():
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
def is_cuda_available():
return torch.cuda.is_available()
def detect_device() -> Literal["cuda", "mps", "cpu"]: def detect_device() -> Literal["cuda", "mps", "cpu"]:
try: try:
import torch
if torch.cuda.is_available(): if torch.cuda.is_available():
return "cuda" return "cuda"
if is_mps_available(): if torch.backends.mps.is_available():
return "mps" return "mps"
except: except:
pass pass
return "cpu" return "cpu"
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]: def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
device = device or LLM_DEVICE device = device or LLM_DEVICE
if device not in ["cuda", "mps", "cpu", "xpu"]: if device not in ["cuda", "mps", "cpu"]:
logging.warning(f"device not in ['cuda', 'mps', 'cpu','xpu'], device = {device}")
device = detect_device() device = detect_device()
elif device == 'cuda' and not is_cuda_available() and is_mps_available():
logging.warning("cuda is not available, fallback to mps")
return "mps"
if device == 'mps' and not is_mps_available() and is_cuda_available():
logging.warning("mps is not available, fallback to cuda")
return "cuda"
# auto detect device if not specified
if device not in ["cuda", "mps", "cpu", "xpu"]:
return detect_device()
return device return device
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]: def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
device = device or LLM_DEVICE device = device or EMBEDDING_DEVICE
if device not in ["cuda", "mps", "cpu"]: if device not in ["cuda", "mps", "cpu"]:
logging.warning(f"device not in ['cuda', 'mps', 'cpu','xpu'], device = {device}")
device = detect_device() device = detect_device()
elif device == 'cuda' and not is_cuda_available() and is_mps_available():
logging.warning("cuda is not available, fallback to mps")
return "mps"
if device == 'mps' and not is_mps_available() and is_cuda_available():
logging.warning("mps is not available, fallback to cuda")
return "cuda"
# auto detect device if not specified
if device not in ["cuda", "mps", "cpu"]:
return detect_device()
return device return device