Code viewer for World: Voice Doodle Recognizer

// Cloned by Sagar Ramachandra Murthy on 5 Dec 2021 from World "Doodle Recognizer" by Sagar Ramachandra Murthy 
// Please leave this clone trail here.

//Data Reference:
// https://quickdraw.withgoogle.com/data

const len = 784;
const totalData = 1000;

const CAT = 0;
const RAINBOW = 1;
const TRAIN = 2;
const APPLE = 3;
const BREAD = 4;
const DONUT = 5;
const CACTUS = 6;
const GUITAR = 7;
const TSHIRT = 8;
const TV = 9;

let catsData;
let trainsData;
let rainbowsData;
let applesData;
let breadsData;
let donutsData;
let cactussData;
let guitarsData;
let tshirtsData;
let tvsData;

let cats = {};
let trains = {};
let rainbows = {};
let apples = {};
let breads = {};
let donuts = {};
let cactuss = {};
let guitars = {};
let tshirts = {};
let tvs = {};

let nn;

//Voice
/*! p5.speech.js v0.0.1 2015-06-12 */
/* updated v0.0.2 2017-10-17 */
/**
 * @module p5.speech
 * @submodule p5.speech
 * @for p5.speech
 * @main
 */
/**
 *  p5.speech
 *  R. Luke DuBois (dubois@nyu.edu)
 *  ABILITY Lab / Brooklyn Experimental Media Center
 *  New York University
 *  The MIT License (MIT).
 *  
 *  https://github.com/IDMNYU/p5.js-speech
 *
 *  Web Speech API: https://dvcs.w3.org/hg/speech-api/raw-file/tip/speechapi.html
 *  Web Speech Recognition API: https://dvcs.w3.org/hg/speech-api/raw-file/tip/speechapi.html
 */
