package dali.learning;

import dali.prefs.PeerData;
import java.io.Serializable;
import java.util.Random;

/* loaded from: input_file:dali/learning/Perceptron.class */
public class Perceptron extends NeuralNetwork implements Serializable {
    public static final int CP_SIGMOID_OUTPUTS = 0;
    public static final int CP_LINEAR_OUTPUTS = 1;
    public static final int CASEBYCASE_MODE = 0;
    public static final int BATCH_MODE = 1;
    protected int configParams;
    protected int[] nodeCounts;
    protected float[][] layerWeights;
    protected float[][] layerOutputs;
    protected float learningRate;
    protected float momentum;
    protected BackpropData backpropData;

    public Perceptron(int[] iArr) {
        this(iArr, null, 0);
    }

    public Perceptron(int[] iArr, Random random) {
        this(iArr, random, 0);
    }

    public Perceptron(int[] iArr, Random random, int i) {
        this.configParams = 0;
        this.learningRate = 0.01f;
        this.momentum = 0.0f;
        this.backpropData = null;
        this.configParams = i;
        initNetDataStructures(iArr);
        initRandomWeights(random);
    }

    public Perceptron(PerceptronWeights perceptronWeights) {
        this(perceptronWeights, 0);
    }

    public Perceptron(PerceptronWeights perceptronWeights, int i) {
        this.configParams = 0;
        this.learningRate = 0.01f;
        this.momentum = 0.0f;
        this.backpropData = null;
        this.configParams = i;
        initNetDataStructures(perceptronWeights.getNodeCounts());
        setAllWeights(perceptronWeights);
    }

    /* JADX WARN: Type inference failed for: r1v13, types: [float[], float[][]] */
    /* JADX WARN: Type inference failed for: r1v6, types: [float[], float[][]] */
    protected void initNetDataStructures(int[] iArr) {
        this.nodeCounts = (int[]) iArr.clone();
        this.layerWeights = new float[this.nodeCounts.length];
        for (int i = 1; i < this.nodeCounts.length; i++) {
            this.layerWeights[i] = new float[(1 + this.nodeCounts[i - 1]) * this.nodeCounts[i]];
        }
        this.layerOutputs = new float[this.nodeCounts.length];
        for (int i2 = 0; i2 < this.nodeCounts.length; i2++) {
            this.layerOutputs[i2] = new float[1 + this.nodeCounts[i2]];
        }
    }

    public void initWeights(float[][] fArr) {
        for (int i = 1; i < this.layerWeights.length; i++) {
            for (int i2 = 0; i2 < this.layerWeights[i].length; i2++) {
                this.layerWeights[i][i2] = fArr[i][i2];
            }
        }
    }

    public void initRandomWeights(Random random) {
        initRandomWeights(random, 0.5f);
    }

    public void initRandomWeights(Random random, float f) {
        if (random == null) {
            random = new Random();
        }
        for (int i = 1; i < this.layerWeights.length; i++) {
            for (int i2 = 0; i2 < this.layerWeights[i].length; i2++) {
                this.layerWeights[i][i2] = 2.0f * f * (random.nextFloat() - 0.5f);
            }
        }
    }

    @Override // dali.learning.NeuralNetwork
    public int getInputCount() {
        return this.nodeCounts[0];
    }

    public float getInputValue(int i) {
        return this.layerOutputs[0][1 + i];
    }

    @Override // dali.learning.NeuralNetwork
    public void setInputValue(int i, float f) {
        this.layerOutputs[0][1 + i] = f;
        calculateOutputs();
    }

    @Override // dali.learning.NeuralNetwork
    public void setAllInputValues(float[] fArr) {
        for (int i = 0; i < this.nodeCounts[0]; i++) {
            this.layerOutputs[0][1 + i] = fArr[i];
        }
        calculateOutputs();
    }

    @Override // dali.learning.NeuralNetwork
    public int getOutputCount() {
        return this.nodeCounts[this.nodeCounts.length - 1];
    }

    @Override // dali.learning.NeuralNetwork
    public float getOutputValue(int i) {
        return this.layerOutputs[this.layerOutputs.length - 1][1 + i];
    }

    @Override // dali.learning.NeuralNetwork
    public float[] getAllOutputValues() {
        int outputCount = getOutputCount();
        int length = this.layerOutputs.length - 1;
        float[] fArr = new float[outputCount];
        for (int i = 0; i < outputCount; i++) {
            fArr[i] = this.layerOutputs[length][1 + i];
        }
        return fArr;
    }

