package edu.tau.compbio.expression.deconvolute;

import edu.tau.compbio.expression.ds.ExtendedDataMatrix;
import edu.tau.compbio.math.VecCalc;
import edu.tau.compbio.util.CollectionUtil;
import edu.tau.compbio.util.OutputUtilities;
import java.util.Arrays;
import java.util.Collection;
import java.util.Random;
import java.util.Set;

/* loaded from: input_file:edu/tau/compbio/expression/deconvolute/SimpleSADeconvolutor.class */
public class SimpleSADeconvolutor implements ExpressionDecovolutor {
    protected ExtendedDataMatrix _basisMat;
    protected ExtendedDataMatrix _combineMat;
    protected float[] _initialCorrelations;
    protected float[] _finalCorrelations;
    protected Collection<String> _ids;
    protected float[][] _fracs = null;
    protected double[][] _preds = null;
    protected double[][] _diffs = null;
    protected double[] _probeDiffs = null;
    protected double _score = Double.NaN;
    protected double _epsilon = 1.0d;
    protected int _iters = 20000;
    protected double _step = 0.009999999776482582d;
    protected Random _random = new Random();
    protected int[] _basisInds = null;
    protected int[] _combinedInds = null;
    int iter = 0;

    public void setMaxIters(int i) {
        this._iters = i;
    }