(function (root, factory) {
  if (typeof define === 'function' && define.amd)
    define('p5.speech', ['p5'], function (p5) { (factory(p5));});
  else if (typeof exports === 'object')
    factory(require('../p5'));
  else
    factory(root['p5']);
}(this, function (p5) {
// =============================================================================
//                         p5.Speech
// =============================================================================


  /**
   * Base class for a Speech Synthesizer
   *
   * @class p5.Speech
   * @constructor
   */
  p5.Speech = function(_dv, _callback) {

    //
    // speech synthesizers consist of a single synthesis engine
    // per window instance, and a variable number of 'utterance'
    // objects, which can be cached and re-used for, e.g.
    // auditory UI.
    //
    // this implementation assumes a monolithic (one synth, 
    // one phrase at a time) system.
    //

    // make a speech synthizer (this will load voices):
    this.synth = window.speechSynthesis;

    // make an utterance to use with this synthesizer:
    this.utterance = new SpeechSynthesisUtterance();

    this.isLoaded = 0; // do we have voices yet?

    // do we queue new utterances upon firing speak() 
    // or interrupt what's speaking:
    this.interrupt = false; 

    // callback properties to be filled in within the p5 sketch
    // if the author needs custom callbacks:
    this.onLoad; // fires when voices are loaded and synth is ready
    this.onStart; // fires when an utterance begins...
    this.onPause; // ...is paused...
    this.onResume; // ...resumes...
    this.onEnd; // ...and ends.

    this.voices = []; // array of available voices (dependent on browser/OS)

    // first parameter of constructor is an initial voice selector
    this.initvoice;
    if(_dv !== undefined) this.initvoice=_dv;
    if(_callback !== undefined) this.onLoad =_callback;

    var that = this; // some bullshit

    // onvoiceschanged() fires automatically when the synthesizer
    // is configured and has its voices loaded.  you don't need
    // to wait for this if you're okay with the default voice.
    // 
    // we use this function to load the voice array and bind our
    // custom callback functions.
    window.speechSynthesis.onvoiceschanged = function() {
      if(that.isLoaded==0) { // run only once
        that.voices = window.speechSynthesis.getVoices();
        that.isLoaded = 1; // we're ready
        console.log("p5.Speech: voices loaded!");

        if(that.initvoice!=undefined) {
          that.setVoice(that.initvoice); // set a custom initial voice
          console.log("p5.Speech: initial voice: " + that.initvoice);
        }

        // fire custom onLoad() callback, if it exists:
        if(that.onLoad!=undefined) that.onLoad();

        //
        // bind other custom callbacks:
        //

        that.utterance.onstart = function(e) {
          //console.log("STARTED");
          if(that.onStart!=undefined) that.onStart(e);     
        };
        that.utterance.onpause = function(e) {
          //console.log("PAUSED");
          if(that.onPause!=undefined) that.onPause(e);
        };
        that.utterance.onresume = function(e) {
          //console.log("RESUMED");
          if(that.onResume!=undefined) that.onResume(e);
        };
        that.utterance.onend = function(e) {
          //console.log("ENDED");
          if(that.onEnd!=undefined) that.onEnd(e); 
        };
      }
    };

  };     // end p5.Speech constructor


  // listVoices() - dump voice names to javascript console:
  p5.Speech.prototype.listVoices = function() {
    if(this.isLoaded)
    {
      for(var i = 0;i<this.voices.length;i++)
      {
        console.log(this.voices[i].name);
      }
    }
    else
    {
    	console.log("p5.Speech: voices not loaded yet!")
    }
  };

  // setVoice() - assign voice to speech synthesizer, by name
  // (using voices found in the voices[] array), or by index.
  p5.Speech.prototype.setVoice = function(_v) {
    // type check so you can set by label or by index:
    if(typeof(_v)=='string') this.utterance.voice = this.voices.filter(function(v) { return v.name == _v; })[0];
    else if(typeof(_v)=='number') this.utterance.voice = this.voices[Math.min(Math.max(_v,0),this.voices.length-1)];
  };

  // volume of voice. API range 0.0-1.0.
  p5.Speech.prototype.setVolume = function(_v) {
    this.utterance.volume = Math.min(Math.max(_v, 0.0), 1.0);
  };

  // rate of voice.  not all voices support this feature.
  // API range 0.1-2.0.  voice will crash out of bounds.
  p5.Speech.prototype.setRate = function(_v) {
    this.utterance.rate = Math.min(Math.max(_v, 0.1), 2.0);
  };

  // pitch of voice.  not all voices support this feature.
  // API range >0.0-2.0.  voice will crash out of bounds.
  p5.Speech.prototype.setPitch = function(_v) {
    this.utterance.pitch = Math.min(Math.max(_v, 0.01), 2.0);
  };

  // sets the language of the voice.
  p5.Speech.prototype.setLang = function(_lang) {
    this.utterance.lang = _lang;
}

  // speak a phrase through the current synthesizer:
  p5.Speech.prototype.speak = function(_phrase) {
    if(this.interrupt) this.synth.cancel();
    this.utterance.text = _phrase;

    this.synth.speak(this.utterance);
  };

  // not working...
  p5.Speech.prototype.pause = function() {
    this.synth.pause();
  };

  // not working...
  p5.Speech.prototype.resume = function() {
    this.synth.resume();
  };

  // stop current utterance:
  p5.Speech.prototype.stop = function() {
    // not working...
    //this.synth.stop();
    this.synth.cancel();
  };

  // kill synthesizer completely, clearing any queued utterances:
  p5.Speech.prototype.cancel = function() {
    this.synth.cancel(); // KILL SYNTH
  };

  // Setting callbacks with functions instead
  p5.Speech.prototype.started = function(_cb) {
   this.onStart = _cb;
  }

  p5.Speech.prototype.ended = function(_cb) {
    this.onEnd = _cb;
  }

  p5.Speech.prototype.paused = function(_cb) {
    this.onPause = _cb;
  }

  p5.Speech.prototype.resumed = function(_cb) {
    this.onResume = _cb;
  }

// =============================================================================
//                         p5.SpeechRec
// =============================================================================


  /**
   * Base class for a Speech Recognizer
   *
   * @class p5.SpeechRec
   * @constructor
   */
  p5.SpeechRec = function(_lang, _callback) {

    //
    // speech recognition consists of a recognizer object per 
    // window instance that returns a JSON object containing
    // recognition.  this JSON object grows when the synthesizer
    // is in 'continuous' mode, with new recognized phrases
    // appended into an internal array.
    //
    // this implementation returns the full JSON, but also a set
    // of simple, query-ready properties containing the most
    // recently recognized speech.
    //

    // make a recognizer object.
    if('webkitSpeechRecognition' in window) {
      this.rec = new webkitSpeechRecognition();
    }
    else {
      this.rec = new Object();
      console.log("p5.SpeechRec: webkitSpeechRecognition not supported in this browser.");
    }

    // first parameter is language model (defaults to empty=U.S. English)
    // no list of valid models in API, but it must use BCP-47.
    // here's some hints:
    // http://stackoverflow.com/questions/14257598/what-are-language-codes-for-voice-recognition-languages-in-chromes-implementati
    if(_lang !== undefined) this.rec.lang=_lang;

    // callback properties to be filled in within the p5 sketch
    // if the author needs custom callbacks:
    this.onResult; // fires when something has been recognized
    this.onStart; // fires when the recognition system is started...
    this.onError; // ...has a problem (e.g. the mic is shut off)...
    this.onEnd; // ...and ends (in non-continuous mode).
    if(_callback !== undefined) this.onResult=_callback;

    // recognizer properties:

    // continous mode means the object keeps recognizing speech,
    // appending new tokens to the internal JSON.
    this.continuous = false; 
    // interimResults means the object will report (i.e. fire its
    // onresult() callback) more frequently, rather than at pauses
    // in microphone input.  this gets you quicker, but less accurate,
    // results.
    this.interimResults = false;

    // result data:

    // resultJSON:
    // this is a full JSON returned by onresult().  it consists of a 
    // SpeechRecognitionEvent object, which contains a (wait for it)
    // SpeechRecognitionResultList.  this is an array.  in continuous
    // mode, it will be appended to, not cleared.  each element is a 
    // SpeechRecognition result, which contains a (groan)
    // SpeechRecognitionAlternative, containing a 'transcript' property.
    // the 'transcript' is the recognized phrase.  have fun.
    this.resultJSON; 
    // resultValue:
    // validation flag which indicates whether the recognizer succeeded.  
    // this is *not* a metric of speech clarity, but rather whether the
    // speech recognition system successfully connected to and received
    // a response from the server.  you can construct an if() around this
    // if you're feeling worried.
    this.resultValue; 
    // resultValue:
    // the 'transcript' of the most recently recognized speech as a simple
    // string.  this will be blown out and replaced at every firing of the
    // onresult() callback.
    this.resultString; 
    // resultConfidence:
    // the 'confidence' (0-1) of the most recently recognized speech, e.g.
    // that it reflects what was actually spoken.  you can use this to filter
    // out potentially bogus recognition tokens.
    this.resultConfidence; 

    var that = this; // some bullshit

    // onresult() fires automatically when the recognition engine
    // detects speech, or times out trying.
    // 
    // it fills up a JSON array internal to the webkitSpeechRecognition
    // object.  we reference it over in our struct here, and also copy 
    // out the most recently detected phrase and confidence value.
    this.rec.onresult = function(e) { 
      that.resultJSON = e; // full JSON of callback event
      that.resultValue = e.returnValue; // was successful?
      // store latest result in top-level object struct
      that.resultString = e.results[e.results.length-1][0].transcript.trim();
      that.resultConfidence = e.results[e.results.length-1][0].confidence;
      if(that.onResult!=undefined) that.onResult();
    };

    // fires when the recognition system starts (i.e. when you 'allow'
    // the mic to be used in the browser).
    this.rec.onstart = function(e) {
      if(that.onStart!=undefined) that.onStart(e);
    };
    // fires on a client-side error (server-side errors are expressed 
    // by the resultValue in the JSON coming back as 'false').
    this.rec.onerror = function(e) {
      if(that.onError!=undefined) that.onError(e);
    };
    // fires when the recognition finishes, in non-continuous mode.
    this.rec.onend = function() {
      if(that.onEnd!=undefined) that.onEnd();
    };

  }; // end p5.SpeechRec constructor

  // start the speech recognition engine.  this will prompt a 
  // security dialog in the browser asking for permission to 
  // use the microphone.  this permission will persist throughout
  // this one 'start' cycle.  if you need to recognize speech more
  // than once, use continuous mode rather than firing start() 
  // multiple times in a single script.
  p5.SpeechRec.prototype.start = function(_continuous, _interim) {
    if('webkitSpeechRecognition' in window) {
      if(_continuous !== undefined) this.continuous = _continuous;
      if(_interim !== undefined) this.interimResults = _interim;
      this.rec.continuous = this.continuous;
      this.rec.interimResults = this.interimResults;
      this.rec.start();
    }
  };

}));

