Code viewer for World: Resit-Practical 2: Khizer'...


// Cloned by Khizer Ahmed on 25 Jul 2022 from World "Character recognition neural network" by "Coding Train" project 
// Please leave this clone trail here.
 

// Resit Practical 2: Doodle Recognition 

// Port of Character recognition neural network from here:
// https://github.com/CodingTrain/Toy-Neural-Network-JS/tree/master/examples/mnist
// with many modifications 
// --- defined by MNIST - do not change these ---------------------------------------

var CROP_PIXELS = 24;
var PIXELS = 28;                    // images in data set are tiny 
var PIXELSSQUARED = PIXELS * PIXELS;


// number of training and test exemplars in the data set:
var NOTRAIN = 60000;
var NOTEST = 10000;

//--- can modify all these --------------------------------------------------

// no of nodes in network 
var noinput = PIXELSSQUARED;
var nohidden = 64;
var nooutput = 10;

var learningrate = 0.1;         // default 0.1

// should we train every timestep or not 

var do_training = true;
var BATCH_SIZE = 50;
var theNN = 3;

// how many to train and test per timestep 
var TRAINPERSTEP = 60;
var TESTPERSTEP = 5;
var ZOOMFACTOR = 7;
var ZOOMPIXELS = 7 * PIXELS;

// 3 rows of
// large image + 50 gap + small image    
// 50 gap between rows 

var canvaswidth = PIXELS + ZOOMPIXELS + 50;
var canvasheight = 3 * ZOOMPIXELS + 100;


var DOODLE_THICK = 18;    // thickness of doodle lines 
var DOODLE_BLUR = 0;      // blur factor applied to doodles


var mnist = void 0;
// all data is loaded into this 
// mnist.train_images
// mnist.train_labels
// mnist.test_images
// mnist.test_labels

// Defining the Neural Network Object

var cnn;

var trainrun = 1;
var train_index = 0;


var testrun = 1;
var test_index = 0;
var total_tests = 0;
var total_correct = 0;

let doodle, demo;
var doodle_exists = false;
var demo_exists = false;


var mousedrag = false;     // are we in the middle of a mouse drag drawing?

// save inputs to global var to inspect
// type these names in console
var train_inputs, test_inputs, demo_inputs, doodle_inputs;

var thehtml;

// Matrix.randomize() is changed to point to this. Must be defined by user of Matrix. 

function randomWeight() 
{
    
    return (AB.randomFloatAtoB(-.5, .5));
    
}

// make run header bigger
AB.headerCSS ( { "max-height": "95vh" } );


//--- start of AB.msgs structure: ---------------------------------------------------------
// We output a serious of AB.msgs to put data at various places in the run header 



  var thehtml;

  // 1 Doodle header 
  thehtml = "<hr> <h1> 1. Doodle </h1> Top row: Doodle (left) and shrunk (right). <br> " +
        " Draw your doodle in top LHS. <button onclick='wipeDoodle();' class='normbutton' >Clear doodle</button> <br> ";
   AB.msg ( thehtml, 1 );

  // 2 Doodle variable data (guess)
  
  // 3 Training header
  thehtml = "<hr> <h1> 2. Training </h1> Middle row: Training image magnified (left) and original (right). <br>  " +
        " <button onclick='do_training = false;' class='normbutton' >Stop training</button> <br> ";
  AB.msg ( thehtml, 3 );
     
  // 4 variable training data 
  
  // 5 Testing header
  thehtml = "<h3> Hidden tests </h3> " ;
  AB.msg ( thehtml, 5 );
           
  // 6 variable testing data 
  
  // 7 Demo header 
  thehtml = "<hr> <h1> 3. Demo </h1> Bottom row: Test image magnified (left) and  original (right). <br>" +
        " The network is <i>not</i> trained on any of these images. <br> " +
        " <button onclick='makeDemo();' class='normbutton' >Demo test image</button> <br> ";
   AB.msg ( thehtml, 7 );
   
  // 8 Demo variable data (random demo ID)
  // 9 Demo variable data (changing guess)
  
var greenspan = "<span style='font-weight:bold; font-size:x-large; color:darkgreen'> "  ;

//--- end of AB.msgs structure: ---------------------------------------------------------