    @Override // edu.tau.compbio.expression.deconvolute.ExpressionDecovolutor
    public float[][] deconvolute(Collection<String> collection, ExtendedDataMatrix extendedDataMatrix, ExtendedDataMatrix extendedDataMatrix2) {
        this._ids = collection;
        this._basisMat = extendedDataMatrix;
        this._combineMat = extendedDataMatrix2;
        initInds();
        this._fracs = new float[extendedDataMatrix2.sizeConditions()][extendedDataMatrix.sizeConditions()];
        initFracs();
        this._probeDiffs = new double[this._combinedInds.length];
        this._diffs = new double[this._combinedInds.length][extendedDataMatrix2.sizeConditions()];
        this._preds = new double[this._combinedInds.length][extendedDataMatrix2.sizeConditions()];
        this._score = computeScoreAndPredictions();
        this._preds = computePredictions();
        this._initialCorrelations = computeCorrelations();
        int i = 0;
        System.out.println("Initial score: " + this._score);
        this.iter = 0;
        while (this.iter < this._iters) {
            if (this.iter % 1000 == 0) {
                System.out.println(this.iter);
            }
            if (iterate()) {
                i++;
            }
            this.iter++;
        }
        System.out.println("Optimization finished after " + this.iter + " iterations with " + i + " improvements");
        System.out.println("Final score: " + this._score);
        this._preds = computePredictions();
        this._finalCorrelations = computeCorrelations();
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < this._finalCorrelations.length; i4++) {
            if (this._finalCorrelations[i4] > this._initialCorrelations[i4]) {
                i2++;
            } else {
                i3++;
            }
        }
        System.out.println("Initial correlations:");
        System.out.println(OutputUtilities.buildString(this._initialCorrelations));
        System.out.println("Final correlations:");
        System.out.println(OutputUtilities.buildString(this._finalCorrelations));
        System.out.println("Better: " + i2 + "; Worse: " + i3);
        return this._fracs;
    }

    protected void checkScore() {
        double computeScoreAndPredictions = computeScoreAndPredictions();
        if (Math.abs(computeScoreAndPredictions - this._score) > 0.01d) {
            throw new IllegalStateException("Iteration: " + this.iter + "; The score has gone bogus. Should be:" + computeScoreAndPredictions + " and is: " + this._score);
        }
    }

    public float[] computeCorrelations() {
        float[] fArr = new float[this._combineMat.sizeConditions()];
        float[] fArr2 = new float[this._combinedInds.length];
        float[] fArr3 = new float[this._combinedInds.length];
        for (int i = 0; i < this._combineMat.sizeConditions(); i++) {
            for (int i2 = 0; i2 < this._combinedInds.length; i2++) {
                fArr2[i2] = this._combineMat.getDataRow(this._combinedInds[i2])[i];
                fArr3[i2] = (float) this._preds[i2][i];
            }
            fArr[i] = VecCalc.calcCorrelationCoefficient(fArr2, fArr3);
        }
        return fArr;
    }

    public double[][] computePredictions() {
        double[][] dArr = new double[this._combineMat.sizeProbes()][this._combineMat.sizeConditions()];
        for (int i = 0; i < this._combinedInds.length; i++) {
            for (int i2 = 0; i2 < this._basisMat.sizeConditions(); i2++) {
                float[] dataRow = this._basisMat.getDataRow(this._basisInds[i]);
                for (int i3 = 0; i3 < this._combineMat.sizeConditions(); i3++) {
                    double[] dArr2 = dArr[i];
                    int i4 = i3;
                    dArr2[i4] = dArr2[i4] + (this._fracs[i3][i2] * dataRow[i2]);
                }
            }
        }
        return dArr;
    }

    public void initInds() {
        Set<String> overlap = CollectionUtil.getOverlap(new Collection[]{this._ids, this._basisMat.getProbeIdsSet(), this._combineMat.getProbeIdsSet()});
        this._basisInds = new int[overlap.size()];
        this._combinedInds = new int[overlap.size()];
        int i = 0;
        for (String str : overlap) {
            this._basisInds[i] = this._basisMat.getProbeIndex(str);
            this._combinedInds[i] = this._combineMat.getProbeIndex(str);
            i++;
        }
    }

    public boolean iterate() {
        boolean z = false;
        int nextInt = this._random.nextInt(this._combineMat.sizeConditions());
        double d = this._random.nextBoolean() ? this._step : -this._step;
        int i = -1;
        int i2 = -1;
        while (!z) {
            i = this._random.nextInt(this._basisMat.sizeConditions());
            i2 = this._random.nextInt(this._basisMat.sizeConditions());
            if (this._fracs[nextInt][i] - d >= 0.0d && this._fracs[nextInt][i2] + d >= 0.0d && i2 != i) {
                z = true;
            }
        }
        double d2 = 0.0d;
        double[] dArr = new double[this._combinedInds.length];
        for (int i3 = 0; i3 < this._combinedInds.length; i3++) {
            float[] dataRow = this._basisMat.getDataRow(this._basisInds[i3]);
            double d3 = (d * dataRow[i2]) - (d * dataRow[i]);
            double d4 = ((this._diffs[i3][nextInt] - d3) * (this._diffs[i3][nextInt] - d3)) - (this._diffs[i3][nextInt] * this._diffs[i3][nextInt]);
            dArr[i3] = d3;
            d2 += d4;
        }
        if (d2 >= 0.0d) {
            return false;
        }
        this._fracs[nextInt][i2] = (float) (r0[r1] + d);
        this._fracs[nextInt][i] = (float) (r0[r1] - d);
        for (int i4 = 0; i4 < this._combinedInds.length; i4++) {
            double[] dArr2 = this._diffs[i4];
            dArr2[nextInt] = dArr2[nextInt] - dArr[i4];
        }
        this._score += d2;
        return true;
    }

    protected void initFracs() {
        for (int i = 0; i < this._fracs.length; i++) {
            for (int i2 = 0; i2 < this._fracs[i].length; i2++) {
                this._fracs[i][i2] = 1.0f / this._fracs[i].length;
            }
        }
    }

    public double getScore() {
        return this._score;
    }

    protected double computeScoreAndPredictions() {
        double d = 0.0d;
        for (int i = 0; i < this._combinedInds.length; i++) {
            Arrays.fill(this._preds[i], 0.0d);
            Arrays.fill(this._diffs[i], 0.0d);
            for (int i2 = 0; i2 < this._basisMat.sizeConditions(); i2++) {
                float[] dataRow = this._basisMat.getDataRow(this._basisInds[i]);
                for (int i3 = 0; i3 < this._combineMat.sizeConditions(); i3++) {
                    double[] dArr = this._preds[i];
                    int i4 = i3;
                    dArr[i4] = dArr[i4] + (this._fracs[i3][i2] * dataRow[i2]);
                }
            }
            double d2 = 0.0d;
            float[] dataRow2 = this._combineMat.getDataRow(this._combinedInds[i]);
            for (int i5 = 0; i5 < this._combineMat.sizeConditions(); i5++) {
                this._diffs[i][i5] = dataRow2[i5] - this._preds[i][i5];
                d2 += this._diffs[i][i5] * this._diffs[i][i5];
            }
            this._probeDiffs[i] = d2;
            d += d2;
        }
        return d;
    }

    public float[] getInitialCorrelations() {
        return this._initialCorrelations;
    }

    public float[] getFinalCorrelations() {
        return this._finalCorrelations;
    }
}
