package org.jskat.ai.nn.util;

import java.io.BufferedReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.StringTokenizer;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/* loaded from: input_file:org/jskat/ai/nn/util/NeuralNetwork.class */
public class NeuralNetwork {
    private static Log log = LogFactory.getLog(NeuralNetwork.class);
    Random rand = new Random();
    private long iterations = 0;
    private double learningRate = 0.3d;
    private List<Layer> layers = new ArrayList();
    private NetworkTopology topo;

    public NeuralNetwork() {
        initializeVariables();
    }

    public NeuralNetwork(NetworkTopology networkTopology) {
        this.topo = networkTopology;
        initializeVariables();
    }

    private void initializeVariables() {
        this.iterations = 0L;
        this.layers.clear();
        if (this.topo != null) {
            this.layers.add(new InputLayer(this.topo.getInputNeuronCount()));
            for (int i = 0; i < this.topo.getHiddenLayerCount(); i++) {
                this.layers.add(new HiddenLayer(this.topo.getHiddenNeuronCount(i)));
            }
            this.layers.add(new OutputLayer(this.topo.getOutputNeuronCount()));
        } else {
            this.layers.add(new InputLayer(24));
            this.layers.add(new HiddenLayer(10));
            this.layers.add(new OutputLayer(1));
        }
        connectAllLayers();
    }

    private void connectAllLayers() {
        for (int i = 0; i < this.layers.size() - 1; i++) {
            connectLayers(this.layers.get(i), this.layers.get(i + 1));
        }
    }

    private void connectLayers(Layer layer, Layer layer2) {
        ArrayList<Double> arrayList = new ArrayList<>();
        for (int i = 0; i < layer.getNeurons().size() * layer2.getNeurons().size(); i++) {
            arrayList.add(new Double(getLittleRandomWeightValue()));
        }
        connectLayers(layer, layer2, arrayList);
    }

    private double getLittleRandomWeightValue() {
        return (this.rand.nextDouble() - 0.5d) * 0.1d;
    }

    private void connectLayers(Layer layer, Layer layer2, ArrayList<Double> arrayList) {
        int i = 0;
        for (Neuron neuron : layer.getNeurons()) {
            for (Neuron neuron2 : layer2.getNeurons()) {
                Weight weight = new Weight(neuron, neuron2, arrayList.get(i).doubleValue());
                neuron.addOutgoingWeight(weight);
                neuron2.addIncomingWeight(weight);
                i++;
            }
        }
    }

    private void setInputParameter(double[] dArr) {
        ((InputLayer) this.layers.get(0)).setInputParameter(dArr);
    }

    private void propagateForward() {
        for (int i = 1; i < this.layers.size(); i++) {
            Iterator<Neuron> it = this.layers.get(i).getNeurons().iterator();
            while (it.hasNext()) {
                it.next().calcActivationValue();
            }
        }
    }

    private void setOutputParameter(double[] dArr) {
        ((OutputLayer) this.layers.get(this.layers.size() - 1)).setOuputParameter(dArr, this.learningRate);
    }

    public double getAvgDiff() {
        return ((OutputLayer) this.layers.get(this.layers.size() - 1)).getAvgDiff();
    }

    private void propagateBackward() {
        for (int size = this.layers.size() - 2; size > 0; size--) {
            Iterator<Neuron> it = this.layers.get(size).getNeurons().iterator();
            while (it.hasNext()) {
                it.next().adjustWeights(this.learningRate);
            }
        }
        this.iterations++;
    }

    public synchronized double adjustWeights(double[] dArr, double[] dArr2) {
        setInputParameter(dArr);
        propagateForward();
        setOutputParameter(dArr2);
        propagateBackward();
        return getAvgDiff();
    }

    public synchronized void resetNetwork() {
        Iterator<Layer> it = this.layers.iterator();
        while (it.hasNext()) {
            Iterator<Neuron> it2 = it.next().getNeurons().iterator();
            while (it2.hasNext()) {
                Iterator<Weight> it3 = it2.next().incomingWeights.iterator();
                while (it3.hasNext()) {
                    it3.next().setWeightValue(getLittleRandomWeightValue());
                }
            }
        }
        this.iterations = 0L;
    }

    public synchronized double getPredictedOutcome(double[] dArr) {
        setInputParameter(dArr);
        propagateForward();
        return getOutputValue(0);
    }

    public long getIterations() {
        return this.iterations;
    }

    private double getOutputValue(int i) {
        return this.layers.get(this.layers.size() - 1).getNeurons().get(i).getActivationValue();
    }

