// Copyright (C)  2000 Intel Corporation.  All rights reserved.
//
// $Header: /usr/development/orp/orp/base_natives/common/find_natives.cpp,v 1.14 2001/12/19 02:08:10 michal Exp $
//

#ifndef  OBJECT_LOCK_V2
#include "platform.h"
#include <assert.h>
#include "object_layout.h"
#include "orp_utils.h"
#include "exceptions.h"
#include "jni.h"

#ifdef ORP_POSIX
#include <dlfcn.h>
#endif


#ifdef USE_IA64_JIT
#define IA64_HACK_NO_JAVA_NET
#endif

#include "find_natives.h"



#ifdef ORP_POSIX
#define DLL_HANDLE void *
#else
#define DLL_HANDLE HINSTANCE
#endif


class Dll_Record {
    char *file_name;
    DLL_HANDLE handle;
    Dll_Record *next;

public:

    Dll_Record *get_next() { return next; };
    const char *get_file_name() { return file_name; };

    DLL_HANDLE get_handle() {return handle; };

    Dll_Record(const char *fname, DLL_HANDLE h, Dll_Record *n);

    void *find_native_method(const char *name);

    static Dll_Record *system_loader_dlls;
};


Dll_Record *Dll_Record::system_loader_dlls = NULL;


Dll_Record::Dll_Record(const char *fname, DLL_HANDLE h, Dll_Record *n)
{
  file_name = (char *)malloc(strlen(fname) + 1);
  strcpy(file_name, fname);
  handle = h;
  next   = n;
  //printf("Dll_Record::Dll_Record for '%s'\n", file_name);
} //Dll_Record::Dll_Record



static void *find_builtin_native_method(const char *pname,
                                        NI_TYPE *native_intf,
                                        Built_In_Method_Entry _built_ins[],
                                        int num_funcs)
{
    //int num_funcs = sizeof(_built_ins) / sizeof(_built_ins[0]);
    int L = 0, R = num_funcs;

//#define JNI_DEBUG
#ifdef JNI_DEBUG
    // Make sure that the array is sorted
    for (int i = 1; i < R; i++) {
        int cmp = strcmp(_built_ins[i - 1].pname, _built_ins[i].pname);
		if (cmp >= 0)	// DER - strcmp may return any negative value
			cout << _built_ins[i-1].pname << " not less than " << _built_ins[i].pname << endl;
        assert(cmp < 0);
    }
#endif
    char *mm = _built_ins[R].pname;
    while(L < R) {
        int M = (L + R) / 2;
        char *n = _built_ins[M].pname;
        int cmp = strcmp(pname, n);
        if(cmp < 0) {
            R = M;
        } else if(cmp > 0) {
            L = M + 1;
        } else {
            *native_intf = _built_ins[M].native_intf;
            return _built_ins[M].pfunc;
        }
    }

    // Not found
    return 0;
} //find_builtin_native_method



static int translate_name(char *new_name, const char *old_name)
{
    int len = 0;
    for(const char *s = old_name; *s; s++, len++) {
        if(new_name) {
            *new_name = *s;
            new_name++;
        }
        switch(*s) {
        case '/':
            if(new_name) {
                *(new_name - 1) = '_';
            }
            break;
        case '_':
            len++;
            if(new_name) {
                *new_name = '1';
                new_name++;
            }
            break;
        case ';':
            len++;
            if(new_name) {
                *(new_name - 1) = '_';
                *new_name = '2';
                new_name++;
            }
            break;
        case '[':
            len++;
            if(new_name) {
                *(new_name - 1) = '_';
                *new_name = '3';
                new_name++;
            }
            break;
        default:
            break;
        }
    }
    if(new_name) {
        *new_name = 0;
    }
    return len;
} //translate_name



