/*
 * Decompiled with CFR 0.152.
 */
package org.apache.commons.math4.legacy.ml.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
import org.apache.commons.math4.legacy.ml.clustering.CentroidCluster;
import org.apache.commons.math4.legacy.ml.clustering.Clusterable;
import org.apache.commons.math4.legacy.ml.clustering.DoublePoint;
import org.apache.commons.math4.legacy.ml.clustering.KMeansPlusPlusClusterer;
import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
import org.apache.commons.math4.legacy.stat.descriptive.moment.VectorialMean;
import org.apache.commons.rng.UniformRandomProvider;

public class ElkanKMeansPlusPlusClusterer<T extends Clusterable>
extends KMeansPlusPlusClusterer<T> {
    public ElkanKMeansPlusPlusClusterer(int k) {
        super(k);
    }

    public ElkanKMeansPlusPlusClusterer(int k, int maxIterations, DistanceMeasure measure, UniformRandomProvider random) {
        super(k, maxIterations, measure, random);
    }

    public ElkanKMeansPlusPlusClusterer(int k, int maxIterations, DistanceMeasure measure, UniformRandomProvider random, KMeansPlusPlusClusterer.EmptyClusterStrategy emptyStrategy) {
        super(k, maxIterations, measure, random, emptyStrategy);
    }

    @Override
    public List<CentroidCluster<T>> cluster(Collection<T> points) {
        int k = this.getNumberOfClusters();
        if (points.size() < k) {
            throw new NumberIsTooSmallException((Number)points.size(), (Number)k, false);
        }
        ArrayList<T> pointsList = new ArrayList<T>(points);
        int n = points.size();
        int dim = ((Clusterable)pointsList.get(0)).getPoint().length;
        double[] s = new double[k];
        Arrays.fill(s, Double.MAX_VALUE);
        double[][] dcc = new double[k][k];
        double[] u = new double[n];
        Arrays.fill(u, Double.MAX_VALUE);
        double[][] l = new double[n][k];
        double[][] centers = this.seed(pointsList);
        int[] partitions = this.partitionPoints(pointsList, centers, u, l);
        double[] deltas = new double[k];
        Object[] means = new VectorialMean[k];
        int max = this.getMaxIterations();
        for (int it = 0; it < max; ++it) {
            int i;
            int changes = 0;
            this.updateIntraCentersDistances(centers, dcc, s);
            for (int xi = 0; xi < n; ++xi) {
                boolean r = true;
                if (u[xi] <= s[partitions[xi]]) continue;
                for (int c = 0; c < k; ++c) {
                    if (ElkanKMeansPlusPlusClusterer.isSkipNext(partitions, u, l, dcc, xi, c)) continue;
                    double[] x = ((Clusterable)pointsList.get(xi)).getPoint();
                    if (r) {
                        u[xi] = this.distance(x, centers[partitions[xi]]);
                        l[xi][partitions[xi]] = u[xi];
                        r = false;
                    }
                    if (!(u[xi] > l[xi][c]) && !(u[xi] > dcc[partitions[xi]][c])) continue;
                    l[xi][c] = this.distance(x, centers[c]);
                    if (!(l[xi][c] < u[xi])) continue;
                    partitions[xi] = c;
                    u[xi] = l[xi][c];
                    ++changes;
                }
            }
            if (changes == 0 && it != 0) break;
            Arrays.fill(means, null);
            for (i = 0; i < n; ++i) {
                if (means[partitions[i]] == null) {
                    means[partitions[i]] = new VectorialMean(dim);
                }
                ((VectorialMean)means[partitions[i]]).increment(((Clusterable)pointsList.get(i)).getPoint());
            }
            for (i = 0; i < k; ++i) {
                deltas[i] = this.distance(centers[i], ((VectorialMean)means[i]).getResult());
                centers[i] = ((VectorialMean)means[i]).getResult();
            }
            this.updateBounds(partitions, u, l, deltas);
        }
        return this.buildResults(pointsList, partitions, centers);
    }

    private double[][] seed(List<T> points) {
        int k = this.getNumberOfClusters();
        UniformRandomProvider random = this.getRandomGenerator();
        double[][] result = new double[k][];
        int n = points.size();
        int pointIndex = random.nextInt(n);
        double[] minDistances = new double[n];
        int idx = 0;
        result[idx] = ((Clusterable)points.get(pointIndex)).getPoint();
        double sumSqDist = 0.0;
        for (int i = 0; i < n; ++i) {
            double d = this.distance(result[idx], ((Clusterable)points.get(i)).getPoint());
            minDistances[i] = d * d;
            sumSqDist += minDistances[i];
        }
        while (++idx < k) {
            double p = sumSqDist * random.nextDouble();
            int next = 0;
            double cdf = 0.0;
            while (cdf < p) {
                cdf += minDistances[next];
                ++next;
            }
            result[idx] = ((Clusterable)points.get(next - 1)).getPoint();
            for (int i = 0; i < n; ++i) {
                double d = this.distance(result[idx], ((Clusterable)points.get(i)).getPoint());
                sumSqDist -= minDistances[i];
                minDistances[i] = Math.min(minDistances[i], d * d);
                sumSqDist += minDistances[i];
            }
        }
        return result;
    }

    private int[] partitionPoints(List<T> pointsList, double[][] centers, double[] u, double[][] l) {
        int k = this.getNumberOfClusters();
        int n = pointsList.size();
        int[] assignments = new int[n];
        Arrays.fill(assignments, -1);
        for (int i = 0; i < n; ++i) {
            double[] x = ((Clusterable)pointsList.get(i)).getPoint();
            for (int j = 0; j < k; ++j) {
                l[i][j] = this.distance(x, centers[j]);
                if (!(u[i] > l[i][j])) continue;
                u[i] = l[i][j];
                assignments[i] = j;
            }
        }
        return assignments;
    }

    private void updateIntraCentersDistances(double[][] centers, double[][] dcc, double[] s) {
        int k = this.getNumberOfClusters();
        for (int i = 0; i < k; ++i) {
            for (int j = i + 1; j < k; ++j) {
                dcc[i][j] = 0.5 * this.distance(centers[i], centers[j]);
                dcc[j][i] = dcc[i][j];
                if (dcc[i][j] < s[i]) {
                    s[i] = dcc[i][j];
                }
                if (!(dcc[j][i] < s[j])) continue;
                s[j] = dcc[j][i];
            }
        }
    }

    private static boolean isSkipNext(int[] partitions, double[] u, double[][] l, double[][] dcc, int xi, int c) {
        return c == partitions[xi] || u[xi] <= l[xi][c] || u[xi] <= dcc[partitions[xi]][c];
    }

    private List<CentroidCluster<T>> buildResults(List<T> pointsList, int[] partitions, double[][] centers) {
        int i;
        int k = this.getNumberOfClusters();
        ArrayList<CentroidCluster<T>> result = new ArrayList<CentroidCluster<T>>();
        for (i = 0; i < k; ++i) {
            CentroidCluster cluster = new CentroidCluster(new DoublePoint(centers[i]));
            result.add(cluster);
        }
        for (i = 0; i < pointsList.size(); ++i) {
            ((CentroidCluster)result.get(partitions[i])).addPoint((Clusterable)pointsList.get(i));
        }
        return result;
    }

    private void updateBounds(int[] partitions, double[] u, double[][] l, double[] deltas) {
        int k = this.getNumberOfClusters();
        for (int i = 0; i < partitions.length; ++i) {
            int n = i;
            u[n] = u[n] + deltas[partitions[i]];
            for (int j = 0; j < k; ++j) {
                l[i][j] = Math.max(0.0, l[i][j] - deltas[j]);
            }
        }
    }

    private double distance(double[] a, double[] b) {
        return this.getDistanceMeasure().compute(a, b);
    }
}

