diff --git a/bonus-modules/bonus-ai/pom.xml b/bonus-modules/bonus-ai/pom.xml
index 228b4a6..6e33653 100644
--- a/bonus-modules/bonus-ai/pom.xml
+++ b/bonus-modules/bonus-ai/pom.xml
@@ -41,7 +41,14 @@
spring-boot-starter-actuator
-
+
+ org.springframework.boot
+ spring-boot-starter-web
+
+
+ org.springframework
+ spring-web
+
@@ -99,7 +106,10 @@
com.google.code.gson
gson
-
+
+ org.springframework.boot
+ spring-boot-starter-web
+
diff --git a/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/config/CustomMultipartFile.java b/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/config/CustomMultipartFile.java
index 431803e..84fb0b2 100644
--- a/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/config/CustomMultipartFile.java
+++ b/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/config/CustomMultipartFile.java
@@ -18,19 +18,23 @@ public class CustomMultipartFile implements MultipartFile {
@Override
public String getOriginalFilename() {
- return file.getName(); // 或根据需要返回不同的名称
+ return file.getName(); // 可以返回上传时的文件名
}
@Override
public String getContentType() {
// 根据文件扩展名返回内容类型
- String name = file.getName();
+ String name = file.getName().toLowerCase();
if (name.endsWith(".jpg") || name.endsWith(".jpeg")) {
return "image/jpeg";
} else if (name.endsWith(".png")) {
return "image/png";
+ } else if (name.endsWith(".txt")) {
+ return "text/plain";
+ } else if (name.endsWith(".pdf")) {
+ return "application/pdf";
}
- return null; // 或者抛出异常
+ return "application/octet-stream"; // 返回默认的二进制流类型
}
@Override
@@ -45,9 +49,12 @@ public class CustomMultipartFile implements MultipartFile {
@Override
public byte[] getBytes() throws IOException {
+ // 打开文件流并读取字节内容
try (FileInputStream fis = new FileInputStream(file)) {
byte[] bytes = new byte[(int) file.length()];
- fis.read(bytes);
+ if (fis.read(bytes) == -1) {
+ throw new IOException("Unable to read the file content");
+ }
return bytes;
}
}
@@ -59,13 +66,14 @@ public class CustomMultipartFile implements MultipartFile {
@Override
public void transferTo(File dest) throws IOException, IllegalStateException {
- // 实现文件传输的逻辑
- try (FileInputStream fis = new FileInputStream(file);
- FileOutputStream fos = new FileOutputStream(dest)) {
- byte[] buffer = new byte[1024];
- int length;
- while ((length = fis.read(buffer)) > 0) {
- fos.write(buffer, 0, length);
+ // 可以使用 NIO 的 Files.copy 方法来简化文件传输逻辑
+ try (InputStream inputStream = new FileInputStream(file)) {
+ try (OutputStream outputStream = new FileOutputStream(dest)) {
+ byte[] buffer = new byte[8192];
+ int bytesRead;
+ while ((bytesRead = inputStream.read(buffer)) != -1) {
+ outputStream.write(buffer, 0, bytesRead);
+ }
}
}
}
diff --git a/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/controller/DataSetController.java b/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/controller/DataSetController.java
index dd07580..6e17aa1 100644
--- a/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/controller/DataSetController.java
+++ b/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/controller/DataSetController.java
@@ -1,26 +1,22 @@
package com.bonus.ai.controller;
-import cn.hutool.core.io.IoUtil;
import com.bonus.ai.config.CustomMultipartFile;
import com.bonus.ai.domain.*;
import com.bonus.ai.service.DataSetService;
-import com.bonus.common.core.utils.StringUtils;
import com.bonus.common.core.web.controller.BaseController;
import com.bonus.common.core.web.domain.AjaxResult;
import com.bonus.common.core.web.page.TableDataInfo;
+import com.bonus.common.security.annotation.PreventRepeatSubmit;
import com.bonus.system.api.RemoteFileService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ObjectUtils;
-import org.apache.poi.openxml4j.opc.internal.FileHelper;
-import org.aspectj.util.FileUtil;
-import org.springframework.http.HttpStatus;
-import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
-import org.springframework.web.multipart.commons.CommonsMultipartFile;
import javax.annotation.Resource;
-import java.io.*;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
@@ -44,6 +40,7 @@ public class DataSetController extends BaseController {
@Resource
private RemoteFileService remoteFileService;
+
/**
* 根据数据集 ID 查询对应的数据集信息。
*
@@ -127,6 +124,10 @@ public class DataSetController extends BaseController {
return dataSetService.getCategoryById(categoryId);
}
+ @PostMapping("/getCategories/{categoryId}")
+ public AjaxResult getCategoriesById(@PathVariable Long categoryId) {
+ return dataSetService.getCategoriesById(categoryId);
+ }
/**
* 获取所有符合条件的类别信息。
*
@@ -144,6 +145,7 @@ public class DataSetController extends BaseController {
* @param entity 要插入的 DataSetCategoryEntity 对象,包含类别的详细信息
* @return 返回插入操作影响的行数(成功插入的记录数,通常为 1)
*/
+ @PreventRepeatSubmit
@PostMapping("/insertCategory")
public AjaxResult insertCategory(@RequestBody DataSetCategoryEntity entity) {
return dataSetService.insertCategory(entity);
@@ -490,41 +492,38 @@ public class DataSetController extends BaseController {
@PostMapping("/uploadZipFiles")
public AjaxResult uploadZipFiles(@RequestParam("files") MultipartFile[] files, @RequestParam("datasetId") Long datasetId) {
- List multipartFiles = new ArrayList<>();
Path tempDir;
-
try {
// 创建一个临时目录用于存放解压后的文件
tempDir = Files.createTempDirectory("uploads");
// 遍历上传的压缩文件
for (MultipartFile file : files) {
- processZipFile(file, multipartFiles, tempDir);
+ // 处理每个压缩文件,并上传其中的文件
+ uploadFilesFromZip(file, datasetId, tempDir);
}
-
- // 处理完成后删除临时文件夹
- deleteTempDirectory(tempDir);
-
- // 将 List 转换为数组,方便后续上传
- MultipartFile[] resultFiles = multipartFiles.toArray(new MultipartFile[0]);
- return handleFileUpload(resultFiles, datasetId);
+ return AjaxResult.success("上传成功");
} catch (IOException e) {
// 如果创建临时目录失败,返回错误信息
- return AjaxResult.error("Failed to create temporary directory: " + e.getMessage());
+ return AjaxResult.error("上传失败");
}
}
- // 处理每个压缩文件,解压并将图片文件添加到列表中
- private void processZipFile(MultipartFile file, List multipartFiles, Path tempDir) {
+ // 从压缩文件中解压并上传文件
+ private void uploadFilesFromZip(MultipartFile file, Long datasetId, Path tempDir) {
+ List multipartFiles = new ArrayList<>();
+ int batchSize = 100; // 每批处理的文件数量
+
try (ZipInputStream zis = new ZipInputStream(file.getInputStream())) {
ZipEntry entry;
// 逐个读取压缩文件中的条目
while ((entry = zis.getNextEntry()) != null) {
// 如果条目不是目录且是图片文件
if (!entry.isDirectory() && isImageFile(entry.getName())) {
- // 创建解压后的图片文件
+ // 创建解压后的图片文 件
File imgFile = new File(tempDir.toFile(), entry.getName());
imgFile.getParentFile().mkdirs(); // 确保目录存在
+
// 将条目内容写入到文件
try (FileOutputStream fos = new FileOutputStream(imgFile)) {
byte[] buffer = new byte[1024];
@@ -537,8 +536,27 @@ public class DataSetController extends BaseController {
// 创建 MultipartFile 对象并添加到列表中
MultipartFile multipartFile = new CustomMultipartFile(imgFile);
multipartFiles.add(multipartFile);
+
+ // 每处理 100 个文件时进行一次上传
+ if (multipartFiles.size() >= batchSize) {
+ // 上传当前批次的文件
+ AjaxResult result = handleFileUpload(multipartFiles.toArray(new MultipartFile[0]), datasetId, tempDir);
+ multipartFiles.clear(); // 清空已上传的文件列表
+ // 检查上传结果
+ if (!result.isSuccess()) {
+ throw new RuntimeException("File upload failed for " + file.getOriginalFilename());
+ }
+ }
}
- zis.closeEntry(); // 关闭当前条目
+ }
+ // 处理剩余未上传的文件
+ if (!multipartFiles.isEmpty()) {
+ AjaxResult result = handleFileUpload(multipartFiles.toArray(new MultipartFile[0]), datasetId, tempDir);
+ if (!result.isSuccess()) {
+ throw new RuntimeException("File upload failed for " + file.getOriginalFilename());
+ }
+ }else {
+ deleteTempDirectory(tempDir);
}
} catch (IOException e) {
// 抛出运行时异常,包含原始文件名的信息
@@ -547,7 +565,7 @@ public class DataSetController extends BaseController {
}
// 处理文件上传及数据库插入
- private AjaxResult handleFileUpload(MultipartFile[] resultFiles, Long datasetId) {
+ private AjaxResult handleFileUpload(MultipartFile[] resultFiles, Long datasetId, Path tempDir) throws IOException {
// 调用远程服务上传文件
AjaxResult ajaxResult = remoteFileService.uploadFile(resultFiles);
if (ajaxResult.isSuccess()) {
@@ -564,7 +582,7 @@ public class DataSetController extends BaseController {
entity.setFileName(map.get("name"));
dataSetService.insertFile(entity);
}
- // 返回成功结果
+ // deleteTempDirectory(tempDir); // 删除临时目录放在上传后再处理
return AjaxResult.success();
} else {
// 上传失败,返回错误信息
@@ -601,4 +619,5 @@ public class DataSetController extends BaseController {
e.printStackTrace(); // 打印错误信息
}
}
+
}
diff --git a/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/mapper/DataSetMapper.java b/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/mapper/DataSetMapper.java
index d11546e..5e76882 100644
--- a/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/mapper/DataSetMapper.java
+++ b/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/mapper/DataSetMapper.java
@@ -70,7 +70,7 @@ public interface DataSetMapper {
* @return 返回所有符合条件的类别的列表
*/
List getCategories(DataSetCategoryEntity entity);
-
+ List getCategoriesById(Long categoryId);
/**
* 插入新的样本类别数据到数据库。
*
@@ -79,6 +79,15 @@ public interface DataSetMapper {
*/
int insertCategory(DataSetCategoryEntity entity);
+ // 检查类别名称是否已存在
+ int countByNameAndParentId(@Param("categoryName") String categoryName, @Param("parentId") Long parentId, @Param("categoryId") Long categoryId);
+
+ // 获取所有祖先类别名称
+ List getAncestorCategoryNames(@Param("categoryId") Long categoryId);
+
+ // 获取所有祖先类别名称
+ List getAncestorCategoryName(@Param("parentId") Long parentId);
+
/**
* 更新指定类别的详细信息。
*
@@ -93,7 +102,9 @@ public interface DataSetMapper {
* @param categoryId 要删除的类别的唯一标识符
* @return 返回删除操作影响的行数(成功删除的记录数,通常为 1)
*/
- int deleteCategory(Long categoryId);
+ int deleteCategory(@Param("categoryIds") List categoryIds);
+
+ List getCategoryId(Long categoryId);
/**
* 根据 ID 查询数据集文件信息
@@ -281,4 +292,6 @@ public interface DataSetMapper {
* @return 影响的行数
*/
int deleteAlgorithm(Long[] algorithmIds);
+
+
}
diff --git a/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/service/DataSetService.java b/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/service/DataSetService.java
index ff7fa73..ff6eb6d 100644
--- a/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/service/DataSetService.java
+++ b/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/service/DataSetService.java
@@ -267,4 +267,6 @@ public interface DataSetService {
* @return 影响的行数
*/
AjaxResult deleteAlgorithm(Long[] algorithmIds);
+
+ AjaxResult getCategoriesById(Long categoryId);
}
diff --git a/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/service/impl/DataSetServiceImpl.java b/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/service/impl/DataSetServiceImpl.java
index 11449e5..97651b7 100644
--- a/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/service/impl/DataSetServiceImpl.java
+++ b/bonus-modules/bonus-ai/src/main/java/com/bonus/ai/service/impl/DataSetServiceImpl.java
@@ -4,15 +4,19 @@ import com.bonus.ai.domain.*;
import com.bonus.ai.mapper.DataSetMapper;
import com.bonus.ai.service.DataSetService;
import com.bonus.common.core.domain.R;
+import com.bonus.common.core.utils.StringUtils;
import com.bonus.common.core.web.domain.AjaxResult;
import com.bonus.common.security.utils.SecurityUtils;
import com.bonus.system.api.RemoteFileService;
import com.bonus.system.api.domain.SysFile;
import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.lang3.ObjectUtils;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import javax.annotation.Resource;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@@ -138,6 +142,20 @@ public class DataSetServiceImpl implements DataSetService {
}
}
+ /**
+ * @param categoryId
+ * @return
+ */
+ @Override
+ public AjaxResult getCategoriesById(Long categoryId) {
+ try {
+ List categories = mapper.getCategoriesById(categoryId);
+ return AjaxResult.success(categories);
+ } catch (Exception e) {
+ return AjaxResult.error("获取数据失败");
+ }
+ }
+
/**
* 插入新的样本类别数据到数据库。
*
@@ -147,7 +165,20 @@ public class DataSetServiceImpl implements DataSetService {
@Override
public AjaxResult insertCategory(DataSetCategoryEntity entity) {
try {
+ // 创建一个列表来存储所有要删除的类别ID
entity.setCreatedBy(SecurityUtils.getUserId());
+ List ancestorCategoryName = mapper.getAncestorCategoryName(entity.getParentId());
+ List ancestorNames = new ArrayList<>(ancestorCategoryName);
+ int num = mapper.countByNameAndParentId(entity.getCategoryName(), entity.getParentId(), entity.getCategoryId());
+ if (num > 0) {
+ return AjaxResult.error("新增失败,类别已存在");
+ }
+ List ancestorCategoryNames = findAncestorCategoryNames(entity.getParentId(), ancestorNames);
+ for (DataSetCategoryEntity e : ancestorCategoryNames) {
+ if (entity.getCategoryName().equals(e.getCategoryName())) {
+ return AjaxResult.error("子类别名称不能与任何祖先类别名称相同");
+ }
+ }
int i = mapper.insertCategory(entity);
if (i > 0) {
return AjaxResult.success("新增成功");
@@ -159,6 +190,18 @@ public class DataSetServiceImpl implements DataSetService {
}
}
+ public List findAncestorCategoryNames(Long categoryId, List ancestorNames) {
+ // 获取当前类别的所有子类别ID
+ List ancestorCategoryNames = mapper.getAncestorCategoryNames(categoryId);
+ // 将子类别ID加入列表
+ ancestorNames.addAll(ancestorCategoryNames);
+ // 递归调用获取每个子类别的子类别
+ for (DataSetCategoryEntity entity : ancestorCategoryNames) {
+ findAncestorCategoryNames(entity.getCategoryId(), ancestorNames);
+ }
+ return ancestorNames;
+ }
+
/**
* 更新指定类别的详细信息。
*
@@ -168,6 +211,18 @@ public class DataSetServiceImpl implements DataSetService {
@Override
public AjaxResult updateCategory(DataSetCategoryEntity entity) {
try {
+ int num = mapper.countByNameAndParentId(entity.getCategoryName(), entity.getParentId(), entity.getCategoryId());
+ if (num > 0) {
+ return AjaxResult.error("修改失败,类别已存在");
+ }
+ List ancestorCategoryName = mapper.getAncestorCategoryName(entity.getParentId());
+ List ancestorNames = new ArrayList<>(ancestorCategoryName);
+ List ancestorCategoryNames = findAncestorCategoryNames(entity.getParentId(), ancestorNames);
+ for (DataSetCategoryEntity e : ancestorCategoryNames) {
+ if (entity.getCategoryName().equals(e.getCategoryName())) {
+ return AjaxResult.error("子类别名称不能与任何祖先类别名称相同");
+ }
+ }
int i = mapper.updateCategory(entity);
return i > 0 ? AjaxResult.success("修改成功") : AjaxResult.error("修改失败");
} catch (Exception e) {
@@ -184,13 +239,31 @@ public class DataSetServiceImpl implements DataSetService {
@Override
public AjaxResult deleteCategory(Long categoryId) {
try {
- int i = mapper.deleteCategory(categoryId);
+ // 创建一个列表来存储所有要删除的类别ID
+ List categoryIds = new ArrayList<>();
+ // 递归获取所有子类别ID
+ getAllSubcategories(categoryId, categoryIds);
+ // 将当前类别 ID 加入待删除列表
+ categoryIds.add(categoryId);
+ int i = mapper.deleteCategory(categoryIds);
return i > 0 ? AjaxResult.success("删除成功") : AjaxResult.error("删除失败");
} catch (Exception e) {
return AjaxResult.error("删除失败");
}
}
+ // 递归方法获取所有子类别ID
+ private void getAllSubcategories(Long categoryId, List categoryIds) {
+ // 获取当前类别的所有子类别ID
+ List subcategoryIds = mapper.getCategoryId(categoryId);
+ // 将子类别ID加入列表
+ categoryIds.addAll(subcategoryIds);
+ // 递归调用获取每个子类别的子类别
+ for (Long subcategoryId : subcategoryIds) {
+ getAllSubcategories(subcategoryId, categoryIds);
+ }
+ }
+
/**
* 根据 ID 查询数据集文件信息
*
@@ -390,25 +463,28 @@ public class DataSetServiceImpl implements DataSetService {
public AjaxResult insertModel(AiModelEntity entity) {
try {
entity.setCreateBy(SecurityUtils.getUsername());
- AjaxResult ajaxResult = remoteFileService.uploadFile(entity.getModelFile());
- if (ajaxResult.isSuccess()) {
- List