static char *create_overloaded_jni_name(char *prefix, Method *method, char *suffix)
{
    bool no_args = method->get_descriptor()[1] == ')';

    int len = 9;  // "Java_" "_" "__" "\0"
    len += method->get_class()->name->len;
    len += method->get_name()->len;

    // "Lfoo/bar;" becomes "foo_bar_2", so no length adjustment is needed.
    
    const char *descr;
    for(descr = method->get_descriptor() + 1; *descr != ')'; descr++) {
        switch(*descr) {
        case '$':
            len += 5;
            break;
        case '[':
        case '_':
        case ';':
            len++;
            break;
        default:
            break;
        }
        len++;
    }

    // The worst case expansion factor is 6:  '$' ===> "_00024"
    // We also add a constant overhead for the prefix, "Java_", "__", postfix, etc.
    len = 6 * len + 100;
    char *name = (char *)malloc(len + 1);
    sprintf(name,
            "%sJava_",
            prefix);
    int pos = strlen(name);
    pos += translate_name(name + pos, method->get_class()->name->bytes);
    name[pos++] = '_';
    pos += translate_name(name + pos, method->get_name()->bytes);

    char *p = name + pos;
    if(!no_args) {
        *p++ = '_';
        *p++ = '_';
        for(descr = method->get_descriptor() + 1; *descr != ')'; descr++) {
            switch(*descr) {
            case '$':
                *p++ = '_';
                *p++ = '0';
                *p++ = '0';
                *p++ = '0';
                *p++ = '2';
                *p++ = '4';
                break;
            case '_':
                *p++ = '_';
                *p++ = '1';
                break;
            case '[':
                *p++ = '_';
                *p++ = '3';
                break;
            case 'L':
                for(; *descr != ';'; descr++) {
                    if(*descr == '/') {
                        *p++ = '_';
                    } else {
                        *p++ = *descr;
                    }
                }
                *p++ = '_';
                *p++ = '2';
                break;
            default:
                *p++ = *descr;
                break;
            }
        }
        *p = '\0';
    }

    strcat(name, suffix);

    return name;
} //create_overloaded_jni_name



static char *create_non_overloaded_name(const char *prefix,
                                        char *decorated_name,
                                        const String *clss_name,
                                        const String *meth_name,
                                        const String *descriptor,
                                        const char *suffix)
{
    sprintf(decorated_name, "%s%s/%s%s",
            prefix,
            clss_name->bytes,
            meth_name->bytes,
            suffix);

    int len = translate_name(0, decorated_name);
    char *name = (char *)malloc(len + 1);
    translate_name(name, decorated_name);

    return name;
} //create_non_overloaded_name



#if 0
void test_find_native_method()
{
    cout << "Testing native methods..." << endl;
    int num_funcs = sizeof(_built_ins) / sizeof(_built_ins[0]);
    for(int i = 0; i < num_funcs; i++) {
        char *n = _built_ins[i].pname;
        void *func;
        NI_TYPE native_intf;
        find_native_method(n, &func, &native_intf);
        assert(func);
    }
    cout << "OK." << endl;
} //test_find_native_method
#endif



///////////////////////////////////////////////////////////////////////////////
// begin user native code support


void *Dll_Record::find_native_method(const char *name)
{
#ifdef JNI_DEBUG
    orp_cout << "  Looking for " << name << " in " << file_name << endl;
#endif
#ifdef ORP_POSIX
    return dlsym(handle, name);
#else
    return (void *)GetProcAddress(handle, name);
#endif
} //Dll_Record::find_native_method


Dll_Record *find_dll(const char *file_name, Dll_Record *head)
{
    for(; head; head = head->get_next()) {
        if(!strcmp(file_name, head->get_file_name())) {
            return head;
        }
    }
    return 0;
} //find_dll



static Dll_Record *system_loader_dlls = 0;


