package edu.tau.compbio.interaction.parameters;

import edu.tau.compbio.ds.SimilarityMatrix;
import edu.tau.compbio.util.ProgressManager;

/* loaded from: input_file:edu/tau/compbio/interaction/parameters/MultipleRegParamsEM.class */
public class MultipleRegParamsEM extends AbstractParamsEM {
    protected edu.tau.compbio.stat.NormalDistribution _foesDist;
    protected double _foesP;
    private static final double MIN_BORDER = 0.01d;
    private static final double MAX_BORDER = 0.2d;
    private boolean _fixedNon;

    public MultipleRegParamsEM(SimilarityMatrix<String> similarityMatrix, double[] dArr) {
        super(similarityMatrix, dArr);
        this._foesDist = null;
        this._fixedNon = false;
    }

    public edu.tau.compbio.stat.NormalDistribution getFoesDistribution() {
        return this._foesDist;
    }

    public double getFoesP() {
        return this._foesP;
    }

    public void setInitial(edu.tau.compbio.stat.NormalDistribution normalDistribution, edu.tau.compbio.stat.NormalDistribution normalDistribution2, edu.tau.compbio.stat.NormalDistribution normalDistribution3, double d, double d2) {
        this._matesDist = normalDistribution;
        this._foesDist = normalDistribution;
        this._nonDist = normalDistribution3;
        this._matesP = d;
        this._foesP = d2;
        this._selfInit = false;
    }

    public MultipleRegParamsEM(SimilarityMatrix<String> similarityMatrix, double[] dArr, int i) {
        super(similarityMatrix, dArr, i);
        this._foesDist = null;
        this._fixedNon = false;
    }

    private void init(ProgressManager progressManager) {
        progressManager.setTitle("Seeking start point...");
        double averageSimilarity = this._simMat.getAverageSimilarity();
        double similarityStd = this._simMat.getSimilarityStd();
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        for (int i5 = 0; i5 < this._inpSize; i5++) {
            for (int i6 = i5 + 1; i6 < this._inpSize; i6++) {
                float similarity = this._simMat.getSimilarity(this._inds[i5], this._inds[i6]);
                if (!Float.isNaN(similarity)) {
                    if (similarity > averageSimilarity) {
                        i++;
                    }
                    if (similarity >= averageSimilarity + similarityStd) {
                        i2++;
                    }
                    if (similarity >= averageSimilarity - similarityStd) {
                        i3++;
                    }
                    i4++;
                }
            }
        }
        double d = i / i4;
        double d2 = i2 / i4;
        double d3 = i3 / i4;
        double d4 = 0.0d;
        double d5 = 0.0d;
        double d6 = 0.0d;
        double d7 = 0.0d;
        double d8 = 0.0d;
        double d9 = Double.MAX_VALUE;
        double d10 = 1.0d;
        double d11 = 0.5d;
        while (true) {
            double d12 = d11;
            if (d12 > 2.5d) {
                this._matesDist = new edu.tau.compbio.stat.NormalDistribution(d6, d8);
                this._foesDist = new edu.tau.compbio.stat.NormalDistribution(d7, d8);
                this._nonDist = new edu.tau.compbio.stat.NormalDistribution(0.0d, d8);
                this._matesP = d4;
                this._foesP = d5;
                return;
            }
            double d13 = 0.5d;
            while (true) {
                double d14 = d13;
                if (d14 > 2.5d) {
                    break;
                }
                double d15 = 0.02d;
                while (true) {
                    double d16 = d15;
                    if (d16 > 0.25d) {
                        break;
                    }
                    double d17 = 0.02d;
                    while (true) {
                        double d18 = d17;
                        if (d18 > 0.25d) {
                            break;
                        }
                        double d19 = (1.0d - d18) - d16;
                        double d20 = (averageSimilarity - ((d16 * d12) * d10)) + (d18 * d14 * d10);
                        double d21 = d20 + (d12 * d10);
                        double d22 = d20 - (d14 * d10);
                        d10 = Math.sqrt(((similarityStd * similarityStd) - (((2.0d * d16) * d18) * (((d20 * d22) + (d20 * d21)) - 1.0d))) / ((1.0d + (((d16 * (1.0d - d16)) * d12) * d12)) + (((d18 * (1.0d - d18)) * d14) * d14)));
                        if (d10 >= similarityStd / 2.0d) {
                            edu.tau.compbio.stat.NormalDistribution normalDistribution = new edu.tau.compbio.stat.NormalDistribution(d21, d10);
                            edu.tau.compbio.stat.NormalDistribution normalDistribution2 = new edu.tau.compbio.stat.NormalDistribution(d22, d10);
                            edu.tau.compbio.stat.NormalDistribution normalDistribution3 = new edu.tau.compbio.stat.NormalDistribution(d20, d10);
                            double abs = Math.abs((((d16 * (1.0d - normalDistribution.approxPhi(averageSimilarity))) + (d18 * (1.0d - normalDistribution2.approxPhi(averageSimilarity)))) + (d19 * (1.0d - normalDistribution3.approxPhi(averageSimilarity)))) - d) + Math.abs((((d16 * (1.0d - normalDistribution.approxPhi(averageSimilarity + similarityStd))) + (d18 * (1.0d - normalDistribution2.approxPhi(averageSimilarity + similarityStd)))) + (d19 * (1.0d - normalDistribution3.approxPhi(averageSimilarity + similarityStd)))) - d2) + Math.abs((((d16 * (1.0d - normalDistribution.approxPhi(averageSimilarity - similarityStd))) + (d18 * (1.0d - normalDistribution2.approxPhi(averageSimilarity - similarityStd)))) + (d19 * (1.0d - normalDistribution3.approxPhi(averageSimilarity - similarityStd)))) - d3);
                            if (abs < d9) {
                                d9 = abs;
                                d4 = d16;
                                d5 = d18;
                                d6 = d21;
                                d7 = d22;
                                d8 = d10;
                            }
                        }
                        d17 = d18 + MIN_BORDER;
                    }
                    d15 = d16 + MIN_BORDER;
                }
                d13 = d14 + 0.1d;
            }
            d11 = d12 + 0.1d;
        }
    }

