/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine;

import java.lang.reflect.Constructor;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.apache.commons.beanutils.BeanUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.engine.annotation.ConnectorExecutor;
import org.opensearch.ml.engine.annotation.Function;
import org.opensearch.ml.engine.annotation.Ingester;
import org.opensearch.ml.engine.annotation.Processor;
import org.opensearch.ml.engine.processor.MLProcessorType;
import org.reflections.Reflections;
import org.reflections.scanners.Scanner;

public class MLEngineClassLoader {
    private static final Logger logger = LogManager.getLogger(MLEngineClassLoader.class);
    private static Map<Enum<?>, Class<?>> mlAlgoClassMap = new HashMap();
    private static Map<String, Class<?>> connectorExecutorMap = new HashMap();
    private static Map<String, Class<?>> ingesterMap = new HashMap();
    private static Map<MLProcessorType, Class<?>> mlProcessorMap = new HashMap();
    private static Map<Enum<?>, Object> mlObjects = new HashMap();

    public static void register(Enum<?> functionName, Object obj) {
        mlObjects.put(functionName, obj);
    }

    public static Object deregister(Enum<?> functionName) {
        return mlObjects.remove(functionName);
    }

    private static void loadClassMapping() {
        Reflections reflections = new Reflections("org.opensearch.ml.engine.algorithms", new Scanner[0]);
        Set classes = reflections.getTypesAnnotatedWith(Function.class);
        for (Class clazz : classes) {
            Function function = clazz.getAnnotation(Function.class);
            FunctionName functionName = function.value();
            if (functionName == null) continue;
            mlAlgoClassMap.put((Enum<?>)functionName, clazz);
        }
        Set connectorExecutorClasses = reflections.getTypesAnnotatedWith(ConnectorExecutor.class);
        for (Class clazz : connectorExecutorClasses) {
            ConnectorExecutor connectorExecutor = clazz.getAnnotation(ConnectorExecutor.class);
            String connectorName = connectorExecutor.value();
            if (connectorName == null) continue;
            connectorExecutorMap.put(connectorName, clazz);
        }
    }

    private static void loadMLProcessorClassMapping() {
        Reflections reflections = new Reflections("org.opensearch.ml.engine.processor", new Scanner[0]);
        Set processorClasses = reflections.getTypesAnnotatedWith(Processor.class);
        for (Class clazz : processorClasses) {
            Processor processorExecutor = clazz.getAnnotation(Processor.class);
            MLProcessorType processorType = processorExecutor.value();
            if (processorType == null) continue;
            mlProcessorMap.put(processorType, clazz);
        }
    }

    private static void loadIngestClassMapping() {
        Reflections reflections = new Reflections("org.opensearch.ml.engine.ingest", new Scanner[0]);
        Set ingesterClasses = reflections.getTypesAnnotatedWith(Ingester.class);
        for (Class clazz : ingesterClasses) {
            Ingester ingester = clazz.getAnnotation(Ingester.class);
            String ingesterSource = ingester.value();
            if (ingesterSource == null) continue;
            ingesterMap.put(ingesterSource, clazz);
        }
    }

    public static <T, S, I> S initInstance(T type, I in, Class<?> constructorParamClass) {
        return MLEngineClassLoader.initInstance(type, in, constructorParamClass, null);
    }

    public static <T, S, I> S initInstance(T type, I in, Class<?> constructorParamClass, Map<String, Object> properties) {
        if (mlObjects.containsKey(type)) {
            return (S)mlObjects.get(type);
        }
        Class<?> clazz = mlAlgoClassMap.get(type);
        if (clazz == null) {
            clazz = connectorExecutorMap.get(type);
        }
        if (clazz == null) {
            clazz = ingesterMap.get(type);
        }
        if (clazz == null) {
            clazz = mlProcessorMap.get(type);
        }
        if (clazz == null) {
            throw new IllegalArgumentException("Can't find class for type " + String.valueOf(type));
        }
        try {
            Object instance;
            try {
                Constructor<?> constructor = clazz.getConstructor(constructorParamClass);
                instance = constructor.newInstance(in);
            }
            catch (NoSuchMethodException e) {
                Constructor<?> constructor = clazz.getConstructor(new Class[0]);
                instance = constructor.newInstance(new Object[0]);
            }
            BeanUtils.populate(instance, properties);
            return (S)instance;
        }
        catch (Exception e) {
            Throwable cause = e.getCause();
            if (cause instanceof MLException) {
                throw (MLException)cause;
            }
            logger.error("Failed to init instance for type " + String.valueOf(type), (Throwable)e);
            return null;
        }
    }

    static {
        try {
            AccessController.doPrivileged(() -> {
                MLEngineClassLoader.loadClassMapping();
                MLEngineClassLoader.loadIngestClassMapping();
                MLEngineClassLoader.loadMLProcessorClassMapping();
                return null;
            });
        }
        catch (PrivilegedActionException e) {
            throw new RuntimeException("Can't load class mapping in ML engine", e);
        }
    }
}

