Code viewer for World: Character recognition by J...
/* jorge blanco
   student number 19214700
   
   using sigmoid or tanh in the hidden and output layer with momentum and decay 
   
   noticed with relu/lrelu which is better that sigmoid/tangh get stuck in low % 
   
   the hidden layer has 2 nodes (played with 100 and eveything was supper slow) - same parameters, tried to use max pool but results not as expected (surprised), in fact after some executions 
   the % seem to go down.
   
   i added the ability to change the activation function for the hidden and output layer
   
   The reason i added momentum to get over the local max and decay to help avoid that infinity issue with 
   some activation functions, as was looking for a stable value after 'n' number of runs 
   
   After each run i print the result
   
   I have not done the 10 % rotation of th einput image to the left or right as more sample data 
   
   tanh (hidden + output)
    testrun: 1 no: 8445 correct: 7552 score: 89.43% decay: 0.01 momentum: 0.95
    testrun: 2 no: 8445 correct: 7853 score: 92.99% decay: 0.01 momentum: 0.95
    testrun: 3 no: 8445 correct: 7936 score: 93.97% decay: 0.01 momentum: 0.95
    testrun: 4 no: 8445 correct: 7982 score: 94.52% decay: 0.01 momentum: 0.95
    testrun: 5 no: 8445 correct: 8002 score: 94.75% decay: 0.01 momentum: 0.95
    testrun: 6 no: 8445 correct: 8015 score: 94.91% decay: 0.01 momentum: 0.95
    testrun: 7 no: 8445 correct: 8029 score: 95.07% decay: 0.01 momentum: 0.95
    testrun: 8 no: 8445 correct: 8040 score: 95.20% decay: 0.01 momentum: 0.95
    testrun: 9 no: 8445 correct: 8046 score: 95.28% decay: 0.01 momentum: 0.95
    
    sign (hidden + ouput)
    testrun: 1 no: 7815 correct: 7084 score: 90.65% decay: 0.01 momentum: 0.95
    testrun: 2 no: 7815 correct: 7316 score: 93.61% decay: 0.01 momentum: 0.95
    testrun: 3 no: 7815 correct: 7402 score: 94.72% decay: 0.01 momentum: 0.95
    testrun: 4 no: 7815 correct: 7447 score: 95.29% decay: 0.01 momentum: 0.95
    testrun: 5 no: 7815 correct: 7473 score: 95.62% decay: 0.01 momentum: 0.95
    
    sig (hidden) + relu (output)
    testrun: 1 no: 5145 correct: 3623 score: 70.42% decay: 0.01 momentum: 0.95
    testrun: 2 no: 5145 correct: 3739 score: 72.67% decay: 0.01 momentum: 0.95
    testrun: 3 no: 5145 correct: 3793 score: 73.72% decay: 0.01 momentum: 0.95
    testrun: 4 no: 5145 correct: 3811 score: 74.07% decay: 0.01 momentum: 0.95
    testrun: 5 no: 5145 correct: 3836 score: 74.56% decay: 0.01 momentum: 0.95
    testrun: 6 no: 5145 correct: 3845 score: 74.73% decay: 0.01 momentum: 0.95
    testrun: 7 no: 5145 correct: 3846 score: 74.75% decay: 0.01 momentum: 0.95
    
    relu + tanh
    testrun: 1 no: 1475 correct: 125 score: 8.47% decay: 0.01 momentum: 0.95
    testrun: 2 no: 1475 correct: 125 score: 8.47% decay: 0.01 momentum: 0.95
   
*/

// Port of Character recognition neural network from here:
// https://github.com/CodingTrain/Toy-Neural-Network-JS/tree/master/examples/mnist
// with many modifications 


// --- defined by MNIST - do not change these ---------------------------------------

const PIXELS        = 28;                       // images in data set are tiny 
const PIXELSSQUARED = PIXELS * PIXELS;

// number of training and test exemplars in the data set:
const NOTRAIN = 60000;
const NOTEST  = 10000;



//--- can modify all these --------------------------------------------------

// no of nodes in network 
const noinput  = PIXELSSQUARED;
const nohidden = 64;
const nooutput = 10;

/*my code */
const learningrate = 0.01;   // default 0.1  
let decay = 0.1;
let momentum = 0.9;
let result = [];
let hiddenLayersNum = 2;
let maxPool = false;

// should we train every timestep or not 
let do_training = true;

// how many to train and test per timestep 
const TRAINPERSTEP = 30;
const TESTPERSTEP  = 5;

// multiply it by this to magnify for display 
const ZOOMFACTOR    = 7;                        
const ZOOMPIXELS    = ZOOMFACTOR * PIXELS; 

// 3 rows of
// large image + 50 gap + small image    
// 50 gap between rows 

const canvaswidth = ( PIXELS + ZOOMPIXELS ) + 50;
const canvasheight = ( ZOOMPIXELS * 3 ) + 100;


const DOODLE_THICK = 18;    // thickness of doodle lines 
const DOODLE_BLUR = 3;      // blur factor applied to doodles 


let mnist;      
// all data is loaded into this 
// mnist.train_images
// mnist.train_labels
// mnist.test_images
// mnist.test_labels


let nn;

let trainrun = 1;
let train_index = 0;

let testrun = 1;
let test_index = 0;
let total_tests = 0;
let total_correct = 0;

// images in LHS:
let doodle, demo;
let doodle_exists = false;
let demo_exists = false;

let mousedrag = false;      // are we in the middle of a mouse drag drawing?  


// save inputs to global var to inspect
// type these names in console 
var train_inputs, test_inputs, demo_inputs, doodle_inputs;


