Code viewer for World: Practical-2: Doodle Recogn...

// Cloned by Vyoma Patel on 18 Nov 2021 from World "Character recognition neural network" by "Coding Train" project 
// Please leave this clone trail here.
 

// Port of Character recognition neural network from here:
// https://github.com/CodingTrain/Toy-Neural-Network-JS/tree/master/examples/mnist
// with many modifications 


// --- defined by MNIST - do not change these ---------------------------------------

var PIXELS = 28; //images in dataset are tiny
var PIXELSSQUARED = PIXELS * PIXELS;

// number of training and test exemplars in the data set:
var NOTRAIN = 124800;
var NOTEST = 20800;


//--- can modify all these --------------------------------------------------

// no of nodes in network 
var noinput = PIXELSSQUARED;
var nohidden = 64;
var nooutput = 10;

var learningrate = 0.1;   // default 0.1  

// should we train every timestep or not 
let do_training = true;

// how many to train and test per timestep 
const TRAINPERSTEP = 20;//30;
const TESTPERSTEP  = 5;

// multiply it by this to magnify for display 
const ZOOMFACTOR    = 7;                        
const ZOOMPIXELS    = ZOOMFACTOR * PIXELS; 

// 3 rows of
// large image + 50 gap + small image    
// 50 gap between rows 

const canvaswidth = ( PIXELS + ZOOMPIXELS ) + 50;
const canvasheight = ( ZOOMPIXELS * 3 ) + 100;


const DOODLE_THICK = 18;    // thickness of doodle lines 
const DOODLE_BLUR = 3;      // blur factor applied to doodles 


let mnist;      
// all data is loaded into this 
// mnist.train_images
// mnist.train_labels
// mnist.test_images
// mnist.test_labels


let nn;
alphabets = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"];
let trainrun = 1;
let train_index = 0;

let testrun = 1;
let test_index = 0;
let total_tests = 0;
let total_correct = 0;

// images in LHS:
let doodle, demo;
let doodle_exists = false;
let demo_exists = false;

let mousedrag = false;      // are we in the middle of a mouse drag drawing?  


// save inputs to global var to inspect
// type these names in console 
var train_inputs, test_inputs, demo_inputs, doodle_inputs;
var thehtml;

// Matrix.randomize() is changed to point to this. Must be defined by user of Matrix. 

function randomWeight() {
 return AB.randomFloatAtoB(-.5, .5);
}
// make run header bigger
AB.headerCSS({
 "max-height" : "95vh"
}),

//--- start of AB.msgs structure: ---------------------------------------------------------
// We output a serious of AB.msgs to put data at various places in the run header 
 // 1 Doodle header 
thehtml = "<hr> <h1> 1. Doodle </h1> Top row: Doodle (left) and shrunk (right). <br>  Draw your doodle in top LHS. <button onclick='wipeDoodle();' class='normbutton' >Clear and Re-draw</button> <br> ", AB.msg(thehtml, 1), thehtml = "<hr> <h1> 2. Training </h1> Middle row: Training image magnified (left) and original (right). <br>   <button onclick='do_training = false;' class='normbutton' >Pause training</button>  <button onclick='do_training = true;' class='normbutton' >Resume training</button> <br> ",
AB.msg(thehtml, 3), 
thehtml = "<h3> Hidden tests </h3> ", AB.msg(thehtml, 5), 
 // 3 Training header
