/*
 * Decompiled with CFR 0.152.
 */
package pal.tree;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import pal.alignment.AbstractAlignment;
import pal.datatype.DataType;
import pal.math.MersenneTwisterFast;
import pal.misc.SimpleIdGroup;
import pal.substmodel.SubstitutionModel;
import pal.tree.Node;
import pal.tree.NodeUtils;
import pal.tree.Tree;
import pal.util.AlgorithmCallback;

public class SimulatedAlignment
extends AbstractAlignment {
    private Tree tree;
    private SubstitutionModel model;
    private double[] cumFreqs;
    private int[] rateAtSite;
    private double[] cumRateProbs;
    private int numStates;
    private byte[][] stateData;
    private MersenneTwisterFast rng = new MersenneTwisterFast();

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.writeByte(1);
        out.writeObject(this.tree);
        out.writeObject(this.model);
        out.writeObject(this.cumFreqs);
        out.writeObject(this.rateAtSite);
        out.writeObject(this.cumRateProbs);
        out.writeObject(this.stateData);
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        byte version = in.readByte();
        switch (version) {
            default: 
        }
        this.tree = (Tree)in.readObject();
        this.model = (SubstitutionModel)in.readObject();
        this.cumFreqs = (double[])in.readObject();
        this.rateAtSite = (int[])in.readObject();
        this.cumRateProbs = (double[])in.readObject();
        this.stateData = (byte[][])in.readObject();
        this.numStates = this.getDataType().getNumStates();
        this.rng = new MersenneTwisterFast();
    }

    public SimulatedAlignment(int sites, Tree t, SubstitutionModel m) {
        this.setDataType(m.getDataType());
        this.numStates = this.getDataType().getNumStates();
        this.model = m;
        this.tree = t;
        this.tree.createNodeList();
        this.numSeqs = this.tree.getExternalNodeCount();
        this.numSites = sites;
        this.idGroup = new SimpleIdGroup(this.numSeqs);
        int i = 0;
        while (i < this.numSeqs) {
            this.idGroup.setIdentifier(i, this.tree.getExternalNode(i).getIdentifier());
            ++i;
        }
        this.stateData = new byte[this.numSeqs][this.numSites];
        int i2 = 0;
        while (i2 < this.tree.getExternalNodeCount()) {
            this.tree.getExternalNode(i2).setSequence(this.stateData[i2]);
            ++i2;
        }
        int i3 = 0;
        while (i3 < this.tree.getInternalNodeCount() - 1) {
            this.tree.getInternalNode(i3).setSequence(new byte[this.numSites]);
            ++i3;
        }
        this.rateAtSite = new int[this.numSites];
        this.cumFreqs = new double[this.numStates];
        this.cumRateProbs = new double[m.getNumberOfTransitionCategories()];
    }

    public char getData(int seq, int site) {
        return this.getChar(this.stateData[seq][site]);
    }

    public void simulate() {
        this.simulate(this.makeRandomRootSequence());
    }

    public void simulate(String givenRootSequence) {
        this.simulate(DataType.Utils.getByteStates(givenRootSequence, this.model.getDataType()));
    }

    public void simulate(byte[] rootSeq) {
        double[][][] transitionStore = SubstitutionModel.Utils.generateTransitionProbabilityTables(this.model);
        int i = 0;
        while (i < this.numSites) {
            if (rootSeq[i] >= this.numStates || rootSeq[i] < 0) {
                throw new IllegalArgumentException("Root sequence contains illegal state (?,-, etc.)");
            }
            ++i;
        }
        this.tree.getInternalNode(this.tree.getInternalNodeCount() - 1).setSequence(rootSeq);
        this.assignRates();
        Node node = NodeUtils.preorderSuccessor(this.tree.getRoot());
        do {
            this.determineMutatedSequence(node, transitionStore);
        } while ((node = NodeUtils.preorderSuccessor(node)) != this.tree.getRoot());
    }

    private void determineMutatedSequence(Node node, double[][][] transitionStore) {
        if (node.isRoot()) {
            throw new IllegalArgumentException("Root node not allowed");
        }
        this.model.getTransitionProbabilities(node.getBranchLength(), transitionStore);
        byte[] oldS = node.getParent().getSequence();
        byte[] newS = node.getSequence();
        int i = 0;
        while (i < this.numSites) {
            double[] freqs = transitionStore[this.rateAtSite[i]][oldS[i]];
            this.cumFreqs[0] = freqs[0];
            int j = 1;
            while (j < this.numStates) {
                this.cumFreqs[j] = this.cumFreqs[j - 1] + freqs[j];
                ++j;
            }
            newS[i] = (byte)this.randomChoice(this.cumFreqs);
            ++i;
        }
    }

    private byte[] makeRandomRootSequence() {
        double[] frequencies = this.model.getEquilibriumFrequencies();
        this.cumFreqs[0] = frequencies[0];
        int i = 1;
        while (i < this.numStates) {
            this.cumFreqs[i] = this.cumFreqs[i - 1] + frequencies[i];
            ++i;
        }
        byte[] rootSequence = new byte[this.numSites];
        int i2 = 0;
        while (i2 < this.numSites) {
            rootSequence[i2] = (byte)this.randomChoice(this.cumFreqs);
            ++i2;
        }
        return rootSequence;
    }

    private void assignRates() {
        double[] categoryProbabilities = this.model.getTransitionCategoryProbabilities();
        this.cumRateProbs[0] = categoryProbabilities[0];
        int i = 1;
        while (i < categoryProbabilities.length) {
            this.cumRateProbs[i] = this.cumRateProbs[i - 1] + categoryProbabilities[i];
            ++i;
        }
        int i2 = 0;
        while (i2 < this.numSites) {
            this.rateAtSite[i2] = this.randomChoice(this.cumRateProbs);
            ++i2;
        }
    }

    private int randomChoice(double[] cf) {
        int s;
        double rnd = this.rng.nextDouble();
        if (rnd <= cf[0]) {
            s = 0;
        } else {
            s = 1;
            while (s < cf.length) {
                if (rnd <= cf[s] && rnd > cf[s - 1]) break;
                ++s;
            }
        }
        return s;
    }

    public static final class Factory {
        private int sequenceLength_;
        private SubstitutionModel model_;

        public Factory(int sequenceLength, SubstitutionModel model) {
            if (sequenceLength < 1) {
                throw new IllegalArgumentException("Invalid sequence length:" + sequenceLength);
            }
            this.sequenceLength_ = sequenceLength;
            this.model_ = model;
        }

        public final SimulatedAlignment generateAlignment(Tree tree) {
            if (tree.getUnits() != 0 && tree.getUnits() != 5) {
                throw new IllegalArgumentException("Tree units must be Expected Substitutions (or reluctantly Unknown)");
            }
            SimulatedAlignment sa = new SimulatedAlignment(this.sequenceLength_, tree, this.model_);
            sa.simulate();
            return sa;
        }

        public final SimulatedAlignment[] generateAlignments(Tree[] trees, AlgorithmCallback callback) {
            SimulatedAlignment[] as = new SimulatedAlignment[trees.length];
            int i = 0;
            while (i < trees.length) {
                if (callback.isPleaseStop()) {
                    SimulatedAlignment[] partial = new SimulatedAlignment[i];
                    System.arraycopy(as, 0, partial, 0, i);
                    return partial;
                }
                as[i] = this.generateAlignment(trees[i]);
                as[i].simulate();
                callback.updateProgress((double)i / (double)trees.length);
                ++i;
            }
            callback.clearProgress();
            return as;
        }
    }
}