// Matrix.randomize() is changed to point to this. Must be defined by user of Matrix. 

function randomWeight()
{
    return ( AB.randomFloatAtoB ( -0.5, 0.5 ) ); // Coding Train default is -1 to 1
}    

// CSS trick 
// make run header bigger 
 $("#runheaderbox").css ( { "max-height": "95vh" } );


//--- start of AB.msgs structure: ---------------------------------------------------------
// We output a serious of AB.msgs to put data at various places in the run header 
var thehtml;

  // 1 Doodle header 
  thehtml = "<hr> <h1> 1. Doodle </h1> Top row: Doodle (left) and shrunk (right). <br> " +
        " Draw your doodle in top LHS. <button onclick='wipeDoodle();' class='normbutton' >Clear doodle</button> <br> ";
   AB.msg ( thehtml, 1 );

  // 2 Doodle variable data (guess)
  
  // 3 Training header
  thehtml = "<hr> <h1> 2. Training </h1> " +
            "<span>('restart' to be used for the activation function)</span> <br>" + 
            " <button onclick='do_training = false;' class='normbutton' >Stop</button>" +
            " <button onclick='do_training = true;' class='normbutton' >Continue</button>" +
            " <button onclick='clearResultsPrint()' class='normbutton'>Clear Result Log</button> <br> " +
            " <button style='background:red' onclick='reload()' class='normbutton' >Restart</button> <br> " +
            "<span id='actorlist'><input type='checkbox' value='true' onclick='useMaxPool(this);'>Use max POOL 2X2<br> </span>" + 
            "Num Hidden Layer (disabled): <span id='lNum'><select disabled onchange='setNumLayer(this)'><option value='2'>2</option><option value='10'>10</option><option value='50'>50</option><option value='100'>100</option></select></span> <br>" + 
            "Hidden layer: <span id='hLayer'><select onchange='setHLayer(this)'><option value='sigmoid'>sigmoid</option><option value='relu'>relu</option><option value='lrelu'>lrelu</option><option value='tanh'>tanh</option></select></span> <br>" + 
            "Output layer: <span id='oLayer'><select onchange='setOLayer(this)'><option value='sigmoid'>sigmoid</option><option value='relu'>relu</option><option value='lrelu'>lrelu</option><option value='tanh'>tanh</option></select></span> <br>" + 
            "Set decay: <span id='actorlist'><select onchange='setDecay(this)'><option value='0.0'>0.0</option><option value='0.001'>0.001</option><option value='0.0001'>0.0001</option><option selected='selected' value='0.01'>0.01</option></select></span><br>" + 
            "Set momentum: <span id='actorlist'><select onchange='setSpeed(this)'><option selected='selected' value='0.9'>0.9</option><option value='0.95'>0.95</option><option value='0.7'>0.7</option><option value='0.5'>0.5</option><option value='0.2'>0.2</option><option value='0.0'>0.0</option></select></span><br>";
  AB.msg ( thehtml, 3 );
  
     
  // 4 variable training data 
  
  // 5 Testing header
  thehtml = "<h3> Hidden tests </h3> " ;
  AB.msg ( thehtml, 5 );
           
  // 6 variable testing data 
  
  // 7 Demo header 
  thehtml = "<hr> <h1> 3. Demo </h1> Bottom row: Test image magnified (left) and  original (right). <br>" +
        " The network is <i>not</i> trained on any of these images. <br> " +
        " <button onclick='makeDemo();' class='normbutton' >Demo test image</button> <br> ";
   AB.msg ( thehtml, 7 );
   
  // 8 Demo variable data (random demo ID)
  // 9 Demo variable data (changing guess)
  
    AB.newDiv("statsDiv");
    
    $("#statsDiv").css({
      "padding-left": "50vw",
       "color": "black"
      });
  
const greenspan = "<span style='font-weight:bold; font-size:x-large; color:darkgreen'> "  ;

function setSpeed(option) {
    momentum = option.value * 1;
}

function setDecay(option) {
    decay = option.value * 1;
}

function setNumLayer(option) {
    hiddenLayersNum = option.value * 1;
}

function useMaxPool(option) {
    maxPool = !maxPool;
}

    
//--- end of AB.msgs structure: ---------------------------------------------------------

function setup() 
{
  createCanvas ( canvaswidth, canvasheight );

  doodle = createGraphics ( ZOOMPIXELS, ZOOMPIXELS );       // doodle on larger canvas 
  doodle.pixelDensity(1);
  
  AB.loadingScreen();
 
  console.log ("All JS loaded");
  nn = new NeuralNetwork(  noinput, nohidden, nooutput );
  nn.setLearningRate ( learningrate );
  loadData();

}

function reload() {
  do_training = false;
  
  nn = new NeuralNetwork(  noinput, nohidden, nooutput );
  nn.setLearningRate ( learningrate );
  
  printStats();
  
  trainrun = 0;
  train_index = 0;
  
  testrun = 0;
  total_tests = 0;
  total_correct = 0;
  
  do_training = true;
}

function printStats() {
    
   let percent = (total_correct / total_tests) * 100 ;
   var msg  = "testrun: " + testrun + " no: " + total_tests + " correct: " + total_correct + "  score: " + percent.toFixed(2) + "% decay: " + decay + " momentum: "+ momentum;
   result.push(msg)
  
    var i, text = '';
    for (i = 0; i < result.length; i++) {
      text += result[i] + "<br>";
    }

  $("#statsDiv").html(text);
}

