/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;

public class HDPGibbsSampler {
    public double beta = 0.5;
    public double gamma = 1.5;
    public double alpha = 1.0;
    private Random random = new Random();
    private double[] p;
    private double[] f;
    protected DOCState[] docStates;
    protected int[] numberOfTablesByTopic;
    protected int[] wordCountByTopic;
    protected int[][] wordCountByTopicAndTerm;
    protected int sizeOfVocabulary;
    protected int totalNumberOfWords;
    protected int numberOfTopics = 1;
    protected int totalNumberOfTables;
    private InstanceList data;

    public void addInstances(InstanceList corpus) {
        DOCState docState;
        int k;
        this.data = corpus;
        this.sizeOfVocabulary = corpus.getDataAlphabet().size();
        this.totalNumberOfWords = 0;
        this.docStates = new DOCState[corpus.size()];
        for (int d = 0; d < corpus.size(); ++d) {
            this.docStates[d] = new DOCState((Instance)corpus.get(d), d);
            FeatureSequence tokens = (FeatureSequence)((Instance)corpus.get(d)).getData();
            for (int position = 0; position < tokens.getLength(); ++position) {
                ++this.totalNumberOfWords;
            }
        }
        this.p = new double[20];
        this.f = new double[20];
        this.numberOfTablesByTopic = new int[this.numberOfTopics + 1];
        this.wordCountByTopic = new int[this.numberOfTopics + 1];
        this.wordCountByTopicAndTerm = new int[this.numberOfTopics + 1][];
        for (k = 0; k <= this.numberOfTopics; ++k) {
            this.wordCountByTopicAndTerm[k] = new int[this.sizeOfVocabulary];
        }
        for (k = 0; k < this.numberOfTopics; ++k) {
            docState = this.docStates[k];
            for (int i = 0; i < docState.documentLength; ++i) {
                this.addWord(docState.docID, i, 0, k);
            }
        }
        for (int j = this.numberOfTopics; j < this.docStates.length; ++j) {
            docState = this.docStates[j];
            k = this.random.nextInt(this.numberOfTopics);
            for (int i = 0; i < docState.documentLength; ++i) {
                this.addWord(docState.docID, i, 0, k);
            }
        }
    }

    protected void nextGibbsSweep() {
        for (int d = 0; d < this.docStates.length; ++d) {
            for (int i = 0; i < this.docStates[d].documentLength; ++i) {
                this.removeWord(d, i);
                int table = this.sampleTable(d, i);
                if (table == this.docStates[d].numberOfTables) {
                    this.addWord(d, i, table, this.sampleTopic());
                    continue;
                }
                this.addWord(d, i, table, this.docStates[d].tableToTopic[table]);
            }
        }
        this.defragment();
    }

    private int sampleTopic() {
        int k;
        double pSum = 0.0;
        this.p = HDPGibbsSampler.ensureCapacity(this.p, this.numberOfTopics);
        for (k = 0; k < this.numberOfTopics; ++k) {
            this.p[k] = pSum += (double)this.numberOfTablesByTopic[k] * this.f[k];
        }
        this.p[this.numberOfTopics] = pSum += this.gamma / (double)this.sizeOfVocabulary;
        double u = this.random.nextDouble() * pSum;
        for (k = 0; k <= this.numberOfTopics && !(u < this.p[k]); ++k) {
        }
        return k;
    }

    int sampleTable(int docID, int i) {
        int j;
        double pSum = 0.0;
        double vb = (double)this.sizeOfVocabulary * this.beta;
        DOCState docState = this.docStates[docID];
        this.f = HDPGibbsSampler.ensureCapacity(this.f, this.numberOfTopics);
        this.p = HDPGibbsSampler.ensureCapacity(this.p, docState.numberOfTables);
        double fNew = this.gamma / (double)this.sizeOfVocabulary;
        for (int k = 0; k < this.numberOfTopics; ++k) {
            this.f[k] = ((double)this.wordCountByTopicAndTerm[k][docState.words[i].termIndex] + this.beta) / ((double)this.wordCountByTopic[k] + vb);
            fNew += (double)this.numberOfTablesByTopic[k] * this.f[k];
        }
        for (j = 0; j < docState.numberOfTables; ++j) {
            if (docState.wordCountByTable[j] > 0) {
                pSum += (double)docState.wordCountByTable[j] * this.f[docState.tableToTopic[j]];
            }
            this.p[j] = pSum;
        }
        this.p[docState.numberOfTables] = pSum += this.alpha * fNew / ((double)this.totalNumberOfTables + this.gamma);
        double u = this.random.nextDouble() * pSum;
        for (j = 0; j <= docState.numberOfTables && !(u < this.p[j]); ++j) {
        }
        return j;
    }

