/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.xgboost;

import com.devsmart.ubjson.GsonUtil;
import com.devsmart.ubjson.UBArray;
import com.devsmart.ubjson.UBObject;
import com.devsmart.ubjson.UBValue;
import com.google.common.primitives.Floats;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashSet;
import java.util.Set;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.Schema;
import org.jpmml.xgboost.GradientBooster;
import org.jpmml.xgboost.ObjFunction;
import org.jpmml.xgboost.RegTree;
import org.jpmml.xgboost.UBJSONUtil;
import org.jpmml.xgboost.XGBoostDataInput;

public class GBTree
extends GradientBooster {
    private int num_trees;
    private int num_roots;
    private int num_feature;
    private int num_output_group;
    private int size_leaf_vector;
    private RegTree[] trees;
    private int[] tree_info;

    @Override
    public String getAlgorithmName() {
        return "GBTree";
    }

    @Override
    public void loadBinary(XGBoostDataInput input) throws IOException {
        this.num_trees = input.readInt();
        this.num_roots = input.readInt();
        this.num_feature = input.readInt();
        input.readReserved(3);
        this.num_output_group = input.readInt();
        this.size_leaf_vector = input.readInt();
        input.readReserved(32);
        this.trees = (RegTree[])input.readObjectArray(RegTree.class, this.num_trees);
        this.tree_info = input.readIntArray(this.num_trees);
    }

    @Override
    public void loadJSON(JsonObject gradientBooster) {
        UBValue value = GsonUtil.toUBValue((JsonElement)gradientBooster);
        this.loadUBJSON(value.asObject());
    }

    @Override
    public void loadUBJSON(UBObject gradientBooster) {
        UBObject model = gradientBooster.get((Object)"model").asObject();
        UBObject gbtreeModelParam = model.get((Object)"gbtree_model_param").asObject();
        this.num_trees = gbtreeModelParam.get((Object)"num_trees").asInt();
        this.size_leaf_vector = gbtreeModelParam.get((Object)"size_leaf_vector").asInt();
        UBArray trees = model.get((Object)"trees").asArray();
        this.trees = new RegTree[this.num_trees];
        for (int i = 0; i < this.num_trees; ++i) {
            UBObject tree = trees.get(i).asObject();
            this.trees[i] = new RegTree();
            this.trees[i].loadUBJSON(tree);
        }
        this.tree_info = UBJSONUtil.toIntArray(model.get((Object)"tree_info"));
    }

    public boolean hasCategoricalSplits() {
        for (int i = 0; i < this.num_trees; ++i) {
            RegTree tree = this.trees[i];
            if (!tree.hasCategoricalSplits()) continue;
            return true;
        }
        return false;
    }

    public Set<Integer> getSplitType(int splitIndex) {
        HashSet<Integer> result = new HashSet<Integer>();
        for (int i = 0; i < this.num_trees; ++i) {
            RegTree tree = this.trees[i];
            result.addAll(tree.getSplitType(splitIndex));
        }
        return result;
    }

    public BitSet getSplitCategories(int splitIndex) {
        BitSet result = null;
        for (int i = 0; i < this.num_trees; ++i) {
            RegTree tree = this.trees[i];
            BitSet splitCategories = tree.getSplitCategories(splitIndex);
            if (splitCategories == null) continue;
            if (result == null) {
                result = new BitSet();
            }
            result.or(splitCategories);
        }
        return result;
    }

    public MiningModel encodeMiningModel(ObjFunction obj, float base_score, Integer ntreeLimit, boolean numeric, Schema schema) {
        RegTree[] trees = this.trees();
        float[] weights = this.tree_weights();
        return obj.encodeMiningModel(Arrays.asList(trees), weights != null ? Floats.asList((float[])weights) : null, base_score, ntreeLimit, numeric, schema);
    }

    public int num_trees() {
        return this.num_trees;
    }

    public RegTree[] trees() {
        return this.trees;
    }

    public float[] tree_weights() {
        return null;
    }
}

