#include "java.h"
namespace java {
enum class JavaType {
Int,
Byte,
Char,
Double,
Float,
Long,
String,
Short,
Void,
Boolean,
Object
};
static const char *mapJavaTypeToString(JavaType type) {
switch (type) {
case JavaType::Int: return "I";
case JavaType::Byte: return "B";
case JavaType::Char: return "C";
case JavaType::Double: return "D";
case JavaType::Float: return "F";
case JavaType::Long: return "J";
case JavaType::String: return "Ljava/lang/String;";
case JavaType::Short: return "S";
case JavaType::Void: return "V";
case JavaType::Boolean: return "Z";
case JavaType::Object: return "Ljava/lang/Object;";
}
}
static std::string generateFunctionSignature(
JavaType returnType, std::initializer_list params) {
std::string signature = "(";
for (const auto ¶m : params) {
signature += mapJavaTypeToString(param);
}
signature += ")";
signature += mapJavaTypeToString(returnType);
return signature;
}
void throwArrayFireException(JNIEnv *env, const char *functionName,
const char *file, const int line, const int code) {
// Find and instantiate an ArrayFireException
jclass exceptionClass = env->FindClass("com/arrayfire/ArrayFireException");
if (env->ExceptionCheck()) {
env->ExceptionDescribe();
}
const std::string constructorSig = generateFunctionSignature(
JavaType::Void, {JavaType::Int, JavaType::String});
jmethodID constructor =
env->GetMethodID(exceptionClass, "", constructorSig.c_str());
jthrowable exception = static_cast(
env->NewObject(exceptionClass, constructor, code,
env->NewStringUTF("")));
// Find setLocation method and call it with
// the function name, file and line parameters
const std::string setLocationSig = generateFunctionSignature(
JavaType::Void, {JavaType::String, JavaType::String, JavaType::Int});
jmethodID setLocationID =
env->GetMethodID(exceptionClass, "setLocation", setLocationSig.c_str());
env->CallVoidMethod(exception, setLocationID, env->NewStringUTF(functionName),
env->NewStringUTF(file), line);
env->Throw(exception);
env->DeleteLocalRef(exceptionClass);
}
template
jobject createJavaObject(JNIEnv *env, JavaObjects objectType, Args... args) {
switch (objectType) {
case JavaObjects::FloatComplex: {
jclass cls = env->FindClass("com/arrayfire/FloatComplex");
std::string sig = generateFunctionSignature(
JavaType::Void, {JavaType::Float, JavaType::Float});
jmethodID id = env->GetMethodID(cls, "", sig.c_str());
jobject obj = env->NewObject(cls, id, args...);
return obj;
} break;
case JavaObjects::DoubleComplex: {
jclass cls = env->FindClass("com/arrayfire/DoubleComplex");
std::string sig = generateFunctionSignature(
JavaType::Void, {JavaType::Double, JavaType::Double});
jmethodID id = env->GetMethodID(cls, "", sig.c_str());
jobject obj = env->NewObject(cls, id, args...);
return obj;
} break;
}
}
af_index_t jIndexToCIndex(JNIEnv *env, jobject obj) {
af_index_t index;
jclass cls = env->GetObjectClass(obj);
std::string getIsSeqSig = generateFunctionSignature(JavaType::Boolean, {});
jmethodID getIsSeqId = env->GetMethodID(cls, "isSeq", getIsSeqSig.c_str());
assert(getIsSeqId != NULL);
index.isSeq = env->CallBooleanMethod(obj, getIsSeqId);
std::string getIsBatchSig = generateFunctionSignature(JavaType::Boolean, {});
jmethodID getIsBatchId = env->GetMethodID(cls, "isBatch", getIsBatchSig.c_str());
assert(getIsBatchId != NULL);
index.isBatch = env->CallBooleanMethod(obj, getIsBatchId);
if (index.isSeq) {
// get seq object
std::string getSeqSig = generateFunctionSignature(JavaType::Object, {});
jmethodID getSeqId = env->GetMethodID(cls, "getSeq", getSeqSig.c_str());
assert(getSeqId != NULL);
jobject seq = env->CallObjectMethod(obj, getSeqId);
// get seq fields
jclass seqCls = env->GetObjectClass(seq);
assert(seqCls == env->FindClass("com/arrayfire/Seq"));
jfieldID beginID = env->GetFieldID(seqCls, "begin", mapJavaTypeToString(JavaType::Double));
assert(beginID != NULL);
double begin = env->GetDoubleField(seq, beginID);
jfieldID endID = env->GetFieldID(seqCls, "end", mapJavaTypeToString(JavaType::Double));
assert(endID != NULL);
double end = env->GetDoubleField(seq, endID);
jfieldID stepID = env->GetFieldID(seqCls, "step", mapJavaTypeToString(JavaType::Double));
assert(stepID != NULL);
double step = env->GetDoubleField(seq, stepID);
index.idx.seq = af_make_seq(begin, end, step);
} else {
std::string getArrSig = generateFunctionSignature(JavaType::Long, {});
jmethodID getArrId = env->GetMethodID(cls, "getArrRef", getArrSig.c_str());
assert(getArrId != NULL);
long arrRef = env->CallLongMethod(obj, getArrId);
index.idx.arr = (af_array)arrRef;
}
return index;
}
#define INSTANTIATE(type) \
template jobject createJavaObject(JNIEnv *, JavaObjects, type, type);
INSTANTIATE(float)
INSTANTIATE(double)
#undef INSTANTIATE
} // namespace java