/*
 * Decompiled with CFR 0.152.
 */
package de.dfki.s2m2de.expression.svm;

import de.dfki.s2m2de.expression.svm.ModelSearchSpace;
import de.dfki.s2m2de.expression.svm.SVM;
import de.dfki.s2m2de.expression.svm.SVMException;
import de.dfki.s2m2de.expression.svm.SVMParameters;
import de.dfki.s2m2de.expression.svm.Sample;
import de.dfki.s2m2de.expression.svm.TrainingSet;
import java.util.LinkedList;
import org.apache.log4j.Logger;

public abstract class AbstractSVM<Domain>
implements SVM<Domain> {
    @Override
    public abstract boolean predict(Sample<Domain> var1) throws SVMException;

    @Override
    public abstract void train(TrainingSet<Domain> var1, SVMParameters var2) throws SVMException;

    @Override
    public double nFoldCrossValidation(TrainingSet<Domain> trainingSet, SVMParameters parameters, int n, long seed) throws SVMException {
        Logger logger = Logger.getLogger(this.getClass());
        try {
            LinkedList<TrainingSet<Domain>> folds = trainingSet.fold(n, seed);
            int count = 0;
            int correct = 0;
            int i = 0;
            while (i < n) {
                TrainingSet<Domain> testFold = folds.removeFirst();
                TrainingSet<Domain> trainFolds = TrainingSet.merge(folds);
                try {
                    SVM svm = (SVM)this.getClass().newInstance();
                    svm.train(trainFolds, parameters);
                    for (Sample<Domain> sample : testFold.getTrainingSet()) {
                        ++count;
                        if (svm.predict(sample) != sample.isRelevant()) continue;
                        ++correct;
                    }
                }
                catch (SVMException e) {
                    logger.warn((Object)"Problem during cross validation test.", (Throwable)e);
                }
                folds.addLast(testFold);
                ++i;
            }
            return (double)correct / (double)count * 100.0;
        }
        catch (IllegalAccessException e) {
            logger.warn((Object)"Unable to perform n-fold cross validation.", (Throwable)e);
            throw new SVMException("Unable to perform n-fold cross validation.", e);
        }
        catch (InstantiationException e) {
            logger.warn((Object)"Unable to perform n-fold cross validation.", (Throwable)e);
            throw new SVMException("Unable to perform n-fold cross validation.", e);
        }
    }

    @Override
    public SVMParameters searchModel(TrainingSet<Domain> trainingSet, ModelSearchSpace modelSearchSpace, int n, long seed) throws SVMException {
        Logger logger = Logger.getLogger(this.getClass());
        SVMParameters bestParameters = null;
        double bestAccuracy = Double.MIN_VALUE;
        while (modelSearchSpace.hasMoreModels) {
            try {
                SVMParameters tempParameters = modelSearchSpace.getNextModel();
                double accuracy = this.nFoldCrossValidation(trainingSet, tempParameters, n, seed);
                if (!(accuracy > bestAccuracy)) continue;
                bestAccuracy = accuracy;
                bestParameters = tempParameters;
            }
            catch (SVMException e) {
                logger.warn((Object)"Problem with model search.", (Throwable)e);
            }
        }
        logger.info((Object)("Cross validation accuracy of found model: " + bestAccuracy));
        bestParameters.setAccuracy(new Double(bestAccuracy));
        return bestParameters;
    }
}

