diff --git a/bonus-modules/bonus-ai/pom.xml b/bonus-modules/bonus-ai/pom.xml index 228b4a6..6e33653 100644 --- a/bonus-modules/bonus-ai/pom.xml +++ b/bonus-modules/bonus-ai/pom.xml @@ -41,7 +41,14 @@ spring-boot-starter-actuator - + + org.springframework.boot + spring-boot-starter-web + + + org.springframework + spring-web + @@ -99,7 +106,10 @@ com.google.code.gson gson - + + org.springframework.boot + spring-boot-starter-web + diff --git a/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/config/CustomMultipartFile.java b/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/config/CustomMultipartFile.java index 431803e..84fb0b2 100644 --- a/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/config/CustomMultipartFile.java +++ b/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/config/CustomMultipartFile.java @@ -18,19 +18,23 @@ public class CustomMultipartFile implements MultipartFile { @Override public String getOriginalFilename() { - return file.getName(); // 或根据需要返回不同的名称 + return file.getName(); // 可以返回上传时的文件名 } @Override public String getContentType() { // 根据文件扩展名返回内容类型 - String name = file.getName(); + String name = file.getName().toLowerCase(); if (name.endsWith(".jpg") || name.endsWith(".jpeg")) { return "image/jpeg"; } else if (name.endsWith(".png")) { return "image/png"; + } else if (name.endsWith(".txt")) { + return "text/plain"; + } else if (name.endsWith(".pdf")) { + return "application/pdf"; } - return null; // 或者抛出异常 + return "application/octet-stream"; // 返回默认的二进制流类型 } @Override @@ -45,9 +49,12 @@ public class CustomMultipartFile implements MultipartFile { @Override public byte[] getBytes() throws IOException { + // 打开文件流并读取字节内容 try (FileInputStream fis = new FileInputStream(file)) { byte[] bytes = new byte[(int) file.length()]; - fis.read(bytes); + if (fis.read(bytes) == -1) { + throw new IOException("Unable to read the file content"); + } return bytes; } } @@ -59,13 +66,14 @@ public class CustomMultipartFile implements MultipartFile { @Override public void transferTo(File dest) throws IOException, IllegalStateException { - // 实现文件传输的逻辑 - try (FileInputStream fis = new FileInputStream(file); - FileOutputStream fos = new FileOutputStream(dest)) { - byte[] buffer = new byte[1024]; - int length; - while ((length = fis.read(buffer)) > 0) { - fos.write(buffer, 0, length); + // 可以使用 NIO 的 Files.copy 方法来简化文件传输逻辑 + try (InputStream inputStream = new FileInputStream(file)) { + try (OutputStream outputStream = new FileOutputStream(dest)) { + byte[] buffer = new byte[8192]; + int bytesRead; + while ((bytesRead = inputStream.read(buffer)) != -1) { + outputStream.write(buffer, 0, bytesRead); + } } } } diff --git a/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/controller/DataSetController.java b/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/controller/DataSetController.java index dd07580..6e17aa1 100644 --- a/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/controller/DataSetController.java +++ b/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/controller/DataSetController.java @@ -1,26 +1,22 @@ package com.bonus.ai.controller; -import cn.hutool.core.io.IoUtil; import com.bonus.ai.config.CustomMultipartFile; import com.bonus.ai.domain.*; import com.bonus.ai.service.DataSetService; -import com.bonus.common.core.utils.StringUtils; import com.bonus.common.core.web.controller.BaseController; import com.bonus.common.core.web.domain.AjaxResult; import com.bonus.common.core.web.page.TableDataInfo; +import com.bonus.common.security.annotation.PreventRepeatSubmit; import com.bonus.system.api.RemoteFileService; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.ObjectUtils; -import org.apache.poi.openxml4j.opc.internal.FileHelper; -import org.aspectj.util.FileUtil; -import org.springframework.http.HttpStatus; -import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.*; import org.springframework.web.multipart.MultipartFile; -import org.springframework.web.multipart.commons.CommonsMultipartFile; import javax.annotation.Resource; -import java.io.*; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import java.util.ArrayList; @@ -44,6 +40,7 @@ public class DataSetController extends BaseController { @Resource private RemoteFileService remoteFileService; + /** * 根据数据集 ID 查询对应的数据集信息。 * @@ -127,6 +124,10 @@ public class DataSetController extends BaseController { return dataSetService.getCategoryById(categoryId); } + @PostMapping("/getCategories/{categoryId}") + public AjaxResult getCategoriesById(@PathVariable Long categoryId) { + return dataSetService.getCategoriesById(categoryId); + } /** * 获取所有符合条件的类别信息。 * @@ -144,6 +145,7 @@ public class DataSetController extends BaseController { * @param entity 要插入的 DataSetCategoryEntity 对象,包含类别的详细信息 * @return 返回插入操作影响的行数(成功插入的记录数,通常为 1) */ + @PreventRepeatSubmit @PostMapping("/insertCategory") public AjaxResult insertCategory(@RequestBody DataSetCategoryEntity entity) { return dataSetService.insertCategory(entity); @@ -490,41 +492,38 @@ public class DataSetController extends BaseController { @PostMapping("/uploadZipFiles") public AjaxResult uploadZipFiles(@RequestParam("files") MultipartFile[] files, @RequestParam("datasetId") Long datasetId) { - List multipartFiles = new ArrayList<>(); Path tempDir; - try { // 创建一个临时目录用于存放解压后的文件 tempDir = Files.createTempDirectory("uploads"); // 遍历上传的压缩文件 for (MultipartFile file : files) { - processZipFile(file, multipartFiles, tempDir); + // 处理每个压缩文件,并上传其中的文件 + uploadFilesFromZip(file, datasetId, tempDir); } - - // 处理完成后删除临时文件夹 - deleteTempDirectory(tempDir); - - // 将 List 转换为数组,方便后续上传 - MultipartFile[] resultFiles = multipartFiles.toArray(new MultipartFile[0]); - return handleFileUpload(resultFiles, datasetId); + return AjaxResult.success("上传成功"); } catch (IOException e) { // 如果创建临时目录失败,返回错误信息 - return AjaxResult.error("Failed to create temporary directory: " + e.getMessage()); + return AjaxResult.error("上传失败"); } } - // 处理每个压缩文件,解压并将图片文件添加到列表中 - private void processZipFile(MultipartFile file, List multipartFiles, Path tempDir) { + // 从压缩文件中解压并上传文件 + private void uploadFilesFromZip(MultipartFile file, Long datasetId, Path tempDir) { + List multipartFiles = new ArrayList<>(); + int batchSize = 100; // 每批处理的文件数量 + try (ZipInputStream zis = new ZipInputStream(file.getInputStream())) { ZipEntry entry; // 逐个读取压缩文件中的条目 while ((entry = zis.getNextEntry()) != null) { // 如果条目不是目录且是图片文件 if (!entry.isDirectory() && isImageFile(entry.getName())) { - // 创建解压后的图片文件 + // 创建解压后的图片文 件 File imgFile = new File(tempDir.toFile(), entry.getName()); imgFile.getParentFile().mkdirs(); // 确保目录存在 + // 将条目内容写入到文件 try (FileOutputStream fos = new FileOutputStream(imgFile)) { byte[] buffer = new byte[1024]; @@ -537,8 +536,27 @@ public class DataSetController extends BaseController { // 创建 MultipartFile 对象并添加到列表中 MultipartFile multipartFile = new CustomMultipartFile(imgFile); multipartFiles.add(multipartFile); + + // 每处理 100 个文件时进行一次上传 + if (multipartFiles.size() >= batchSize) { + // 上传当前批次的文件 + AjaxResult result = handleFileUpload(multipartFiles.toArray(new MultipartFile[0]), datasetId, tempDir); + multipartFiles.clear(); // 清空已上传的文件列表 + // 检查上传结果 + if (!result.isSuccess()) { + throw new RuntimeException("File upload failed for " + file.getOriginalFilename()); + } + } } - zis.closeEntry(); // 关闭当前条目 + } + // 处理剩余未上传的文件 + if (!multipartFiles.isEmpty()) { + AjaxResult result = handleFileUpload(multipartFiles.toArray(new MultipartFile[0]), datasetId, tempDir); + if (!result.isSuccess()) { + throw new RuntimeException("File upload failed for " + file.getOriginalFilename()); + } + }else { + deleteTempDirectory(tempDir); } } catch (IOException e) { // 抛出运行时异常,包含原始文件名的信息 @@ -547,7 +565,7 @@ public class DataSetController extends BaseController { } // 处理文件上传及数据库插入 - private AjaxResult handleFileUpload(MultipartFile[] resultFiles, Long datasetId) { + private AjaxResult handleFileUpload(MultipartFile[] resultFiles, Long datasetId, Path tempDir) throws IOException { // 调用远程服务上传文件 AjaxResult ajaxResult = remoteFileService.uploadFile(resultFiles); if (ajaxResult.isSuccess()) { @@ -564,7 +582,7 @@ public class DataSetController extends BaseController { entity.setFileName(map.get("name")); dataSetService.insertFile(entity); } - // 返回成功结果 + // deleteTempDirectory(tempDir); // 删除临时目录放在上传后再处理 return AjaxResult.success(); } else { // 上传失败,返回错误信息 @@ -601,4 +619,5 @@ public class DataSetController extends BaseController { e.printStackTrace(); // 打印错误信息 } } + } diff --git a/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/mapper/DataSetMapper.java b/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/mapper/DataSetMapper.java index d11546e..5e76882 100644 --- a/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/mapper/DataSetMapper.java +++ b/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/mapper/DataSetMapper.java @@ -70,7 +70,7 @@ public interface DataSetMapper { * @return 返回所有符合条件的类别的列表 */ List getCategories(DataSetCategoryEntity entity); - + List getCategoriesById(Long categoryId); /** * 插入新的样本类别数据到数据库。 * @@ -79,6 +79,15 @@ public interface DataSetMapper { */ int insertCategory(DataSetCategoryEntity entity); + // 检查类别名称是否已存在 + int countByNameAndParentId(@Param("categoryName") String categoryName, @Param("parentId") Long parentId, @Param("categoryId") Long categoryId); + + // 获取所有祖先类别名称 + List getAncestorCategoryNames(@Param("categoryId") Long categoryId); + + // 获取所有祖先类别名称 + List getAncestorCategoryName(@Param("parentId") Long parentId); + /** * 更新指定类别的详细信息。 * @@ -93,7 +102,9 @@ public interface DataSetMapper { * @param categoryId 要删除的类别的唯一标识符 * @return 返回删除操作影响的行数(成功删除的记录数,通常为 1) */ - int deleteCategory(Long categoryId); + int deleteCategory(@Param("categoryIds") List categoryIds); + + List getCategoryId(Long categoryId); /** * 根据 ID 查询数据集文件信息 @@ -281,4 +292,6 @@ public interface DataSetMapper { * @return 影响的行数 */ int deleteAlgorithm(Long[] algorithmIds); + + } diff --git a/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/service/DataSetService.java b/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/service/DataSetService.java index ff7fa73..ff6eb6d 100644 --- a/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/service/DataSetService.java +++ b/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/service/DataSetService.java @@ -267,4 +267,6 @@ public interface DataSetService { * @return 影响的行数 */ AjaxResult deleteAlgorithm(Long[] algorithmIds); + + AjaxResult getCategoriesById(Long categoryId); } diff --git a/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/service/impl/DataSetServiceImpl.java b/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/service/impl/DataSetServiceImpl.java index 11449e5..97651b7 100644 --- a/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/service/impl/DataSetServiceImpl.java +++ b/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/service/impl/DataSetServiceImpl.java @@ -4,15 +4,19 @@ import com.bonus.ai.domain.*; import com.bonus.ai.mapper.DataSetMapper; import com.bonus.ai.service.DataSetService; import com.bonus.common.core.domain.R; +import com.bonus.common.core.utils.StringUtils; import com.bonus.common.core.web.domain.AjaxResult; import com.bonus.common.security.utils.SecurityUtils; import com.bonus.system.api.RemoteFileService; import com.bonus.system.api.domain.SysFile; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.ArrayUtils; +import org.apache.commons.lang3.ObjectUtils; import org.springframework.stereotype.Service; import org.springframework.web.multipart.MultipartFile; import javax.annotation.Resource; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; @@ -138,6 +142,20 @@ public class DataSetServiceImpl implements DataSetService { } } + /** + * @param categoryId + * @return + */ + @Override + public AjaxResult getCategoriesById(Long categoryId) { + try { + List categories = mapper.getCategoriesById(categoryId); + return AjaxResult.success(categories); + } catch (Exception e) { + return AjaxResult.error("获取数据失败"); + } + } + /** * 插入新的样本类别数据到数据库。 * @@ -147,7 +165,20 @@ public class DataSetServiceImpl implements DataSetService { @Override public AjaxResult insertCategory(DataSetCategoryEntity entity) { try { + // 创建一个列表来存储所有要删除的类别ID entity.setCreatedBy(SecurityUtils.getUserId()); + List ancestorCategoryName = mapper.getAncestorCategoryName(entity.getParentId()); + List ancestorNames = new ArrayList<>(ancestorCategoryName); + int num = mapper.countByNameAndParentId(entity.getCategoryName(), entity.getParentId(), entity.getCategoryId()); + if (num > 0) { + return AjaxResult.error("新增失败,类别已存在"); + } + List ancestorCategoryNames = findAncestorCategoryNames(entity.getParentId(), ancestorNames); + for (DataSetCategoryEntity e : ancestorCategoryNames) { + if (entity.getCategoryName().equals(e.getCategoryName())) { + return AjaxResult.error("子类别名称不能与任何祖先类别名称相同"); + } + } int i = mapper.insertCategory(entity); if (i > 0) { return AjaxResult.success("新增成功"); @@ -159,6 +190,18 @@ public class DataSetServiceImpl implements DataSetService { } } + public List findAncestorCategoryNames(Long categoryId, List ancestorNames) { + // 获取当前类别的所有子类别ID + List ancestorCategoryNames = mapper.getAncestorCategoryNames(categoryId); + // 将子类别ID加入列表 + ancestorNames.addAll(ancestorCategoryNames); + // 递归调用获取每个子类别的子类别 + for (DataSetCategoryEntity entity : ancestorCategoryNames) { + findAncestorCategoryNames(entity.getCategoryId(), ancestorNames); + } + return ancestorNames; + } + /** * 更新指定类别的详细信息。 * @@ -168,6 +211,18 @@ public class DataSetServiceImpl implements DataSetService { @Override public AjaxResult updateCategory(DataSetCategoryEntity entity) { try { + int num = mapper.countByNameAndParentId(entity.getCategoryName(), entity.getParentId(), entity.getCategoryId()); + if (num > 0) { + return AjaxResult.error("修改失败,类别已存在"); + } + List ancestorCategoryName = mapper.getAncestorCategoryName(entity.getParentId()); + List ancestorNames = new ArrayList<>(ancestorCategoryName); + List ancestorCategoryNames = findAncestorCategoryNames(entity.getParentId(), ancestorNames); + for (DataSetCategoryEntity e : ancestorCategoryNames) { + if (entity.getCategoryName().equals(e.getCategoryName())) { + return AjaxResult.error("子类别名称不能与任何祖先类别名称相同"); + } + } int i = mapper.updateCategory(entity); return i > 0 ? AjaxResult.success("修改成功") : AjaxResult.error("修改失败"); } catch (Exception e) { @@ -184,13 +239,31 @@ public class DataSetServiceImpl implements DataSetService { @Override public AjaxResult deleteCategory(Long categoryId) { try { - int i = mapper.deleteCategory(categoryId); + // 创建一个列表来存储所有要删除的类别ID + List categoryIds = new ArrayList<>(); + // 递归获取所有子类别ID + getAllSubcategories(categoryId, categoryIds); + // 将当前类别 ID 加入待删除列表 + categoryIds.add(categoryId); + int i = mapper.deleteCategory(categoryIds); return i > 0 ? AjaxResult.success("删除成功") : AjaxResult.error("删除失败"); } catch (Exception e) { return AjaxResult.error("删除失败"); } } + // 递归方法获取所有子类别ID + private void getAllSubcategories(Long categoryId, List categoryIds) { + // 获取当前类别的所有子类别ID + List subcategoryIds = mapper.getCategoryId(categoryId); + // 将子类别ID加入列表 + categoryIds.addAll(subcategoryIds); + // 递归调用获取每个子类别的子类别 + for (Long subcategoryId : subcategoryIds) { + getAllSubcategories(subcategoryId, categoryIds); + } + } + /** * 根据 ID 查询数据集文件信息 * @@ -390,25 +463,28 @@ public class DataSetServiceImpl implements DataSetService { public AjaxResult insertModel(AiModelEntity entity) { try { entity.setCreateBy(SecurityUtils.getUsername()); - AjaxResult ajaxResult = remoteFileService.uploadFile(entity.getModelFile()); - if (ajaxResult.isSuccess()) { - List> data = (List>) ajaxResult.get("data"); - for (Map map : data) { - entity.setModelAddress(map.get("url")); + if (ObjectUtils.isNotEmpty(entity.getModelFile())) { + AjaxResult ajaxResult = remoteFileService.uploadFile(entity.getModelFile()); + if (ajaxResult.isSuccess()) { + List> data = (List>) ajaxResult.get("data"); + for (Map map : data) { + entity.setModelAddress(map.get("url")); + } + } else { + return AjaxResult.error("上传文件失败"); } - } else { - return AjaxResult.error("上传文件失败"); } + if (ObjectUtils.isNotEmpty(entity.getUserGuideFile())) { + AjaxResult userGuide = remoteFileService.uploadFile(entity.getUserGuideFile()); + if (userGuide.isSuccess()) { + List> data = (List>) userGuide.get("data"); + for (Map map : data) { + entity.setUserGuide(map.get("url")); + } + } else { - AjaxResult userGuide = remoteFileService.uploadFile(entity.getUserGuideFile()); - if (userGuide.isSuccess()) { - List> data = (List>) userGuide.get("data"); - for (Map map : data) { - entity.setUserGuide(map.get("url")); + return AjaxResult.error("上传文件失败"); } - } else { - - return AjaxResult.error("上传文件失败"); } int i = mapper.insertModel(entity); return i > 0 ? AjaxResult.success("新增成功") : AjaxResult.error("新增失败"); @@ -426,7 +502,7 @@ public class DataSetServiceImpl implements DataSetService { @Override public AjaxResult updateModel(AiModelEntity entity) { try { - if (entity.getModelFile().length > 0) { + if (ObjectUtils.isNotEmpty(entity.getModelFile())) { AjaxResult ajaxResult = remoteFileService.uploadFile(entity.getModelFile()); if (ajaxResult.isSuccess()) { List> data = (List>) ajaxResult.get("data"); @@ -437,7 +513,7 @@ public class DataSetServiceImpl implements DataSetService { return AjaxResult.error("上传文件失败"); } } - if (entity.getUserGuideFile().length > 0) { + if (ObjectUtils.isNotEmpty(entity.getModelFile())) { AjaxResult userGuide = remoteFileService.uploadFile(entity.getUserGuideFile()); if (userGuide.isSuccess()) { List> data = (List>) userGuide.get("data"); @@ -551,4 +627,6 @@ public class DataSetServiceImpl implements DataSetService { return AjaxResult.error("删除失败"); } } + + } diff --git a/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/service/impl/FaceServiceImpl.java b/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/service/impl/FaceServiceImpl.java index 7ce9942..9303baa 100644 --- a/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/service/impl/FaceServiceImpl.java +++ b/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/service/impl/FaceServiceImpl.java @@ -4,6 +4,7 @@ import cn.hutool.core.io.unit.DataUnit; import com.alibaba.fastjson2.JSON; import com.alibaba.fastjson2.JSONArray; import com.alibaba.fastjson2.JSONObject; +import com.bonus.ai.domain.DataSetFileEntity; import com.bonus.ai.domain.vo.FaceResultVo; import com.bonus.ai.domain.vo.FaceVo; import com.bonus.ai.mapper.AiFaceRecognizeResultMapper; @@ -121,10 +122,13 @@ public class FaceServiceImpl implements FaceService { } private AjaxResult handleFileUpload(MultipartFile file, FaceVo face) throws Exception { - R upload = remoteFileService.upload(file); - if (upload.getCode() == 200) { - face.setFaceAddress(upload.getData().getUrl().replaceFirst("http://[^/]+", "")); - return faceMapper.insertFace(face) > 0 ? AjaxResult.success() : AjaxResult.error(); + AjaxResult upload = remoteFileService.upload(file); + if (upload.isSuccess()) { + List> data = (List>) upload.get("data"); + for (Map map : data) { + face.setFaceAddress(map.get("url").replaceFirst("http://[^/]+", "")); + return faceMapper.insertFace(face) > 0 ? AjaxResult.success() : AjaxResult.error(); + } } return AjaxResult.error(); } diff --git a/bonus-modules/bonus-ai/src/main/resources/mapper/ai/DataSetMapper.xml b/bonus-modules/bonus-ai/src/main/resources/mapper/ai/DataSetMapper.xml index b542d8d..9b760e0 100644 --- a/bonus-modules/bonus-ai/src/main/resources/mapper/ai/DataSetMapper.xml +++ b/bonus-modules/bonus-ai/src/main/resources/mapper/ai/DataSetMapper.xml @@ -217,6 +217,50 @@ AND ada.model_id =#{modelId} + + + + + + INSERT INTO ai_dataset_category (category_name, parent_id, enabled, created_by, description) VALUES (#{categoryName}, #{parentId}, #{enabled}, #{createdBy}, #{description}) @@ -250,11 +295,11 @@ - INSERT INTO ai_model (model_name, model_version, sub_task_type_id, + INSERT INTO ai_model (model_name, model_version, infer_language, model_format, deploy_requirement, model_address, user_guide, description, create_by, algorithm, dataSetId, model_type) - VALUES (#{modelName}, #{modelVersion}, #{subTaskTypeId}, #{inferLanguage}, #{modelFormat}, + VALUES (#{modelName}, #{modelVersion}, #{inferLanguage}, #{modelFormat}, #{deployRequirement}, #{modelAddress}, #{userGuide}, #{description}, #{createBy}, #{algorithm}, #{dataSetId}, #{modelType}); @@ -306,7 +351,10 @@ UPDATE ai_dataset_category SET del_flag = '1' - WHERE category_id = #{categoryId} + WHERE category_id IN + + #{categoryId} + UPDATE ai_dataset_file