Code viewer for World: Character Recognition Netw...
/* Defined by MNIST. */
const PIXELS = 28;
const PIXELSSQUARED = PIXELS * PIXELS;

/* The number of training and test exemplars in the data set. */
const NOTRAIN = 60000;
const NOTEST  = 10000;

/* The number of nodes in the network. */
const noinput = PIXELSSQUARED;
const nohidden = 64;
const nooutput = 10;

/* The rate at which the network should learn. */
const learningrate = 0.1;

/* Whether or not we should train every timestep. */
let do_training = true;

/* How many exemplars to train on, and test, per timestep. */
const TRAINPERSTEP = 30;
const TESTPERSTEP = 30;

/* The zoom factor to use when displaying each digit. */
const ZOOMFACTOR = 7;                        
const ZOOMPIXELS = ZOOMFACTOR * PIXELS; 

const canvaswidth = (PIXELS + ZOOMPIXELS) + 50;
const canvasheight = (ZOOMPIXELS * 3) + 100;

const DOODLE_THICK = 18;
const DOODLE_BLUR = 3;

let mnist;      
let nn;

let trainrun = 1;
let train_index = 0;

let testrun = 1;
let test_index = 0;
let total_tests = 0;
let total_correct = 0;

let doodle, demo;
let doodle_exists = false;
let demo_exists = false;

let mousedrag = false;

var train_inputs, test_inputs, demo_inputs, doodle_inputs;

/* Generate a random weight for each of our nodes. */
function randomWeight(){
    return (AB.randomFloatAtoB(-0.5, 0.5));
}    

/* Set the size of the header. */
AB.headerCSS ( { "max-height": "95vh" } );

/* Output various messages, using AB's built-in library. */
var thehtml;

/* The header for the doodle. */
thehtml = "<hr> <h1> 1. Doodle </h1> Top row: Doodle (left) and shrunk (right). <br>" +
          "Draw your doodle in top LHS. <button onclick='wipeDoodle();' class='normbutton' >Clear doodle</button> <br>";
AB.msg(thehtml, 1);

/* The guess being made about the doodle. */
  
/* The header for the training section. */
thehtml = "<hr> <h1> 2. Training </h1> Middle row: Training image magnified (left) and original (right). <br>" +
          "<button onclick='do_training = false;' class='normbutton' >Stop training</button> <br>";
AB.msg(thehtml, 3);
     
/* Further information about the training data. */
  
/* The header for the testing section. */
thehtml = "<h3> Hidden tests </h3>";
AB.msg(thehtml, 5);
           
/* Further information about the testing data. */ 
  
/* The header for the demo. */
thehtml = "<hr> <h1> 3. Demo </h1> Bottom row: Test image magnified (left) and  original (right). <br>" +
          "The network is <i>not</i> trained on any of these images. <br>" +
          "<button onclick='makeDemo();' class='normbutton' >Demo test image</button> <br>";
AB.msg(thehtml, 7);
   
const greenspan = "<span style='font-weight:bold; font-size:x-large; color:darkgreen'>";

/* Initialise our neural network. */
function setup() {
    createCanvas(canvaswidth, canvasheight);
    
    doodle = createGraphics(ZOOMPIXELS, ZOOMPIXELS);
    doodle.pixelDensity(1);
    
    AB.loadingScreen();
    
    $.getScript("/uploads/codingtrain/matrix.js", function() {
        $.getScript("/uploads/codingtrain/nn.js", function() {
            $.getScript("/uploads/codingtrain/mnist.js", function() {
                nn = new NeuralNetwork(noinput, nohidden, nooutput);
                nn.setLearningRate(learningrate);
                loadData();
            });
        });
    });
}

/* Load our training data. */
function loadData() {
    loadMNIST(function(data) {
        mnist = data;
        AB.removeLoading();
    });
}

/* Construct a P5 image object for a given image. */
function getImage(img) {
    let theimage = createImage(PIXELS, PIXELS);
    theimage.loadPixels();        
    
    for (let i = 0; i < PIXELSSQUARED; i++) {
        let bright = img[i];
        let index = i * 4;
        theimage.pixels[index + 0] = bright;
        theimage.pixels[index + 1] = bright;
        theimage.pixels[index + 2] = bright;
        theimage.pixels[index + 3] = 255;
    }
    
    theimage.updatePixels();
    return theimage;
}

/* Convert an image into a normalised set of inputs. */
function getInputs(img) {
    let inputs = [];
    for (let i = 0; i < PIXELSSQUARED; i++) {
        let bright = img[i];
        inputs[i] = bright / 255;
    }
    
    return inputs;
}

/* Train our network with a single exemplar. */
function trainit(show) {
    let img = mnist.train_images[train_index];
    let label = mnist.train_labels[train_index];
    
    if (show) {
        var theimage = getImage(img);
        image(theimage, 0, ZOOMPIXELS + 50, ZOOMPIXELS, ZOOMPIXELS);
        image(theimage, ZOOMPIXELS + 50, ZOOMPIXELS + 50, PIXELS, PIXELS);
    }
    
    let inputs = getInputs(img);
    let targets = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
    
    /* Randomly set some of the outputs to a confidence level between 0 and 0.5 */
    for (let i = 0; i < 3; i++) {
        targets[Math.floor(Math.random() * targets.length)] = Math.round(AB.randomFloatAtoB(0.0, 0.5));
    }
    
    targets[label] = 1;

    train_inputs = inputs;
    nn.train(inputs, targets);
    
    thehtml = "trainrun: " + trainrun + "<br> no: " + train_index;
    AB.msg(thehtml, 4);
    
    train_index++;
    if (train_index == NOTRAIN) {
        train_index = 0;
        trainrun++;
    }
}