function setup() 
{
 createCanvas(canvaswidth, canvasheight);
 
 doodle = createGraphics(ZOOMPIXELS, ZOOMPIXELS);
 doodle.pixelDensity(1);
 
// JS load other JS 
// maybe have a loading screen while loading the JS and the data set 

 AB.loadingScreen();
 

    $.getScript("/uploads/codingtrain/mnist.js", function() 
    {
        $.getScript("/uploads/khizer2599/mathutils.js", function() 
        {
            $.getScript("/uploads/khizer2599/webcnn.js", function() 
            {
            $.getJSON("/uploads/khizer2599/cnn_mnist_10_20_98accuracy.json", function(n) 
                {
                    console.log("All JS loaded");
                    if (0 === theNN) 
                    {
                            cnn = createShallowNetwork(ACTIVATION_RELU);
                    } 
                    else 
                    {
                        if (1 === theNN) 
                        {
                            cnn = createShallowNetwork(ACTIVATION_TANH);
                        } 
                        else 
                        {
                            if (2 === theNN) 
                            {
                                cnn = createDefaultNetwork();
                            } 
                            else 
                            {
                                if (3 === theNN) 
                                {
                                    cnn = loadNetworkFromJSON(n);
                                } 
                                else 
                                {
                                    console.log("Unknown NN type: " + theNN);
                                }
                            }
                        }
                    }
                    loadData();
                });
            });
        });
    });

}

// Function for Loading the Network Framework

function loadNetworkFromJSON(networkJSON) 
{
    var cnn = new WebCNN();
    if (void 0 !== networkJSON.momentum) 
    {
        cnn.setMomentum(networkJSON.momentum);
    }
    if (void 0 !== networkJSON.lambda) 
    {
        cnn.setLambda(networkJSON.lambda);
    }
    if (void 0 !== networkJSON.learningRate) 
    {
        cnn.setLearningRate(networkJSON.learningRate);
    }
    var i = 0;
    for (; i < networkJSON.layers.length; ++i) 
    {
        var item = networkJSON.layers[i];
        console.log(item);
        cnn.newLayer(item);
    }
    var layerIndex = 0;
    for (; layerIndex < networkJSON.layers.length; ++layerIndex) 
    {
        var layerDesc = networkJSON.layers[layerIndex];
        switch(networkJSON.layers[layerIndex].type) 
        {
            case LAYER_TYPE_CONV:
            case LAYER_TYPE_FULLY_CONNECTED:
            if (void 0 !== layerDesc.weights && void 0 !== layerDesc.biases) 
            {
                cnn.layers[layerIndex].setWeightsAndBiases(layerDesc.weights, layerDesc.biases);
            }
        }
    }
    return cnn.initialize(), cnn;
}

function createDefaultNetwork()
{
	cnn = new WebCNN();
	return cnn.newLayer( { name: "image", type: LAYER_TYPE_INPUT_IMAGE, width: 24, height: 24, depth: 1 } ),
	cnn.newLayer( { name: "conv1", type: LAYER_TYPE_CONV, units: 10, kernelWidth: 5, kernelHeight: 5, strideX: 1, strideY: 1, padding: false } ),
	cnn.newLayer( { name: "pool1", type: LAYER_TYPE_MAX_POOL, poolWidth: 2, poolHeight: 2, strideX: 2, strideY: 2 } ),
	cnn.newLayer( { name: "conv2", type: LAYER_TYPE_CONV, units: 20, kernelWidth: 5, kernelHeight: 5, strideX: 1, strideY: 1, padding: false } ),
	cnn.newLayer( { name: "pool2", type: LAYER_TYPE_MAX_POOL, poolWidth: 2, poolHeight: 2, strideX: 2, strideY: 2 } ),
	cnn.newLayer( { name: "out", type: LAYER_TYPE_FULLY_CONNECTED, units: 10, activation: ACTIVATION_SOFTMAX } ),
	cnn.initialize(),

	cnn.setLearningRate( 0.01 ),
	cnn.setMomentum( 0.9 ),
	cnn.setLambda( 0 ), cnn;
}