function clearResultsPrint() {
    result = [];
    $("#statsDiv").html('');
}

// load data set from local file (on this server)
function loadData()    
{
  loadMNIST ( function(data)    
  {
    mnist = data;
    console.log ("All data loaded into mnist object:")
    console.log(mnist);
    AB.removeLoading();     // if no loading screen exists, this does nothing 
  });
}


function getImage ( img )      // make a P5 image object from a raw data array   
{
    let theimage  = createImage (PIXELS, PIXELS);    // make blank image, then populate it
    
    theimage.loadPixels();        
    for (let i = 0; i < PIXELSSQUARED ; i++) 
    {
        let bright = img[i];
        let index = i * 4;
        theimage.pixels[index + 0] = bright;
        theimage.pixels[index + 1] = bright;
        theimage.pixels[index + 2] = bright;
        theimage.pixels[index + 3] = 255;
    }
    theimage.updatePixels();
    
    return theimage;
}

function getInputs ( img )      // convert img array into normalised input array 
{
    let inputs = [];
    for (let i = 0; i < PIXELSSQUARED ; i++)          
    {
        let bright = img[i];
        inputs[i] = bright / 255;       // normalise to 0 to 1
    } 
    return ( inputs );
}

function trainit (show)        // train the network with a single exemplar, from global var "train_index", show visual on or off 
{
  let img   = mnist.train_images[train_index];
  let label = mnist.train_labels[train_index];
  
  // optional - show visual of the image 
  if (show)                
  {
    var theimage = getImage ( img );    // get image from data array 
    
    image ( theimage,   0,                ZOOMPIXELS+50,    ZOOMPIXELS,     ZOOMPIXELS  );      // magnified 
    image ( theimage,   ZOOMPIXELS+50,    ZOOMPIXELS+50,    PIXELS,         PIXELS      );      // original
  }

  // set up the inputs
  let inputs = getInputs( img );       // get inputs from data array 

  // set up the outputs
  let targets = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
  targets[label] = 1;       // change one output location to 1, the rest stay at 0 

  // console.log(train_index);
  // console.log(inputs);
  // console.log(targets);

  train_inputs = inputs;        // can inspect in console 
  nn.train ( inputs, targets, maxPool );

  thehtml = " trainrun: " + trainrun + " / " + train_index ;
  AB.msg ( thehtml, 4 );

  train_index++;
  if ( train_index == NOTRAIN ) 
  {
    train_index = 0;
    console.log( "finished trainrun: " + trainrun );
    printStats();
    trainrun++;
  }
}


function testit()    // test the network with a single exemplar, from global var "test_index"
{ 
  let img   = mnist.test_images[test_index];
  let label = mnist.test_labels[test_index];

  // set up the inputs
  let inputs = getInputs ( img ); 
  
  test_inputs = inputs;        // can inspect in console 
  let prediction    = nn.predict(inputs);       // array of outputs 
  let guess         = findMax(prediction);      // the top output 

  total_tests++;
  if (guess == label)  total_correct++;

  let percent = (total_correct / total_tests) * 100 ;
  
  thehtml =  " testrun: " + testrun + " / " + total_tests + " correct: " + total_correct + " = " + greenspan + percent.toFixed(2) + "%</span>";
  AB.msg ( thehtml, 6 );

  test_index++;
  if ( test_index == NOTEST ) 
  {
    console.log( "finished testrun: " + testrun + " score: " + percent.toFixed(2) );
    testrun++;
    test_index = 0;
    total_tests = 0;
    total_correct = 0;
  }
}




//--- find no.1 (and maybe no.2) output nodes ---------------------------------------
// (restriction) assumes array values start at 0 (which is true for output nodes) 


function find12 (a)         // return array showing indexes of no.1 and no.2 values in array 
{
  let no1 = 0;
  let no2 = 0;
  let no1value = 0;     
  let no2value = 0;
  
  for (let i = 0; i < a.length; i++) 
  {
    if (a[i] > no1value) 
    {
      no1 = i;
      no1value = a[i];
    }
    else if (a[i] > no2value) 
    {
      no2 = i;
      no2value = a[i];
    }
  }
  
  var b = [ no1, no2 ];
  return b;
}


// just get the maximum - separate function for speed - done many times 
// find our guess - the max of the output nodes array
// this gets the index 
function findMax (a)        
{
  let no1 = 0;
  let no1value = 0;     
  
  for (let i = 0; i < a.length; i++) 
  {
    if (a[i] > no1value) 
    {
      no1 = i;
      no1value = a[i];
    }
  }
  
  return no1;
}


// --- the draw function -------------------------------------------------------------
// every step:
 
function draw() 
{
  // check if libraries and data loaded yet:
  if ( typeof mnist == 'undefined' ) return;

  background ('black');
    
if ( do_training )    
{
  // do some training per step 
    for (let i = 0; i < TRAINPERSTEP; i++) 
    {
      if (i == 0)    trainit(true);    // show only one per step - still flashes by  
      else           trainit(false);
    }
    
  // do some testing per step 
    for (let i = 0; i < TESTPERSTEP; i++) 
      testit();
}

  // keep drawing demo and doodle images 
  // and keep guessing - we will update our guess as time goes on 
  
  if ( demo_exists )
  {
    drawDemo();
    guessDemo();
  }
  if ( doodle_exists ) 
  {
    drawDoodle();
    guessDoodle();
  }


// detect doodle drawing 
// (restriction) the following assumes doodle starts at 0,0 

  if ( mouseIsPressed )         // gets called when we click buttons, as well as if in doodle corner  
  {
     // console.log ( mouseX + " " + mouseY + " " + pmouseX + " " + pmouseY );
     var MAX = ZOOMPIXELS + 20;     // can draw up to this pixels in corner 
     if ( (mouseX < MAX) && (mouseY < MAX) && (pmouseX < MAX) && (pmouseY < MAX) )
     {
        mousedrag = true;       // start a mouse drag 
        doodle_exists = true;
        doodle.stroke('white');
        doodle.strokeWeight( DOODLE_THICK );
        doodle.line(mouseX, mouseY, pmouseX, pmouseY);      
     }
  }
  else 
  {
      // are we exiting a drawing
      if ( mousedrag )
      {
            mousedrag = false;
            // console.log ("Exiting draw. Now blurring.");
            doodle.filter (BLUR, DOODLE_BLUR);    // just blur once 
            //   console.log (doodle);
      }
  }
}




