package edu.tau.compbio.expression.algo;

import edu.tau.compbio.ds.MatrixData;
import edu.tau.compbio.math.VecCalc;
import edu.tau.compbio.med.graph.GraphEvent;
import edu.tau.compbio.stat.NormalDistribution;
import edu.tau.compbio.stat.StatUtils;
import java.util.Arrays;
import java.util.Collection;

/* loaded from: input_file:edu/tau/compbio/expression/algo/SimpleCorrelationAnalysis.class */
public class SimpleCorrelationAnalysis extends CorrelationAnalysis {
    protected CorrelationType _corrType;
    protected int _binCount;
    protected int[] _crossCounts;
    protected float[] _medians;
    protected float[] _means;
    protected float[] _stds;
    protected NormalDistribution[] _normalDists;
    protected float _focusedScaling;
    private static /* synthetic */ int[] $SWITCH_TABLE$edu$tau$compbio$expression$algo$CorrelationType;

    public SimpleCorrelationAnalysis(MatrixData matrixData, CorrelationType correlationType) {
        super(matrixData);
        this._corrType = CorrelationType.DOT_PRODUCT;
        this._binCount = 5;
        this._crossCounts = null;
        this._medians = null;
        this._means = null;
        this._stds = null;
        this._normalDists = null;
        this._focusedScaling = 1.5f;
        this._corrType = correlationType;
        if (this._data != null && (correlationType == CorrelationType.PEARSON_CORR_SUBS_BACKGROUND || correlationType == CorrelationType.PARTIAL_CORRELATION)) {
            preprocessAverageProfile();
            this._partialRefProf = this._averageProf;
            return;
        }
        if (correlationType == CorrelationType.SPEARMAN_CORR) {
            preprocessRanks();
            return;
        }
        if (correlationType == CorrelationType.MUTUAL_INFORMATION) {
            preprocessBins();
        } else if (correlationType == CorrelationType.MEDIAN_NORMALIZED_EUCLIDEAN_DISTANCE) {
            preprocessMedians();
        } else if (correlationType == CorrelationType.MASS_DISTANCE_NORMAL) {
            preprocessMeansStds();
        }
    }

    public void setPartialReferenceProfile(float[] fArr) {
        this._partialRefProf = fArr;
    }

    public void setFocusedScaling(float f) {
        this._focusedScaling = f;
    }

    public CorrelationType getType() {
        return this._corrType;
    }

    public int getBinCount() {
        return this._binCount;
    }

    public void setBinCount(int i) {
        this._binCount = i;
    }

    public float getSimilarity(float[] fArr, float[] fArr2, float f, float f2) {
        return (this._corrType.equals(CorrelationType.DOT_PRODUCT) || this._corrType.equals(CorrelationType.PEARSON_CORRELATION)) ? VecCalc.calcCorrelationCoefficient(fArr, fArr2, f, f2) : getSimilarity(fArr, fArr2);
    }

    public float getSimilarity(float[] fArr, float[] fArr2) {
        float f = Float.NaN;
        if (this._corrType.equals(CorrelationType.DOT_PRODUCT) || this._corrType.equals(CorrelationType.PEARSON_CORRELATION)) {
            f = VecCalc.calcCorrelationCoefficient(fArr, fArr2);
        } else if (this._corrType.equals(CorrelationType.SPEARMAN_CORR)) {
            f = computeSpearman(fArr, fArr2);
        } else if (this._corrType.equals(CorrelationType.EUCLIDEAN_DISTANCE)) {
            f = VecCalc.distance(fArr, fArr2);
        } else if (this._corrType.equals(CorrelationType.PEARSON_CORR_SUBS_BACKGROUND)) {
            float calcCorrelationCoefficient = (VecCalc.calcCorrelationCoefficient(fArr, this._averageProf) + VecCalc.calcCorrelationCoefficient(fArr2, this._averageProf)) / 2.0f;
            if (calcCorrelationCoefficient < 0.0f) {
                calcCorrelationCoefficient = 0.0f;
            }
            f = VecCalc.calcCorrelationCoefficient(fArr, fArr2) - calcCorrelationCoefficient;
        } else if (this._corrType.equals(CorrelationType.PARTIAL_CORRELATION)) {
            f = computePartial(fArr, fArr2);
        } else if (this._corrType.equals(CorrelationType.FOCUSED_PEARSON_CORRELATION_PARTIAL)) {
            f = computeFocusedPartial(fArr, fArr2);
        } else if (this._corrType.equals(CorrelationType.FOCUSED_PEARSON_CORRELATION_STANDARD)) {
            f = computeFocusedStandard(fArr, fArr2);
        }
        if (this._transform != null && !Float.isNaN(f)) {
            f = this._transform.transformValue(f);
        }
        return f;
    }

