/*
 * Decompiled with CFR 0.152.
 */
package edu.ucla.uclapl;

import cern.colt.function.DoubleDoubleFunction;
import cern.colt.matrix.DoubleFactory1D;
import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.jet.math.Functions;
import edu.ucla.uclapl.IIS;
import edu.ucla.uclapl.IIS_gaussian_prior;
import edu.ucla.uclapl.RandomField;
import edu.ucla.uclapl.UCLAPhonotacticLearner;
import org.apache.commons.math.analysis.BrentSolver;
import org.apache.commons.math.analysis.DifferentiableUnivariateRealFunction;
import org.apache.commons.math.analysis.NewtonSolver;
import org.apache.commons.math.analysis.UnivariateRealFunction;

public class ImprovedIterativeScaling {
    static Functions functions = Functions.functions;
    static BrentSolver brent = null;
    static NewtonSolver newton = null;
    static int METHOD = 1;
    static int SOLVER = 0;
    static int ITERATION_MAX = UCLAPhonotacticLearner.WEIGHTING_ITERATION;
    static double TOL = 0.25;
    static double WEIGHT_MINIMUM = 1.0E-6;
    static boolean silent = false;

    public void optimize(RandomField RF, DoubleMatrix1D lambda, double tol, int prin) {
        DoubleMatrix2D weightHistory = DoubleFactory2D.dense.make(1, lambda.size());
        weightHistory.assign(0.0);
        IIS_gaussian_prior objective = null;
        switch (METHOD) {
            case 1: {
                objective = new IIS_gaussian_prior();
                break;
            }
            default: {
                objective = new IIS_gaussian_prior();
            }
        }
        DoubleMatrix1D O = RF.getEmpiricalCounts();
        int iterations = 1;
        do {
            RF.resetSample();
            RF.getRandomSample(true);
            if (!silent) {
                System.out.println("\toptimize(): initializing f");
            }
            DoubleMatrix2D f = DoubleFactory2D.sparse.make(RF.sample.data.length, lambda.size());
            if (!silent) {
                System.out.println("\toptimize(): initializing fsharp");
            }
            DoubleMatrix1D fsharp = DoubleFactory1D.sparse.make(RF.sample.data.length);
            for (int w = 0; w < RF.sample.data.length; ++w) {
                f.viewRow(w).assign(RF.scores(RF.sample.data[w]));
                fsharp.set(w, f.viewRow(w).zSum());
            }
            if (!silent) {
                System.out.println("\tfinding delta for each constraint weight");
            }
            DoubleMatrix1D delta = DoubleFactory1D.dense.make(lambda.size());
            block13: for (int i = 0; i < lambda.size(); ++i) {
                IIS.O_i = O.get(i);
                IIS.N = RF.corpus.size;
                IIS.lambda_i = lambda.get(i);
                IIS.f_i = f.viewColumn(i);
                IIS.fsharp = fsharp;
                switch (SOLVER) {
                    case 0: {
                        brent = new BrentSolver((UnivariateRealFunction)objective);
                        double[] bracket = this.bracketByBisection((UnivariateRealFunction)objective, -10.0, 10.0);
                        if (bracket == null) {
                            delta.set(i, 0.0);
                            continue block13;
                        }
                        try {
                            delta.set(i, brent.solve(bracket[0], bracket[1], 0.0));
                        }
                        catch (Exception e) {
                            System.out.println(e);
                            System.exit(1);
                        }
                        continue block13;
                    }
                    case 1: {
                        newton = new NewtonSolver((DifferentiableUnivariateRealFunction)objective);
                        try {
                            delta.set(i, newton.solve(-10.0, 10.0));
                            continue block13;
                        }
                        catch (Exception e) {
                            System.out.println(e);
                        }
                    }
                }
            }
            f = null;
            fsharp = null;
            boolean converged = true;
            for (int i = 0; i < delta.size(); ++i) {
                if (!(Math.abs(delta.get(i)) > TOL)) continue;
                converged = false;
            }
            if (converged) break;
            lambda.assign(delta, Functions.minus);
            lambda.assign(Functions.bindArg1((DoubleDoubleFunction)Functions.max, (double)WEIGHT_MINIMUM));
            weightHistory = DoubleFactory2D.dense.appendRows(weightHistory, DoubleFactory2D.dense.make(lambda.toArray(), 1));
        } while (++iterations <= ITERATION_MAX);
        System.out.println("\nevolution of the weights:");
        System.out.println(weightHistory);
    }

    private double[] bracket(UnivariateRealFunction f, double lower, double upper) {
        double lo = 0.0;
        double hi = 0.0;
        int MAX_ITERATIONS = 2;
        int iterations = 0;
        do {
            try {
                lo = f.value(lower);
                hi = f.value(upper);
            }
            catch (Exception e) {
                System.out.println(e);
                System.exit(1);
            }
            if (lo * hi < 0.0) {
                return new double[]{lower, upper};
            }
            lower *= 10.0;
            upper *= 10.0;
        } while (++iterations != MAX_ITERATIONS);
        return null;
    }

    private double[] bracketByBisection(UnivariateRealFunction f, double lower, double upper) {
        double lo = 0.0;
        double hi = 0.0;
        double mid = 0.0;
        try {
            hi = f.value(upper);
            lo = f.value(lower);
        }
        catch (Exception e) {
            System.out.println(e);
            System.exit(1);
        }
        if (hi * lo >= 0.0) {
            return null;
        }
        double midpoint = 0.0;
        for (int iteration = 10; iteration > 0; --iteration) {
            midpoint = (lower + upper) / 2.0;
            try {
                hi = f.value(upper);
                mid = f.value(midpoint);
                lo = f.value(lower);
            }
            catch (Exception e) {
                System.out.println(e);
                System.exit(1);
            }
            if (hi * lo >= 0.0) {
                return new double[]{lower, upper};
            }
            if (hi * mid > 0.0) {
                upper = midpoint;
                continue;
            }
            if (!(lo * mid > 0.0)) continue;
            lower = midpoint;
        }
        if (hi * lo < 0.0) {
            return new double[]{lower, upper};
        }
        return null;
    }

    public static void main(String[] args) throws Exception {
        ImprovedIterativeScaling iiser = new ImprovedIterativeScaling();
        iiser.test();
    }

    public void test() throws Exception {
        UCLAPhonotacticLearner.SIGMA2 = 1.0;
        IIS_gaussian_prior objective = new IIS_gaussian_prior();
        IIS.O_i = 10.0;
        IIS.lambda_i = -10.0;
        IIS.f_i = DoubleFactory1D.dense.make(new double[]{20.0, 20.0, 20.0, 20.0, 20.0});
        IIS.fsharp = DoubleFactory1D.dense.make(new double[]{20.0, 20.0, 20.0, 20.0, 20.0});
        BrentSolver solver = new BrentSolver((UnivariateRealFunction)objective);
        solver.solve(-100.0, 100.0, 0.0);
        double x = solver.getResult();
        System.out.println(x + " -> " + ((IIS)objective).value(x));
    }
}