/*
todo:
* fix callbacks (pause, resume) in synthesizer.
* support speech grammar models for scoped auditory UI.
* support markdown, boundaries, etc for better synthesis tracking.
* support utterance parser for long phrases.
*/

// EOF

// Voice

// Voice
let Xpos = 0, Xv = 0, Ypos = 0, Yv = 0, Stop = false;
var SpeechRec = new p5.SpeechRec();
let pword = "";
SpeechRec.continuous = true;
SpeechRec.interimResults = true;
// Voice

function preload() {
  catsData = loadBytes('/uploads/sagarr1/cats.bin');
  trainsData = loadBytes('/uploads/sagarr1/trains.bin');
  rainbowsData = loadBytes('/uploads/sagarr1/rainbows.bin');
  applesData = loadBytes('/uploads/sagarr1/apple.bin');
  breadsData = loadBytes('/uploads/sagarr1/bread.bin');
  donutsData = loadBytes('/uploads/sagarr1/donut.bin');
  cactussData = loadBytes('/uploads/sagarr1/cactus.bin');
  guitarsData = loadBytes('/uploads/sagarr1/guitar.bin');
  tshirtsData = loadBytes('/uploads/sagarr1/t-shirt.bin');
  tvsData = loadBytes('/uploads/sagarr1/television.bin');
}