//--- demo -------------------------------------------------------------
// demo some test image and predict it
// get it from test set so have not used it in training


function makeDemo()
{
    demo_exists = true;
    var  i = AB.randomIntAtoB ( 0, NOTEST - 1 );  
    
    demo        = mnist.test_images[i];     
    var label   = mnist.test_labels[i];
    
   thehtml =  "Test image no: " + i + "<br>" + "Classification: " + label + "<br>" ;
   AB.msg ( thehtml, 8 );
}


function drawDemo()
{
    var theimage = getImage ( demo );
     //  console.log (theimage);
     
    image ( theimage,   0,                canvasheight - ZOOMPIXELS,    ZOOMPIXELS,     ZOOMPIXELS  );      // magnified 
    image ( theimage,   ZOOMPIXELS+50,    canvasheight - ZOOMPIXELS,    PIXELS,         PIXELS      );      // original
}


function guessDemo()
{
   let inputs = getInputs ( demo ); 
   
  demo_inputs = inputs;  // can inspect in console 
  
  let prediction    = nn.predict(inputs);       // array of outputs 
  let guess         = findMax(prediction);      // the top output 

   thehtml =   " We classify it as: " + greenspan + guess + "</span>" ;
   AB.msg ( thehtml, 9 );
}




//--- doodle -------------------------------------------------------------

function drawDoodle()
{
    // doodle is createGraphics not createImage
    let theimage = doodle.get();
    // console.log (theimage);
    
    image ( theimage,   0,                0,    ZOOMPIXELS,     ZOOMPIXELS  );      // original 
    image ( theimage,   ZOOMPIXELS+50,    0,    PIXELS,         PIXELS      );      // shrunk
}
      
      
function guessDoodle() 
{
   // doodle is createGraphics not createImage
   let img = doodle.get();
  
  img.resize ( PIXELS, PIXELS );     
  img.loadPixels();

  // set up inputs   
  let inputs = [];
  for (let i = 0; i < PIXELSSQUARED ; i++) 
  {
     inputs[i] = img.pixels[i * 4] / 255;
  }
  
  doodle_inputs = inputs;     // can inspect in console 

  // feed forward to make prediction 
  let prediction    = nn.predict(inputs);       // array of outputs 
  let b             = find12(prediction);       // get no.1 and no.2 guesses  

  thehtml = " We classify it as: " + greenspan + b[0] + "</span> <br>" + " No.2 guess is: " + greenspan + b[1] + "</span>";
  AB.msg ( thehtml, 2 );
}


function wipeDoodle()    
{
    doodle_exists = false;
    doodle.background('black');
}


// --- debugging --------------------------------------------------
// in console
// showInputs(demo_inputs);
// showInputs(doodle_inputs);


function showInputs ( inputs )
// display inputs row by row, corresponding to square of pixels 
{
    var str = "";
    for (let i = 0; i < inputs.length; i++) 
    {
      if ( i % PIXELS == 0 )    str = str + "\n";                                   // new line for each row of pixels 
      var value = inputs[i];
      str = str + " " + value.toFixed(2) ; 
    }
    console.log (str);
}

// ------------------------------   MATRIX CLASS ------- 

class Matrix {
  constructor(rows, cols) {
    this.rows = rows;
    this.cols = cols;
    this.data = Array(this.rows).fill().map(() => Array(this.cols).fill(0));
  }

  copy() {
    let m = new Matrix(this.rows, this.cols);
    for (let i = 0; i < this.rows; i++) {
      for (let j = 0; j < this.cols; j++) {
        m.data[i][j] = this.data[i][j];
      }
    }
    return m;
  }

  static fromArray(arr) {
    return new Matrix(arr.length, 1).map((e, i) => arr[i]);
  }

  static subtract(a, b) {
    if (a.rows !== b.rows || a.cols !== b.cols) {
      console.log('Columns and Rows of A must match Columns and Rows of B.');
      return;
    }

    // Return a new Matrix a-b
    return new Matrix(a.rows, a.cols)
      .map((_, i, j) => a.data[i][j] - b.data[i][j]);
  }

  toArray() {
    let arr = [];
    for (let i = 0; i < this.rows; i++) {
      for (let j = 0; j < this.cols; j++) {
        arr.push(this.data[i][j]);
      }
    }
    return arr;
  }

  randomize() {
    return this.map ( e => randomWeight() );       
  }
  
  initValue(value) {
    return this.map ( e => value );       
  }

  add(n) {
    if (n instanceof Matrix) {
      if (this.rows !== n.rows || this.cols !== n.cols) {
        console.log('Columns and Rows of A must match Columns and Rows of B.');
        return;
      }
      return this.map((e, i, j) => e + n.data[i][j]);
    } else {
      return this.map(e => e + n);
    }
  }