    protected void calcPFMultMatesFoes(double[] dArr, double d, edu.tau.compbio.stat.StatDistribution statDistribution) {
        int i = 0;
        for (int i2 = 0; i2 < this._inpSize; i2++) {
            for (int i3 = i2 + 1; i3 < this._inpSize; i3++) {
                float similarity = this._simMat.getSimilarity(this._inds[i2], this._inds[i3]);
                if (!Float.isNaN(similarity)) {
                    int i4 = i;
                    i++;
                    dArr[i4] = d * this._regPs[this._inds[i2]] * this._regPs[this._inds[i3]] * statDistribution.calcF(similarity);
                }
            }
        }
        if (i != dArr.length) {
            throw new IllegalStateException("The vector size is bogus!");
        }
    }

    protected void calcPFMultNon(double[] dArr, double d, edu.tau.compbio.stat.StatDistribution statDistribution) {
        int i = 0;
        for (int i2 = 0; i2 < this._inpSize; i2++) {
            for (int i3 = i2 + 1; i3 < this._inpSize; i3++) {
                float similarity = this._simMat.getSimilarity(this._inds[i2], this._inds[i3]);
                if (!Float.isNaN(similarity)) {
                    int i4 = i;
                    i++;
                    dArr[i4] = (1.0d - ((d * this._regPs[this._inds[i2]]) * this._regPs[this._inds[i3]])) * statDistribution.calcF(similarity);
                }
            }
        }
        if (i != dArr.length) {
            throw new IllegalStateException("The vector size is bogus!");
        }
    }