function createShallowNetwork(act) 
{
    cnn = new WebCNN();
	return cnn.newLayer( { name: "image", type: LAYER_TYPE_INPUT_IMAGE, width: 24, height: 24, depth: 1 } ),
	cnn.newLayer( { name: "conv1", type: LAYER_TYPE_CONV, units: 10, kernelWidth: 5, kernelHeight: 5, strideX: 1, strideY: 1, padding: false } ),
	cnn.newLayer( { name: "pool1", type: LAYER_TYPE_MAX_POOL, poolWidth: 2, poolHeight: 2, strideX: 2, strideY: 2 } ),
	cnn.newLayer( { name: "conv2", type: LAYER_TYPE_CONV, units: 20, kernelWidth: 5, kernelHeight: 5, strideX: 1, strideY: 1, padding: false } ),
	cnn.newLayer( { name: "pool2", type: LAYER_TYPE_MAX_POOL, poolWidth: 2, poolHeight: 2, strideX: 2, strideY: 2 } ),
	cnn.newLayer( { name: "out", type: LAYER_TYPE_FULLY_CONNECTED, units: 10, activation: ACTIVATION_SOFTMAX } ),
	cnn.initialize(),

	cnn.setLearningRate( 0.01 ),
	cnn.setMomentum( 0.9 ),
	cnn.setLambda( 0.0 ), cnn;
}

function loadData()    
{
  loadMNIST ( function(data)    
  {
    mnist = data;
    console.log ("All data loaded into mnist object:");
    console.log(mnist);
    AB.removeLoading();     // if no loading screen exists, this does nothing 
  });
}

function centerImage(img, width)   // centering of the images for the output
{
    var list = [];
    var j = 0;
    for (; j < width; j++) 
    {
        list[j] = [];
        var i = 0;
        for (; i < width; i++) 
        {
            list[j][i] = img[4 * (j * width + i)];
        }
    }
    
    
    var x = Number.MAX_VALUE;
    var left = Number.MAX_VALUE;
    var w = -1;
    var right = -1;
    var l = 0;
    
    for (; l < list.length; l++) 
    {
        var minx = list[l].indexOf(255);
        var current = list[l].lastIndexOf(255);
        if (minx >= 0 && minx < left) 
        {
            left = minx;
        }
        if (current >= 0 && current > right) 
        {
            right = current;
        }
        if (minx >= 0 && l < x) 
        {
            x = l;
        }
        if (minx >= 0 && l > w) 
        {
            w = l;
        }
    }
    
    
    var y1 = Math.floor((width - w - x) / 2);
    var x1 = Math.floor((width - right - left) / 2);
    var patterns_data = Array(width).fill().map(function() {
        return Array(width).fill(0);
        });
    i = x;
    
    for (let i; i <= w; i++) 
    {
        j = left;
        for (; j <= right; j++) 
        {
            let i;
            patterns_data[i + y1][j + x1] = list[i][j];
        }
    }
    var vga_charmap = [];
    var k = 0;
    for (; k < width; k++) 
    {
        j = 0;
        for (; j < width; j++) 
        {
            vga_charmap[k * width + j] = patterns_data[k][j];
        }
    }
    return vga_charmap;
}


function getImage ( img )      // make a P5 image object from a raw data array   
{
    let theimage  = createImage (PIXELS, PIXELS);    // make blank image, then populate it 
    theimage.loadPixels();        
    
    for (let i = 0; i < PIXELSSQUARED ; i++) 
    {
        let bright = img[i];
        let index = i * 4;
        theimage.pixels[index + 0] = bright;
        theimage.pixels[index + 1] = bright;
        theimage.pixels[index + 2] = bright;
        theimage.pixels[index + 3] = 255;
    }
    
    theimage.updatePixels();
    return theimage;
}

function randomCrop(context, duration)  // performing some random image crops, to get better results. 
{
    var delta = PIXELS - duration;
    return crop(context, duration, Math.floor(Math.random() * delta), Math.floor(Math.random() * delta));
}


function crop(a, t) // cropping the image for training purpose, basically can be used to improve the models accuracy in classification task. 
{
    var f_ = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 2;
    var y = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : 2;
    var i = PIXELS;
    var ser_type_id = f_ + t;
    var b = y + t;
    var str = [];
    var hexRadius = f_;
    for (; hexRadius < ser_type_id; hexRadius++) 
    {
        var pos = y;
        for (; pos < b; pos++) 
        {
            str.push(a[hexRadius * i + pos]);
        }
    }
    return str;
}