  static transpose(matrix) {
    return new Matrix(matrix.cols, matrix.rows).map((_, i, j) => matrix.data[j][i]);
  }

  // Matrix product
  static multiply(a, b) {
    
    if (a.cols !== b.rows) {
      console.log('Columns of A must match rows of B.');
      return;
    }

    return new Matrix(a.rows, b.cols)
      .map((e, i, j) => {
        // Dot product of values in col
        let sum = 0;
        for (let k = 0; k < a.cols; k++) {
          sum += a.data[i][k] * b.data[k][j];
        }
        return sum;
      });
  }

  multiply(n) {
    if (n instanceof Matrix) {
      if (this.rows !== n.rows || this.cols !== n.cols) {
        console.log('Columns and Rows of A must match Columns and Rows of B.');
        return;
      }

      // hadamard product
      return this.map((e, i, j) => e * n.data[i][j]);
    } else {
      // Scalar product
      return this.map(e => e * n);
    }
  }

  map(func) {
    // Apply a function to every element of matrix
    for (let i = 0; i < this.rows; i++) {
      for (let j = 0; j < this.cols; j++) {
        let val = this.data[i][j];
        this.data[i][j] = func(val, i, j);
      }
    }
    return this;
  }

  static map(matrix, func) {
    // Apply a function to every element of matrix
    return new Matrix(matrix.rows, matrix.cols)
      .map((e, i, j) => func(matrix.data[i][j], i, j));
  }

  print() {
    console.table(this.data);
    return this;
  }

  serialize() {
    return JSON.stringify(this);
  }

  static deserialize(data) {
    if (typeof data == 'string') {
      data = JSON.parse(data);
    }
    let matrix = new Matrix(data.rows, data.cols);
    matrix.data = data.data;
    return matrix;
  }
}

if (typeof module !== 'undefined') {
  module.exports = Matrix;
}

// ------------------------------  Other techniques for learning ------- 

class ActivationFunction {
  constructor(func, dfunc) {
    this.func = func;
    this.dfunc = dfunc;
  }
}

let sigmoid = new ActivationFunction(
  x => 1 / (1 + Math.exp(-x)), // activation
  y => y * (1 - y)             // derivatives
);

// my code 
let relu = new ActivationFunction(
  x => ((x < 0) ? 0 : x),
  y => ((y < 0) ? 0 : y)
);

// my code 
let lRelu = new ActivationFunction(
  x => ((x < 0) ? (0.01 * x) : x),
  y => ((y < 0) ? (0.01 * y) : y)
);

let tanh = new ActivationFunction(
  x => Math.tanh(x),
  y => 1 - (y * y)
);

// calculate lRate taking into account the decay value 
// we do it once per round 
// doiung it after each run the results not that great 
let lRate = function(rate) {
    if (decay === 0.0) {
       return rate; 
    }
    
    // doesnt work if i update after each test run 
    // so doing it after each batch 
    
    // this one produced the worst result - so not using it 
    // rate * 0.96 ^ (train_index / 1000)
    
    return rate * (1 / (1 + decay * trainrun));
}

// my code setup the activations functions 
// that can be used 

let hiddenActivation = sigmoid;
let outputActivation = sigmoid;

function setHLayer(option) {
    if (option.value === 'sigmoid') {
        hiddenActivation = sigmoid;
    } else if (option.value ==='relu') {
        hiddenActivation = relu;
    } else if (option.value === 'lrelu') {
        hiddenActivation = lRelu;
    } else if (option.value === 'tanh') {
        hiddenActivation = tanh;
    }
}

function setOLayer(option) {
    if (option.value === 'sigmoid') {
        outputActivation = sigmoid;
    } else if (option.value ==='relu') {
        outputActivation = relu;
    } else if (option.value === 'lrelu') {
        outputActivation = lRelu;
    } else if (option.value === 'tanh') {
        outputActivation = tanh;
    }
}

// setup the max layer -- my code 
class MaxPoolLayer {
    constructor(inValue, outValue) {
        this.inValue = inValue;
        this.weights_in = null;
          
        this.outValue = outValue;
    }
    
    calculateOutValue () {
        // take the input value and "trim" it without lossing the input data
        // there is redundant data anyways 
        var pool = this.maxPool2by2(this.inValue, Math.sqrt(this.inValue.length));
         
        // update the out value 
        this.outValue =  Matrix.fromArray(pool);
    }
    
    // 2 x 2 
    maxPool2by2(inputs, targetSizeArray) 
    {
         var step = 4;
         
         // get back - x * x --> 8 * 8 or 28 * 28 
         var splitImage = this.splitInput(inputs, targetSizeArray);
         
         // get target resize 14 * 14 or 4 * 4 (for each 2 x 2 box we will be fething 1 value)
         var resizeSize = targetSizeArray / 2;
         var resized = new Array(resizeSize).fill(null).map(() => new Array(resizeSize).fill(null)),
             resizeLastRow = 0, 
             resizeLastColumn = 0;
         
         for (var rowIndex = 0; rowIndex < splitImage[0].length;)
         {
             resizeLastColumn = 0;
             for(var columnIndex = 0; columnIndex < splitImage.length; ) 
             {
                 var topLeftBox = this.getMaxValuesSquare(splitImage, columnIndex, rowIndex, step);
                 var topRightBox = this.getMaxValuesSquare(splitImage, columnIndex+2, rowIndex, step);
                 var bottomLeftBox = this.getMaxValuesSquare(splitImage, columnIndex, rowIndex+2, step);
                 var bottomRightBox = this.getMaxValuesSquare(splitImage, columnIndex+2, rowIndex+2, step);
                 
                 resized[resizeLastColumn][resizeLastRow] = topLeftBox;
                 resized[resizeLastColumn+1][resizeLastRow] = topRightBox;
                 
                 resized[resizeLastColumn][resizeLastRow+1] = bottomLeftBox;
                 resized[resizeLastColumn+1][resizeLastRow+1] = bottomRightBox;
                 
                 columnIndex += step;
                 resizeLastColumn += 2;
             }
             
              rowIndex += step;
              resizeLastRow += 2;
         }
         
         // dont understand why it is an array of single aray itmes the input 
         // convert a flat array into an array of flat array 
         // 0 0 0 -> [0] [0] [0]
         var concatResult = [].concat.apply([], resized);
         
         return concatResult;
    }
    