    protected double calcGi(double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = dArr2[i] / ((dArr2[i] + dArr3[i]) + dArr4[i]);
            d += dArr[i];
        }
        return d;
    }

    protected double calcCombinedStd(double[] dArr, double[] dArr2, double[] dArr3, double d, double d2, double d3, double d4, double d5, double d6) {
        double d7 = 0.0d;
        int i = 0;
        for (int i2 = 0; i2 < this._inpSize; i2++) {
            for (int i3 = i2 + 1; i3 < this._inpSize; i3++) {
                float similarity = this._simMat.getSimilarity(this._inds[i2], this._inds[i3]);
                if (!Float.isNaN(similarity)) {
                    d7 = d7 + (dArr[i] * Math.pow(d4 - similarity, 2.0d)) + (dArr2[i] * Math.pow(d5 - similarity, 2.0d)) + (dArr3[i] * Math.pow(d6 - similarity, 2.0d));
                    i++;
                }
            }
        }
        if (i != dArr.length) {
            throw new IllegalStateException("The vector size is bogus!");
        }
        return Math.sqrt(d7 / ((d + d2) + d3));
    }

    protected double calcLikelihood(double[] dArr, double[] dArr2, double[] dArr3) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += Math.log(dArr[i] + dArr2[i] + dArr3[i]);
        }
        return d;
    }

    public void run(ProgressManager progressManager) {
        double[] dArr = new double[this._nonNaNCount];
        double[] dArr2 = new double[this._nonNaNCount];
        double[] dArr3 = new double[this._nonNaNCount];
        if (this._selfInit) {
            init(progressManager);
        }
        System.out.println("Initial mates:" + this._matesDist);
        System.out.println("Initial foes:" + this._foesDist);
        System.out.println("Initial non-mates:" + this._nonDist);
        System.out.println("Initial p-mates:" + this._matesP);
        System.out.println("Initial p-foes:" + this._foesP);
        progressManager.setTitle("Starts iterating...");
        if (this._nonNaNCount < 5000) {
            if (this._nonNaNCount <= 10) {
                throw new IllegalStateException("CEM::Run - not enough values");
            }
            double d = this._matesP;
            double d2 = this._foesP;
            calcPFMultMatesFoes(dArr, d, this._matesDist);
            calcPFMultMatesFoes(dArr2, d2, this._foesDist);
            calcPFMultNon(dArr3, d + d2, this._nonDist);
            calcLikelihood(dArr, dArr2, dArr3);
            if (this._matesP > MAX_BORDER) {
                this._matesP = MAX_BORDER;
            }
        }
        if (this._matesP <= 0.0d || this._matesP > 1.0d) {
            throw new IllegalStateException("CEM::Run - illegal value of initial mates prob");
        }
        calcPFMultMatesFoes(dArr, this._matesP, this._matesDist);
        calcPFMultMatesFoes(dArr2, this._foesP, this._foesDist);
        calcPFMultNon(dArr3, this._matesP + this._foesP, this._nonDist);
        double[] dArr4 = new double[this._nonNaNCount];
        double[] dArr5 = new double[this._nonNaNCount];
        double[] dArr6 = new double[this._nonNaNCount];
        double d3 = -1.0d;
        double calcLikelihood = calcLikelihood(dArr, dArr2, dArr3);
        int i = 0;
        double d4 = Double.NaN;
        double d5 = Double.NaN;
        double d6 = Double.NaN;
        double d7 = Double.NaN;
        double d8 = Double.NaN;
        double d9 = Double.NaN;
        while (true) {
            if (d3 != -1.0d && calcLikelihood - d3 < 0.5d) {
                if (this._matesP < MIN_BORDER) {
                    this._matesP = MIN_BORDER;
                }
                if (this._foesP < MIN_BORDER) {
                    this._foesP = MIN_BORDER;
                }
                this._likelihood = calcLikelihood;
                System.out.println("EM finished after " + i + " iterations");
                System.out.println("Mates Distribution:" + this._matesDist);
                System.out.println("Foes Distribution:" + this._foesDist);
                System.out.println("Non-Mates Distribution:" + this._nonDist);
                System.out.println("Mates P:" + this._matesP);
                System.out.println("Foes P:" + this._foesP);
                return;
            }
            d3 = calcLikelihood;
            double calcGi = calcGi(dArr4, dArr, dArr2, dArr3);
            if (calcGi == 0.0d) {
                throw new IllegalStateException("sumGMates == 0");
            }
            double calcGi2 = calcGi(dArr5, dArr2, dArr, dArr3);
            if (calcGi2 == 0.0d) {
                throw new IllegalStateException("sumGFoes == 0");
            }
            double calcGi3 = calcGi(dArr6, dArr3, dArr, dArr2);
            if (calcGi3 == 0.0d) {
                throw new IllegalStateException("sumGNon == 0");
            }
            if (this._updateExp) {
                d4 = calcExp(dArr4, calcGi);
                d6 = calcExp(dArr5, calcGi2);
                if (i == 0 || !this._fixedNon) {
                    d8 = calcExp(dArr6, calcGi3);
                }
            }
            if (this._updateStd) {
                if (this._sameStd) {
                    d5 = calcCombinedStd(dArr4, dArr5, dArr6, calcGi, calcGi2, calcGi3, d4, d6, d8);
                    d7 = d5;
                    if (i == 0 || !this._fixedNon) {
                        d9 = d5;
                    }
                } else {
                    d5 = calcStd(dArr4, calcGi, d4);
                    d7 = calcStd(dArr5, calcGi2, d6);
                    if (i == 0 || !this._fixedNon) {
                        d9 = calcStd(dArr6, calcGi3, d8);
                    }
                }
            }
            if (this._updateExp && this._updateStd) {
                this._matesDist.init(d4, d5);
            } else if (this._updateExp) {
                this._matesDist.initExp(d4);
            } else if (this._updateStd) {
                this._matesDist.initStd(d5);
            }
            if (this._updateExp && this._updateStd) {
                this._foesDist.init(d6, d7);
            } else if (this._updateExp) {
                this._foesDist.initExp(d6);
            } else if (this._updateStd) {
                this._foesDist.initStd(d7);
            }
            if (this._updateExp && this._updateStd) {
                this._nonDist.init(d8, d9);
            } else if (this._updateExp) {
                this._nonDist.initExp(d8);
            } else if (this._updateStd) {
                this._nonDist.initStd(d9);
            }
            if (this._updateP) {
                double d10 = 0.0d;
                int i2 = 0;
                for (int i3 = 0; i3 < this._inpSize; i3++) {
                    for (int i4 = i3 + 1; i4 < this._inpSize; i4++) {
                        if (!Float.isNaN(this._simMat.getSimilarity(this._inds[i3], this._inds[i4]))) {
                            d10 += this._regPs[this._inds[i3]] * this._regPs[this._inds[i4]];
                            i2++;
                        }
                    }
                }
                int i5 = 0;
                double d11 = 0.0d;
                for (int i6 = 0; i6 < this._inpSize; i6++) {
                    for (int i7 = i6 + 1; i7 < this._inpSize; i7++) {
                        if (!Float.isNaN(this._simMat.getSimilarity(this._inds[i6], this._inds[i7]))) {
                            d11 += this._regPs[this._inds[i6]] * this._regPs[this._inds[i7]];
                            i5++;
                        }
                    }
                }
                this._matesP = calcGi / d10;
                this._foesP = calcGi2 / d11;
            }
            calcPFMultMatesFoes(dArr, this._matesP, this._matesDist);
            calcPFMultMatesFoes(dArr2, this._foesP, this._foesDist);
            calcPFMultNon(dArr3, this._matesP + this._foesP, this._nonDist);
            calcLikelihood = calcLikelihood(dArr, dArr2, dArr3);
            if (Double.isNaN(calcLikelihood)) {
                throw new IllegalStateException("The likelihood went NaN");
            }
            if (calcLikelihood < d3) {
                System.out.flush();
                System.err.flush();
                System.err.println("Prev likelihood:" + d3);
                System.err.println("New likelihood:" + calcLikelihood);
                throw new IllegalStateException("The likelihood went down during EM.");
            }
            i++;
            progressManager.setTitle("Iteraction " + i);
            if (i % this._plotIterCount == 0) {
                System.out.println(String.valueOf(i) + ", likelihood=" + calcLikelihood);
                System.out.println("mates:" + this._matesDist);
                System.out.println("foes:" + this._foesDist);
                System.out.println("non-mates:" + this._nonDist);
                System.out.println("p-mates:" + this._matesP);
                System.out.println("p-foes:" + this._foesP);
            }
        }
    }
}