    public float computeFocusedPartial(float[] fArr, float[] fArr2) {
        float computePartial = (computePartial(fArr, fArr2) + (this._focusedScaling * Math.min(VecCalc.calcCorrelationCoefficient(fArr, this._partialRefProf), VecCalc.calcCorrelationCoefficient(fArr2, this._partialRefProf)))) / (1.0f + this._focusedScaling);
        if (Float.isNaN(computePartial)) {
            throw new IllegalStateException("NaN focused correlation");
        }
        return computePartial;
    }

    public float computeFocusedStandard(float[] fArr, float[] fArr2) {
        float calcCorrelationCoefficient = (VecCalc.calcCorrelationCoefficient(fArr, fArr2) + (this._focusedScaling * Math.min(VecCalc.calcCorrelationCoefficient(fArr, this._partialRefProf), VecCalc.calcCorrelationCoefficient(fArr2, this._partialRefProf)))) / (1.0f + this._focusedScaling);
        if (Float.isNaN(calcCorrelationCoefficient)) {
            throw new IllegalStateException("NaN focused correlation");
        }
        return calcCorrelationCoefficient;
    }

    public float computePartial(float[] fArr, float[] fArr2) {
        float calcCorrelationCoefficient = VecCalc.calcCorrelationCoefficient(fArr, this._partialRefProf);
        float f = calcCorrelationCoefficient * calcCorrelationCoefficient;
        float calcCorrelationCoefficient2 = VecCalc.calcCorrelationCoefficient(fArr2, this._partialRefProf);
        return (float) ((VecCalc.calcCorrelationCoefficient(fArr, fArr2) - (calcCorrelationCoefficient * calcCorrelationCoefficient2)) / Math.sqrt((1.0f - f) * (1.0f - (calcCorrelationCoefficient2 * calcCorrelationCoefficient2))));
    }

    @Override // edu.tau.compbio.expression.algo.CorrelationAnalysis
    public float computeCoef(float[] fArr, float[] fArr2, int i, int i2) {
        float f = Float.NaN;
        switch ($SWITCH_TABLE$edu$tau$compbio$expression$algo$CorrelationType()[this._corrType.ordinal()]) {
            case 1:
                f = computePearsonFast(fArr, fArr2);
                break;
            case 2:
                f = VecCalc.calcCorrelationCoefficient(fArr, fArr2);
                break;
            case 3:
                f = computeSpearmanFromRanks(this._ranks[i], this._ranks[i2]);
                break;
            case 4:
                float computePearsonFast = (computePearsonFast(fArr, this._averageProf) + computePearsonFast(fArr2, this._averageProf)) / 2.0f;
                if (computePearsonFast < 0.0f) {
                    computePearsonFast = 0.0f;
                }
                f = computePearsonFast(fArr, fArr2) - computePearsonFast;
                break;
            case 5:
                f = (float) ((this._entropies[i] + this._entropies[i2]) - computeJointEntropy(this._bins[i], this._bins[i2]));
                break;
            case 6:
                f = computePartial(fArr, fArr2);
                break;
            case 7:
                f = VecCalc.distance(fArr, fArr2);
                break;
            case 11:
                f = getCorrelation(fArr, fArr2);
                break;
            case GraphEvent.NODE_REMOVED /* 12 */:
                f = computeFocusedPartial(fArr, fArr2);
                break;
            case 13:
                f = computeFocusedStandard(fArr, fArr2);
                break;
        }
        if (this._transform != null && !Float.isNaN(f)) {
            f = this._transform.transformValue(f);
        }
        return f;
    }

