/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.classify.MaxEnt;
import cc.mallet.topics.TopicInferencer;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.Randoms;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;

public class DMRInferencer
extends TopicInferencer
implements Serializable {
    protected MaxEnt dmrParameters = null;
    protected int numFeatures;
    protected int defaultFeatureIndex;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 0;
    private static final int NULL_INTEGER = -1;

    public DMRInferencer(int[][] typeTopicCounts, int[] tokensPerTopic, MaxEnt dmrParameters, Alphabet alphabet, double beta, double betaSum) {
        this.dmrParameters = dmrParameters;
        this.numFeatures = dmrParameters.getAlphabet().size();
        this.defaultFeatureIndex = dmrParameters.getDefaultFeatureIndex();
        this.tokensPerTopic = tokensPerTopic;
        this.typeTopicCounts = typeTopicCounts;
        this.alphabet = alphabet;
        this.numTopics = tokensPerTopic.length;
        this.numTypes = typeTopicCounts.length;
        if (Integer.bitCount(this.numTopics) == 1) {
            this.topicMask = this.numTopics - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        } else {
            this.topicMask = Integer.highestOneBit(this.numTopics) * 2 - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        }
        this.beta = beta;
        this.betaSum = betaSum;
        this.cachedCoefficients = new double[this.numTopics];
        this.alpha = new double[this.numTopics];
        for (int topic = 0; topic < this.numTopics; ++topic) {
            this.smoothingOnlyMass += this.alpha[topic] * beta / ((double)tokensPerTopic[topic] + betaSum);
            this.cachedCoefficients[topic] = this.alpha[topic] / ((double)tokensPerTopic[topic] + betaSum);
        }
        this.random = new Randoms();
    }

    @Override
    public double[] getSampledDistribution(Instance instance, int numIterations, int thinning, int burnIn) {
        FeatureVector features = (FeatureVector)instance.getTarget();
        double[] parameters = this.dmrParameters.getParameters();
        for (int topic = 0; topic < this.numTopics; ++topic) {
            this.alpha[topic] = parameters[topic * this.numFeatures + this.defaultFeatureIndex] + MatrixOps.rowDotProduct(parameters, this.numFeatures, topic, features, this.defaultFeatureIndex, null);
            this.alpha[topic] = Math.exp(this.alpha[topic]);
            this.cachedCoefficients[topic] = this.alpha[topic] / ((double)this.tokensPerTopic[topic] + this.betaSum);
        }
        return super.getSampledDistribution(instance, numIterations, thinning, burnIn);
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        System.out.println("writing");
        out.writeInt(0);
        out.writeObject(this.dmrParameters);
        out.writeObject(this.alphabet);
        out.writeInt(this.numTopics);
        out.writeInt(this.topicMask);
        out.writeInt(this.topicBits);
        out.writeInt(this.numTypes);
        out.writeObject(this.alpha);
        out.writeDouble(this.beta);
        out.writeDouble(this.betaSum);
        out.writeObject(this.typeTopicCounts);
        out.writeObject(this.tokensPerTopic);
        out.writeObject(this.random);
        out.writeDouble(this.smoothingOnlyMass);
        out.writeObject(this.cachedCoefficients);
        System.out.println("done");
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        int version = in.readInt();
        this.dmrParameters = (MaxEnt)in.readObject();
        this.numFeatures = this.dmrParameters.getAlphabet().size();
        this.defaultFeatureIndex = this.dmrParameters.getDefaultFeatureIndex();
        this.alphabet = (Alphabet)in.readObject();
        this.numTopics = in.readInt();
        this.topicMask = in.readInt();
        this.topicBits = in.readInt();
        this.numTypes = in.readInt();
        this.alpha = (double[])in.readObject();
        this.beta = in.readDouble();
        this.betaSum = in.readDouble();
        this.typeTopicCounts = (int[][])in.readObject();
        this.tokensPerTopic = (int[])in.readObject();
        this.random = (Randoms)in.readObject();
        this.smoothingOnlyMass = in.readDouble();
        this.cachedCoefficients = (double[])in.readObject();
    }

    public static DMRInferencer read(File f) throws Exception {
        DMRInferencer inferencer = null;
        ObjectInputStream ois = new ObjectInputStream(new FileInputStream(f));
        inferencer = (DMRInferencer)ois.readObject();
        ois.close();
        return inferencer;
    }
}