thehtml = "<hr> <h1> 3. Demo </h1> Bottom row: Test image magnified (left) and  original (right). <br> The network is <i>not</i> trained on any of these images. <br>  <button onclick='makeDemo();' class='normbutton' >Demo test image</button> <br> ", AB.msg(thehtml, 7);
var bluespan = "<span style='font-weight:bold; font-size:x-large; color:darkblue'> ";
function setup() {
 createCanvas(canvaswidth, canvasheight);
 (doodle = createGraphics(ZOOMPIXELS, ZOOMPIXELS)).pixelDensity(1);
 wipeDoodle();
 // JS load other JS 
// maybe have a loading screen while loading the JS and the data set 
 AB.loadingScreen();
 $.getScript("/uploads/codingtrain/matrix.js", function() {
   $.getScript("/uploads/vyoma136/convnet.js", function() {
     $.getScript("/uploads/vyoma136/Vyoma_MNIST.js", function() {
       console.log("All JS Files loaded");
       // Convolutional Layer featuring images and taking all values from 0 to 255 in each pixel defined with better quality of inout provided
       //Pooling Layer where only most salient features of images are considered one which gives the max. information
       //refer this link for better understanding: https://cs.stanford.edu/people/karpathy/convnetjs/demo/mnist.html
       var layerDefs = [];
       layerDefs.push({
         type : "input",
         out_sx : 28,
         out_sy : 28,
         out_depth : 1
       });
       layerDefs.push({
         type : "conv",
         sx : 5,
         filters : 8,
         stride : 1,
         pad : 2,
         activation : "relu"
       });
       layerDefs.push({
         type : "pool",
         sx : 2,
         stride : 2
       });
       layerDefs.push({
         type : "conv",
         sx : 5,
         filters : 16,
         stride : 1,
         pad : 2,
         activation : "relu"
       });
       layerDefs.push({
         type : "pool",
         sx : 3,
         stride : 3
       });
       layerDefs.push({
         type : "softmax",
         num_classes : 27
       });
       (mycnnModel = new convnetjs.Net).makeLayers(layerDefs);
       mycnnTrain = new convnetjs.SGDTrainer(mycnnModel, {
         method : "adadelta",
         momentum : .9,
         batch_size : 10,
         l2_decay : .001
       });
       loadData();
     });
   });
 });
}
// load data set from local file (on this server)
function loadData() {
 loadMNIST(function(canCreateDiscussions) {
   mnist = canCreateDiscussions;
   var indexLookupKey = 0;
   for (; indexLookupKey < NOTRAIN; indexLookupKey++) {
     rotateImage(mnist.train_images[indexLookupKey]);
   }
   indexLookupKey = 0;
   for (; indexLookupKey < NOTEST; indexLookupKey++) {
     rotateImage(mnist.test_images[indexLookupKey]);
   }
   console.log("All data loaded into Emnist object.");
   console.log(mnist);
   AB.removeLoading(); // if no loading screen exists, this does nothing 
 });
}
function getImage(id) {
 var img = createImage(PIXELS, PIXELS);  // make a P5 image object from a raw data array   
 img.loadPixels();
 var i = 0;
 for (; i < PIXELSSQUARED; i++) {
   var ident = id[i];
   var offset = 4 * i;
   img.pixels[offset + 0] = ident;
   img.pixels[offset + 1] = ident;
   img.pixels[offset + 2] = ident;
   img.pixels[offset + 3] = 255;
 }
 return img.updatePixels(), img;
}
function getInputs(id) { // convert img array into normalised input array
 var resolutions = [];
 var i = 0;
 for (; i < PIXELSSQUARED; i++) {
   var d = id[i];
   resolutions[i] = d / 255; // normalise to 0 to 1
 }
 return resolutions;
}
function rotateImage(ctx) {
 var a = 0;
 for (; a < PIXELS; a++) {
   var b = a;
   for (; b < PIXELS; b++) {
     var key = a * PIXELS + b;
     var s = b * PIXELS + a;
     var val = ctx[key];
     ctx[key] = ctx[s];
     ctx[s] = val;
   }
 }
}
function trainit(canCreateDiscussions) { // train the network with a single exemplar, from global var "train_index", show visual on or off 
 var id = mnist.train_images[train_index];
 var ystruct = mnist.train_labels[train_index];
  // optional - show visual of the image 
 if (canCreateDiscussions) {
   var img = getImage(id);
   image(img, 0, ZOOMPIXELS + 50, ZOOMPIXELS, ZOOMPIXELS);  // magnified
   image(img, ZOOMPIXELS + 50, ZOOMPIXELS + 50, PIXELS, PIXELS); // original
 }
 var level = getInputs(id); // get inputs from data array 
 train_inputs = level;
 {
   var label = getmycnnInputs(level);
   mycnnTrain.train(label, ystruct);
 }
 thehtml = " trainrun: " + trainrun + "<br> no: " + train_index;
 AB.msg(thehtml, 4);
 if (++train_index == NOTRAIN) {
   train_index = 0;
   console.log("finished trainrun: " + trainrun);
   trainrun++;
 }
}
function getmycnnInputs(obj) {
 var e = new convnetjs.Vol(28, 28, 1, 0);
 var i = 0;
 for (; i < PIXELSSQUARED; i++) {
   e.w[i] = obj[i];
 }
 return e;
}
function testit() { // test the network with a single exemplar
 var id = mnist.test_images[test_index];
 var test = mnist.test_labels[test_index];
 var data = getInputs(id);
 var encoding = getmycnnInputs(data);
 test_inputs = data;
 var dirName = findMax(mycnnModel.forward(encoding).w);
 var img = getImage(id);
 image(img, 0, ZOOMPIXELS + 50, ZOOMPIXELS, ZOOMPIXELS);
 image(img, ZOOMPIXELS + 50, ZOOMPIXELS + 50, PIXELS, PIXELS);
 total_tests++;
 if (dirName == test) {
   total_correct++;
 }
 var e_total = total_correct / total_tests * 100;
 thehtml = " testrun: " + testrun + "<br> no: " + total_tests + " <br>  correct: " + total_correct + "<br>  score: " + bluespan + e_total.toFixed(2) + "</span>";
 AB.msg(thehtml, 6);
 if (++test_index == NOTEST) {
   console.log("finished testrun: " + testrun + " score: " + e_total.toFixed(2));
   testrun++;
   test_index = 0;
   total_tests = 0;
   total_correct = 0;
 }
}