function setup() {
  createCanvas(500, 500);
  AB.msg(`<div> <button id="train">Model Train</button>
	<button id="test">Model Test</button>
	<button id="guess">Predict</button>
	<button id="clear">Clear Doodle</button> </div>
	<br> <div> Put your doodle to the left </div>
	<br> <div> Draw: Apple, Bread, Cactus, Cat, Donut, Television, Train, T-Shirt, Guitar, Rainbow </div>
	
	<br><div id = "epoch"></div>
	<div id = "percent"></div>
	<br><div id = "output"></div>`);
  background(255);
    // background('grey');
    
// Voice //    
//fullscreen(true);
  createCanvas(displayWidth, displayHeight);
  strokeWeight(4)
  background(255)
  SpeechRec.start()
  SpeechRec.onResult = showResult
  SpeechRec.onStart = onStart
// Voice

  // Preparing the data
  prepareData(cats, catsData, CAT);
  prepareData(rainbows, rainbowsData, RAINBOW);
  prepareData(trains, trainsData, TRAIN);
  prepareData(apples, applesData, APPLE);
  prepareData(breads, breadsData, BREAD);
  prepareData(donuts, donutsData, DONUT);
  prepareData(cactuss, cactussData, CACTUS);
  prepareData(guitars, guitarsData, GUITAR);
  prepareData(tshirts, tshirtsData, TSHIRT);
  prepareData(tvs, tvsData, TV);

  // Making the neural network
  nn = new NeuralNetwork(784, 100, 10);

  // Randomizing the data
  let training = [];
  training = training.concat(cats.training);
  training = training.concat(rainbows.training);
  training = training.concat(trains.training);
  training = training.concat(apples.training);
  training = training.concat(breads.training);
  training = training.concat(donuts.training);
  training = training.concat(cactuss.training);
  training = training.concat(guitars.training);
  training = training.concat(tshirts.training);
  training = training.concat(tvs.training);

  let testing = [];
  testing = testing.concat(cats.testing);
  testing = testing.concat(rainbows.testing);
  testing = testing.concat(trains.testing);
  testing = testing.concat(apples.testing);
  testing = testing.concat(breads.testing);
  testing = testing.concat(donuts.testing);
  testing = testing.concat(cactuss.testing);
  testing = testing.concat(guitars.testing);
  testing = testing.concat(tshirts.testing);
  testing = testing.concat(tvs.testing);

// let id = document.getElementById("id");
// $(id).change(function (event) {
// });

  let trainButton = document.getElementById("train");
  let epochCounter = 0;
  $(trainButton).click(function (event) {
    trainEpoch(training);
    epochCounter++;
    console.log("Epoch: " + epochCounter);
    $('#epoch').text("Train Epoch: " + epochCounter);
  });

  let testButton = document.getElementById("test");
  $(testButton).click(function (event){
    let percent = testAll(testing);
    console.log("Percent: " + nf(percent, 2, 2) + "%");
    $('#percent').text("Test Accuracy: " + nf(percent, 2, 2) + "%" );
  });

  let guessButton = document.getElementById("guess");
  $(guessButton).click(function (event) {
    let inputs = [];
    let img = get();
    img.resize(28, 28);
    img.loadPixels();
    for (let i = 0; i < len; i++) {
      let bright = img.pixels[i * 4];
      inputs[i] = (255 - bright) / 255.0;
    }

    let guess = nn.predict(inputs);
    // console.log(guess);
    let m = max(guess);
    let classification = guess.indexOf(m);
    if (classification === CAT) {
      console.log("cat");
      $('#output').text('Output: Cat');
    } else if (classification === RAINBOW) {
      $('#output').text('Output: Rainbow');
      console.log("rainbow");
    } else if (classification === TRAIN) {
      console.log("train");
      $('#output').text('Output: Train');
    } else if (classification === APPLE) {
      console.log("apple");
      $('#output').text('Output: Apple');
    } else if (classification === BREAD) {
      console.log("bread");
      $('#output').text('Output: Bread');
    } else if (classification === DONUT) {
      console.log("donut");
      $('#output').text('Output: Donut');
    } else if (classification === CACTUS) {
      console.log("cactus");
      $('#output').text('Output: Cactus');
    } else if (classification === GUITAR) {
      console.log("guitar");
      $('#output').text('Output: Guitar');
    } else if (classification === TSHIRT) {
      console.log("tshirt");
      $('#output').text('Output: TShirt');
    } else if (classification === TV) {
      console.log("television");
      $('#output').text('Output: Television');
    }
    

    //image(img, 0, 0);
  });

  let clearButton = document.getElementById("clear");
  $(clearButton).click(function (event) {
    background(255);
    // background('grey');
  });
//   for (let i = 1; i < 6; i++) {
//      trainEpoch(training);
//      console.log("Epoch: " + i);
//      let percent = testAll(testing);
//      console.log("% Correct: " + percent);
//   }
}



