sishu-yolo-sdk/cpp/src/YoloSdk_JNI.cpp

106 lines
3.4 KiB
C++

#include "com_bonus_sdk_YoloSdk.h"
#include "YoloCore.h"
#include <string>
#include <vector>
#include <stdexcept>
#include <iostream>
void throwJavaException(JNIEnv *env, const char *message) {
env->ThrowNew(env->FindClass("java/lang/RuntimeException"), message);
}
/*
* Class: com_bonus_sdk_YoloSdk
* Method: nativeInit
* Signature: (Ljava/lang/String;II)J
*/
JNIEXPORT jlong JNICALL Java_com_bonus_sdk_YoloSdk_nativeInit
(JNIEnv *env, jobject thiz, jstring modelPath, jint inputWidth, jint inputHeight) {
const char* c_model_path = env->GetStringUTFChars(modelPath, nullptr);
if (c_model_path == nullptr) {
throwJavaException(env, "Failed to get model path from Java string.");
return 0;
}
try {
YoloDetector* detector = new YoloDetector(c_model_path, inputWidth, inputHeight);
env->ReleaseStringUTFChars(modelPath, c_model_path);
return reinterpret_cast<jlong>(detector);
} catch (const std::exception& e) {
env->ReleaseStringUTFChars(modelPath, c_model_path);
std::string errMsg = "Failed to initialize C++ YoloDetector: " + std::string(e.what());
throwJavaException(env, errMsg.c_str());
return 0;
}
}
/*
* Class: com_bonus_sdk_YoloSdk
* Method: nativeRelease
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_com_bonus_sdk_YoloSdk_nativeRelease
(JNIEnv *env, jobject thiz, jlong handle) {
YoloDetector* detector = reinterpret_cast<YoloDetector*>(handle);
if (detector) {
delete detector;
}
}
/*
* Class: com_bonus_sdk_YoloSdk
* Method: nativePredict
* Signature: (J[BIIFF)[Lcom/bonus/sdk/Detection;
*/
JNIEXPORT jobjectArray JNICALL Java_com_bonus_sdk_YoloSdk_nativePredict
(JNIEnv *env, jobject thiz, jlong handle, jbyteArray bgrBytes,
jint imageWidth, jint imageHeight, jfloat confThreshold, jfloat iouThreshold) {
YoloDetector* detector = reinterpret_cast<YoloDetector*>(handle);
if (!detector) {
throwJavaException(env, "Native handle is null.");
return nullptr;
}
try {
jbyte* bytes = env->GetByteArrayElements(bgrBytes, nullptr);
std::vector<Detection> results_cpp = detector->detect(
reinterpret_cast<unsigned char*>(bytes),
imageWidth, imageHeight, confThreshold, iouThreshold
);
env->ReleaseByteArrayElements(bgrBytes, bytes, JNI_ABORT);
jclass detClass = env->FindClass("com/bonus/sdk/Detection");
if (!detClass) return nullptr;
jmethodID detConstructor = env->GetMethodID(detClass, "<init>", "(IFFIIII)V");
if (!detConstructor) return nullptr;
jobjectArray resultArray = env->NewObjectArray(results_cpp.size(), detClass, nullptr);
for (size_t i = 0; i < results_cpp.size(); ++i) {
const auto& d = results_cpp[i];
jobject javaDet = env->NewObject(detClass, detConstructor,
d.class_id, d.score,
d.x, d.y, d.width, d.height);
env->SetObjectArrayElement(resultArray, i, javaDet);
env->DeleteLocalRef(javaDet);
}
return resultArray;
} catch (const std::exception& e) {
std::string errMsg = "Error during native prediction: " + std::string(e.what());
throwJavaException(env, errMsg.c_str());
return nullptr;
}
}