ORPExport void add_dll(const char *file_name, Class_Loader *loader)
{
    assert(!loader);  // multiple class loaders are not supported yet

#ifdef NON_ORP_NATIVE_LIBS
#ifdef JNI_DEBUG
    orp_cout << "** (add_dll) Loading dll " << file_name << endl;
#endif
#endif //NON_ORP_NATIVE_LIBS
    if(find_dll(file_name, system_loader_dlls))  {
        // Is it already loaded?
#ifdef NON_ORP_NATIVE_LIBS
#ifdef JNI_DEBUG
        orp_cout << "** (add_dll) dll already loaded " << file_name << endl;
#endif
#endif //NON_ORP_NATIVE_LIBS
        return;
    }

#ifdef ORP_POSIX
    DLL_HANDLE handle = dlopen(file_name, RTLD_LAZY);
    if ( !strcmp(file_name, "gtkpeer.so") )
    {
        int jj = 0;  // this exits to allow gdb set breakpoints inside gtkpeer.so 
    }
#else
    DLL_HANDLE handle = LoadLibrary(file_name);
#endif
    if(handle) {
        Dll_Record *new_record = new Dll_Record(file_name, handle, system_loader_dlls);
        system_loader_dlls = new_record;
    } else {
#ifdef JNI_DEBUG
#ifndef NON_ORP_NATIVE_LIBS
#ifdef ORP_POSIX
		file_name += 3; //skip "lib" prefix		
#endif
		if(strncmp(file_name, "javalang.", 9) &&
		   strncmp(file_name, "javalangmath.", 13) &&
		   strncmp(file_name, "java_lang_Math.", 15) &&
		   strncmp(file_name, "javaio.", 7) && 
		   strncmp(file_name, "javanet.", 8) && 	   
		   strncmp(file_name, "javautil.", 9) &&
		   strncmp(file_name, "javalangreflect.", 16) &&
		   //strncmp(file_name, "bigint.", 7) &&
		   strncmp(file_name, "runtime.", 8) )
#endif //NON_ORP_NATIVE_LIBS
		{
				orp_cout << "** Couldn't open dll: " << file_name << endl; 
#ifdef ORP_POSIX
                fputs (dlerror(), stderr);
#endif
		}
#endif //JNI_DEBUG
    }
} //add_dll



static void *find_user_native_method_by_name(char *decorated_name, NI_TYPE *native_intf)
{
#ifdef JNI_DEBUG
    orp_cout << "**Looking for a native method '"
             << decorated_name << "'"
             << endl;
#endif
#if defined(ORP_POSIX) || defined(_IA64_)
    // (mjc 20000726) This is a hack, but it will work correctly on Linux and Win64.
    char *new_decorated_name = (char *)malloc(strlen(decorated_name));
    strcpy(new_decorated_name, decorated_name + 1);
    char *at = strchr(new_decorated_name, '@');
    assert(at);
    *at = '\0';
    decorated_name = new_decorated_name;
#endif

    void *f = 0;
    for(Dll_Record *r = system_loader_dlls; r; r = r->get_next()) {
#ifdef JNI_DEBUG
        orp_cout << "** Searching dll: " << r->get_file_name() << endl;
#endif
        f = r->find_native_method(decorated_name);
        if(f) {
            *native_intf = NI_IS_JNI;
#ifdef JNI_DEBUG
            orp_cout << "*** Found: " << f << endl;
#endif
            break;
        }
    }

    return f;
} //find_user_native_method_by_name






static void *find_user_native_method(Method *method,
                                     NI_TYPE *native_intf,
                                     char *non_overloaded_jni_name,
                                     char *overloaded_jni_name)
{
    Class *clss = method->get_class();
    assert(!clss->class_loader);  // multiple class loaders are not supported yet

    void *f = 0;

    f = find_user_native_method_by_name(non_overloaded_jni_name, native_intf);
    if(f) {
        return f;
    }

    f = find_user_native_method_by_name(overloaded_jni_name, native_intf);
    return f;
} //find_user_native_method



// end user native code support
///////////////////////////////////////////////////////////////////////////////




///////////////////////////////////////////////////////////////////////////////
// begin register natives


void register_native_method(Class *clss, String *name, String *sig, void *f)
{
#if 0
    printf("Registering %s.%s%s --> 0x%x\n",
           clss->name->bytes, name->bytes, sig->bytes, f);
#endif
    Registered_Native *rn = (Registered_Native *)malloc(sizeof(Registered_Native));
    rn->name = name;
    rn->sig  = sig;
    rn->func = f;
    rn->next = clss->registered_natives;
    clss->registered_natives = rn;
} //register_native_method