function draw() {
  strokeWeight(8);
  stroke(0);
// Voice
  point(width/3 + Xpos, height/3 + Ypos);
  if(Stop == false){
    Xpos = Xpos + Xv;
    Ypos = Ypos + Yv}
// Voice
  if (mouseIsPressed) {
    line(pmouseX, pmouseY, mouseX, mouseY);
  }
}

//Voice
function onStart(){
  pword = SpeechRec.resultString;
}

function showResult(){
  if(SpeechRec.resultConfidence < 0.009 || pword == SpeechRec.resultString){
    return 0;
  }
  pword = SpeechRec.resultString;
  switch(SpeechRec.resultString){
    case "Stop": case "stop":
      Stop = true;
      Xv = 0;
      Yv = 0;
      break;
    case "start":
      Stop = false;
      break;
    case "clear":
      background(255);
      break;
    case "restart":
      Xv = 0;
      Xpos = 0;
      Yv = 0;
      Ypos = 0;
      Stop = false;
      background(255);
      break;
    case "up":
      Yv = -1;
      Xv = 0;
      break;
    case "down":
      Yv = 1;
      Xv = 0;
      break;
    case "left":
      Yv = 0;
      Xv = -1;
      break;
    case "right":
      Yv = 0;
      Xv = 1;
      break;
  }
}
// Voice