    public void calculateOutputs() {
        this.layerOutputs[0][0] = -1.0f;
        for (int i = 1; i < this.nodeCounts.length; i++) {
            int i2 = 0;
            this.layerOutputs[i][0] = -1.0f;
            for (int i3 = 1; i3 <= this.nodeCounts[i]; i3++) {
                int i4 = 0;
                float f = 0.0f;
                while (i4 <= this.nodeCounts[i - 1]) {
                    f += this.layerOutputs[i - 1][i4] * this.layerWeights[i][i2];
                    i4++;
                    i2++;
                }
                if (this.backpropData != null) {
                    this.backpropData.layerInputs[i][i3] = f;
                }
                this.layerOutputs[i][i3] = activationFn(i, f);
            }
        }
    }

    public float getLearningRate() {
        return this.learningRate;
    }

    public void setLearningRate(float f) {
        this.learningRate = f;
    }

    public float getMomentum() {
        return this.momentum;
    }

    public void setMomentum(float f) {
        this.momentum = f;
    }

    public float trainNetwork(float[] fArr, float[] fArr2, int i) {
        checkBackpropData();
        setAllInputValues(fArr);
        this.backpropData.incrementCaseCount();
        float calculateOutputNodeDeltas = calculateOutputNodeDeltas(fArr2);
        calculateHiddenNodeDeltas();
        calculateLayerWeightUpdates();
        this.backpropData.addMomentumToDeltas(this.momentum);
        if (i != 1) {
            this.backpropData.applyLayerWeightDeltas(this.layerWeights);
        }
        return calculateOutputNodeDeltas;
    }

    public float trainNetwork(float[] fArr, int i) {
        checkBackpropData();
        this.backpropData.incrementCaseCount();
        float outputNodeDeltas = setOutputNodeDeltas(fArr);
        calculateHiddenNodeDeltas();
        calculateLayerWeightUpdates();
        this.backpropData.addMomentumToDeltas(this.momentum);
        if (i != 1) {
            this.backpropData.applyLayerWeightDeltas(this.layerWeights);
        }
        return outputNodeDeltas;
    }

    protected void checkBackpropData() {
        if (this.backpropData == null) {
            this.backpropData = new BackpropData(this.nodeCounts);
        }
    }

    public float[] getInpDeltasFromOutputs(float[] fArr, float[] fArr2) {
        checkBackpropData();
        setAllInputValues(fArr);
        calculateOutputNodeDeltas(fArr2);
        calculateHiddenNodeDeltas();
        return calculateInputNodeDeltas();
    }

    public float[] getInpDeltasFromOutDeltas(float[] fArr, float[] fArr2) {
        checkBackpropData();
        setAllInputValues(fArr);
        setOutputNodeDeltas(fArr2);
        calculateHiddenNodeDeltas();
        return calculateInputNodeDeltas();
    }

    public void endBatch() {
        this.backpropData.applyLayerWeightDeltas(this.layerWeights);
    }

    protected float setOutputNodeDeltas(float[] fArr) {
        int length = this.nodeCounts.length - 1;
        float f = 0.0f;
        for (int i = 0; i < getOutputCount(); i++) {
            this.backpropData.backpropDeltas[length][i] = dActivationFn(length, this.backpropData.layerInputs[length][i + 1], getOutputValue(i)) * fArr[i];
            f += fArr[i] * fArr[i];
        }
        return f;
    }

    protected float calculateOutputNodeDeltas(float[] fArr) {
        int length = this.nodeCounts.length - 1;
        float f = 0.0f;
        for (int i = 1; i <= this.nodeCounts[length]; i++) {
            float f2 = this.backpropData.layerInputs[length][i];
            float f3 = this.layerOutputs[length][i];
            float f4 = fArr[i - 1] - f3;
            this.backpropData.backpropDeltas[length][i - 1] = dActivationFn(length, f2, f3) * 2.0f * f4;
            f += f4 * f4;
        }
        return f;
    }

    protected void calculateHiddenNodeDeltas() {
        for (int length = this.nodeCounts.length - 2; length >= 1; length--) {
            for (int i = 0; i < this.nodeCounts[length]; i++) {
                float f = 0.0f;
                for (int i2 = 0; i2 < this.nodeCounts[length + 1]; i2++) {
                    f += this.backpropData.backpropDeltas[length + 1][i2] * getLayerWeight(length + 1, i2, i + 1);
                }
                this.backpropData.backpropDeltas[length][i] = dActivationFn(length, this.backpropData.layerInputs[length][1 + i], this.layerOutputs[length][1 + i]) * f;
            }
        }
    }

    protected float[] calculateInputNodeDeltas() {
        float[] fArr = new float[this.nodeCounts[0]];
        for (int i = 0; i < this.nodeCounts[0]; i++) {
            float f = 0.0f;
            for (int i2 = 0; i2 < this.nodeCounts[0 + 1]; i2++) {
                f += this.backpropData.backpropDeltas[0 + 1][i2] * getLayerWeight(0 + 1, i2, i + 1);
            }
            fArr[i] = f;
        }
        return fArr;
    }