/* Test the network with a single exemplar. */
function testit() { 
    let img = mnist.test_images[test_index];
    let label = mnist.test_labels[test_index];
    
    let inputs = getInputs(img); 
    test_inputs = inputs;
    let prediction = nn.predict(inputs);
    let guess = findMax(prediction);

    total_tests++;
    if (guess == label) total_correct++;

    let percent = (total_correct / total_tests) * 100;
  
    thehtml = "testrun: " + testrun + "<br> no: " + total_tests + "<br>" +
              "correct: " + total_correct + "<br>" +
              "score: " + greenspan + percent.toFixed(2) + "</span>";
    AB.msg(thehtml, 6);
    
    test_index++;
    if (test_index == NOTEST) {
        testrun++;
        test_index = 0;
        total_tests = 0;
        total_correct = 0;
    }
}

/* Find the no. 1 and no. 2 nodes in a given array. */
function find12(a) {
    let no1 = 0;
    let no2 = 0;
    let no1value = 0;     
    let no2value = 0;
    
    for (let i = 0; i < a.length; i++) {
        console.log(i + ": " + a[i]);
        
        if (a[i] > no1value) {
            no2 = no1;
            no2value = no1value;
            no1 = i;
            no1value = a[i];
        } else if (a[i] > no2value) {
            no2 = i;
            no2value = a[i];
        }
    }
  
    var b = [no1, no2];
    return b;
}

/* Find the maximum of our outputs. */
function findMax(a) {
    let no1 = 0;
    let no1value = 0;     
    
    for (let i = 0; i < a.length; i++) {
        if (a[i] > no1value) {
            no1 = i;
            no1value = a[i];
        }
    }
    
    return no1;
}

function draw() {
    if (typeof mnist == 'undefined') return;
    
    background('black');
    
    if (do_training) {
        for (let i = 0; i < TRAINPERSTEP; i++) {
            if (i === 0) {
                trainit(true);
            } else {
                trainit(false);
            }
        }
    
        for (let i = 0; i < TESTPERSTEP; i++) {
            testit();
        }
        
    }
    
    if (demo_exists) {
        drawDemo();
        guessDemo();
    }
    
    if (doodle_exists) {
        drawDoodle();
        guessDoodle();
    }
    
    if (mouseIsPressed) {
        var MAX = ZOOMPIXELS + 20;
        if ((mouseX < MAX) && (mouseY < MAX) && (pmouseX < MAX) && (pmouseY < MAX)) {
            mousedrag = true;
            doodle_exists = true;
            doodle.stroke('white');
            doodle.strokeWeight(DOODLE_THICK);
            doodle.line(mouseX, mouseY, pmouseX, pmouseY);      
        }
    } else {
        if (mousedrag) {
            mousedrag = false;
            doodle.filter(BLUR, DOODLE_BLUR);
        }
    }
}

function makeDemo() {
    demo_exists = true;
    var i = AB.randomIntAtoB(0, NOTEST - 1);  
    
    demo = mnist.test_images[i];     
    var label = mnist.test_labels[i];
    
    thehtml = "Test image no: " + i + "<br>" + 
              "Classification: " + label + "<br>";
    
    AB.msg(thehtml, 8);
}

function drawDemo() {
    var theimage = getImage(demo);
    image(theimage, 0, canvasheight - ZOOMPIXELS, ZOOMPIXELS, ZOOMPIXELS);
    image(theimage, ZOOMPIXELS + 50, canvasheight - ZOOMPIXELS, PIXELS, PIXELS);
}

function guessDemo() {
    let inputs = getInputs(demo); 
    demo_inputs = inputs;
  
    let prediction = nn.predict(inputs);
    let guess = findMax(prediction);
    
    thehtml = "We classify it as: " + greenspan + guess + "</span>";
    AB.msg(thehtml, 9);
}

function drawDoodle() {
    let theimage = doodle.get();
    image(theimage, 0, 0, ZOOMPIXELS, ZOOMPIXELS);
    image(theimage, ZOOMPIXELS + 50, 0, PIXELS, PIXELS);
}
      
function guessDoodle() {
    let img = doodle.get();
    
    img.resize(PIXELS, PIXELS);     
    img.loadPixels();
    
    let inputs = [];
    for (let i = 0; i < PIXELSSQUARED; i++) {
       inputs[i] = img.pixels[i * 4] / 255;
    }
    
    doodle_inputs = inputs;
    
    let prediction = nn.predict(inputs);
    let b = find12(prediction);
    
    thehtml = "We classify it as: " + greenspan + b[0] + "</span> <br>" +
             "No.2 guess is: " + greenspan + b[1] + "</span>";
    AB.msg(thehtml, 2);
}

function wipeDoodle() {
    doodle_exists = false;
    doodle.background('black');
}

function showInputs(inputs) {
    var str = "";
    for (let i = 0; i < inputs.length; i++) {
        if (i % PIXELS === 0) str = str + "\n";
        var value = inputs[i];
        str = str + " " + value.toFixed(2); 
    }
}