    protected double computeJointEntropy(int[] iArr, int[] iArr2) {
        Arrays.fill(this._crossCounts, 0);
        for (int i = 0; i < iArr.length; i++) {
            int[] iArr3 = this._crossCounts;
            int i2 = (iArr[i] * this._binCount) + iArr2[i];
            iArr3[i2] = iArr3[i2] + 1;
        }
        return computeEntropy(this._crossCounts, iArr.length, 0.0d);
    }

    @Override // edu.tau.compbio.algorithm.AbstractSimilarityAnalysis, edu.tau.compbio.algorithm.SimilarityAnalysis
    public float getAverageSimilarity(Collection<String> collection, Collection<String> collection2) {
        if (this._corrType != CorrelationType.DOT_PRODUCT) {
            return super.getAverageSimilarity(collection, collection2);
        }
        float[] fArr = new float[this._data.sizeConditions()];
        int profileSum = this._mda.getProfileSum(collection, fArr);
        return (calcDotProductFast(fArr, new float[this._data.sizeConditions()]) / this._data.sizeConditions()) / (profileSum * this._mda.getProfileSum(collection2, r0));
    }

    public void computeSimilairities(Collection<String> collection) {
        if (this._corrType != CorrelationType.DOT_PRODUCT) {
            super.computeSimilarities(collection);
            return;
        }
        float[] fArr = new float[this._data.sizeConditions()];
        this._lastCount = this._mda.getProfileSum(collection, fArr);
        this._avgSim = ((calcDotProductFast(fArr, fArr) / this._data.sizeConditions()) - this._lastCount) / ((this._lastCount * this._lastCount) - 1);
    }

    protected float computePearson(float[] fArr, float[] fArr2, boolean[] zArr) {
        int i = 0;
        for (int i2 = 0; i2 < zArr.length; i2++) {
            zArr[i2] = (fArr[i2] == Float.NEGATIVE_INFINITY || fArr2[i2] == Float.NEGATIVE_INFINITY) ? false : true;
            i += zArr[i2] ? 1 : 0;
        }
        float mean = mean(fArr, zArr);
        float calcStd = calcStd(fArr, zArr);
        float mean2 = mean(fArr2, zArr);
        float calcStd2 = calcStd(fArr2, zArr);
        if (mean == 0.0f && mean2 == 0.0f && calcStd == 1.0f && calcStd2 == 1.0f) {
            return calcDotProduct(fArr, fArr2, zArr) / i;
        }
        float f = 0.0f;
        float f2 = calcStd * calcStd2 * i;
        if (f2 == 0.0f) {
            f2 = 1.0E-8f;
        }
        for (int i3 = 0; i3 < fArr.length; i3++) {
            if (zArr[i3]) {
                f += ((fArr[i3] - mean) * (fArr2[i3] - mean2)) / f2;
            }
        }
        return f;
    }

    protected float computePearsonFast(float[] fArr, float[] fArr2) {
        float calcDotProductFast = calcDotProductFast(fArr, fArr2) / fArr.length;
        if (calcDotProductFast > 1.1f || calcDotProductFast < -1.1f) {
            throw new IllegalStateException("Pearson correlation in an illegal range :" + calcDotProductFast);
        }
        return calcDotProductFast;
    }

    public float calcDotProduct(float[] fArr, float[] fArr2, boolean[] zArr) {
        float f = 0.0f;
        for (int i = 0; i < fArr.length; i++) {
            if (zArr[i]) {
                f += fArr[i] * fArr2[i];
            }
        }
        return f;
    }

