/*
 * Decompiled with CFR 0.152.
 */
package pal.eval;

import pal.datatype.DataType;
import pal.eval.ConditionalProbabilityStore;
import pal.eval.FastFourStateLHCalculator;
import pal.eval.LHCalculator;
import pal.eval.PatternInfo;
import pal.eval.SimpleLeafCalculator;
import pal.substmodel.SubstitutionModel;

public class SimpleModelFastFourStateLHCalculator
implements LHCalculator {
    private static final int FOUR_STATES = 4;
    private static final int ONE_CATEGORY = 1;

    private static final void calculateExtendedImpl(double[][] transitionProbabilityStore, PatternInfo centerPattern, ConditionalProbabilityStore leftConditionalProbabilityProbabilties, ConditionalProbabilityStore rightConditionalProbabilityProbabilties, ConditionalProbabilityStore resultStore) {
        int[] patternLookup = centerPattern.getPatternLookup();
        int numberOfPatterns = centerPattern.getNumberOfPatterns();
        int patternAccess = 0;
        double[][] myPatternStateProbabilities = resultStore.getConditionalProbabilityAccess(numberOfPatterns, false)[0];
        double[][] leftPatternStateProbabilities = leftConditionalProbabilityProbabilties.getCurrentConditionalProbabilities(0);
        double[][] rightPatternStateProbabilities = rightConditionalProbabilityProbabilties.getCurrentConditionalProbabilities(0);
        int pattern = 0;
        while (pattern < numberOfPatterns) {
            int leftPattern = patternLookup[patternAccess++];
            int rightPattern = patternLookup[patternAccess++];
            double[] myStateProbabilities = myPatternStateProbabilities[pattern];
            double[] leftStateProbabilities = leftPatternStateProbabilities[leftPattern];
            double[] rightStateProbabilities = rightPatternStateProbabilities[rightPattern];
            double es0 = leftStateProbabilities[0] * rightStateProbabilities[0];
            double es1 = leftStateProbabilities[1] * rightStateProbabilities[1];
            double es2 = leftStateProbabilities[2] * rightStateProbabilities[2];
            double es3 = leftStateProbabilities[3] * rightStateProbabilities[3];
            double[] sa0 = transitionProbabilityStore[0];
            double[] sa1 = transitionProbabilityStore[1];
            double[] sa2 = transitionProbabilityStore[2];
            double[] sa3 = transitionProbabilityStore[3];
            myStateProbabilities[0] = sa0[0] * es0 + sa0[1] * es1 + sa0[2] * es2 + sa0[3] * es3;
            myStateProbabilities[1] = sa1[0] * es0 + sa1[1] * es1 + sa1[2] * es2 + sa1[3] * es3;
            myStateProbabilities[2] = sa2[0] * es0 + sa2[1] * es1 + sa2[2] * es2 + sa2[3] * es3;
            myStateProbabilities[3] = sa3[0] * es0 + sa3[1] * es1 + sa3[2] * es2 + sa3[3] * es3;
            ++pattern;
        }
    }

    private static final void calculateFlatImpl(PatternInfo centerPattern, ConditionalProbabilityStore leftConditionalProbabilityProbabilties, ConditionalProbabilityStore rightConditionalProbabilityProbabilties, ConditionalProbabilityStore resultStore) {
        int patternAccess = 0;
        int[] patternLookup = centerPattern.getPatternLookup();
        int numberOfPatterns = centerPattern.getNumberOfPatterns();
        double[][] myPatternStateProbabilities = resultStore.getConditionalProbabilityAccess(numberOfPatterns, false)[0];
        double[][] leftPatternStateProbabilities = leftConditionalProbabilityProbabilties.getCurrentConditionalProbabilities(0);
        double[][] rightPatternStateProbabilities = rightConditionalProbabilityProbabilties.getCurrentConditionalProbabilities(0);
        int pattern = 0;
        while (pattern < numberOfPatterns) {
            int leftPattern = patternLookup[patternAccess++];
            int rightPattern = patternLookup[patternAccess++];
            double[] myStateProbabilities = myPatternStateProbabilities[pattern];
            double[] leftStateProbabilities = leftPatternStateProbabilities[leftPattern];
            double[] rightStateProbabilities = rightPatternStateProbabilities[rightPattern];
            myStateProbabilities[0] = leftStateProbabilities[0] * rightStateProbabilities[0];
            myStateProbabilities[1] = leftStateProbabilities[1] * rightStateProbabilities[1];
            myStateProbabilities[2] = leftStateProbabilities[2] * rightStateProbabilities[2];
            myStateProbabilities[3] = leftStateProbabilities[3] * rightStateProbabilities[3];
            ++pattern;
        }
    }

    public static final LHCalculator.Factory getFactory(LHCalculator.Factory fallbackFactory) {
        return new SimpleFactory(fallbackFactory);
    }

    public static final LHCalculator.Factory getFactory() {
        return new SimpleFactory(FastFourStateLHCalculator.getFactory());
    }

    private static final class SimpleGenerator
    implements LHCalculator.Generator {
        public SimpleGenerator() {
            System.out.println("Simple Model, four state");
        }

        public LHCalculator.Leaf createNewLeaf(int[] patternStateMatchup, int numberOfPatterns) {
            return new SimpleLeafCalculator(patternStateMatchup, numberOfPatterns, 4, 1, this);
        }

        public LHCalculator.Leaf createNewLeaf(int[] patternStateMatchup, int numberOfPatterns, LHCalculator.Generator parentGenerator) {
            return new SimpleLeafCalculator(patternStateMatchup, numberOfPatterns, 4, 1, parentGenerator);
        }

        public LHCalculator.Internal createNewInternal() {
            return new InternalImpl(this);
        }

        public LHCalculator.External createNewExternal() {
            return new ExternalImpl();
        }

        public LHCalculator.External createNewExternal(LHCalculator.Generator parentGenerator) throws IllegalArgumentException {
            return new ExternalImpl();
        }

        public LHCalculator.Internal createNewInternal(LHCalculator.Generator parentGenerator) throws IllegalArgumentException {
            return new InternalImpl(parentGenerator);
        }

        public ConditionalProbabilityStore createAppropriateConditionalProbabilityStore(boolean isForLeaf) {
            return new ConditionalProbabilityStore(1, 4, false);
        }

        public boolean isAllowCaching() {
            return true;
        }
    }

    private static final class SimpleFactory
    implements LHCalculator.Factory {
        private final LHCalculator.Factory fallbackFactory_;

        public SimpleFactory(LHCalculator.Factory fallbackFactory) {
            this.fallbackFactory_ = fallbackFactory;
        }

        public LHCalculator.Generator createSeries(int numberOfCategories, DataType dt) {
            if (dt.getNumStates() == 4 && numberOfCategories == 1) {
                System.out.println("OPTIMISING simple model...!");
                return new SimpleGenerator();
            }
            return this.fallbackFactory_.createSeries(numberOfCategories, dt);
        }
    }

    private static final class ExternalImpl
    implements LHCalculator.External {
        private final double[][] transitionProbabilityStore_ = new double[4][4];

        private ExternalImpl() {
        }

        private final double[][] getResultStoreValues(double distance, SubstitutionModel model, PatternInfo centerPattern, ConditionalProbabilityStore leftFlatConditionalProbabilities, ConditionalProbabilityStore rightFlatConditionalProbabilities, ConditionalProbabilityStore tempStore) {
            int[] patternWeights = centerPattern.getPatternWeights();
            int[] patternLookup = centerPattern.getPatternLookup();
            int numberOfPatterns = centerPattern.getNumberOfPatterns();
            model.getTransitionProbabilities(distance, 0, this.transitionProbabilityStore_);
            double[][][] resultStoreValues = tempStore.getConditionalProbabilityAccess(numberOfPatterns, false);
            int patternAccess = 0;
            double[][] myPatternStateProbabilities = resultStoreValues[0];
            double[][] leftPatternStateProbabilities = leftFlatConditionalProbabilities.getCurrentConditionalProbabilities(0);
            double[][] rightPatternStateProbabilities = rightFlatConditionalProbabilities.getCurrentConditionalProbabilities(0);
            int pattern = 0;
            while (pattern < numberOfPatterns) {
                int leftPattern = patternLookup[patternAccess++];
                int rightPattern = patternLookup[patternAccess++];
                double[] myStateProbabilities = myPatternStateProbabilities[pattern];
                double[] leftStateProbabilities = leftPatternStateProbabilities[leftPattern];
                double[] rightStateProbabilities = rightPatternStateProbabilities[rightPattern];
                double[] sa0 = this.transitionProbabilityStore_[0];
                double[] sa1 = this.transitionProbabilityStore_[1];
                double[] sa2 = this.transitionProbabilityStore_[2];
                double[] sa3 = this.transitionProbabilityStore_[3];
                myStateProbabilities[0] = (sa0[0] * leftStateProbabilities[0] + sa0[1] * leftStateProbabilities[1] + sa0[2] * leftStateProbabilities[2] + sa0[3] * leftStateProbabilities[3]) * rightStateProbabilities[0];
                myStateProbabilities[1] = (sa1[0] * leftStateProbabilities[0] + sa1[1] * leftStateProbabilities[1] + sa1[2] * leftStateProbabilities[2] + sa1[3] * leftStateProbabilities[3]) * rightStateProbabilities[1];
                myStateProbabilities[2] = (sa2[0] * leftStateProbabilities[0] + sa2[1] * leftStateProbabilities[1] + sa2[2] * leftStateProbabilities[2] + sa2[3] * leftStateProbabilities[3]) * rightStateProbabilities[2];
                myStateProbabilities[3] = (sa3[0] * leftStateProbabilities[0] + sa3[1] * leftStateProbabilities[1] + sa3[2] * leftStateProbabilities[2] + sa3[3] * leftStateProbabilities[3]) * rightStateProbabilities[3];
                ++pattern;
            }
            return resultStoreValues[0];
        }

        public final void calculateExtended(double distance, SubstitutionModel model, PatternInfo centerPattern, ConditionalProbabilityStore leftConditionalProbabilityProbabilties, ConditionalProbabilityStore rightConditionalProbabilityProbabilties, ConditionalProbabilityStore resultStore) {
            model.getTransitionProbabilities(distance, 0, this.transitionProbabilityStore_);
            SimpleModelFastFourStateLHCalculator.calculateExtendedImpl(this.transitionProbabilityStore_, centerPattern, leftConditionalProbabilityProbabilties, rightConditionalProbabilityProbabilties, resultStore);
        }

        public final void calculateFlat(PatternInfo centerPattern, ConditionalProbabilityStore leftConditionalProbabilityProbabilties, ConditionalProbabilityStore rightConditionalProbabilityProbabilties, ConditionalProbabilityStore resultStore) {
            SimpleModelFastFourStateLHCalculator.calculateFlatImpl(centerPattern, leftConditionalProbabilityProbabilties, rightConditionalProbabilityProbabilties, resultStore);
        }

        public double calculateLogLikelihood(double distance, SubstitutionModel model, PatternInfo centerPattern, ConditionalProbabilityStore leftFlatConditionalProbabilities, ConditionalProbabilityStore rightFlatConditionalProbabilities, ConditionalProbabilityStore tempStore) {
            int[] patternWeights = centerPattern.getPatternWeights();
            int[] patternLookup = centerPattern.getPatternLookup();
            int numberOfPatterns = centerPattern.getNumberOfPatterns();
            double[][] myPatternStateProbabilities = this.getResultStoreValues(distance, model, centerPattern, leftFlatConditionalProbabilities, rightFlatConditionalProbabilities, tempStore);
            double[] equilibriumFrequencies = model.getEquilibriumFrequencies();
            double[] probabilities = model.getTransitionCategoryProbabilities();
            double logLikelihood = 0.0;
            int pattern = 0;
            while (pattern < numberOfPatterns) {
                double[] states = myPatternStateProbabilities[pattern];
                double total = equilibriumFrequencies[0] * states[0] + equilibriumFrequencies[1] * states[1] + equilibriumFrequencies[2] * states[2] + equilibriumFrequencies[3] * states[3];
                logLikelihood += Math.log(total) * (double)patternWeights[pattern];
                ++pattern;
            }
            return logLikelihood;
        }

        public final void calculateCategoryPatternProbabilities(SubstitutionModel model, PatternInfo centerPattern, ConditionalProbabilityStore leftConditionalProbabilities, ConditionalProbabilityStore rightConditionalProbabilities, double[][] categoryPatternLogLikelihoodStore) {
            int[] patternLookup = centerPattern.getPatternLookup();
            int numberOfPatterns = centerPattern.getNumberOfPatterns();
            double[] equilibriumFrequencies = model.getEquilibriumFrequencies();
            double[] probabilities = model.getTransitionCategoryProbabilities();
            int patternIndex = 0;
            double[][] leftValues = leftConditionalProbabilities.getCurrentConditionalProbabilities(0);
            double[][] rightValues = rightConditionalProbabilities.getCurrentConditionalProbabilities(0);
            double[] patternLogLikelihoods = categoryPatternLogLikelihoodStore[0];
            int pattern = 0;
            while (pattern < numberOfPatterns) {
                double total;
                double[] left = leftValues[patternLookup[patternIndex++]];
                double[] right = rightValues[patternLookup[patternIndex++]];
                patternLogLikelihoods[pattern] = total = equilibriumFrequencies[0] * (left[0] * right[0]) + equilibriumFrequencies[1] * (left[1] * right[1]) + equilibriumFrequencies[2] * (left[2] * right[2]) + equilibriumFrequencies[3] * (left[3] * right[3]);
                ++pattern;
            }
        }

        public final void calculateCategoryPatternProbabilities(double distance, SubstitutionModel model, PatternInfo centerPattern, ConditionalProbabilityStore leftFlatConditionalProbabilities, ConditionalProbabilityStore rightFlatConditionalProbabilities, ConditionalProbabilityStore tempStore, double[][] categoryPatternLogLikelihoodStore) {
            double[][] myPatternStateProbabilities = this.getResultStoreValues(distance, model, centerPattern, leftFlatConditionalProbabilities, rightFlatConditionalProbabilities, tempStore);
            double[] equilibriumFrequencies = model.getEquilibriumFrequencies();
            int numberOfPatterns = centerPattern.getNumberOfPatterns();
            double[] patternLogLikelihoodStore = categoryPatternLogLikelihoodStore[0];
            int pattern = 0;
            while (pattern < numberOfPatterns) {
                double total;
                double[] states = myPatternStateProbabilities[pattern];
                patternLogLikelihoodStore[pattern] = total = equilibriumFrequencies[0] * states[0] + equilibriumFrequencies[1] * states[1] + equilibriumFrequencies[2] * states[2] + equilibriumFrequencies[3] * states[3];
                ++pattern;
            }
        }

        public double calculateLogLikelihood(SubstitutionModel model, PatternInfo centerPattern, ConditionalProbabilityStore leftConditionalProbabilities, ConditionalProbabilityStore rightConditionalProbabilities) {
            int[] patternWeights = centerPattern.getPatternWeights();
            int[] patternLookup = centerPattern.getPatternLookup();
            int numberOfPatterns = centerPattern.getNumberOfPatterns();
            double[] equilibriumFrequencies = model.getEquilibriumFrequencies();
            double[] probabilities = model.getTransitionCategoryProbabilities();
            double logLikelihood = 0.0;
            int patternIndex = 0;
            double[][] leftValues = leftConditionalProbabilities.getCurrentConditionalProbabilities(0);
            double[][] rightValues = rightConditionalProbabilities.getCurrentConditionalProbabilities(0);
            int pattern = 0;
            while (pattern < numberOfPatterns) {
                double[] left = leftValues[patternLookup[patternIndex++]];
                double[] right = rightValues[patternLookup[patternIndex++]];
                double total = equilibriumFrequencies[0] * (left[0] * right[0]) + equilibriumFrequencies[1] * (left[1] * right[1]) + equilibriumFrequencies[2] * (left[2] * right[2]) + equilibriumFrequencies[3] * (left[3] * right[3]);
                logLikelihood += Math.log(total) * (double)patternWeights[pattern];
                ++pattern;
            }
            return logLikelihood;
        }
    }

    private static final class InternalImpl
    implements LHCalculator.Internal {
        private final ConditionalProbabilityStore myResultStore_;
        private final double[][] transitionProbabilityStore_ = new double[4][4];
        private double lastDistance_ = -1.0;

        private InternalImpl(LHCalculator.Generator parentGenerator) {
            this.myResultStore_ = parentGenerator.createAppropriateConditionalProbabilityStore(false);
        }

        public final ConditionalProbabilityStore calculateExtended(double distance, SubstitutionModel model, PatternInfo centerPattern, ConditionalProbabilityStore leftConditionalProbabilityProbabilties, ConditionalProbabilityStore rightConditionalProbabilityProbabilties, boolean modelChangedSinceLastCall, boolean childrenChanged) {
            if (modelChangedSinceLastCall || distance != this.lastDistance_ || this.lastDistance_ < 0.0) {
                model.getTransitionProbabilities(distance, 0, this.transitionProbabilityStore_);
                this.lastDistance_ = distance;
            }
            SimpleModelFastFourStateLHCalculator.calculateExtendedImpl(this.transitionProbabilityStore_, centerPattern, leftConditionalProbabilityProbabilties, rightConditionalProbabilityProbabilties, this.myResultStore_);
            return this.myResultStore_;
        }

        public final ConditionalProbabilityStore calculateFlat(PatternInfo centerPattern, ConditionalProbabilityStore leftConditionalProbabilityProbabilties, ConditionalProbabilityStore rightConditionalProbabilityProbabilties, boolean childrenChanged) {
            SimpleModelFastFourStateLHCalculator.calculateFlatImpl(centerPattern, leftConditionalProbabilityProbabilties, rightConditionalProbabilityProbabilties, this.myResultStore_);
            return this.myResultStore_;
        }
    }
}