    // get the max value from the imaginary 2x2 square 
    getMaxValuesSquare(inputs, columnIndex, rowIndex, numRowsScan) {
        var result = [];
        
        result.push(inputs[columnIndex][rowIndex]);
        result.push(inputs[columnIndex][rowIndex+1]);
        
        result.push(inputs[columnIndex+1][rowIndex]);
        result.push(inputs[columnIndex+1][rowIndex+1]);
        
        return result[findMax(result)];
    }
    
    // break the flat array in to a 28 * 28
    splitInput(inputs, length) {
        var results = [];
        
        // the input may not allways be an array - so lets make sure it is 
        var clone = [];
        inputs.forEach(function(item, ind, array) {
            clone[ind] = item;
        });
        
        // build the 28 by 28 grid -- redundant -- makes it easy to visualise the data 
        while (clone.length > 0) {
            var chunk = clone.splice(0, length);
            results.push(chunk);
        }
        
        return results;
    }
} 

// encapsulate 1 layer - my code 
class Layer {
    
     constructor(numNeuronsIn, numNeuronsLayer, activation_function, gradient_function) {
         this.initialise(numNeuronsIn, numNeuronsLayer);
         
         this.activation_function = activation_function;
         this.gradient_function = gradient_function;
         
         this.lastInputValue = null;
         this.values = null;
         this.errorValue = null;
         
         this.lastDelta = null;
         this.lastGradient = null;
         this.layerBefore = null;
     }
     
     initialise (numNeuronsIn, numNeuronsLayer) {
         // ini only if different
         if ( this.numNeuronsIn !== numNeuronsIn || this.numNeuronsLayer !== numNeuronsLayer) {
             // dimentions 
             this.numNeuronsIn = numNeuronsIn;
             this.numNeuronsLayer = numNeuronsLayer;
             
             if (this.numNeuronsIn > 0 && this.numNeuronsLayer > 0) {
                 // setup the 'in' weights 
                 this.weights_in = new Matrix(this.numNeuronsLayer, this.numNeuronsIn);
                 this.weights_in.randomize();
             }
              
              if (this.numNeuronsLayer > 0) {
                 // setup the bias values 
                 this.bias = new Matrix(this.numNeuronsLayer, 1);
                 this.bias.randomize();   
             }
         }
     }
     
     calculate(inputs) {
        this.lastInputValue = inputs;
        
        this.values = Matrix.multiply(this.weights_in, inputs);
        this.values.add(this.bias);
        this.normaliseValues();  // make the values easy to read 
     }
     
     error (targets) {
       this.errorValue = Matrix.subtract(targets, this.values);
     }
     
     errorLayerBefore(weight, errorValue) {
        let who_t = Matrix.transpose(weight);
        
        return Matrix.multiply(who_t, errorValue);
     }
     
     calculateGradient (learningRate, error) {
         
        // get the weight gradient 
        let gradients = Matrix.map(this.values, this.gradient_function);
            gradients.multiply(error);
            
            gradients.multiply(learningRate); // multiply learning rate 
    
        return gradients; // weight increment 
     }
     
     // clean up the values  
     normaliseValues () {
        this.values = Matrix.map(this.values, this.activation_function);
     }
     
     deltaWeight(gradients, values) {
         // if the value is 1 x 64 we want 64 x 1
         let hidden = Matrix.transpose(values);
         
         return Matrix.multiply(gradients, hidden);
     }
        
     adjustWeightBias(gradients, values) {
        var deltaW = this.deltaWeight(gradients, values);
        
        // Adjust the weight with a delta and momentum 
        if (this.lastDelta === null || momentum === 0.0) {
            this.lastDelta = deltaW;
            
            this.weights_in.add(deltaW);
        } else {
            this.weights_in.add(deltaW);
            
            // apply momentum from last run 
            this.lastDelta = this.lastDelta.multiply(momentum)
            this.weights_in.add(this.lastDelta);
        }
        
        // Adjust the bias by its deltas (which is just the gradients)
        if (this.lastGradient === null || momentum === 0.0) {
            this.lastGradient = gradients;
            
            this.bias.add(gradients);
        } else {
            this.bias.add(gradients);
            
            // apply momentum from last run 
            this.lastGradient = this.lastGradient.multiply(momentum)
            this.bias.add(this.lastGradient);
        }
     }
}

// encapsulate a group of Layers - my code 
class LayerGroup {
    constructor(numLayers, numNeuronsIn, numNeuronsLayer, activation_function, gradient_function) {
        this.firstLayer = new Layer(numNeuronsIn, numNeuronsLayer, activation_function, gradient_function);
        
        // we dont know the input size untill we start doing the first pass 
        // we dut lets set it when we have the data - it is a once off anyways 
        this.layers = Array.from({ length: numLayers-1 },()=> (new Layer(numNeuronsLayer, numNeuronsLayer, activation_function, gradient_function)) );
    }
    