function getInputs ( img )      // convert img array into normalised input array 
{
    let inputs = [];
    for (let i = 0; i < PIXELSSQUARED ; i++)          
    {
        let bright = img[i];
        inputs[i] = bright / 255;       // normalise to 0 to 1
    } 
    return  inputs ;
}


function toModelFormat(a, b)  // For the image to be converted as per the cnn input format defined above. 
{
    return {
        width : b,
        height : b,
        data : getImage(randomCrop(a, b), b).pixels
    };
}



function trainit (show) // train the network with a single exemplar, from global var "train_index", show visual on or off 
{
    
    if (train_index % TRAINPERSTEP !== 0) 
    {
        return train_index++;
    }
    
    let img   = mnist.train_images[train_index];
    var inputs = (mnist.train_labels[train_index], []);
    var sample = [];
  
    i = 0;
    for (; i < TRAINPERSTEP; i++) 
    {
        inputs.push(toModelFormat(mnist.train_images[train_index + i], CROP_PIXELS));
        sample.push(mnist.train_labels[train_index + i]);
    }
    
  // optional - show visual of the image 
  
    var theimage = getImage ( img );    // get image from data array 
    image ( theimage,   0,                ZOOMPIXELS+50,    ZOOMPIXELS,     ZOOMPIXELS  );      // magnified 
    image ( theimage,   ZOOMPIXELS+50,    ZOOMPIXELS+50,    PIXELS,         PIXELS      );      // original
  // set up the inputs
  
    var currpath = getInputs ( img );       // get inputs from data array 

    train_inputs = currpath;
    cnn.trainCNNClassifier(inputs,sample);

//   console.log(train_index);
//   console.log(inputs);
//   console.log(targets);


  thehtml = " trainrun: " + trainrun + "<br> no: " + train_index ;
  AB.msg ( thehtml, 4 );

  train_index++;
  if ( train_index == NOTRAIN ) 
  {
    train_index = 0;
    console.log( "finished trainrun: " + trainrun );
    trainrun++;
  }
}

function testit()   // test the network with a single exemplar, from global var "test_index"
{
    var level = mnist.test_images[test_index];
    var test = mnist.test_labels[test_index];
    var adjustedLevel = getInputs(level);
    
    //     set up the inputs
    //     var inputs = getInputs(label);
    //     test_inputs = inputs; // can inspect in console
    
    test_inputs = adjustedLevel;
    
    // the top output
    var dirName = findMax(cnn.classifyImages([toModelFormat(mnist.test_images[test_index], CROP_PIXELS)]));
    total_tests++;
    if (dirName == test) 
    {
            total_correct++;
    }  
    var e_total = total_correct / total_tests * 100;
    thehtml = " testrun: " + testrun + "<br> no: " + total_tests + " <br>  correct: " + total_correct + "<br>  score: " + greenspan + e_total.toFixed(2) + "</span>";
    AB.msg(thehtml, 6);
    if (++test_index == NOTEST) 
    {
        console.log("finished testrun: " + testrun + " score: " + e_total.toFixed(2));
        testrun++;
        test_index = 0;
        total_tests = 0;
        total_correct = 0;
    }
}





//--- find no.1 (and maybe no.2) output nodes ---------------------------------------
// (restriction) assumes array values start at 0 (which is true for output nodes) 


function find12(formItems) 
{
    var videoRemoteEnabled = 0;
    var child = 0;
    var audioRemoteEnabled = 0;
    var old = 0;
    var val = 0;
    for (; val < 10; ++val) 
    {
        if (formItems[0].getValue(0, 0, val) > child) 
        {
            child = formItems[0].getValue(0, 0, val);
            videoRemoteEnabled = val;
        } 
        else 
        {
            if (formItems[0].getValue(0, 0, val) > old) 
            {
                old = formItems[0].getValue(0, 0, val);
                audioRemoteEnabled = val;
            }
        }
    }
 return [videoRemoteEnabled, audioRemoteEnabled];
 
}



// just get the maximum - separate function for speed - done many times 
// find our guess - the max of the output nodes array

function findMax(node) 
{
    var max = 0;
    var child = 0;
    var val = 0;
    for (; val < 10; ++val) 
    {
        if (node[0].getValue(0, 0, val) > child) 
        {
            child = node[0].getValue(0, 0, val);
            max = val;
        }
    }
    return max;
}