    public static float calcDotProductFast(float[] fArr, float[] fArr2) {
        float f = 0.0f;
        for (int i = 0; i < fArr.length; i++) {
            if (!Float.isNaN(fArr[i]) && !Float.isNaN(fArr2[i])) {
                f += fArr[i] * fArr2[i];
            }
        }
        return f;
    }

    public float mean(float[] fArr, boolean[] zArr) {
        float f = 0.0f;
        int i = 0;
        for (int i2 = 0; i2 < fArr.length; i2++) {
            if (zArr[i2]) {
                f += fArr[i2];
                i++;
            }
        }
        return f / i;
    }

    public float mean(float[] fArr) {
        float f = 0.0f;
        int i = 0;
        for (int i2 = 0; i2 < fArr.length; i2++) {
            if (!Float.isNaN(fArr[i2])) {
                f += fArr[i2];
                i++;
            }
        }
        return f / i;
    }

    public float calcStd(float[] fArr) {
        return (float) Math.sqrt(calcVariance(fArr));
    }

    public float calcVariance(float[] fArr) {
        float mean = mean(fArr);
        float f = 0.0f;
        int i = 0;
        for (int i2 = 0; i2 < fArr.length; i2++) {
            if (!Float.isNaN(fArr[i2])) {
                f += (fArr[i2] - mean) * (fArr[i2] - mean);
                i++;
            }
        }
        return f / i;
    }

    public float calcStd(float[] fArr, boolean[] zArr) {
        return (float) Math.sqrt(calcVariance(fArr, zArr));
    }

    public float calcVariance(float[] fArr, boolean[] zArr) {
        float mean = mean(fArr, zArr);
        float f = 0.0f;
        for (int i = 0; i < fArr.length; i++) {
            if (zArr[i]) {
                f += (fArr[i] - mean) * (fArr[i] - mean);
            }
        }
        return f / fArr.length;
    }

    public float computeSpearman(float[] fArr, float[] fArr2) {
        double[] calculateSpearman = StatUtils.calculateSpearman(VecCalc.calcRanks(fArr), VecCalc.calcRanks(fArr2));
        this._pval = (float) calculateSpearman[1];
        return (float) calculateSpearman[0];
    }

    public float computeSpearmanFromRanks(float[] fArr, float[] fArr2) {
        return (float) StatUtils.calculateSpearman(fArr, fArr2)[0];
    }

    public float getPValue() {
        return this._pval;
    }

    public int[][] computeBins() {
        System.out.println("Computing bins...");
        int[][] binData = new MatrixDataBinner(this._data).binData(this._binCount);
        System.out.println("Bin computation finished");
        return binData;
    }

    public float[] computeMedians() {
        float[] fArr = new float[this._data.sizeProbes()];
        float[] fArr2 = new float[this._data.sizeConditions()];
        int length = (fArr2.length / 2) + 1;
        for (int i = 0; i < fArr.length; i++) {
            float[] dataRow = this._data.getDataRow(i);
            for (int i2 = 0; i2 < dataRow.length; i2++) {
                fArr2[i2] = dataRow[i2];
            }
            Arrays.sort(fArr2);
            if (fArr2.length % 2 == 1) {
                fArr[i] = fArr2[length];
            } else {
                fArr[i] = (fArr2[length - 1] + fArr2[length]) / 2.0f;
            }
        }
        return fArr;
    }

    public float[][] computeRanks() {
        if (this._data == null) {
            return null;
        }
        float[][] fArr = new float[this._data.sizeProbes()][this._data.sizeConditions()];
        for (int i = 0; i < fArr.length; i++) {
            fArr[i] = VecCalc.calcRanks(this._data.getDataRow(i));
        }
        return fArr;
    }

    public void preprocessRanks() {
        this._ranks = computeRanks();
    }

    public void preprocessMedians() {
        this._medians = computeMedians();
    }