    public void run(int shuffleLag, int maxIter, PrintStream log) throws IOException {
        for (int iter = 0; iter < maxIter; ++iter) {
            if (shuffleLag > 0 && iter > 0 && iter % shuffleLag == 0) {
                this.doShuffle();
            }
            this.nextGibbsSweep();
            log.println("iter = " + iter + " #topics = " + this.numberOfTopics + ", #tables = " + this.totalNumberOfTables);
        }
    }

    protected void removeWord(int docID, int i) {
        DOCState docState = this.docStates[docID];
        int table = docState.words[i].tableAssignment;
        int k = docState.tableToTopic[table];
        int n = table;
        docState.wordCountByTable[n] = docState.wordCountByTable[n] - 1;
        int n2 = k;
        this.wordCountByTopic[n2] = this.wordCountByTopic[n2] - 1;
        int[] nArray = this.wordCountByTopicAndTerm[k];
        int n3 = docState.words[i].termIndex;
        nArray[n3] = nArray[n3] - 1;
        if (docState.wordCountByTable[table] == 0) {
            --this.totalNumberOfTables;
            int n4 = k;
            this.numberOfTablesByTopic[n4] = this.numberOfTablesByTopic[n4] - 1;
            int n5 = table;
            docState.tableToTopic[n5] = docState.tableToTopic[n5] - 1;
        }
    }

    protected void addWord(int docID, int i, int table, int k) {
        DOCState docState = this.docStates[docID];
        docState.words[i].tableAssignment = table;
        int n = table;
        docState.wordCountByTable[n] = docState.wordCountByTable[n] + 1;
        int n2 = k;
        this.wordCountByTopic[n2] = this.wordCountByTopic[n2] + 1;
        int[] nArray = this.wordCountByTopicAndTerm[k];
        int n3 = docState.words[i].termIndex;
        nArray[n3] = nArray[n3] + 1;
        if (docState.wordCountByTable[table] == 1) {
            ++docState.numberOfTables;
            docState.tableToTopic[table] = k;
            ++this.totalNumberOfTables;
            int n4 = k;
            this.numberOfTablesByTopic[n4] = this.numberOfTablesByTopic[n4] + 1;
            docState.tableToTopic = HDPGibbsSampler.ensureCapacity(docState.tableToTopic, docState.numberOfTables);
            docState.wordCountByTable = HDPGibbsSampler.ensureCapacity(docState.wordCountByTable, docState.numberOfTables);
            if (k == this.numberOfTopics) {
                ++this.numberOfTopics;
                this.numberOfTablesByTopic = HDPGibbsSampler.ensureCapacity(this.numberOfTablesByTopic, this.numberOfTopics);
                this.wordCountByTopic = HDPGibbsSampler.ensureCapacity(this.wordCountByTopic, this.numberOfTopics);
                this.wordCountByTopicAndTerm = HDPGibbsSampler.add(this.wordCountByTopicAndTerm, new int[this.sizeOfVocabulary], this.numberOfTopics);
            }
        }
    }

    protected void defragment() {
        int[] kOldToKNew = new int[this.numberOfTopics];
        int newNumberOfTopics = 0;
        for (int k = 0; k < this.numberOfTopics; ++k) {
            if (this.wordCountByTopic[k] <= 0) continue;
            kOldToKNew[k] = newNumberOfTopics;
            HDPGibbsSampler.swap(this.wordCountByTopic, newNumberOfTopics, k);
            HDPGibbsSampler.swap(this.numberOfTablesByTopic, newNumberOfTopics, k);
            HDPGibbsSampler.swap(this.wordCountByTopicAndTerm, newNumberOfTopics, k);
            ++newNumberOfTopics;
        }
        this.numberOfTopics = newNumberOfTopics;
        for (int j = 0; j < this.docStates.length; ++j) {
            this.docStates[j].defragment(kOldToKNew);
        }
    }