//--- find no.1 (and maybe no.2) output nodes ---------------------------------------
// (restriction) assumes array values start at 0 (which is true for output nodes) 
function find12(prices) {
 var common = 0;
 var current = 0;
 var maxSell = 0;
 var minBuy = 0;
 var i = 0;
 for (; i < prices.length; i++) {
   if (prices[i] > maxSell) {
     current = common;
     minBuy = maxSell;
     common = i;
     maxSell = prices[i];
   } else {
     if (prices[i] > minBuy) {
       current = i;
       minBuy = prices[i];
     }
   }
 }
 return [common, current];
}
function findMax(array) {
 var maxI = 0;
 var maxValue = 0;
 var i = 0;
 for (; i < array.length; i++) {
   if (array[i] > maxValue) {
     maxI = i;
     maxValue = array[i];
   }
 }
 return maxI;
}
function draw() {
 if (void 0 !== mnist) {
   if (background("black"), strokeWeight(1), stroke("green"), rect(0, 0, ZOOMPIXELS, ZOOMPIXELS), textSize(10), textAlign(CENTER), text("DOODLE HERE", ZOOMPIXELS / 2, ZOOMPIXELS / 2.2), do_training) {
     var _t2 = 0;
     for (; _t2 < TRAINPERSTEP; _t2++) {
       trainit(0 === _t2);
     }
     var _t3 = 0;
     for (; _t3 < TESTPERSTEP; _t3++) {
       testit();
     }
   }
   if (demo_exists && (drawDemo(), guessDemo()), doodle_exists && (drawDoodle(), guessDoodle()), mouseIsPressed) {
     var left = ZOOMPIXELS + 20;
     if (mouseX < left && mouseY < left && pmouseX < left && pmouseY < left) {
       mousedrag = true;
       doodle_exists = true;
       doodle.stroke("red");
       strokeJoin(ROUND);
       doodle.strokeWeight(DOODLE_THICK);
       doodle.line(mouseX, mouseY, pmouseX, pmouseY);
     }
   } else {
     if (mousedrag) {
       mousedrag = false;
       doodle.filter(BLUR, DOODLE_BLUR);
     }
   }
 }
}
function makeDemo() {
 demo_exists = true;
 var i = AB.randomIntAtoB(0, NOTEST - 1);
 demo = mnist.test_images[i];
 var beforeTab = mnist.test_labels[i];
 thehtml = "Test image no: " + i + "<br>Classification: " + alphabets[beforeTab - 1] + "<br>";
 AB.msg(thehtml, 8);
}
function drawDemo() {
 var sal = getImage(demo);
 image(sal, 0, canvasheight - ZOOMPIXELS, ZOOMPIXELS, ZOOMPIXELS);
 image(sal, ZOOMPIXELS + 50, canvasheight - ZOOMPIXELS, PIXELS, PIXELS);
}
function guessDemo() {
 var data = getInputs(demo);
 demo_inputs = data;
 var encoding = getmycnnInputs(data);
 var guess = findMax(mycnnModel.forward(encoding).w);
 thehtml = " We classify it as: " + bluespan + alphabets[guess - 1] + "</span>";
 AB.msg(thehtml, 9);
}
function drawDoodle() {
 var sal = doodle.get();
 image(sal, 0, 0, ZOOMPIXELS, ZOOMPIXELS);
 image(sal, ZOOMPIXELS + 20, 0, PIXELS, PIXELS);
}
function guessDoodle() {
 var dst = doodle.get();
 dst.resize(PIXELS, PIXELS);
 dst.loadPixels();
 var structuredData = [];
 var newTypeName = 0;
 for (; newTypeName < PIXELSSQUARED; newTypeName++) {
   structuredData[newTypeName] = dst.pixels[4 * newTypeName] / 255;
 }
 doodle_inputs = structuredData;
 var encoding = getmycnnInputs(structuredData);
 var o = find12(mycnnModel.forward(encoding).w);
 thehtml = " Our 1st Guess is: " + bluespan + alphabets[o[0] - 1] + "</span> <br> Our 2nd Guess is: " + bluespan + alphabets[o[1] - 1] + "</span>";
 AB.msg(thehtml, 2);
}
function wipeDoodle() {
 doodle_exists = false;
 doodle.background("black");
}

// --- debugging --------------------------------------------------
// in console
// showInputs(demo_inputs);
// showInputs(doodle_inputs);
function showInputs(groups) {
 var stderr = "";
 var i = 0;
 for (; i < groups.length; i++) {
   if (i % PIXELS == 0) {
     stderr = stderr + "\n";
   }
   stderr = stderr + " " + groups[i].toFixed(2);
 }
 console.log(stderr);
}
;