/*
 * Decompiled with CFR 0.152.
 */
package weka.knowledgeflow.steps;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import weka.classifiers.Classifier;
import weka.classifiers.UpdateableClassifier;
import weka.clusterers.Clusterer;
import weka.clusterers.DensityBasedClusterer;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.OptionMetadata;
import weka.core.WekaException;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Add;
import weka.knowledgeflow.Data;
import weka.knowledgeflow.steps.BaseStep;
import weka.knowledgeflow.steps.KFStep;

@KFStep(name="PredictionAppender", category="Evaluation", toolTipText="Append predictions from classifiers or clusterers to incoming data ", iconPath="weka/gui/knowledgeflow/icons/PredictionAppender.gif")
public class PredictionAppender
extends BaseStep {
    private static final long serialVersionUID = 3558618759400903936L;
    protected boolean m_appendProbabilities;
    protected Instances m_streamingOutputStructure;
    protected Data m_instanceData = new Data("instance");
    protected List<Integer> m_stringAttIndexes;

    @Override
    public void stepInit() throws WekaException {
        this.m_streamingOutputStructure = null;
    }

    @Override
    public List<String> getIncomingConnectionTypes() {
        if (this.getStepManager().numIncomingConnections() == 0) {
            return Arrays.asList("batchClassifier", "incrementalClassifier");
        }
        return new ArrayList<String>();
    }

    @Override
    public List<String> getOutgoingConnectionTypes() {
        ArrayList<String> result = new ArrayList<String>();
        if (this.getStepManager().numIncomingConnectionsOfType("batchClassifier") > 0 || this.getStepManager().numIncomingConnectionsOfType("batchClusterer") > 0) {
            result.add("trainingSet");
            result.add("testSet");
        } else if (this.getStepManager().numIncomingConnectionsOfType("incrementalClassifier") > 0) {
            result.add("instance");
        }
        return result;
    }

    @Override
    public void processIncoming(Data data) throws WekaException {
        Instances trainingData = (Instances)data.getPayloadElement("aux_trainingSet");
        Instances testData = (Instances)data.getPayloadElement("aux_testsSet");
        Instance streamInstance = (Instance)data.getPayloadElement("aux_testInstance");
        if (this.getStepManager().numIncomingConnectionsOfType("batchClassifier") > 0) {
            this.processBatchClassifierCase(data, trainingData, testData);
        } else if (this.getStepManager().numIncomingConnectionsOfType("incrementalClassifier") > 0) {
            this.processIncrementalClassifier(data, streamInstance);
        } else if (this.getStepManager().numIncomingConnectionsOfType("batchClusterer") > 0) {
            this.processBatchClustererCase(data, trainingData, testData);
        }
    }

    protected void processIncrementalClassifier(Data data, Instance inst) throws WekaException {
        if (this.isStopRequested()) {
            return;
        }
        if (this.getStepManager().isStreamFinished(data)) {
            Data d = new Data("instance");
            this.getStepManager().throughputFinished(d);
            return;
        }
        this.getStepManager().throughputUpdateStart();
        boolean labelOrNumeric = !this.m_appendProbabilities || inst.classAttribute().isNumeric();
        Classifier classifier = (Classifier)data.getPayloadElement("incrementalClassifier");
        if (this.m_streamingOutputStructure == null) {
            if (classifier == null) {
                throw new WekaException("No classifier in incoming data object!");
            }
            if (!(classifier instanceof UpdateableClassifier)) {
                throw new WekaException("Classifier in data object is not an UpdateableClassifier!");
            }
            this.m_stringAttIndexes = new ArrayList<Integer>();
            for (int i = 0; i < inst.numAttributes(); ++i) {
                if (!inst.attribute(i).isString()) continue;
                this.m_stringAttIndexes.add(i);
            }
            try {
                this.m_streamingOutputStructure = this.makeOutputDataClassifier(inst.dataset(), classifier, !labelOrNumeric, "_with_predictions");
            }
            catch (Exception ex) {
                throw new WekaException(ex);
            }
        }
        double[] instanceVals = new double[this.m_streamingOutputStructure.numAttributes()];
        Object newInstance = null;
        for (int i = 0; i < inst.numAttributes(); ++i) {
            instanceVals[i] = inst.value(i);
        }
        if (!this.m_appendProbabilities || inst.classAttribute().isNumeric()) {
            try {
                double predClass;
                instanceVals[instanceVals.length - 1] = predClass = classifier.classifyInstance(inst);
            }
            catch (Exception ex) {
                throw new WekaException(ex);
            }
        }
        if (this.m_appendProbabilities) {
            try {
                double[] preds = classifier.distributionForInstance(inst);
                int index = 0;
                for (int i = instanceVals.length - inst.classAttribute().numValues(); i < instanceVals.length; ++i) {
                    instanceVals[i] = preds[index++];
                }
            }
            catch (Exception ex) {
                throw new WekaException(ex);
            }
        }
        DenseInstance newInst = new DenseInstance(inst.weight(), instanceVals);
        newInst.setDataset(this.m_streamingOutputStructure);
        if (this.m_stringAttIndexes != null) {
            for (int i = 0; i < this.m_stringAttIndexes.size(); ++i) {
                int index = this.m_stringAttIndexes.get(i);
                this.m_streamingOutputStructure.attribute(index).setStringValue(inst.stringValue(index));
            }
        }
        this.m_instanceData.setPayloadElement("instance", newInst);
        if (this.isStopRequested()) {
            return;
        }
        this.getStepManager().throughputUpdateEnd();
        this.getStepManager().outputData(this.m_instanceData.getConnectionName(), this.m_instanceData);
    }

    protected void processBatchClustererCase(Data data, Instances trainingData, Instances testData) throws WekaException {
        if (this.isStopRequested()) {
            this.getStepManager().interrupted();
            return;
        }
        Clusterer clusterer = (Clusterer)data.getPayloadElement("batchClusterer");
        int setNum = (Integer)data.getPayloadElement("aux_set_num");
        int maxSetNum = (Integer)data.getPayloadElement("aux_max_set_num");
        String relationNameModifier = "_set_" + setNum + "_of_" + maxSetNum;
        if (this.m_appendProbabilities && !(clusterer instanceof DensityBasedClusterer)) {
            throw new WekaException("Only DensityBasedClusterers can append probabilities.");
        }
        try {
            Instances newTestInstances;
            Instances newTrainInstances;
            boolean clusterLabel;
            this.getStepManager().processing();
            boolean bl = clusterLabel = !this.m_appendProbabilities || !(clusterer instanceof DensityBasedClusterer);
            Instances instances = trainingData != null ? this.makeOutputDataClusterer(trainingData, clusterer, !clusterLabel, relationNameModifier) : (newTrainInstances = null);
            Instances instances2 = testData != null ? this.makeOutputDataClusterer(testData, clusterer, !clusterLabel, relationNameModifier) : (newTestInstances = null);
            if (newTrainInstances != null && this.getStepManager().numOutgoingConnectionsOfType("trainingSet") > 0) {
                for (int i = 0; i < newTrainInstances.numInstances(); ++i) {
                    if (clusterLabel) {
                        this.predictLabelClusterer(clusterer, newTrainInstances.instance(i));
                        continue;
                    }
                    this.predictProbabilitiesClusterer((DensityBasedClusterer)clusterer, newTrainInstances.instance(i));
                }
                if (this.isStopRequested()) {
                    this.getStepManager().interrupted();
                    return;
                }
                Data outTrain = new Data("trainingSet");
                outTrain.setPayloadElement("trainingSet", newTrainInstances);
                outTrain.setPayloadElement("aux_set_num", setNum);
                outTrain.setPayloadElement("aux_max_set_num", maxSetNum);
                this.getStepManager().outputData(outTrain);
            }
            if (newTestInstances != null && (this.getStepManager().numOutgoingConnectionsOfType("testSet") > 0 || this.getStepManager().numOutgoingConnectionsOfType("dataSet") > 0)) {
                for (int i = 0; i < newTestInstances.numInstances(); ++i) {
                    if (clusterLabel) {
                        this.predictLabelClusterer(clusterer, newTestInstances.instance(i));
                        continue;
                    }
                    this.predictProbabilitiesClusterer((DensityBasedClusterer)clusterer, newTestInstances.instance(i));
                }
                if (this.isStopRequested()) {
                    this.getStepManager().interrupted();
                    return;
                }
                if (this.getStepManager().numOutgoingConnectionsOfType("testSet") > 0) {
                    Data outTest = new Data("testSet");
                    outTest.setPayloadElement("testSet", newTestInstances);
                    outTest.setPayloadElement("aux_set_num", setNum);
                    outTest.setPayloadElement("aux_max_set_num", maxSetNum);
                    this.getStepManager().outputData(outTest);
                }
                if (this.getStepManager().numIncomingConnectionsOfType("dataSet") > 0) {
                    Data outData = new Data("dataSet");
                    outData.setPayloadElement("dataSet", newTestInstances);
                    outData.setPayloadElement("aux_set_num", setNum);
                    outData.setPayloadElement("aux_max_set_num", maxSetNum);
                    this.getStepManager().outputData(outData);
                }
            }
            this.getStepManager().finished();
        }
        catch (Exception ex) {
            throw new WekaException(ex);
        }
    }

    protected void processBatchClassifierCase(Data data, Instances trainingData, Instances testData) throws WekaException {
        if (this.isStopRequested()) {
            this.getStepManager().interrupted();
            return;
        }
        Classifier classifier = (Classifier)data.getPayloadElement("batchClassifier");
        int setNum = (Integer)data.getPayloadElement("aux_set_num");
        int maxSetNum = (Integer)data.getPayloadElement("aux_max_set_num");
        String relationNameModifier = "_set_" + setNum + "_of_" + maxSetNum;
        boolean classNumeric = trainingData != null ? trainingData.classAttribute().isNumeric() : testData.classAttribute().isNumeric();
        boolean labelOrNumeric = !this.m_appendProbabilities || classNumeric;
        try {
            Instances newTestInstances;
            Instances newTrainInstances;
            this.getStepManager().processing();
            Instances instances = trainingData != null ? this.makeOutputDataClassifier(trainingData, classifier, !labelOrNumeric, relationNameModifier) : (newTrainInstances = null);
            Instances instances2 = testData != null ? this.makeOutputDataClassifier(testData, classifier, !labelOrNumeric, relationNameModifier) : (newTestInstances = null);
            if (newTrainInstances != null && this.getStepManager().numOutgoingConnectionsOfType("trainingSet") > 0) {
                for (int i = 0; i < newTrainInstances.numInstances(); ++i) {
                    if (labelOrNumeric) {
                        this.predictLabelClassifier(classifier, newTrainInstances.instance(i));
                        continue;
                    }
                    this.predictProbabilitiesClassifier(classifier, newTrainInstances.instance(i));
                }
                if (this.isStopRequested()) {
                    this.getStepManager().interrupted();
                    return;
                }
                Data outTrain = new Data("trainingSet");
                outTrain.setPayloadElement("trainingSet", newTrainInstances);
                outTrain.setPayloadElement("aux_set_num", setNum);
                outTrain.setPayloadElement("aux_max_set_num", maxSetNum);
                this.getStepManager().outputData(outTrain);
            }
            if (newTestInstances != null && (this.getStepManager().numOutgoingConnectionsOfType("testSet") > 0 || this.getStepManager().numOutgoingConnectionsOfType("dataSet") > 0)) {
                for (int i = 0; i < newTestInstances.numInstances(); ++i) {
                    if (labelOrNumeric) {
                        this.predictLabelClassifier(classifier, newTestInstances.instance(i));
                        continue;
                    }
                    this.predictProbabilitiesClassifier(classifier, newTestInstances.instance(i));
                }
                if (this.isStopRequested()) {
                    this.getStepManager().interrupted();
                    return;
                }
                if (this.getStepManager().numOutgoingConnectionsOfType("testSet") > 0) {
                    Data outTest = new Data("testSet");
                    outTest.setPayloadElement("testSet", newTestInstances);
                    outTest.setPayloadElement("aux_set_num", setNum);
                    outTest.setPayloadElement("aux_max_set_num", maxSetNum);
                    this.getStepManager().outputData(outTest);
                }
                if (this.getStepManager().numIncomingConnectionsOfType("dataSet") > 0) {
                    Data outData = new Data("dataSet");
                    outData.setPayloadElement("dataSet", newTestInstances);
                    outData.setPayloadElement("aux_set_num", setNum);
                    outData.setPayloadElement("aux_max_set_num", maxSetNum);
                    this.getStepManager().outputData(outData);
                }
            }
            this.getStepManager().finished();
        }
        catch (Exception ex) {
            throw new WekaException(ex);
        }
    }

    protected void predictLabelClusterer(Clusterer clusterer, Instance inst) throws WekaException {
        try {
            int cluster = clusterer.clusterInstance(inst);
            inst.setValue(inst.numAttributes() - 1, (double)cluster);
        }
        catch (Exception ex) {
            throw new WekaException(ex);
        }
    }

    protected void predictProbabilitiesClusterer(DensityBasedClusterer clusterer, Instance inst) throws WekaException {
        try {
            double[] preds = clusterer.distributionForInstance(inst);
            for (int i = 0; i < preds.length; ++i) {
                inst.setValue(inst.numAttributes() - preds.length + i, preds[i]);
            }
        }
        catch (Exception ex) {
            throw new WekaException(ex);
        }
    }

    protected void predictLabelClassifier(Classifier classifier, Instance inst) throws WekaException {
        try {
            double pred = classifier.classifyInstance(inst);
            inst.setValue(inst.numAttributes() - 1, pred);
        }
        catch (Exception ex) {
            throw new WekaException(ex);
        }
    }

    protected void predictProbabilitiesClassifier(Classifier classifier, Instance inst) throws WekaException {
        try {
            double[] preds = classifier.distributionForInstance(inst);
            for (int i = 0; i < preds.length; ++i) {
                inst.setValue(inst.numAttributes() - preds.length + i, preds[i]);
            }
        }
        catch (Exception ex) {
            throw new WekaException(ex);
        }
    }

    protected Instances makeOutputDataClusterer(Instances inputData, Clusterer clusterer, boolean distribution, String relationNameModifier) throws Exception {
        String clustererName = clusterer.getClass().getName();
        clustererName = clustererName.substring(clustererName.lastIndexOf(46) + 1, clustererName.length());
        Instances newData = new Instances(inputData);
        if (distribution) {
            for (int i = 0; i < clusterer.numberOfClusters(); ++i) {
                Add addF = new Add();
                addF.setAttributeIndex("last");
                addF.setAttributeName("prob_cluster" + i);
                addF.setInputFormat(newData);
                newData = Filter.useFilter(newData, addF);
            }
        } else {
            Add addF = new Add();
            addF.setAttributeIndex("last");
            addF.setAttributeName("assigned_cluster: " + clustererName);
            String clusterLabels = "0";
            for (int i = 1; i <= clusterer.numberOfClusters() - 1; ++i) {
                clusterLabels = clusterLabels + "," + i;
            }
            addF.setNominalLabels(clusterLabels);
            addF.setInputFormat(newData);
            newData = Filter.useFilter(newData, addF);
        }
        newData.setRelationName(inputData.relationName() + relationNameModifier);
        return newData;
    }

    protected Instances makeOutputDataClassifier(Instances inputData, Classifier classifier, boolean distribution, String relationNameModifier) throws Exception {
        String classifierName = classifier.getClass().getName();
        classifierName = classifierName.substring(classifierName.lastIndexOf(46) + 1, classifierName.length());
        Instances newData = new Instances(inputData);
        if (distribution) {
            for (int i = 0; i < inputData.classAttribute().numValues(); ++i) {
                Add addF = new Add();
                addF.setAttributeIndex("last");
                addF.setAttributeName(classifierName + "_prob_" + inputData.classAttribute().value(i));
                addF.setInputFormat(newData);
                newData = Filter.useFilter(newData, addF);
            }
        } else {
            Add addF = new Add();
            addF.setAttributeIndex("last");
            addF.setAttributeName("class_predicted_by: " + classifierName);
            if (inputData.classAttribute().isNominal()) {
                String classLabels = inputData.classAttribute().value(0);
                for (int i = 1; i < inputData.classAttribute().numValues(); ++i) {
                    classLabels = classLabels + "," + inputData.classAttribute().value(i);
                }
                addF.setNominalLabels(classLabels);
            }
            addF.setInputFormat(inputData);
            newData = Filter.useFilter(inputData, addF);
        }
        newData.setRelationName(inputData.relationName() + relationNameModifier);
        return newData;
    }

    public void setAppendProbabilities(boolean append) {
        this.m_appendProbabilities = append;
    }

    @OptionMetadata(displayName="Append probabilities", description="Append probabilities")
    public boolean getAppendProbabilities() {
        return this.m_appendProbabilities;
    }
}