    public void preprocessMeansStds() {
        this._means = new float[this._data.sizeConditions()];
        this._stds = new float[this._data.sizeConditions()];
        this._normalDists = new NormalDistribution[this._data.sizeConditions()];
        double[] dArr = new double[this._data.sizeConditions()];
        double[] dArr2 = new double[this._data.sizeConditions()];
        for (int i = 0; i < this._data.sizeProbes(); i++) {
            float[] dataRow = this._data.getDataRow(i);
            for (int i2 = 0; i2 < this._data.sizeConditions(); i2++) {
                if (Float.isNaN(dataRow[i2])) {
                    throw new IllegalStateException("NaN value encountered");
                }
                int i3 = i2;
                dArr[i3] = dArr[i3] + dataRow[i2];
                int i4 = i2;
                dArr2[i4] = dArr2[i4] + (dataRow[i2] * dataRow[i2]);
            }
        }
        for (int i5 = 0; i5 < this._data.sizeConditions(); i5++) {
            this._means[i5] = (float) (dArr[i5] / this._data.sizeProbes());
            this._stds[i5] = (float) Math.sqrt((dArr2[i5] / this._data.sizeProbes()) - (this._means[i5] * this._means[i5]));
            this._normalDists[i5] = new NormalDistribution(this._means[i5], this._stds[i5]);
        }
    }

    public void preprocessBins() {
        this._bins = computeBins();
        this._entropies = new double[this._data.sizeProbes()];
        int[] iArr = new int[this._binCount];
        for (int i = 0; i < this._bins.length; i++) {
            Arrays.fill(iArr, 0);
            int[] iArr2 = this._bins[i];
            for (int i2 : iArr2) {
                iArr[i2] = iArr[i2] + 1;
            }
            this._entropies[i] = computeEntropy(iArr, iArr2.length, 0.0d);
        }
        this._crossCounts = new int[this._binCount * this._binCount];
    }

    public static double computeEntropy(int[] iArr, int i, double d) {
        double d2 = 0.0d;
        for (int i2 : iArr) {
            double length = (i2 + d) / (i + (d * iArr.length));
            d2 += length * (Math.log(length) / Math.log(2.0d));
        }
        return -d2;
    }

    @Override // edu.tau.compbio.expression.algo.CorrelationAnalysis
    public float getCorrelation(float[] fArr, float[] fArr2) {
        float f = Float.NaN;
        switch ($SWITCH_TABLE$edu$tau$compbio$expression$algo$CorrelationType()[this._corrType.ordinal()]) {
            case 1:
                f = computePearsonFast(fArr, fArr2);
                break;
            case 2:
                f = VecCalc.calcCorrelationCoefficient(fArr, fArr2);
                break;
            case 3:
                f = computeSpearman(fArr, fArr2);
                break;
            case 4:
                float computePearsonFast = (computePearsonFast(fArr, this._averageProf) + computePearsonFast(fArr2, this._averageProf)) / 2.0f;
                if (computePearsonFast < 0.0f) {
                    computePearsonFast = 0.0f;
                }
                f = computePearsonFast(fArr, fArr2) - computePearsonFast;
                break;
            case 6:
                f = computePartial(fArr, fArr2);
                break;
            case 7:
                f = VecCalc.distance(fArr, fArr2);
                break;
            case 11:
                double d = 0.0d;
                for (int i = 0; i < this._data.sizeConditions(); i++) {
                    float f2 = fArr[i];
                    float f3 = fArr2[i];
                    d += Math.log(this._normalDists[i].approxPhi(f2 > f3 ? f2 : f3) - this._normalDists[i].approxPhi(f2 > f3 ? f3 : f2));
                }
                f = (float) (-d);
                break;
            case GraphEvent.NODE_REMOVED /* 12 */:
                f = computeFocusedPartial(fArr, fArr2);
                break;
            case 13:
                f = computeFocusedStandard(fArr, fArr2);
                break;
        }
        if (this._transform != null && !Float.isNaN(f)) {
            f = this._transform.transformValue(f);
        }
        return f;
    }