function draw() 
{
    if (void 0 !== mnist) 
    {
        if (background("black"), demo_exists && (drawDemo(), guessDemo()), doodle_exists && (drawDoodle(), guessDoodle()), mouseIsPressed) 
        {
            var left = ZOOMPIXELS + 20;
            if (mouseX < left && mouseY < left && pmouseX < left && pmouseY < left) 
            {
                mousedrag = true;
                doodle_exists = true;
                doodle.stroke("white");
                doodle.strokeWeight(DOODLE_THICK);
                doodle.line(mouseX, mouseY, pmouseX, pmouseY);
            }
        } 
        else 
        {
            if (mousedrag) 
        {
            
            mousedrag = false;
            // console.log("Exiting draw. Now blurring.");
            doodle.filter(BLUR, DOODLE_BLUR);
        }
        
        else 
        {
            
            if (do_training) 
            {
                var _e2 = 0;
                for (; _e2 < TRAINPERSTEP; _e2++) 
                {
                    trainit(0 === _e2);
                }
                var _e3 = 0;
                for (; _e3 < TESTPERSTEP; _e3++) 
                {
                    testit();
                }  
            }
        }
    }
 }
}


//--- demo -------------------------------------------------------------
// demo some test image and predict it
// get it from test set so have not used it in training

function makeDemo() 
{
    demo_exists = true;
    var i = AB.randomIntAtoB(0, NOTEST - 1);
    
    demo = mnist.test_images[i];
    var beforeTab = mnist.test_labels[i];
    
    thehtml = "Test image no: " + i + "<br>" +
                "Classification: " + beforeTab + "<br>";
    
    AB.msg(thehtml, 8);
    
       // type "demo" in console to see raw data 
}



function drawDemo() 
{
    var sal = getImage(demo);
    image(sal, 0, canvasheight - ZOOMPIXELS, ZOOMPIXELS, ZOOMPIXELS);
    image(sal, ZOOMPIXELS + 50, canvasheight - ZOOMPIXELS, PIXELS, PIXELS);
}


function guessDemo() 
{
    // doodle is createGraphics not createImage
    
    var toWatch = getInputs(demo);
    //  console.log (theimage);
    
    
    demo_inputs = toWatch;
    
    var guess = findMax(cnn.classifyImages([toModelFormat(demo, CROP_PIXELS)]));
    
    thehtml = " We classify it as: " + greenspan + guess + "</span>";
    AB.msg(thehtml, 9);
}


//--- doodle -------------------------------------------------------------

function drawDoodle() 
{
    // doodle is createGraphics not createImage

    var sal = doodle.get();
    
    // console.log (theimage);
    
    image(sal, 0, 0, ZOOMPIXELS, ZOOMPIXELS);
    image(sal, ZOOMPIXELS + 50, 0, PIXELS, PIXELS);
}


function guessDoodle() 
{
    // doodle is createGraphics not createImage
    var dst = doodle.get();
    
    
    dst.resize(PIXELS, PIXELS);
    dst.loadPixels();
    
    // // set up inputs
    // let inputs = [];
    // for (let i = 0; i < PIXELSSQUARED ; i++) 
    // {
    //     inputs[i] = img.pixels[i * 4] / 255;
    // }
  
    // doodle_inputs = inputs;   // can inspect in console
    
    // feed forward to make prediction
    var t = find12(cnn.classifyImages([toModelFormat(centerImage(dst.pixels, PIXELS), CROP_PIXELS)]));   // array of outputs
    thehtml = " We classify it as: " + greenspan + t[0] + "</span> <br> No.2 guess is: " + greenspan + t[1] + "</span>";   // get no.1 and no.2 guesses  

    AB.msg(thehtml, 2);
}


function wipeDoodle() 
{
    doodle_exists = false;
    doodle.background("black");
}


// --- debugging --------------------------------------------------
// in console
// showInputs(demo_inputs);
// showInputs(doodle_inputs);

function showInputs(groups) 
// display inputs row by row, corresponding to square of pixels 

{
    var html = "";
    var i = 0;
    for (; i < groups.length; i++) 
    {
        if (i % PIXELS === 0) 
        {
            html = html + "\n";
        }
        html = html + " " + groups[i].toFixed(2);
    }
    console.log(html);
}