void unregister_native_methods(Class *clss)
{
    while(clss->registered_natives) {
        Registered_Native *rn = clss->registered_natives;
        clss->registered_natives = rn->next;
        free(rn);
    }
} //unregister_native_methods



static void *find_registered_method(Method *method, NI_TYPE *native_intf)
{
    Registered_Native *rn = method->get_class()->registered_natives;
    const String *name = method->get_name();
    const String *sig = method->get_signature()->descriptor;
    for(; rn; rn = rn->next) {
        if(name == rn->name && sig == rn->sig) {
            *native_intf = NI_IS_JNI;
            return rn->func;
        }
    }
    return 0;
} //find_registered_method


// end register natives
///////////////////////////////////////////////////////////////////////////////




static void create_mangled_names(Method *method,
                                 char **non_overloaded_jni_name,
                                 char **overloaded_jni_name,
                                 char **non_overloaded_ini_name,
                                 char **non_overloaded_ini_name_in_dll
                                 )
{
    Class *clss = method->get_class();
    String *descriptor = method->get_signature()->descriptor;
    String *name = method->get_signature()->name;
    char *decorated_name = (char *)malloc(6 * (clss->name->len + name->len + descriptor->len) + 20);
    int num_arg_bytes = method->get_num_arg_bytes();
    if(method->is_static()) {
        num_arg_bytes += 4;
    }
    char ini_arg_size[10];
    sprintf(ini_arg_size, "@%d", num_arg_bytes);
    char jni_arg_size[10];
    sprintf(jni_arg_size, "@%d", num_arg_bytes + 4);  // extra 4 bytes for the env.

    // First assume the method is not overloaded
#ifdef WINNT64
    *non_overloaded_jni_name = create_non_overloaded_name("Java/", decorated_name, clss->name, name, descriptor, "");
#else
    *non_overloaded_jni_name = create_non_overloaded_name("/Java/", decorated_name, clss->name, name, descriptor, jni_arg_size);
#endif

    // Now assume the method is overloaded
#ifdef WINNT64
    *overloaded_jni_name = create_overloaded_jni_name("", method, "");
#else
    *overloaded_jni_name = create_overloaded_jni_name("_", method, jni_arg_size);
#endif

    // This name is used by our internal native interface
    *non_overloaded_ini_name = create_non_overloaded_name("", decorated_name, clss->name, name, descriptor, "");

    *non_overloaded_ini_name_in_dll = (char *)malloc(strlen(*non_overloaded_ini_name) + 50);
#ifdef WINNT64
    sprintf(*non_overloaded_ini_name_in_dll, "%s%s", *non_overloaded_ini_name, "");
#else
    sprintf(*non_overloaded_ini_name_in_dll, "_%s%s", *non_overloaded_ini_name, ini_arg_size);
#endif
} //create_mangled_names



static DLL_HANDLE rt_handle;