    // loop all the layers and calculate the 
    calculate(inputs, maxPool) {
        var firstLayer = this.firstLayer;
            firstLayer.calculate(inputs);
        
        var inputNextLayer = firstLayer.values,
            lastLayer = firstLayer;
        for(let i=0; i< this.layers.length; i++) {
            var layer = this.layers[i];
        
            // make the looping easy 
            layer.layerBefore = lastLayer;
            
            if (maxPool)
            {
                // we add a max pool layer between hidden layers 
                // we shrink input to layer hidden 2  
                if (i % 2 === 0) {
                    
                   var maxPool = new MaxPoolLayer(inputNextLayer.data, null);
                   maxPool.calculateOutValue();
                   
                   // build a fake weight - so that we can jump between the 2 hidden layers 
                   maxPool.weight_in = new Matrix(maxPool.outValue.rows, maxPool.inValue.length);
                   maxPool.weight_in.initValue(1);
                   
                   // reload the layer details based on the new size 
                   layer.initialise(maxPool.outValue.rows, maxPool.outValue.rows);
                   layer.maxPool = maxPool;
                   
                   inputNextLayer = maxPool.outValue;
                } 
            }
            
            layer.calculate(inputNextLayer);
            
            inputNextLayer = layer.values;
            lastLayer = layer;
        }
        
        return inputNextLayer;
    }

    lastLayer() {
        if (this.layers.length === 0) {
            return this.layers[this.layers.length];
        }
        
        return this.firstLayer;
    }
    
    adjustWeightBias(learningRate, error) {
        
        var errorValue = error, gradientsHidden = 0;
        
        for(let i=this.layers.length-1; i >= 0; i--) {
            var hiddenLayer = this.layers[i];
            
            gradientsHidden = hiddenLayer.calculateGradient(learningRate, errorValue);
    
            // adjust hidden layer 
            hiddenLayer.adjustWeightBias(gradientsHidden, hiddenLayer.lastInputValue);
            hiddenLayer.errorValue = errorValue;
            
            //calculate error layer before 
            if (hiddenLayer.maxPool) {
                // calculate error using a fake weight so that can compensate for the
                // fact we trimmed input from 64 to 16 
                errorValue = hiddenLayer.errorLayerBefore(hiddenLayer.maxPool.weight_in, errorValue);
            }
            else {
                errorValue = hiddenLayer.errorLayerBefore(hiddenLayer.weights_in, errorValue); 
            }
        }
        
        gradientsHidden = this.firstLayer.calculateGradient(learningRate, errorValue);
    
        this.firstLayer.adjustWeightBias(gradientsHidden, this.firstLayer.lastInputValue);
    }
    
    // 2 x 2 
    maxPool2by2(inputs, targetSizeArray) 
    {
         var step = 4;
         
         // get back - x * x --> 8 * 8 or 28 * 28 
         var splitImage = this.splitInput(inputs, targetSizeArray);
         
         // get target resize 14 * 14 or 4 * 4 (for each 2 x 2 box we will be fething 1 value)
         var resizeSize = targetSizeArray / 2;
         var resized = new Array(resizeSize).fill(null).map(() => new Array(resizeSize).fill(null)),
             resizeLastRow = 0, 
             resizeLastColumn = 0;
         
         for (var rowIndex = 0; rowIndex < splitImage[0].length;)
         {
             resizeLastColumn = 0;
             for(var columnIndex = 0; columnIndex < splitImage.length; ) 
             {
                 var topLeftBox = this.getMaxValuesSquare(splitImage, columnIndex, rowIndex, step);
                 var topRightBox = this.getMaxValuesSquare(splitImage, columnIndex+2, rowIndex, step);
                 var bottomLeftBox = this.getMaxValuesSquare(splitImage, columnIndex, rowIndex+2, step);
                 var bottomRightBox = this.getMaxValuesSquare(splitImage, columnIndex+2, rowIndex+2, step);
                 
                 resized[resizeLastColumn][resizeLastRow] = topLeftBox;
                 resized[resizeLastColumn+1][resizeLastRow] = topRightBox;
                 
                 resized[resizeLastColumn][resizeLastRow+1] = bottomLeftBox;
                 resized[resizeLastColumn+1][resizeLastRow+1] = bottomRightBox;
                 
                 columnIndex += step;
                 resizeLastColumn += 2;
             }
             
              rowIndex += step;
              resizeLastRow += 2;
         }
         
         // dont understand why it is an array of single aray itmes the input 
         // convert a flat array into an array of flat array 
         // 0 0 0 -> [0] [0] [0]
         var concatResult = [].concat.apply([], resized);
         
         //concatResult.forEach(function(item, ind, array) {
         //   array[ind] = [item];
         //});

         return concatResult;
    }
    
    // get the max value from the imaginary 2x2 square 
    getMaxValuesSquare(inputs, columnIndex, rowIndex, numRowsScan) {
        var result = [];
        
        result.push(inputs[columnIndex][rowIndex]);
        result.push(inputs[columnIndex][rowIndex+1]);
        
        result.push(inputs[columnIndex+1][rowIndex]);
        result.push(inputs[columnIndex+1][rowIndex+1]);
        
        return result[findMax(result)];
    }
    
