sishu-yolo-sdk/cpp/src/YoloSdk_JNI.cpp

112 lines
3.5 KiB
C++
Raw Normal View History

//
#include "com_bonus_sdk_YoloSdk.h"
#include "YoloCore.h" //
#include <string>
#include <vector>
#include <stdexcept>
#include <iostream>
// ---
//
std::wstring jstringToWString(JNIEnv* env, jstring jStr) {
if (!jStr) return L"";
const jchar* raw = env->GetStringChars(jStr, nullptr);
if (!raw) return L"";
jsize len = env->GetStringLength(jStr);
std::wstring wStr(reinterpret_cast<const wchar_t*>(raw), len);
env->ReleaseStringChars(jStr, raw);
return wStr;
}
//
void throwJavaException(JNIEnv* env, const char* message) {
env->ThrowNew(env->FindClass("java/lang/RuntimeException"), message);
}
// ---
/*
* Class: com_mycompany_sdk_YoloSdk
* Method: nativeInit
* Signature: (Ljava/lang/String;II)J
*/
JNIEXPORT jlong JNICALL Java_com_mycompany_sdk_YoloSdk_nativeInit
(JNIEnv* env, jobject thiz, jstring modelPath, jint inputWidth, jint inputHeight) {
try {
std::wstring wpath = jstringToWString(env, modelPath);
YoloDetector* detector = new YoloDetector(wpath.c_str(), inputWidth, inputHeight);
return reinterpret_cast<jlong>(detector);
}
catch (const std::exception& e) {
std::string errMsg = "Failed to initialize C++ YoloDetector: " + std::string(e.what());
throwJavaException(env, errMsg.c_str());
return 0;
}
}
/*
* Class: com_mycompany_sdk_YoloSdk
* Method: nativeRelease
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_com_mycompany_sdk_YoloSdk_nativeRelease
(JNIEnv* env, jobject thiz, jlong handle) {
YoloDetector* detector = reinterpret_cast<YoloDetector*>(handle);
if (detector) {
delete detector;
}
}
/*
* Class: com_mycompany_sdk_YoloSdk
* Method: nativePredict
* Signature: (J[BIIFF)[Lcom/mycompany/sdk/Detection;
*/
JNIEXPORT jobjectArray JNICALL Java_com_mycompany_sdk_YoloSdk_nativePredict
(JNIEnv* env, jobject thiz, jlong handle, jbyteArray bgrBytes,
jint imageWidth, jint imageHeight, jfloat confThreshold, jfloat iouThreshold) {
YoloDetector* detector = reinterpret_cast<YoloDetector*>(handle);
if (!detector) {
throwJavaException(env, "Native handle is null.");
return nullptr;
}
try {
// 1.
jbyte* bytes = env->GetByteArrayElements(bgrBytes, nullptr);
// 2.
std::vector<Detection> results_cpp = detector->detect(
reinterpret_cast<unsigned char*>(bytes),
imageWidth, imageHeight, confThreshold, iouThreshold
);
// 3.
env->ReleaseByteArrayElements(bgrBytes, bytes, JNI_ABORT);
// 4.
jclass detClass = env->FindClass("com/mycompany/sdk/Detection");
if (!detClass) return nullptr; //
jmethodID detConstructor = env->GetMethodID(detClass, "<init>", "(IFFIIII)V");
if (!detConstructor) return nullptr; //
jobjectArray resultArray = env->NewObjectArray(results_cpp.size(), detClass, nullptr);
for (size_t i = 0; i < results_cpp.size(); ++i) {
const auto& d = results_cpp[i];
jobject javaDet = env->NewObject(detClass, detConstructor,
d.class_id, d.score,
d.x, d.y, d.width, d.height);
env->SetObjectArrayElement(resultArray, i, javaDet);
env->DeleteLocalRef(javaDet);
}
return resultArray;
}
catch (const std::exception& e) {
std::string errMsg = "Error during native prediction: " + std::string(e.what());
throwJavaException(env, errMsg.c_str());
return nullptr;
}
}