//
// Control flow:
// 1. Look for built-in, non-overloaded methods in the 'base' array.
// 2. Look for built-in, overloaded methods in the 'base' array.
// if !defined(USE_DLL_FOR_STD_LIBS)
//   3. Look for built-in, non-overloaded methods in the 'extra' array.
//   4. Look for built-in, overloaded methods in the 'extra' array.
// else
//   3. Look for non-overloaded JNI methods in rt.dll.
//   4. Look for overloaded JNI methods in rt.dll.
//   5. Look for non-overloaded INI methods in rt.dll.
// endif
// 6. Look for JNI (non-overloaded and overloaded) in user dlls.
// 7. Look for registered user methods.
//
static void find_native_method1(Method *method,
                                void **func,
                                NI_TYPE *native_intf,
                                char *non_overloaded_jni_name,
                                char *overloaded_jni_name,
                                char *non_overloaded_ini_name,
                                char *non_overloaded_ini_name_in_dll
                                )
{
    void *f;
    f = find_builtin_native_method(non_overloaded_ini_name, native_intf, _built_ins_base, sizeof_built_ins_base);
    if(f) {
        *func = f;
#ifdef JNI_DEBUG
        printf("found! %s\n", non_overloaded_ini_name);
#endif
        return;
    }

    // For overloaded, built-in, native methods, we follow the JNI convention.
    f = find_builtin_native_method(overloaded_jni_name, native_intf, _built_ins_base, sizeof_built_ins_base);
    if(f) {
        *func = f;
        return;
    }

#ifndef USE_DLL_FOR_STD_LIBS
#ifndef NON_ORP_NATIVE_LIBS
    f = find_builtin_native_method(non_overloaded_ini_name, native_intf, _built_ins_extra, sizeof_built_ins_extra);
    if(f) {
        *func = f;
        return;
    }

    // For overloaded, built-in, native methods, we follow the JNI convention.
    f = find_builtin_native_method(overloaded_jni_name, native_intf, _built_ins_extra, sizeof_built_ins_extra);
    if(f) {
        *func = f;
        return;
    }
#endif //NON_ORP_NATIVE_LIBS

#else //#ifndef USE_DLL_FOR_STD_LIBS

    // Load the standard library.
    //
    // No need for synchronization, because a first native method is invoked
    // while the app is still single-threaded.
    static bool firstTime = true;
    if(firstTime) {
        firstTime = false;
        const char *rt_lib_name = "rt";
        rt_handle = LoadLibrary(rt_lib_name);
        if(!rt_handle) {
            printf("Couldn't open standard library: %s\n", rt_lib_name);
            orp_exit(1);
        }
    }
    assert(rt_handle);

    // Look in the standard dll
    f = GetProcAddress(rt_handle, non_overloaded_jni_name);
    if(f) {
        *native_intf = NI_IS_JNI;
        *func = f;
        return;
    }
    f = GetProcAddress(rt_handle, overloaded_jni_name);
    if(f) {
        *native_intf = NI_IS_JNI;
        *func = f;
        return;
    }
    f = GetProcAddress(rt_handle, non_overloaded_ini_name_in_dll);
    if(f) {
        *native_intf = NI_IS_RNI;
        *func = f;
        return;
    }

#endif //#ifndef USE_DLL_FOR_STD_LIBS

    // Look in user dlls
    f = find_user_native_method(method, native_intf, non_overloaded_jni_name, overloaded_jni_name);
    if(f) {
        *func = f;
        return;
    }

    // Look for registered methods
    f = find_registered_method(method, native_intf);
    if(f) {
        *func = f;
        return;
    }

#ifdef CLI_TESTING
    f = find_a_cli_native(method, native_intf);
    if(f) {
        *func = f;
        return;
    }
#endif

#ifdef JNI_DEBUG
    printf("Couldn't find native: %s.%s%s\n",
           method->get_class()->name->bytes,
           method->get_name()->bytes,
           method->get_descriptor()
           );
#endif
    *func = 0;
} //find_native_method1



void find_native_method(Method *method, void **func, NI_TYPE *native_intf)
{
    char *non_overloaded_jni_name = 0;
    char *overloaded_jni_name = 0;
    char *non_overloaded_ini_name = 0;
    char *non_overloaded_ini_name_in_dll = 0;
    create_mangled_names(method,
                         &non_overloaded_jni_name,
                         &overloaded_jni_name,
                         &non_overloaded_ini_name,
                         &non_overloaded_ini_name_in_dll);
#ifdef JNI_DEBUG
    printf("Mangled names of native: %s.%s%s\nare:\n%s\n%s\n%s\n%s\n",
           method->get_class()->name->bytes,
           method->get_name()->bytes,
           method->get_descriptor(),
           non_overloaded_jni_name,
           overloaded_jni_name,
           non_overloaded_ini_name,
           non_overloaded_ini_name_in_dll
           );
#endif

    find_native_method1(method,
                        func,
                        native_intf,
                        non_overloaded_jni_name,
                        overloaded_jni_name,
                        non_overloaded_ini_name,
                        non_overloaded_ini_name_in_dll);

    free(non_overloaded_jni_name);
    free(overloaded_jni_name);
    free(non_overloaded_ini_name);
    free(non_overloaded_ini_name_in_dll);
} //find_native_method

#endif //#ifndef  OBJECT_LOCK_V2