    protected void calculateLayerWeightUpdates() {
        for (int i = 1; i < this.nodeCounts.length; i++) {
            int i2 = 0;
            for (int i3 = 1; i3 <= this.nodeCounts[i]; i3++) {
                for (int i4 = 0; i4 <= this.nodeCounts[i - 1]; i4++) {
                    float[] fArr = this.backpropData.layerWeightDeltas[i];
                    int i5 = i2;
                    fArr[i5] = fArr[i5] + (this.learningRate * this.backpropData.backpropDeltas[i][i3 - 1] * this.layerOutputs[i - 1][i4]);
                    i2++;
                }
            }
        }
    }

    protected float getLayerWeight(int i, int i2, int i3) {
        try {
            return this.layerWeights[i][(i2 * (1 + this.nodeCounts[i - 1])) + i3];
        } catch (RuntimeException e) {
            System.out.println(new StringBuffer().append(i).append("\t").append(i2).append("\t").append(i3).toString());
            throw e;
        }
    }

    public PerceptronWeights getAllWeights() {
        return new PerceptronWeights(this.nodeCounts, this.layerWeights);
    }

    public void setAllWeights(PerceptronWeights perceptronWeights) {
        float[][] layerWeights = perceptronWeights.getLayerWeights();
        for (int i = 1; i < this.layerWeights.length; i++) {
            for (int i2 = 0; i2 < this.layerWeights[i].length; i2++) {
                this.layerWeights[i][i2] = layerWeights[i][i2];
            }
        }
    }

    public String toString() {
        String stringBuffer = new StringBuffer().append(new StringBuffer().append(PeerData.DEFAULT_SOCKS_PROXY_HOST).append("BEGIN_PERCEPTRON_DATA\n").toString()).append("Network Weights:\n").toString();
        for (int i = 1; i < this.layerWeights.length; i++) {
            String stringBuffer2 = new StringBuffer().append(stringBuffer).append("  Layer ").append(i).append("\n").toString();
            for (int i2 = 0; i2 < this.layerWeights[i].length; i2++) {
                stringBuffer2 = new StringBuffer().append(stringBuffer2).append("\t").append(this.layerWeights[i][i2]).toString();
                if (i2 % (this.nodeCounts[i - 1] + 1) == this.nodeCounts[i - 1]) {
                    stringBuffer2 = new StringBuffer().append(stringBuffer2).append("\n").toString();
                }
            }
            stringBuffer = new StringBuffer().append(stringBuffer2).append("\n").toString();
        }
        String stringBuffer3 = new StringBuffer().append(stringBuffer).append("Network Outputs:\n").toString();
        for (int i3 = 0; i3 < this.layerOutputs.length; i3++) {
            String stringBuffer4 = new StringBuffer().append(stringBuffer3).append("  Layer ").append(i3).append("\n").toString();
            for (int i4 = 0; i4 < this.layerOutputs[i3].length; i4++) {
                stringBuffer4 = new StringBuffer().append(stringBuffer4).append("\t").append(this.layerOutputs[i3][i4]).toString();
            }
            stringBuffer3 = new StringBuffer().append(stringBuffer4).append("\n").toString();
        }
        String stringBuffer5 = new StringBuffer().append(new StringBuffer().append(stringBuffer3).append("\n").toString()).append("END_PERCEPTRON_DATA\n").toString();
        if (this.backpropData != null) {
            stringBuffer5 = new StringBuffer().append(stringBuffer5).append(this.backpropData.toString()).toString();
        }
        return stringBuffer5;
    }

    protected float activationFn(int i, float f) {
        float f2 = 0.0f;
        if (i == this.nodeCounts.length - 1) {
            switch (this.configParams) {
                case 0:
                    f2 = sigmoidFn(f);
                    break;
                case 1:
                    f2 = f;
                    break;
            }
        } else {
            f2 = sigmoidFn(f);
        }
        return f2;
    }

    protected float dActivationFn(int i, float f, float f2) {
        float f3 = 0.0f;
        if (i == this.nodeCounts.length - 1) {
            switch (this.configParams) {
                case 0:
                    f3 = dSigmoidFn(f, f2);
                    break;
                case 1:
                    f3 = 1.0f;
                    break;
            }
        } else {
            f3 = dSigmoidFn(f, f2);
        }
        return f3;
    }

    protected float sigmoidFn(float f) {
        float f2;
        if (Math.abs(f) > 20.0f) {
            f2 = f > 0.0f ? 1.0f : -1.0f;
        } else {
            double exp = Math.exp(f);
            double exp2 = Math.exp(-f);
            f2 = (float) ((exp - exp2) / (exp + exp2));
        }
        return f2;
    }

    protected float dSigmoidFn(float f, float f2) {
        return 1.0f - (f2 * f2);
    }
}
