/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.continuous;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.treedatalikelihood.DataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.LikelihoodTreeTraversal;
import dr.evomodel.treedatalikelihood.ProcessOnTreeDelegate;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.TreeTraversal;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitPartialsProvider;
import dr.evomodel.treedatalikelihood.continuous.EmptyTraitDataModel;
import dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator;
import dr.evomodel.treedatalikelihood.continuous.cdi.MultivariateIntegrator;
import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType;
import dr.evomodel.treedatalikelihood.preorder.AbstractRealizedContinuousTraitDelegate;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.loggers.NumberColumn;
import dr.inference.model.AbstractModel;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Model;
import dr.inference.model.Variable;
import dr.math.distributions.WishartSufficientStatistics;
import dr.math.interfaces.ConjugateWishartStatisticsProvider;
import dr.math.matrixAlgebra.Matrix;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.Reportable;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.Arrays;
import java.util.List;
import org.ejml.data.D1Matrix64F;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

public class WishartStatisticsWrapper
extends AbstractModel
implements ConjugateWishartStatisticsProvider,
Loggable,
Reportable {
    public static final String PARSER_NAME = "wishartStatistics";
    public static final String TRAIT_NAME = "traitName";
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private final XMLSyntaxRule[] syntax = new XMLSyntaxRule[]{new ElementRule(TreeDataLikelihood.class), AttributeRule.newStringRule("traitName", true)};

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            String string = xMLObject.hasId() ? xMLObject.getId() : WishartStatisticsWrapper.PARSER_NAME;
            String string2 = xMLObject.getAttribute(WishartStatisticsWrapper.TRAIT_NAME, "trait");
            TreeDataLikelihood treeDataLikelihood = (TreeDataLikelihood)xMLObject.getChild(TreeDataLikelihood.class);
            DataLikelihoodDelegate dataLikelihoodDelegate = treeDataLikelihood.getDataLikelihoodDelegate();
            if (!(dataLikelihoodDelegate instanceof ContinuousDataLikelihoodDelegate)) {
                throw new XMLParseException("May not provide a sequence data likelihood in the precision Gibbs sampler");
            }
            ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate = (ContinuousDataLikelihoodDelegate)dataLikelihoodDelegate;
            return new WishartStatisticsWrapper(string, string2, treeDataLikelihood, continuousDataLikelihoodDelegate);
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.syntax;
        }

        @Override
        public String getParserDescription() {
            return null;
        }

        @Override
        public Class getReturnType() {
            return WishartStatisticsWrapper.class;
        }

        @Override
        public String getParserName() {
            return WishartStatisticsWrapper.PARSER_NAME;
        }
    };
    private final LikelihoodTreeTraversal treeTraversalDelegate;
    private final TreeTrait tipSampleTrait;
    private final int dimTrait;
    private final int numTrait;
    private final int tipCount;
    private final int dimPartial;
    private final ContinuousDataLikelihoodDelegate likelihoodDelegate;
    private final ContinuousDataLikelihoodDelegate outerProductDelegate;
    private final TreeDataLikelihood dataLikelihood;
    private boolean traitDataKnown;
    private boolean outerProductsKnown;
    private boolean savedTraitDataKnown;
    private boolean savedOuterProductsKnown;
    private WishartSufficientStatistics wishartStatistics;
    private WishartSufficientStatistics savedWishartStatistics;
    private static final boolean DEBUG = false;

    public WishartStatisticsWrapper(String string, String string2, TreeDataLikelihood treeDataLikelihood, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate) {
        super(string);
        this.dataLikelihood = treeDataLikelihood;
        this.likelihoodDelegate = continuousDataLikelihoodDelegate;
        this.dimTrait = continuousDataLikelihoodDelegate.getTraitDim();
        this.numTrait = continuousDataLikelihoodDelegate.getTraitCount();
        this.tipCount = treeDataLikelihood.getTree().getExternalNodeCount();
        this.dimPartial = this.dimTrait + 1;
        treeDataLikelihood.addModelListener(this);
        treeDataLikelihood.addModel(this);
        String string3 = AbstractRealizedContinuousTraitDelegate.getTipTraitName(string2);
        this.tipSampleTrait = treeDataLikelihood.getTreeTrait(string3);
        this.treeTraversalDelegate = new LikelihoodTreeTraversal(treeDataLikelihood.getTree(), treeDataLikelihood.getBranchRateModel(), TreeTraversal.TraversalType.POST_ORDER);
        if (continuousDataLikelihoodDelegate.getIntegrator() instanceof MultivariateIntegrator) {
            ContinuousTraitPartialsProvider continuousTraitPartialsProvider = continuousDataLikelihoodDelegate.getDataModel();
            if (!continuousTraitPartialsProvider.suppliesWishartStatistics()) {
                continuousTraitPartialsProvider = new EmptyTraitDataModel(string2, continuousTraitPartialsProvider.getParameter(), continuousTraitPartialsProvider.getTraitDimension(), PrecisionType.SCALAR);
            }
            this.outerProductDelegate = ContinuousDataLikelihoodDelegate.createObservedDataOnly(continuousDataLikelihoodDelegate, continuousTraitPartialsProvider);
        } else {
            this.outerProductDelegate = continuousDataLikelihoodDelegate;
        }
        this.traitDataKnown = false;
        this.outerProductsKnown = false;
    }

    @Override
    public WishartSufficientStatistics getWishartStatistics() {
        if (!this.outerProductsKnown) {
            this.computeOuterProducts();
            this.outerProductsKnown = true;
        }
        return this.wishartStatistics;
    }

    public void simulateMissingTraits() {
        int n;
        this.likelihoodDelegate.fireModelChanged();
        double[] dArray = (double[])this.tipSampleTrait.getTrait(this.dataLikelihood.getTree(), null);
        ContinuousDiffusionIntegrator continuousDiffusionIntegrator = this.outerProductDelegate.getIntegrator();
        assert (continuousDiffusionIntegrator instanceof ContinuousDiffusionIntegrator.Basic);
        double[] dArray2 = new double[this.dimPartial * this.numTrait];
        for (n = 0; n < this.numTrait; ++n) {
            dArray2[n * this.dimPartial + this.dimTrait] = Double.POSITIVE_INFINITY;
        }
        for (n = 0; n < this.tipCount; ++n) {
            int n2 = n * this.dimTrait * this.numTrait;
            int n3 = 0;
            for (int i = 0; i < this.numTrait; ++i) {
                System.arraycopy(dArray, n2, dArray2, n3, this.dimTrait);
                n2 += this.dimTrait;
                n3 += this.dimPartial;
            }
            this.outerProductDelegate.setTipDataDirectly(n, dArray2);
        }
    }

    private void computeOuterProducts() {
        this.dataLikelihood.getLogLikelihood();
        if (this.likelihoodDelegate != this.outerProductDelegate) {
            this.simulateMissingTraits();
        }
        this.treeTraversalDelegate.updateAllNodes();
        this.treeTraversalDelegate.dispatchTreeTraversalCollectBranchAndNodeOperations();
        List<ProcessOnTreeDelegate.BranchOperation> list = this.treeTraversalDelegate.getBranchOperations();
        List<ProcessOnTreeDelegate.NodeOperation> list2 = this.treeTraversalDelegate.getNodeOperations();
        NodeRef nodeRef = this.dataLikelihood.getTree().getRoot();
        this.outerProductDelegate.setComputeWishartStatistics(true);
        this.outerProductDelegate.calculateLikelihood(list, list2, nodeRef.getNumber());
        this.outerProductDelegate.setComputeWishartStatistics(false);
        this.wishartStatistics = this.outerProductDelegate.getWishartStatistics();
    }

    @Override
    public MatrixParameterInterface getPrecisionParameter() {
        return this.likelihoodDelegate.getDiffusionModel().getPrecisionParameter();
    }

    @Override
    protected void storeState() {
        this.savedTraitDataKnown = this.traitDataKnown;
        this.savedOuterProductsKnown = this.outerProductsKnown;
        if (this.outerProductsKnown) {
            if (this.savedWishartStatistics == null) {
                this.savedWishartStatistics = this.wishartStatistics.clone();
            } else {
                this.wishartStatistics.copyTo(this.savedWishartStatistics);
            }
        }
    }

    @Override
    protected void restoreState() {
        this.traitDataKnown = this.savedTraitDataKnown;
        this.outerProductsKnown = this.savedOuterProductsKnown;
        if (this.outerProductsKnown) {
            WishartSufficientStatistics wishartSufficientStatistics = this.wishartStatistics;
            this.wishartStatistics = this.savedWishartStatistics;
            this.savedWishartStatistics = wishartSufficientStatistics;
        }
    }

    @Override
    protected void acceptState() {
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        this.outerProductsKnown = false;
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        this.outerProductsKnown = false;
    }

    @Override
    public LogColumn[] getColumns() {
        int n;
        Object[] objectArray;
        int n2 = 0;
        if (this.tipSampleTrait != null) {
            objectArray = (double[])this.tipSampleTrait.getTrait(this.dataLikelihood.getTree(), null);
            n2 = objectArray.length;
        }
        objectArray = new LogColumn[this.dimTrait * this.dimTrait + n2];
        int n3 = 0;
        for (n = 0; n < this.dimTrait; ++n) {
            for (int i = 0; i < this.dimTrait; ++i) {
                objectArray[n3] = (double)new OuterProductColumn("OP" + (n + 1) + "" + (i + 1), n3);
                ++n3;
            }
        }
        for (n = 0; n < n2; ++n) {
            objectArray[n3] = (double)new TipSampleColumn("TIP" + (n + 1), n);
            ++n3;
        }
        return objectArray;
    }

    @Override
    public String getReport() {
        WishartSufficientStatistics wishartSufficientStatistics = this.getWishartStatistics();
        double[][] dArray = this.likelihoodDelegate.getTreePrecision();
        int n = this.likelihoodDelegate.getTraitDim();
        int n2 = dArray.length;
        double[] dArray2 = (double[])this.tipSampleTrait.getTrait(this.dataLikelihood.getTree(), null);
        DenseMatrix64F denseMatrix64F = DenseMatrix64F.wrap(n2, n, dArray2);
        double[] dArray3 = this.likelihoodDelegate.getRootPrior().getMean();
        DenseMatrix64F denseMatrix64F2 = DenseMatrix64F.wrap(n, 1, dArray3);
        double[] dArray4 = new double[n2];
        Arrays.fill(dArray4, 1.0);
        DenseMatrix64F denseMatrix64F3 = DenseMatrix64F.wrap(n2, 1, dArray4);
        DenseMatrix64F denseMatrix64F4 = new DenseMatrix64F(n2, n);
        CommonOps.multTransB(denseMatrix64F3, denseMatrix64F2, denseMatrix64F4);
        CommonOps.add((D1Matrix64F)denseMatrix64F, -1.0, (D1Matrix64F)denseMatrix64F4);
        DenseMatrix64F denseMatrix64F5 = new DenseMatrix64F(n2, n2);
        for (int i = 0; i < n2; ++i) {
            for (int j = 0; j < n2; ++j) {
                denseMatrix64F5.set(i, j, dArray[i][j]);
            }
        }
        DenseMatrix64F denseMatrix64F6 = new DenseMatrix64F(n2, n);
        CommonOps.mult(denseMatrix64F5, denseMatrix64F, denseMatrix64F6);
        DenseMatrix64F denseMatrix64F7 = new DenseMatrix64F(n, n);
        CommonOps.multTransA(denseMatrix64F, denseMatrix64F6, denseMatrix64F7);
        StringBuilder stringBuilder = new StringBuilder("WishartStatisticsWrapper report:\n\n");
        stringBuilder.append("Scale matrix (naive):\n");
        stringBuilder.append(new Matrix(denseMatrix64F7.data, n, n));
        stringBuilder.append("\n");
        stringBuilder.append("Scale matrix (recursive):\n");
        stringBuilder.append(new Matrix(wishartSufficientStatistics.getScaleMatrix(), n, n));
        stringBuilder.append("\n\n");
        return stringBuilder.toString();
    }

    private class OuterProductColumn
    extends NumberColumn {
        private int index;

        private OuterProductColumn(String string, int n) {
            super(string);
            this.index = n;
        }

        @Override
        public double getDoubleValue() {
            WishartSufficientStatistics wishartSufficientStatistics = WishartStatisticsWrapper.this.getWishartStatistics();
            return wishartSufficientStatistics.getScaleMatrix()[this.index];
        }
    }

    private class TipSampleColumn
    extends NumberColumn {
        private int index;

        private TipSampleColumn(String string, int n) {
            super(string);
            this.index = n;
        }

        @Override
        public double getDoubleValue() {
            double[] dArray = (double[])WishartStatisticsWrapper.this.tipSampleTrait.getTrait(WishartStatisticsWrapper.this.dataLikelihood.getTree(), null);
            return dArray[this.index];
        }
    }
}