    protected void doShuffle() {
        List<DOCState> h = Arrays.asList(this.docStates);
        Collections.shuffle(h);
        this.docStates = h.toArray(new DOCState[h.size()]);
        for (int j = 0; j < this.docStates.length; ++j) {
            List<WordState> h2 = Arrays.asList(this.docStates[j].words);
            Collections.shuffle(h2);
            this.docStates[j].words = h2.toArray(new WordState[h2.size()]);
        }
    }

    public void printDocumentTopics(PrintStream out, double threshold, int maxNumberOfTopics) {
        out.println("#doc name topic proportion ...");
        Object[] sortedTopics = new IDSorter[this.numberOfTopics];
        for (int k = 0; k < this.numberOfTopics; ++k) {
            sortedTopics[k] = new IDSorter(k, k);
        }
        if (maxNumberOfTopics <= 0 || maxNumberOfTopics > this.numberOfTopics) {
            maxNumberOfTopics = this.numberOfTopics;
        }
        for (int d = 0; d < this.docStates.length; ++d) {
            int k;
            DOCState doc = this.docStates[d];
            int[] topicCounts = new int[this.numberOfTopics];
            out.print(d + "    ");
            String source = "NA";
            if (((Instance)this.data.get(d)).getSource() != null) {
                source = ((Instance)this.data.get(d)).getSource().toString();
            }
            out.print(source + "    ");
            for (int i = 0; i < doc.documentLength; ++i) {
                int n = doc.tableToTopic[doc.words[i].tableAssignment];
                topicCounts[n] = topicCounts[n] + 1;
            }
            for (k = 0; k < this.numberOfTopics; ++k) {
                ((IDSorter)sortedTopics[k]).set(k, (double)topicCounts[k] / ((double)doc.documentLength * 1.0));
            }
            Arrays.sort(sortedTopics);
            for (k = 0; k < maxNumberOfTopics && !(((IDSorter)sortedTopics[k]).getWeight() < threshold); ++k) {
                out.print(((IDSorter)sortedTopics[k]).getID() + "    " + ((IDSorter)sortedTopics[k]).getWeight() + "    ");
            }
            out.println();
        }
    }

    public void printTopicWords(PrintStream out, int maxNumberOfWords) {
        if (maxNumberOfWords <= 0 || maxNumberOfWords > this.sizeOfVocabulary) {
            maxNumberOfWords = this.sizeOfVocabulary;
        }
        Object[] sortedWords = new IDSorter[this.sizeOfVocabulary];
        for (int v = 0; v < this.sizeOfVocabulary; ++v) {
            sortedWords[v] = new IDSorter(v, v);
        }
        for (int k = 0; k < this.numberOfTopics; ++k) {
            int v;
            out.print(k + "    ");
            for (v = 0; v < this.sizeOfVocabulary; ++v) {
                ((IDSorter)sortedWords[v]).set(v, this.wordCountByTopicAndTerm[k][v]);
            }
            Arrays.sort(sortedWords);
            for (v = 0; v < maxNumberOfWords; ++v) {
                out.print(this.data.getAlphabet().lookupObject(((IDSorter)sortedWords[v]).getID()) + "    ");
            }
            out.println();
        }
    }

    public void printState(PrintStream out) {
        out.println("#doc source pos typeindex type topic");
        out.println("#alpha : " + this.alpha);
        out.println("#beta : " + this.beta);
        out.println("#gamma : " + this.gamma);
        for (int d = 0; d < this.docStates.length; ++d) {
            String source = "NA";
            if (((Instance)this.data.get(d)).getSource() != null) {
                source = ((Instance)this.data.get(d)).getSource().toString();
            }
            DOCState doc = this.docStates[d];
            for (int i = 0; i < doc.documentLength; ++i) {
                int term = doc.words[i].termIndex;
                out.print(d + " ");
                out.print(source + " ");
                out.print(i + " ");
                out.print(term + " ");
                out.print(this.data.getAlphabet().lookupObject(term) + " ");
                out.println(doc.tableToTopic[doc.words[i].tableAssignment]);
            }
        }
    }

    public static void swap(int[] arr, int arg1, int arg2) {
        int t = arr[arg1];
        arr[arg1] = arr[arg2];
        arr[arg2] = t;
    }

    public static void swap(int[][] arr, int arg1, int arg2) {
        int[] t = arr[arg1];
        arr[arg1] = arr[arg2];
        arr[arg2] = t;
    }

