/*
 * Decompiled with CFR 0.152.
 */
package pal.eval;

import pal.alignment.SitePattern;
import pal.datatype.DataType;
import pal.eval.LikelihoodCalculator;
import pal.substmodel.RateMatrix;
import pal.tree.Node;
import pal.tree.NodeUtils;
import pal.tree.Tree;
import pal.tree.TreeUtils;

public class SimpleLikelihoodCalculator
implements LikelihoodCalculator {
    SitePattern sitePattern_;
    Tree tree_;
    RateMatrix model_;
    DataType patternDatatype_;
    boolean modelChanged_ = false;
    private double[][][] partials_;
    private int numberOfStates_;
    private int numberOfPatterns_;
    private double[] frequency_;
    private double[] siteLogL_;

    public SimpleLikelihoodCalculator(SitePattern pattern) {
        this.setPattern(pattern);
    }

    private void setPattern(SitePattern pattern) {
        this.sitePattern_ = pattern;
        this.patternDatatype_ = this.sitePattern_.getDataType();
        this.numberOfPatterns_ = this.sitePattern_.numPatterns;
        this.siteLogL_ = new double[this.numberOfPatterns_];
    }

    public SimpleLikelihoodCalculator(SitePattern pattern, Tree tree, RateMatrix model) {
        this.setPattern(pattern);
        this.setTree(tree);
        this.setRateMatrix(model);
    }

    public void release() {
    }

    public double calculateLogLikelihood() {
        return this.treeLikelihood();
    }

    public SitePattern getSitePattern() {
        return this.sitePattern_;
    }

    public Tree getTree() {
        return this.tree_;
    }

    public void setRateMatrix(RateMatrix m) {
        if (m == null) {
            throw new RuntimeException("Assertion error : SetModel called with null model!");
        }
        this.model_ = m;
        this.frequency_ = this.model_.getEquilibriumFrequencies();
        this.numberOfStates_ = this.model_.getDataType().getNumStates();
        int maxNodes = 2 * this.sitePattern_.getSequenceCount() - 2;
        this.allocatePartialMemory(maxNodes);
    }

    public void setTree(Tree t) {
        this.tree_ = t;
        if (t == null) {
            throw new RuntimeException("Assertion error : SetTree called with null tree!");
        }
        int[] alias = TreeUtils.mapExternalIdentifiers(this.sitePattern_, this.tree_);
        int i = 0;
        while (i < this.tree_.getExternalNodeCount()) {
            this.tree_.getExternalNode(i).setSequence(this.sitePattern_.pattern[alias[i]]);
            ++i;
        }
    }

    public final void modelUpdated() {
        this.setRateMatrix(this.model_);
    }

    public final void treeUpdated() {
        this.setTree(this.tree_);
    }

    private void allocatePartialMemory(int numberOfNodes) {
        if (this.partials_ == null || numberOfNodes != this.partials_.length || this.numberOfPatterns_ != this.partials_[0].length || this.numberOfStates_ != this.partials_[0][0].length) {
            this.partials_ = new double[numberOfNodes][this.numberOfPatterns_][this.numberOfStates_];
        }
    }

    private int getKey(Node node) {
        if (node.isLeaf()) {
            return node.getNumber();
        }
        return node.getNumber() + this.tree_.getExternalNodeCount();
    }

    protected double[][] getPartial(Node branch) {
        return this.partials_[this.getKey(branch)];
    }

    private Node getNextBranchOrRoot(Node branch, Node center) {
        int numChilds = center.getChildCount();
        int num = 0;
        while (num < numChilds) {
            if (center.getChild(num) == branch) break;
            ++num;
        }
        if (++num > numChilds) {
            num = 0;
        }
        if (num == numChilds) {
            return center;
        }
        return center.getChild(num);
    }

    protected Node getNextBranch(Node branch, Node center) {
        Node b = this.getNextBranchOrRoot(branch, center);
        if (b.isRoot()) {
            b = b.getChild(0);
        }
        return b;
    }

    protected void productPartials(Node center) {
        int numBranches = NodeUtils.getUnrootedBranchCount(center);
        Node nextBranch = center.getChild(0);
        double[][] partial = this.getPartial(nextBranch);
        int i = 1;
        while (i < center.getChildCount()) {
            nextBranch = center.getChild(i);
            double[][] partial2 = this.getPartial(nextBranch);
            int patternIndex = 0;
            while (patternIndex < this.numberOfPatterns_) {
                double[] p = partial[patternIndex];
                double[] p2 = partial2[patternIndex];
                int state = 0;
                while (state < this.numberOfStates_) {
                    int n = state;
                    p[n] = p[n] * p2[state];
                    ++state;
                }
                ++patternIndex;
            }
            ++i;
        }
    }

    protected void partialsInternal(Node center) {
        double[][] partial = this.getPartial(center);
        double[][] multPartial = this.getPartial(center.getChild(0));
        this.model_.setDistance(center.getBranchLength());
        int l = 0;
        while (l < this.numberOfPatterns_) {
            double[] p = partial[l];
            double[] mp = multPartial[l];
            int d = 0;
            while (d < this.numberOfStates_) {
                double sum = 0.0;
                int j = 0;
                while (j < this.numberOfStates_) {
                    sum += this.model_.getTransitionProbability(d, j) * mp[j];
                    ++j;
                }
                p[d] = sum;
                ++d;
            }
            ++l;
        }
    }

    protected void partialsExternal(Node branch) {
        double[][] partial = this.getPartial(branch);
        byte[] seq = branch.getSequence();
        this.model_.setDistance(branch.getBranchLength());
        int patternIndex = 0;
        while (patternIndex < this.numberOfPatterns_) {
            int startState;
            double[] p = partial[patternIndex];
            byte endState = seq[patternIndex];
            if (this.patternDatatype_.isUnknownState(endState)) {
                startState = 0;
                while (startState < this.numberOfStates_) {
                    p[startState] = 1.0;
                    ++startState;
                }
            } else {
                startState = 0;
                while (startState < this.numberOfStates_) {
                    p[startState] = this.model_.getTransitionProbability(startState, endState);
                    ++startState;
                }
            }
            ++patternIndex;
        }
    }

    private void traverseTree(Node currentNode) {
        if (currentNode.isLeaf()) {
            this.partialsExternal(currentNode);
        } else {
            int i = 0;
            while (i < currentNode.getChildCount()) {
                this.traverseTree(currentNode.getChild(i));
                ++i;
            }
            if (!currentNode.isRoot()) {
                this.productPartials(currentNode);
                this.partialsInternal(currentNode);
            }
        }
    }

    private int getBranchCount(Node center) {
        if (center.isRoot()) {
            return center.getChildCount();
        }
        return center.getChildCount() + 1;
    }

    private double treeLikelihood() {
        Node center = this.tree_.getRoot();
        this.traverseTree(center);
        Node firstBranch = center.getChild(0);
        Node lastBranch = center.getChild(center.getChildCount() - 1);
        double[][] partial1 = this.getPartial(firstBranch);
        this.productPartials(center);
        double logL = 0.0;
        int patternIndex = 0;
        while (patternIndex < this.numberOfPatterns_) {
            double sum = 0.0;
            double[] p1 = partial1[patternIndex];
            int d = 0;
            while (d < this.numberOfStates_) {
                sum += this.frequency_[d] * p1[d];
                ++d;
            }
            this.siteLogL_[patternIndex] = Math.log(sum);
            logL += this.siteLogL_[patternIndex] * (double)this.sitePattern_.weight[patternIndex];
            ++patternIndex;
        }
        return logL;
    }
}