p5.prototype.registerPreloadMethod('loadBytes');

p5.prototype.loadBytes = function(file, callback) {
  var self = this;
  var data = {};
  var oReq = new XMLHttpRequest();
  oReq.open("GET", file, true);
  oReq.responseType = "arraybuffer";
  oReq.onload = function(oEvent) {
    var arrayBuffer = oReq.response;
    if (arrayBuffer) {
      data.bytes = new Uint8Array(arrayBuffer);
      if (callback) {
        callback(data);
      }
      self._decrementPreload();
    }
  }
  oReq.send(null);
  return data;
}

function trainEpoch(training) {
  shuffle(training, true);
  //console.log(training);
  // Train for one epoch
  for (let i = 0; i < training.length; i++) {
    let data = training[i];
    let inputs = Array.from(data).map(x => x / 255);
    let label = training[i].label;
    let targets = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
    targets[label] = 1;
    // console.log(inputs);
    // console.log(targets);
    nn.train(inputs, targets);
  }
}

function testAll(testing) {

  let correct = 0;
  // Train for one epoch
  for (let i = 0; i < testing.length; i++) {
    // for (let i = 0; i < 1; i++) {
    let data = testing[i];
    let inputs = Array.from(data).map(x => x / 255);
    let label = testing[i].label;
    let guess = nn.predict(inputs);

    let m = max(guess);
    let classification = guess.indexOf(m);
    console.log(guess);
    console.log(classification);
    console.log(label);

    if (classification === label) {
      correct++;
    }
  }
  let percent = 100 * correct / testing.length;
  return percent;

}

function prepareData(category, data, label) {
  category.training = [];
  category.testing = [];
  for (let i = 0; i < totalData; i++) {
    let offset = i * len;
    let threshold = floor(0.8 * totalData);
    if (i < threshold) {
      category.training[i] = data.bytes.subarray(offset, offset + len);
      category.training[i].label = label;
    } else {
      category.testing[i - threshold] = data.bytes.subarray(offset, offset + len);
      category.testing[i - threshold].label = label;
    }
  }
}

// Other techniques for learning

class ActivationFunction {
  constructor(func, dfunc) {
    this.func = func;
    this.dfunc = dfunc;
  }
}

let sigmoid = new ActivationFunction(
  x => 1 / (1 + Math.exp(-x)),
  y => y * (1 - y)
);

let tanh = new ActivationFunction(
  x => Math.tanh(x),
  y => 1 - (y * y)
);