    public static double[] ensureCapacity(double[] arr, int min) {
        int length = arr.length;
        if (min < length) {
            return arr;
        }
        double[] arr2 = new double[min * 2];
        for (int i = 0; i < length; ++i) {
            arr2[i] = arr[i];
        }
        return arr2;
    }

    public static int[] ensureCapacity(int[] arr, int min) {
        int length = arr.length;
        if (min < length) {
            return arr;
        }
        int[] arr2 = new int[min * 2];
        for (int i = 0; i < length; ++i) {
            arr2[i] = arr[i];
        }
        return arr2;
    }

    public static int[][] add(int[][] arr, int[] newElement, int index) {
        int length = ((int[][])arr).length;
        if (length <= index) {
            int[][] arr2 = new int[index * 2][];
            for (int i = 0; i < length; ++i) {
                arr2[i] = arr[i];
            }
            arr = arr2;
        }
        arr[index] = newElement;
        return arr;
    }

    public static void main(String[] args) throws IOException {
        int iter = 0;
        String inputFile = null;
        String outputDir = null;
        HDPGibbsSampler state = new HDPGibbsSampler();
        try {
            state.beta = Double.parseDouble(args[0]);
            state.alpha = Double.parseDouble(args[1]);
            state.gamma = Double.parseDouble(args[2]);
            iter = Integer.parseInt(args[3]);
            inputFile = args[4];
            state.numberOfTopics = Integer.parseInt(args[5]);
            outputDir = args[6];
        }
        catch (Exception e) {
            System.out.println("CRF Gibbs sampling for the Hierarchical Dirichlet Processes");
            System.out.println("The application nees the folowing params in exact order");
            System.out.println("beta alpha gamma iterations inputFile initialNumberOfTOpics outputDir");
            System.out.println("Example:");
            System.out.println("HDP 0.5 1.5 1.0 2000 ./topic-input.mallet 5 ./output/ ");
            System.exit(0);
        }
        state.addInstances(InstanceList.load(new File(inputFile)));
        System.out.println("sizeOfVocabulary=" + state.sizeOfVocabulary);
        System.out.println("totalNumberOfWords=" + state.totalNumberOfWords);
        System.out.println("NumberOfDocs=" + state.docStates.length);
        state.run(0, iter, System.out);
        state.printState(new PrintStream(new File(outputDir + "state.txt")));
        state.printDocumentTopics(new PrintStream(new File(outputDir + "topics.txt")), 1.0E-4, 0);
        state.printTopicWords(new PrintStream(new File(outputDir + "words.txt")), 10);
    }

    class WordState {
        int termIndex;
        int tableAssignment;

        public WordState(int wordIndex, int tableAssignment) {
            this.termIndex = wordIndex;
            this.tableAssignment = tableAssignment;
        }
    }

    class DOCState {
        int docID;
        int documentLength;
        int numberOfTables;
        int[] tableToTopic;
        int[] wordCountByTable;
        WordState[] words;

        public DOCState(Instance instance, int docID) {
            this.docID = docID;
            this.numberOfTables = 0;
            FeatureSequence tokens = (FeatureSequence)instance.getData();
            this.documentLength = tokens.getLength();
            this.words = new WordState[this.documentLength];
            this.wordCountByTable = new int[2];
            this.tableToTopic = new int[2];
            for (int position = 0; position < this.documentLength; ++position) {
                this.words[position] = new WordState(tokens.getIndexAtPosition(position), -1);
            }
        }

        public void defragment(int[] kOldToKNew) {
            int[] tOldToTNew = new int[this.numberOfTables];
            int newNumberOfTables = 0;
            for (int t = 0; t < this.numberOfTables; ++t) {
                if (this.wordCountByTable[t] > 0) {
                    tOldToTNew[t] = newNumberOfTables;
                    this.tableToTopic[newNumberOfTables] = kOldToKNew[this.tableToTopic[t]];
                    HDPGibbsSampler.swap(this.wordCountByTable, newNumberOfTables, t);
                    ++newNumberOfTables;
                    continue;
                }
                this.tableToTopic[t] = -1;
            }
            this.numberOfTables = newNumberOfTables;
            for (int i = 0; i < this.documentLength; ++i) {
                this.words[i].tableAssignment = tOldToTNew[this.words[i].tableAssignment];
            }
        }
    }
}

