/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.types;

import cc.mallet.grmm.types.AbstractTableFactor;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.AssignmentIterator;
import cc.mallet.grmm.types.DiscreteFactor;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.HashVarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.grmm.util.Flops;
import cc.mallet.types.Matrix;
import cc.mallet.types.Matrixn;
import cc.mallet.types.SparseMatrixn;
import cc.mallet.util.Maths;
import java.util.Collection;

public class LogTableFactor
extends AbstractTableFactor {
    public LogTableFactor(AbstractTableFactor in) {
        super(in);
        this.probs = (Matrix)in.getLogValueMatrix().cloneMatrix();
    }

    public LogTableFactor(Variable var) {
        super(var);
    }

    public LogTableFactor(Variable[] allVars) {
        super(allVars);
    }

    public LogTableFactor(Collection allVars) {
        super(allVars);
    }

    private LogTableFactor(Variable[] vars, double[] logValues) {
        super(vars, logValues);
    }

    private LogTableFactor(Variable[] allVars, Matrix probsIn) {
        super(allVars, probsIn);
    }

    public static LogTableFactor makeFromValues(Variable[] vars, double[] vals) {
        double[] vals2 = new double[vals.length];
        for (int i = 0; i < vals.length; ++i) {
            vals2[i] = Math.log(vals[i]);
        }
        return LogTableFactor.makeFromLogValues(vars, vals2);
    }

    public static LogTableFactor makeFromLogValues(Variable[] vars, double[] vals) {
        return new LogTableFactor(vars, vals);
    }

    void setAsIdentity() {
        this.setAll(0.0);
    }

    public Factor duplicate() {
        return new LogTableFactor(this);
    }

    protected AbstractTableFactor createBlankSubset(Variable[] vars) {
        return new LogTableFactor(vars);
    }

    public Factor normalize() {
        double sum = this.logspaceOneNorm();
        if (sum < -500.0) {
            System.err.println("Attempt to normalize all-0 factor " + this.dumpToString());
        }
        for (int i = 0; i < this.probs.numLocations(); ++i) {
            double val = this.probs.valueAtLocation(i);
            this.probs.setValueAtLocation(i, val - sum);
        }
        return this;
    }

    private double logspaceOneNorm() {
        double sum = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < this.probs.numLocations(); ++i) {
            sum = Maths.sumLogProb(sum, this.probs.valueAtLocation(i));
        }
        Flops.sumLogProb(this.probs.numLocations());
        return sum;
    }

    public double sum() {
        Flops.exp();
        return Math.exp(this.logspaceOneNorm());
    }

    public double logsum() {
        return this.logspaceOneNorm();
    }

    protected void multiplyByInternal(DiscreteFactor ptl) {
        int[] projection = this.largeIdxToSmall(ptl);
        int numLocs = this.probs.numLocations();
        for (int singleLoc = 0; singleLoc < numLocs; ++singleLoc) {
            int smallIdx = projection[singleLoc];
            double prev = this.probs.valueAtLocation(singleLoc);
            double newVal = ptl.logValue(smallIdx);
            double product = prev + newVal;
            this.probs.setValueAtLocation(singleLoc, product);
        }
        Flops.increment(numLocs);
    }

    protected void divideByInternal(DiscreteFactor ptl) {
        int[] projection = this.largeIdxToSmall(ptl);
        int numLocs = this.probs.numLocations();
        for (int singleLoc = 0; singleLoc < numLocs; ++singleLoc) {
            int smallIdx = projection[singleLoc];
            double prev = this.probs.valueAtLocation(singleLoc);
            double newVal = ptl.logValue(smallIdx);
            double product = prev - newVal;
            if (Double.isInfinite(newVal)) {
                product = Double.NEGATIVE_INFINITY;
            }
            this.probs.setValueAtLocation(singleLoc, product);
        }
        Flops.increment(numLocs);
    }

    protected void plusEqualsInternal(DiscreteFactor ptl) {
        int[] projection = this.largeIdxToSmall(ptl);
        int numLocs = this.probs.numLocations();
        for (int singleLoc = 0; singleLoc < numLocs; ++singleLoc) {
            int smallIdx = projection[singleLoc];
            double prev = this.probs.valueAtLocation(singleLoc);
            double newVal = ptl.logValue(smallIdx);
            double product = Maths.sumLogProb(prev, newVal);
            this.probs.setValueAtLocation(singleLoc, product);
        }
        Flops.sumLogProb(numLocs);
    }

    public double value(Assignment assn) {
        Flops.exp();
        if (this.getNumVars() == 0) {
            return 1.0;
        }
        return Math.exp(this.rawValue(assn));
    }

    public double value(AssignmentIterator it) {
        Flops.exp();
        return Math.exp(this.rawValue(it.indexOfCurrentAssn()));
    }

    public double value(int idx) {
        Flops.exp();
        return Math.exp(this.rawValue(idx));
    }

    public double logValue(AssignmentIterator it) {
        return this.rawValue(it.indexOfCurrentAssn());
    }

    public double logValue(int idx) {
        return this.rawValue(idx);
    }

    public double logValue(Assignment assn) {
        return this.rawValue(assn);
    }

    protected Factor marginalizeInternal(AbstractTableFactor result) {
        result.setAll(Double.NEGATIVE_INFINITY);
        int[] projection = this.largeIdxToSmall(result);
        int numLocs = this.probs.numLocations();
        for (int largeLoc = 0; largeLoc < numLocs; ++largeLoc) {
            int smallIdx = projection[largeLoc];
            double oldValue = this.probs.valueAtLocation(largeLoc);
            double currentValue = result.probs.singleValue(smallIdx);
            result.probs.setValueAtLocation(smallIdx, Maths.sumLogProb(oldValue, currentValue));
        }
        Flops.sumLogProb(numLocs);
        return result;
    }

    protected double rawValue(Assignment assn) {
        int numVars = this.getNumVars();
        int[] indices = new int[numVars];
        for (int i = 0; i < numVars; ++i) {
            Variable var = this.getVariable(i);
            indices[i] = assn.get(var);
        }
        return this.rawValue(indices);
    }

    private double rawValue(int[] indices) {
        int singleIdx = this.probs.singleIndex(indices);
        return this.rawValue(singleIdx);
    }

    protected double rawValue(int singleIdx) {
        int loc = this.probs.location(singleIdx);
        if (loc < 0) {
            return Double.NEGATIVE_INFINITY;
        }
        return this.probs.valueAtLocation(loc);
    }

    public void exponentiate(double power) {
        Flops.increment(this.probs.numLocations());
        this.probs.timesEquals(power);
    }

    public void setLogValue(Assignment assn, double logValue) {
        this.setRawValue(assn, logValue);
    }

    public void setLogValue(AssignmentIterator assnIt, double logValue) {
        this.setRawValue(assnIt, logValue);
    }

    public void setValue(AssignmentIterator assnIt, double value) {
        Flops.log();
        this.setRawValue(assnIt, Math.log(value));
    }

    public void setLogValues(double[] vals) {
        for (int i = 0; i < vals.length; ++i) {
            this.setRawValue(i, vals[i]);
        }
    }

    public void setValues(double[] vals) {
        Flops.log(vals.length);
        for (int i = 0; i < vals.length; ++i) {
            this.setRawValue(i, Math.log(vals[i]));
        }
    }

    public void timesEquals(double v) {
        this.timesEqualsLog(Math.log(v));
    }

    private void timesEqualsLog(double logV) {
        Flops.increment(this.probs.numLocations());
        Matrix other = (Matrix)this.probs.cloneMatrix();
        other.setAll(logV);
        this.probs.plusEquals(other);
    }

    protected void plusEqualsAtLocation(int loc, double v) {
        Flops.log();
        Flops.sumLogProb(1);
        double oldVal = this.logValue(loc);
        this.setRawValue(loc, Maths.sumLogProb(oldVal, Math.log(v)));
    }

    public static LogTableFactor makeFromValues(Variable var, double[] vals2) {
        return LogTableFactor.makeFromValues(new Variable[]{var}, vals2);
    }

    public static LogTableFactor makeFromMatrix(Variable[] vars, SparseMatrixn values) {
        SparseMatrixn logValues = (SparseMatrixn)values.cloneMatrix();
        for (int i = 0; i < logValues.numLocations(); ++i) {
            logValues.setValueAtLocation(i, Math.log(logValues.valueAtLocation(i)));
        }
        Flops.log(logValues.numLocations());
        return new LogTableFactor(vars, (Matrix)logValues);
    }

    public static LogTableFactor makeFromLogMatrix(Variable[] vars, Matrix values) {
        Matrix logValues = (Matrix)values.cloneMatrix();
        return new LogTableFactor(vars, logValues);
    }

    public static LogTableFactor makeFromLogValues(Variable v, double[] vals) {
        return LogTableFactor.makeFromLogValues(new Variable[]{v}, vals);
    }

    public Matrix getValueMatrix() {
        Matrix logProbs = (Matrix)this.probs.cloneMatrix();
        for (int loc = 0; loc < this.probs.numLocations(); ++loc) {
            logProbs.setValueAtLocation(loc, Math.exp(logProbs.valueAtLocation(loc)));
        }
        Flops.exp(this.probs.numLocations());
        return logProbs;
    }

    public Matrix getLogValueMatrix() {
        return this.probs;
    }

    public double valueAtLocation(int idx) {
        Flops.exp();
        return Math.exp(this.probs.valueAtLocation(idx));
    }

    protected Factor slice_onevar(Variable var, Assignment observed) {
        Assignment assn = (Assignment)observed.duplicate();
        double[] vals = new double[var.getNumOutcomes()];
        for (int i = 0; i < var.getNumOutcomes(); ++i) {
            assn.setValue(var, i);
            vals[i] = this.logValue(assn);
        }
        return LogTableFactor.makeFromLogValues(var, vals);
    }

    protected Factor slice_twovar(Variable v1, Variable v2, Assignment observed) {
        Assignment assn = (Assignment)observed.duplicate();
        int N1 = v1.getNumOutcomes();
        int N2 = v2.getNumOutcomes();
        int[] szs = new int[]{N1, N2};
        double[] vals = new double[N1 * N2];
        for (int i = 0; i < N1; ++i) {
            assn.setValue(v1, i);
            int j = 0;
            while (j < N2) {
                assn.setValue(v2, j);
                int idx = Matrixn.singleIndex(szs, new int[]{i, j++});
                vals[idx] = this.logValue(assn);
            }
        }
        return LogTableFactor.makeFromLogValues(new Variable[]{v1, v2}, vals);
    }

    protected Factor slice_general(Variable[] vars, Assignment observed) {
        HashVarSet toKeep = new HashVarSet(vars);
        toKeep.removeAll(observed.varSet());
        double[] vals = new double[toKeep.weight()];
        AssignmentIterator it = toKeep.assignmentIterator();
        while (it.hasNext()) {
            Assignment union = Assignment.union(observed, it.assignment());
            vals[it.indexOfCurrentAssn()] = this.logValue(union);
            it.advance();
        }
        return LogTableFactor.makeFromLogValues(toKeep.toVariableArray(), vals);
    }

    public static LogTableFactor multiplyAll(Collection phis) {
        HashVarSet vs = new HashVarSet();
        for (Factor phi : phis) {
            vs.addAll(phi.varSet());
        }
        LogTableFactor newCPF = new LogTableFactor(vs);
        for (Factor phi : phis) {
            newCPF.multiplyBy(phi);
        }
        return newCPF;
    }

    public AbstractTableFactor recenter() {
        int loc = this.argmax();
        double lval = this.probs.valueAtLocation(loc);
        this.timesEqualsLog(-lval);
        return this;
    }
}