class NeuralNetwork {
  /*
  * if first argument is a NeuralNetwork the constructor clones it
  * USAGE: cloned_nn = new NeuralNetwork(to_clone_nn);
  */
  constructor(in_nodes, hid_nodes, out_nodes) {
    if (in_nodes instanceof NeuralNetwork) {
      let a = in_nodes;
      this.input_nodes = a.input_nodes;
      this.hidden_nodes = a.hidden_nodes;
      this.output_nodes = a.output_nodes;

      this.weights_ih = a.weights_ih.copy();
      this.weights_ho = a.weights_ho.copy();

      this.bias_h = a.bias_h.copy();
      this.bias_o = a.bias_o.copy();
    } else {
      this.input_nodes = in_nodes;
      this.hidden_nodes = hid_nodes;
      this.output_nodes = out_nodes;

      this.weights_ih = new Matrix(this.hidden_nodes, this.input_nodes);
      this.weights_ho = new Matrix(this.output_nodes, this.hidden_nodes);
      this.weights_ih.randomize();
      this.weights_ho.randomize();

      this.bias_h = new Matrix(this.hidden_nodes, 1);
      this.bias_o = new Matrix(this.output_nodes, 1);
      this.bias_h.randomize();
      this.bias_o.randomize();
    }

    // TODO: copy these as well
    this.setLearningRate();
    this.setActivationFunction();


  }

  predict(input_array) {

    // Generating the Hidden Outputs
    let inputs = Matrix.fromArray(input_array);
    let hidden = Matrix.multiply(this.weights_ih, inputs);
    hidden.add(this.bias_h);
    // activation function!
    hidden.map(this.activation_function.func);

    // Generating the output's output!
    let output = Matrix.multiply(this.weights_ho, hidden);
    output.add(this.bias_o);
    output.map(this.activation_function.func);

    // Sending back to the caller!
    return output.toArray();
  }

  setLearningRate(learning_rate = 0.1) {
    this.learning_rate = learning_rate;
  }

  setActivationFunction(func = sigmoid) {
    this.activation_function = func;
  }

  train(input_array, target_array) {
    // Generating the Hidden Outputs
    let inputs = Matrix.fromArray(input_array);
    let hidden = Matrix.multiply(this.weights_ih, inputs);
    hidden.add(this.bias_h);
    // activation function!
    hidden.map(this.activation_function.func);

    // Generating the output's output!
    let outputs = Matrix.multiply(this.weights_ho, hidden);
    outputs.add(this.bias_o);
    outputs.map(this.activation_function.func);

    // Convert array to matrix object
    let targets = Matrix.fromArray(target_array);

    // Calculate the error
    // ERROR = TARGETS - OUTPUTS
    let output_errors = Matrix.subtract(targets, outputs);

    // let gradient = outputs * (1 - outputs);
    // Calculate gradient
    let gradients = Matrix.map(outputs, this.activation_function.dfunc);
    gradients.multiply(output_errors);
    gradients.multiply(this.learning_rate);


    // Calculate deltas
    let hidden_T = Matrix.transpose(hidden);
    let weight_ho_deltas = Matrix.multiply(gradients, hidden_T);

    // Adjust the weights by deltas
    this.weights_ho.add(weight_ho_deltas);
    // Adjust the bias by its deltas (which is just the gradients)
    this.bias_o.add(gradients);

    // Calculate the hidden layer errors
    let who_t = Matrix.transpose(this.weights_ho);
    let hidden_errors = Matrix.multiply(who_t, output_errors);

    // Calculate hidden gradient
    let hidden_gradient = Matrix.map(hidden, this.activation_function.dfunc);
    hidden_gradient.multiply(hidden_errors);
    hidden_gradient.multiply(this.learning_rate);

    // Calcuate input->hidden deltas
    let inputs_T = Matrix.transpose(inputs);
    let weight_ih_deltas = Matrix.multiply(hidden_gradient, inputs_T);

    this.weights_ih.add(weight_ih_deltas);
    // Adjust the bias by its deltas (which is just the gradients)
    this.bias_h.add(hidden_gradient);

    // outputs.print();
    // targets.print();
    // error.print();
  }

  serialize() {
    return JSON.stringify(this);
  }

  static deserialize(data) {
    if (typeof data == 'string') {
      data = JSON.parse(data);
    }
    let nn = new NeuralNetwork(data.input_nodes, data.hidden_nodes, data.output_nodes);
    nn.weights_ih = Matrix.deserialize(data.weights_ih);
    nn.weights_ho = Matrix.deserialize(data.weights_ho);
    nn.bias_h = Matrix.deserialize(data.bias_h);
    nn.bias_o = Matrix.deserialize(data.bias_o);
    nn.learning_rate = data.learning_rate;
    return nn;
  }