    // break the flat array in to a 28 * 28
    splitInput(inputs, length) {
        var results = [];
        
        // the input may not allways be an array - so lets make sure it is 
        var clone = [];
        inputs.forEach(function(item, ind, array) {
            clone[ind] = item;
        });
        
        // build the 28 by 28 grid -- redundant -- makes it easy to visualise the data 
        while (clone.length > 0) {
            var chunk = clone.splice(0, length);
            results.push(chunk);
        }
        
        return results;
    }
}

class NeuralNetwork {
  /*
  * if first argument is a NeuralNetwork the constructor clones it
  * USAGE: cloned_nn = new NeuralNetwork(to_clone_nn);
  */
  constructor(in_nodes, hid_nodes, out_nodes) {

    // my changes encapsulate all in my layers 
    this.input_nodes = in_nodes;
    this.hidden_nodes = hid_nodes;
    this.output_nodes = out_nodes;
      
    this.hiddenLayers = new LayerGroup(hiddenLayersNum, this.input_nodes, this.hidden_nodes, hiddenActivation.func, hiddenActivation.dfunc);
      
    // we know we are getting 10 numbers 
    // let the input be driven by the hidden layer 
    this.ouputLayer = new Layer(-1, this.output_nodes, outputActivation.func, outputActivation.dfunc);

    // TODO: copy these as well
    this.setLearningRate();
    this.setActivationFunction();
  }

  predict(input_array) {

    // Generating the Hidden Outputs
    let inputs = Matrix.fromArray(input_array);
    
    // pass the input via the hidden layers 
    var inputOutput = this.hiddenLayers.calculate(inputs);
    
    // calculate the ouput values 
    this.ouputLayer.calculate(inputOutput);

    return this.ouputLayer.values.toArray();
  }

  setLearningRate(learning_rate = 0.1) {
    this.learning_rate = learning_rate;
  }

  setActivationFunction(func = sigmoid) {
    this.activation_function = func;
  }

  train(input_array, target_array, maxPool) {
    // setup input / target / learning rate
    // do the calculations using my custom layers 
    let inputs = Matrix.fromArray(input_array);
    let targets = Matrix.fromArray(target_array);
    
    let learning_rate = lRate(this.learning_rate);
    
    // pass the input throught all the layers 
    var hiddenLayerPass = this.hiddenLayers.calculate(inputs, maxPool);
    
    // initialice the ouput layer with the ouput produced by the hidden layer 
    this.ouputLayer.initialise(hiddenLayerPass.rows, this.ouputLayer.numNeuronsLayer);
    this.ouputLayer.calculate(hiddenLayerPass);
    
    this.ouputLayer.error(targets);     // calculate the error expected <-> actual 

    // adjust all the layers ----
    
    // adjust the ouput layer 
    let gradientsOutput = this.ouputLayer.calculateGradient(learning_rate, this.ouputLayer.errorValue);
    this.ouputLayer.adjustWeightBias(gradientsOutput, hiddenLayerPass);
    
    // get error to be used for last hidden layer in the hidden group
    let error = this.ouputLayer.errorLayerBefore(this.ouputLayer.weights_in, this.ouputLayer.errorValue); 
    
    // adjust all the hidden layers 
    this.hiddenLayers.adjustWeightBias(learning_rate, error, maxPool);
  }

  serialize() {
    return JSON.stringify(this);
  }

  static deserialize(data) {
    if (typeof data == 'string') {
      data = JSON.parse(data);
    }
    let nn = new NeuralNetwork(data.input_nodes, data.hidden_nodes, data.output_nodes);
    nn.weights_ih = Matrix.deserialize(data.weights_ih);
    nn.weights_ho = Matrix.deserialize(data.weights_ho);
    nn.bias_h = Matrix.deserialize(data.bias_h);
    nn.bias_o = Matrix.deserialize(data.bias_o);
    nn.learning_rate = data.learning_rate;
    return nn;
  }


  // Adding function for neuro-evolution
  copy() {
    return new NeuralNetwork(this);
  }

  // Accept an arbitrary function for mutation
  mutate(func) {
    this.weights_ih.map(func);
    this.weights_ho.map(func);
    this.bias_h.map(func);
    this.bias_o.map(func);
  }
}

// load the test data 

function loadMNIST(callback) {
  let mnist = {};
  let files = {
    train_images: '/uploads/codingtrain/train-images-idx3-ubyte',
    train_labels: '/uploads/codingtrain/train-labels-idx1-ubyte',
    test_images:  '/uploads/codingtrain/t10k-images-idx3-ubyte',
    test_labels:  '/uploads/codingtrain/t10k-labels-idx1-ubyte',
  };
  return Promise.all(Object.keys(files).map(async file => {
    mnist[file] = await loadFile(files[file])
  }))
    .then(() => callback(mnist));
}

async function loadFile(file) {
  let buffer = await fetch(file).then(r => r.arrayBuffer());
  let headerCount = 4;
  let headerView = new DataView(buffer, 0, 4 * headerCount);
  let headers = new Array(headerCount).fill().map((_, i) => headerView.getUint32(4 * i, false));

  // Get file type from the magic number
  let type, dataLength;
  if(headers[0] == 2049) {
    type = 'label';
    dataLength = 1;
    headerCount = 2;
  } else if(headers[0] == 2051) {
    type = 'image';
    dataLength = headers[2] * headers[3];
  } else {
    throw new Error("Unknown file type " + headers[0])
  }

  let data = new Uint8Array(buffer, headerCount * 4);
  if(type == 'image') {
    dataArr = [];
    for(let i = 0; i < headers[1]; i++) {
      dataArr.push(data.subarray(dataLength * i, dataLength * (i + 1)));
    }
    return dataArr;
  }
  return data;
}