sishu-yolo-sdk/src/main/java/com/bonus/sdk/YoloSdk.java

163 lines
5.9 KiB
Java

package com.bonus.sdk;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import java.util.HashSet;
import java.util.Set;
public class YoloSdk implements AutoCloseable {
private long nativeHandle;
private static final Set<String> loadedLibraries = new HashSet<>();
private static Path tempDir;
static {
try {
tempDir = Files.createTempDirectory("yolo_sdk_native_libs_");
tempDir.toFile().deleteOnExit();
loadSdkLibrary(tempDir);
} catch (IOException e) {
throw new RuntimeException("CRITICAL: Failed to create temp directory for native libraries", e);
}
}
private static void loadSdkLibrary(Path tempDir) throws IOException {
String osName = System.getProperty("os.name").toLowerCase();
String osArch = System.getProperty("os.arch").toLowerCase();
String libPathInJar;
String sdkLibName;
String[] dependencyLibs = {};
if (osName.contains("win") && osArch.contains("64")) {
libPathInJar = "/lib/win-x64/";
dependencyLibs = new String[]{
"abseil_dll.dll",
"libprotobuf.dll",
"zlib1.dll",
"onnxruntime.dll",
"opencv_core4.dll",
"opencv_imgproc4.dll",
"opencv_dnn4.dll"
//
};
sdkLibName = "my_yolo_sdk.dll";
} else if ((osName.contains("nix") || osName.contains("nux")) && osArch.contains("64")) {
libPathInJar = "/lib/linux-x86_64/";
dependencyLibs = new String[]{
"libonnxruntime.so.1.23.2",
"libopencv_core.so.4.6.0",
"libopencv_imgproc.so.4.6.0",
"libopencv_dnn.so.4.6.0"
};
sdkLibName = "libmy_yolo_sdk.so";
} else {
throw new UnsupportedOperationException("Unsupported OS/Arch: " + osName + "/" + osArch);
}
for (String lib : dependencyLibs) {
extractLibraryFromJar(libPathInJar + lib, tempDir);
}
extractLibraryFromJar(libPathInJar + sdkLibName, tempDir);
try {
for (String lib : dependencyLibs) {
if (!loadedLibraries.contains(lib)) {
System.load(tempDir.resolve(lib).toAbsolutePath().toString());
loadedLibraries.add(lib);
}
}
if (!loadedLibraries.contains(sdkLibName)) {
System.load(tempDir.resolve(sdkLibName).toAbsolutePath().toString());
loadedLibraries.add(sdkLibName);
}
} catch (UnsatisfiedLinkError e) {
System.err.println("--- NATIVE LIBRARY LOAD FAILED ---");
System.err.println("Failed to load native library. This often means a dependency is missing on the host system.");
System.err.println("For Windows: Ensure 'Visual C++ Redistributable for Visual Studio (x64)' is installed.");
System.err.println("Libraries were extracted to: " + tempDir.toAbsolutePath());
System.err.println("-----------------------------------");
throw e;
}
}
private static void extractLibraryFromJar(String pathInJar, Path tempDir) throws IOException {
String libName = new File(pathInJar).getName();
try (InputStream in = YoloSdk.class.getResourceAsStream(pathInJar)) {
if (in == null) {
throw new FileNotFoundException("Library " + pathInJar + " not found in JAR.");
}
Path targetFile = tempDir.resolve(libName);
Files.copy(in, targetFile, StandardCopyOption.REPLACE_EXISTING);
targetFile.toFile().deleteOnExit();
}
}
private native long nativeInit(String modelPath, int inputWidth, int inputHeight);
private native void nativeRelease(long handle);
private native Detection[] nativePredict(
long handle, byte[] bgrBytes, int imageWidth, int imageHeight,
float confThreshold, float iouThreshold
);
/**
* @param modelPath
* @param inputWidth
* @param inputHeight
*/
public YoloSdk(String modelPath, int inputWidth, int inputHeight) {
this.nativeHandle = nativeInit(modelPath, inputWidth, inputHeight);
if (this.nativeHandle == 0) {
throw new RuntimeException("Failed to initialize native YOLO SDK. Check logs for details.");
}
}
/**
* @param image
* @param confThreshold
* @param iouThreshold
* @return
*/
public Detection[] predict(BufferedImage image, float confThreshold, float iouThreshold) {
if (this.nativeHandle == 0) {
throw new IllegalStateException("SDK already closed or failed to initialize.");
}
byte[] bgrBytes = getBgrBytes(image);
return nativePredict(
this.nativeHandle, bgrBytes, image.getWidth(), image.getHeight(),
confThreshold, iouThreshold
);
}
@Override
public void close() {
if (this.nativeHandle != 0) {
nativeRelease(this.nativeHandle);
this.nativeHandle = 0;
}
}
/**
* @param image
* @return
*/
private byte[] getBgrBytes(BufferedImage image) {
if (image.getType() == BufferedImage.TYPE_3BYTE_BGR) {
return ((DataBufferByte) image.getRaster().getDataBuffer()).getData();
}
BufferedImage bgrImage = new BufferedImage(image.getWidth(), image.getHeight(), BufferedImage.TYPE_3BYTE_BGR);
bgrImage.getGraphics().drawImage(image, 0, 0, null);
return ((DataBufferByte) bgrImage.getRaster().getDataBuffer()).getData();
}
}