  // Adding function for neuro-evolution
  copy() {
    return new NeuralNetwork(this);
  }

  // Accept an arbitrary function for mutation
  mutate(func) {
    this.weights_ih.map(func);
    this.weights_ho.map(func);
    this.bias_h.map(func);
    this.bias_o.map(func);
  }



}

// let m = new Matrix(3,2);


class Matrix {
  constructor(rows, cols) {
    this.rows = rows;
    this.cols = cols;
    this.data = Array(this.rows).fill().map(() => Array(this.cols).fill(0));
  }

  copy() {
    let m = new Matrix(this.rows, this.cols);
    for (let i = 0; i < this.rows; i++) {
      for (let j = 0; j < this.cols; j++) {
        m.data[i][j] = this.data[i][j];
      }
    }
    return m;
  }

  static fromArray(arr) {
    return new Matrix(arr.length, 1).map((e, i) => arr[i]);
  }

  static subtract(a, b) {
    if (a.rows !== b.rows || a.cols !== b.cols) {
      console.log('Columns and Rows of A must match Columns and Rows of B.');
      return;
    }

    // Return a new Matrix a-b
    return new Matrix(a.rows, a.cols)
      .map((_, i, j) => a.data[i][j] - b.data[i][j]);
  }

  toArray() {
    let arr = [];
    for (let i = 0; i < this.rows; i++) {
      for (let j = 0; j < this.cols; j++) {
        arr.push(this.data[i][j]);
      }
    }
    return arr;
  }

  randomize() {
    return this.map(e => Math.random() * 2 - 1);
  }

  add(n) {
    if (n instanceof Matrix) {
      if (this.rows !== n.rows || this.cols !== n.cols) {
        console.log('Columns and Rows of A must match Columns and Rows of B.');
        return;
      }
      return this.map((e, i, j) => e + n.data[i][j]);
    } else {
      return this.map(e => e + n);
    }
  }

  static transpose(matrix) {
    return new Matrix(matrix.cols, matrix.rows)
      .map((_, i, j) => matrix.data[j][i]);
  }

  static multiply(a, b) {
    // Matrix product
    if (a.cols !== b.rows) {
      console.log('Columns of A must match rows of B.');
      return;
    }

    return new Matrix(a.rows, b.cols)
      .map((e, i, j) => {
        // Dot product of values in col
        let sum = 0;
        for (let k = 0; k < a.cols; k++) {
          sum += a.data[i][k] * b.data[k][j];
        }
        return sum;
      });
  }

  multiply(n) {
    if (n instanceof Matrix) {
      if (this.rows !== n.rows || this.cols !== n.cols) {
        console.log('Columns and Rows of A must match Columns and Rows of B.');
        return;
      }

      // hadamard product
      return this.map((e, i, j) => e * n.data[i][j]);
    } else {
      // Scalar product
      return this.map(e => e * n);
    }
  }

  map(func) {
    // Apply a function to every element of matrix
    for (let i = 0; i < this.rows; i++) {
      for (let j = 0; j < this.cols; j++) {
        let val = this.data[i][j];
        this.data[i][j] = func(val, i, j);
      }
    }
    return this;
  }

  static map(matrix, func) {
    // Apply a function to every element of matrix
    return new Matrix(matrix.rows, matrix.cols)
      .map((e, i, j) => func(matrix.data[i][j], i, j));
  }

  print() {
    console.table(this.data);
    return this;
  }

  serialize() {
    return JSON.stringify(this);
  }

  static deserialize(data) {
    if (typeof data == 'string') {
      data = JSON.parse(data);
    }
    let matrix = new Matrix(data.rows, data.cols);
    matrix.data = data.data;
    return matrix;
  }
}

if (typeof module !== 'undefined') {
  module.exports = Matrix;
}