    protected void preprocessAverageProfile() {
        this._averageProf = this._mda.getAverageProfile();
        MatrixDataNormalizer.standardize(this._averageProf);
    }

    public String toString() {
        return this._corrType.toString();
    }

    public static float[] computeSimilarityIgnoringNaNs(float[] fArr, float[] fArr2, CorrelationType correlationType) {
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < fArr.length; i3++) {
            if (Float.isNaN(fArr[i3]) || Float.isNaN(fArr2[i3])) {
                i2++;
            } else {
                i++;
            }
        }
        float[] fArr3 = new float[i];
        float[] fArr4 = new float[i];
        int i4 = 0;
        for (int i5 = 0; i5 < fArr.length; i5++) {
            if (!Float.isNaN(fArr[i5]) && !Float.isNaN(fArr2[i5])) {
                fArr3[i4] = fArr[i5];
                fArr4[i4] = fArr2[i5];
                i4++;
            }
        }
        return new float[]{new SimpleCorrelationAnalysis(null, correlationType).getSimilarity(fArr3, fArr4), i2};
    }

    static /* synthetic */ int[] $SWITCH_TABLE$edu$tau$compbio$expression$algo$CorrelationType() {
        int[] iArr = $SWITCH_TABLE$edu$tau$compbio$expression$algo$CorrelationType;
        if (iArr != null) {
            return iArr;
        }
        int[] iArr2 = new int[CorrelationType.valuesCustom().length];
        try {
            iArr2[CorrelationType.DOT_PRODUCT.ordinal()] = 1;
        } catch (NoSuchFieldError unused) {
        }
        try {
            iArr2[CorrelationType.EUCLIDEAN_DISTANCE.ordinal()] = 7;
        } catch (NoSuchFieldError unused2) {
        }
        try {
            iArr2[CorrelationType.FOCUSED_PEARSON_CORRELATION_PARTIAL.ordinal()] = 12;
        } catch (NoSuchFieldError unused3) {
        }
        try {
            iArr2[CorrelationType.FOCUSED_PEARSON_CORRELATION_STANDARD.ordinal()] = 13;
        } catch (NoSuchFieldError unused4) {
        }
        try {
            iArr2[CorrelationType.FUNCTIONAL_LIN_SIMILARITY.ordinal()] = 10;
        } catch (NoSuchFieldError unused5) {
        }
        try {
            iArr2[CorrelationType.FUNCTIONAL_RESKIK_SIMILARITY.ordinal()] = 9;
        } catch (NoSuchFieldError unused6) {
        }
        try {
            iArr2[CorrelationType.MASS_DISTANCE_NORMAL.ordinal()] = 11;
        } catch (NoSuchFieldError unused7) {
        }
        try {
            iArr2[CorrelationType.MEDIAN_NORMALIZED_EUCLIDEAN_DISTANCE.ordinal()] = 8;
        } catch (NoSuchFieldError unused8) {
        }
        try {
            iArr2[CorrelationType.MUTUAL_INFORMATION.ordinal()] = 5;
        } catch (NoSuchFieldError unused9) {
        }
        try {
            iArr2[CorrelationType.PARTIAL_CORRELATION.ordinal()] = 6;
        } catch (NoSuchFieldError unused10) {
        }
        try {
            iArr2[CorrelationType.PEARSON_CORRELATION.ordinal()] = 2;
        } catch (NoSuchFieldError unused11) {
        }
        try {
            iArr2[CorrelationType.PEARSON_CORR_SUBS_BACKGROUND.ordinal()] = 4;
        } catch (NoSuchFieldError unused12) {
        }
        try {
            iArr2[CorrelationType.SPEARMAN_CORR.ordinal()] = 3;
        } catch (NoSuchFieldError unused13) {
        }
        $SWITCH_TABLE$edu$tau$compbio$expression$algo$CorrelationType = iArr2;
        return iArr2;
    }
}