    public boolean saveNetwork(String str) {
        boolean z = false;
        FileWriter fileWriter = null;
        try {
            try {
                fileWriter = new FileWriter(str, false);
                fileWriter.write(toString());
                if (fileWriter != null) {
                    try {
                        fileWriter.close();
                        z = true;
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                }
            } catch (IOException e2) {
                e2.printStackTrace();
                if (fileWriter != null) {
                    try {
                        fileWriter.close();
                        z = true;
                    } catch (IOException e3) {
                        e3.printStackTrace();
                    }
                }
            }
            return z;
        } catch (Throwable th) {
            if (fileWriter != null) {
                try {
                    fileWriter.close();
                } catch (IOException e4) {
                    e4.printStackTrace();
                }
            }
            throw th;
        }
    }

    public void loadNetwork(String str) {
        ArrayList<String> arrayList = new ArrayList<>();
        InputStream inputStream = null;
        try {
            inputStream = getClass().getResourceAsStream(str);
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                log.debug(readLine);
                arrayList.add(readLine);
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        if (inputStream != null) {
            try {
                inputStream.close();
            } catch (IOException e2) {
                e2.printStackTrace();
            }
        }
        log.debug(arrayList.size() + " lines read...");
        this.iterations = getIterations(arrayList);
        this.topo = getTopology(arrayList);
        this.layers = getLayers(this.topo, arrayList);
        log.debug(this);
    }

    private static long getIterations(ArrayList<String> arrayList) {
        return Long.parseLong(arrayList.get(1));
    }

    private static NetworkTopology getTopology(ArrayList<String> arrayList) {
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int[] iArr = null;
        Iterator<String> it = arrayList.iterator();
        while (it.hasNext()) {
            String next = it.next();
            log.debug("Current line: " + next + " length: " + next.length());
            if (!next.equals("iterations")) {
                if (!next.equals("input")) {
                    if (!next.substring(0, 6).equals("hidden")) {
                        if (!next.equals("output")) {
                            if (next.equals("weights")) {
                                break;
                            }
                        } else {
                            log.debug("parsing output node count");
                            i2 = Integer.parseInt(it.next());
                            log.debug(i2 + " output nodes");
                        }
                    } else {
                        log.debug("parsing hidden layer count");
                        StringTokenizer stringTokenizer = new StringTokenizer(next);
                        stringTokenizer.nextToken();
                        i3 = Integer.parseInt(stringTokenizer.nextToken());
                        log.debug(i3 + " hidden layers");
                        iArr = new int[i3];
                        for (int i4 = 0; i4 < i3; i4++) {
                            iArr[i4] = Integer.parseInt(it.next());
                        }
                    }
                } else {
                    log.debug("parsing input node count");
                    i = Integer.parseInt(it.next());
                    log.debug(i + " input nodes");
                }
            } else {
                it.next();
            }
        }
        return new NetworkTopology(i, i2, i3, iArr);
    }

    private ArrayList<Layer> getLayers(NetworkTopology networkTopology, ArrayList<String> arrayList) {
        ArrayList<Layer> arrayList2 = new ArrayList<>();
        arrayList2.add(new InputLayer(networkTopology.getInputNeuronCount()));
        for (int i = 0; i < networkTopology.getHiddenLayerCount(); i++) {
            arrayList2.add(new HiddenLayer(networkTopology.getHiddenNeuronCount(i)));
        }
        arrayList2.add(new OutputLayer(networkTopology.getOutputNeuronCount()));
        Iterator<String> it = arrayList.iterator();
        while (it.hasNext()) {
            String next = it.next();
            log.debug("Current line: " + next + " length: " + next.length());
            if (next.equals("input layer")) {
                log.debug("parsing input layer weights");
                StringTokenizer stringTokenizer = new StringTokenizer(it.next());
                ArrayList<Double> arrayList3 = new ArrayList<>();
                while (stringTokenizer.hasMoreTokens()) {
                    arrayList3.add(new Double(stringTokenizer.nextToken()));
                }
                connectLayers(arrayList2.get(0), arrayList2.get(1), arrayList3);
            } else if (next.equals("hidden layer")) {
                log.debug("parsing hidden layer weights");
                StringTokenizer stringTokenizer2 = new StringTokenizer(it.next());
                ArrayList<Double> arrayList4 = new ArrayList<>();
                while (stringTokenizer2.hasMoreTokens()) {
                    arrayList4.add(new Double(stringTokenizer2.nextToken()));
                }
                connectLayers(arrayList2.get(1), arrayList2.get(2), arrayList4);
            }
        }
        return arrayList2;
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("iterations\n");
        stringBuffer.append(this.iterations);
        stringBuffer.append('\n');
        stringBuffer.append("topology\n");
        stringBuffer.append(this.topo);
        stringBuffer.append("weights\n");
        Iterator<Layer> it = this.layers.iterator();
        while (it.hasNext()) {
            stringBuffer.append(it.next());
        }
        return stringBuffer.toString();
    }
}
