// Cloned by Marko on 1 Dec 2019 from World "Character recognition neural network" by "Coding Train" project // Please leave this clone trail here.// Port of Character recognition neural network from here:// https://github.com/CodingTrain/Toy-Neural-Network-JS/tree/master/examples/mnist// with many modifications // --- defined by MNIST - do not change these ---------------------------------------const PIXELS =28;// images in data set are tiny const PIXELSSQUARED = PIXELS * PIXELS;const IMAGE_CHANNELS =1;const NUM_OUTPUT_CLASSES =10;// number of training and test exemplars in the data set:const NOTRAIN =60000;const NOTEST =10000;//--- can modify all these --------------------------------------------------// no of nodes in network const noinput = PIXELSSQUARED;const nohidden =64;const nooutput =10;const learningrate =0.1;// default 0.1 // should we train every timestep or not
let do_training =false;// how many to train and test per timestep const TRAINPERSTEP =30;const TESTPERSTEP =5;// multiply it by this to magnify for display const ZOOMFACTOR =7;const ZOOMPIXELS = ZOOMFACTOR * PIXELS;// 3 rows of// large image + 50 gap + small image // 50 gap between rows const canvaswidth =( PIXELS + ZOOMPIXELS )+50;const canvasheight =( ZOOMPIXELS *3)+100;const DOODLE_THICK =10;// thickness of doodle lines const DOODLE_BLUR =2;// blur factor applied to doodles const BATCH_SIZE =512;
let mnist;// all data is loaded into this // mnist.train_images// mnist.train_labels// mnist.test_images// mnist.test_labels
let nn;
let trainrun =1;
let train_index =0;
let testrun =1;
let test_index =0;
let total_tests =0;
let total_correct =0;// images in LHS:
let doodle, demo;
let doodle_exists =false;
let demo_exists =false;
let mousedrag =false;// are we in the middle of a mouse drag drawing? // save inputs to global var to inspect// type these names in console var train_inputs, test_inputs, demo_inputs, doodle_inputs;// Matrix.randomize() is changed to point to this. Must be defined by user of Matrix. function randomWeight(){return( AB.randomFloatAtoB (-0.5,0.5));// Coding Train default is -1 to 1}// CSS trick // make run header bigger
$("#runheaderbox").css ({"max-height":"95vh"});//--- start of AB.msgs structure: ---------------------------------------------------------// We output a serious of AB.msgs to put data at various places in the run header var thehtml;// 1 Doodle header
thehtml ="<hr> <h1> 1. Doodle </h1> Top row: Doodle (left) and shrunk (right). <br> "+" Draw your doodle in top LHS. <button onclick='wipeDoodle();' class='normbutton' >Clear doodle</button> <br> ";
AB.msg ( thehtml,1);// 2 Doodle variable data (guess)// 3 Training header
thehtml ="<hr> <h1> 2. Training </h1> Middle row: Training image magnified (left) and original (right). <br> "+" <button onclick='do_training = false;' class='normbutton' >Stop training</button> <br> ";
AB.msg ( thehtml,3);// 4 variable training data // 5 Testing header
thehtml ="<h3> Hidden tests </h3> ";
AB.msg ( thehtml,5);// 6 variable testing data // 7 Demo header
thehtml ="<hr> <h1> 3. Demo </h1> Bottom row: Test image magnified (left) and original (right). <br>"+" The network is <i>not</i> trained on any of these images. <br> "+" <button onclick='makeDemo();' class='normbutton' >Demo test image</button> <br> ";
AB.msg ( thehtml,7);// 8 Demo variable data (random demo ID)// 9 Demo variable data (changing guess)const greenspan ="<span style='font-weight:bold; font-size:x-large; color:darkgreen'> ";//--- end of AB.msgs structure: ---------------------------------------------------------
let model;function setup(){
createCanvas ( canvaswidth, canvasheight );
doodle = createGraphics ( ZOOMPIXELS, ZOOMPIXELS );// doodle on larger canvas
doodle.pixelDensity(1);// JS load other JS // maybe have a loading screen while loading the JS and the data set
AB.loadingScreen();
$.getScript ("https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.min.js",function(){
$.getScript ("/uploads/codingtrain/mnist.js",function(){
loadMNIST (function(data){
mnist = data;
imagesToInputs(mnist.train_images);
imagesToInputs(mnist.test_images);
mnist.train_labels = oneHotLabels(mnist.train_labels);
mnist.test_labels = oneHotLabels(mnist.test_labels);
console.log ("All data loaded into mnist object:");
model = getModel();
console.log ("Model built");
train().then(()=>{
console.log("Model trained");
AB.removeLoading();});});});});}// idea is to leave only significant pixelsfunction processPixel(pixel){
pixel /=255;
pixel = pixel /(2- pixel);if(pixel <0.4){return0;}return pixel;}// if image is not centered shift it as necessaryfunction centerImage(image){
let leftmost = PIXELS;
let rightmost =0;
let highest = PIXELS;
let lowest =0;for(var row =0; row < PIXELS; row++){for(var col =0; col < PIXELS; col++){
let filled = image[(row * PIXELS)+ col]>0;if(filled){if(row < highest){
highest = row;}if(row > lowest){
lowest = row;}if(col < leftmost){
leftmost = col;}if(col > rightmost){
rightmost = col;}}}}
let shiftX =Math.floor((PIXELS - rightmost - leftmost)/2);
let shiftY =Math.floor((PIXELS - highest - lowest)/2);// init new image
let newImage =[];for(var i =0; i < PIXELSSQUARED; i++){
newImage[i]=0;}for(var row =0; row < PIXELS; row++){for(var col =0; col < PIXELS; col++){
let field = image[(row * PIXELS)+ col];if(field >0){
newImage[((row + shiftY)* PIXELS)+ col + shiftX]= field;}}}return newImage;}// TensorFlow methodsfunction getModel(){const model = tf.sequential();
model.add(tf.layers.conv2d({
inputShape:[PIXELS, PIXELS, IMAGE_CHANNELS],
kernelSize:5,
filters:8,
strides:1,
activation:'relu',
kernelInitializer:'varianceScaling'}));
model.add(tf.layers.maxPooling2d({poolSize:[2,2], strides:[2,2]}));
model.add(tf.layers.conv2d({
kernelSize:5,
filters:16,
strides:1,
activation:'relu',
kernelInitializer:'varianceScaling'}));
model.add(tf.layers.maxPooling2d({poolSize:[2,2], strides:[2,2]}));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({
units: NUM_OUTPUT_CLASSES,
kernelInitializer:'varianceScaling',
activation:'softmax'}));const optimizer = tf.train.adam();
model.compile({
optimizer: optimizer,
loss:'categoricalCrossentropy',
metrics:['accuracy'],});return model;}function train(){
let td = trainData(mnist);return modelFit(td);}function oneHotLabels(labels){
let newLabels =Array();for(var i =0; i < labels.length; i++){
let tmp =Array(NUM_OUTPUT_CLASSES).fill(0);
tmp[labels[i]]=1;
newLabels.push(tmp);}return newLabels;}function imagesToInputs(images){for(var i =0; i < images.length; i++){
images[i]= getInputs(images[i]);}}function trainData(mnist){
let trainImages = tf.tensor2d(mnist.train_images,[mnist.train_images.length, PIXELSSQUARED]);
let trainLabels = tf.tensor2d(mnist.train_labels,[mnist.train_labels.length, NUM_OUTPUT_CLASSES]);
let testImages = tf.tensor2d(mnist.test_images,[mnist.test_images.length, PIXELSSQUARED]);
let testLabels = tf.tensor2d(mnist.test_labels,[mnist.test_labels.length, NUM_OUTPUT_CLASSES]);return{
trainImages: trainImages.reshape([mnist.train_images.length, PIXELS, PIXELS, IMAGE_CHANNELS]),
trainLabels: trainLabels,
testImages: testImages.reshape([mnist.test_images.length, PIXELS, PIXELS, IMAGE_CHANNELS]),
testLabels: testLabels
};}function modelFit(trainData){return model.fit(trainData.trainImages, trainData.trainLabels,{
batchSize: BATCH_SIZE,
validationData:[trainData.testImages, trainData.testLabels],
epochs:10,
shuffle:true});}function digitPrediction(img){
let imgTensor = tf.tensor2d(img,[1, PIXELSSQUARED]);return model.predict(imgTensor.reshape([1, PIXELS, PIXELS, IMAGE_CHANNELS])).data()}function findPredictionLabels(arr){
let firstVal =0;
let secondVal =0;
let firstDigit =0;
let secondDigit =0;
console.log('extracting from labels');
console.log(arr);for(var i =0; i <10; i++){if(arr[i]> firstVal){
secondVal = firstVal;
secondDigit = firstDigit;
firstVal = arr[i];
firstDigit = i;}elseif(arr[i]> secondVal){
secondVal = arr[i];
secondDigit = i;}}return[firstDigit, secondDigit];}// end TensorFlowfunction getImage ( img )// make a P5 image object from a raw data array {
let theimage = createImage (PIXELS, PIXELS);// make blank image, then populate it
theimage.loadPixels();for(let i =0; i < PIXELSSQUARED ; i++){
let bright = img[i];
let index = i *4;
theimage.pixels[index +0]= bright;
theimage.pixels[index +1]= bright;
theimage.pixels[index +2]= bright;
theimage.pixels[index +3]=255;}
theimage.updatePixels();return theimage;}function getInputs ( img )// convert img array into normalised input array {
let inputs =[];for(let i =0; i < PIXELSSQUARED ; i++){
let bright = img[i];
inputs[i]= processPixel(bright);// normalise to 0 to 1}
inputs = centerImage(inputs);return( inputs );}function trainit (show)// train the network with a single exemplar, from global var "train_index", show visual on or off {
let img = mnist.train_images[train_index];
let label = mnist.train_labels[train_index];// optional - show visual of the image if(show){var theimage = getImage ( img );// get image from data array
image ( theimage,0, ZOOMPIXELS+50, ZOOMPIXELS, ZOOMPIXELS );// magnified
image ( theimage, ZOOMPIXELS+50, ZOOMPIXELS+50, PIXELS, PIXELS );// original}// set up the inputs
let inputs = getInputs ( img );// get inputs from data array // set up the outputs
let targets =[0,0,0,0,0,0,0,0,0,0];
targets[label]=1;// change one output location to 1, the rest stay at 0 // console.log(train_index);// console.log(inputs);// console.log(targets);
train_inputs = inputs;// can inspect in console
nn.train ( inputs, targets );
thehtml =" trainrun: "+ trainrun +"<br> no: "+ train_index ;
AB.msg ( thehtml,4);
train_index++;if( train_index == NOTRAIN ){
train_index =0;
console.log("finished trainrun: "+ trainrun );
trainrun++;}}function testit()// test the network with a single exemplar, from global var "test_index"{
let img = mnist.test_images[test_index];
let label = mnist.test_labels[test_index];// set up the inputs
let inputs = getInputs ( img );
test_inputs = inputs;// can inspect in console
let prediction = nn.predict(inputs);// array of outputs
let guess = findMax(prediction);// the top output
total_tests++;if(guess == label) total_correct++;
let percent =(total_correct / total_tests)*100;
thehtml =" testrun: "+ testrun +"<br> no: "+ total_tests +" <br> "+" correct: "+ total_correct +"<br>"+" score: "+ greenspan + percent.toFixed(2)+"</span>";
AB.msg ( thehtml,6);
test_index++;if( test_index == NOTEST ){
console.log("finished testrun: "+ testrun +" score: "+ percent.toFixed(2));
testrun++;
test_index =0;
total_tests =0;
total_correct =0;}}//--- find no.1 (and maybe no.2) output nodes ---------------------------------------// (restriction) assumes array values start at 0 (which is true for output nodes) function find12 (a)// return array showing indexes of no.1 and no.2 values in array {
let no1 =0;
let no2 =0;
let no1value =0;
let no2value =0;for(let i =0; i < a.length; i++){if(a[i]> no1value){
no1 = i;
no1value = a[i];}elseif(a[i]> no2value){
no2 = i;
no2value = a[i];}}var b =[ no1, no2 ];return b;}// just get the maximum - separate function for speed - done many times // find our guess - the max of the output nodes arrayfunction findMax (a){
let no1 =0;
let no1value =0;for(let i =0; i < a.length; i++){if(a[i]> no1value){
no1 = i;
no1value = a[i];}}return no1;}// --- the draw function -------------------------------------------------------------// every step:function draw(){// check if libraries and data loaded yet:if(typeof mnist =='undefined')return;// how can we get white doodle on black background on yellow canvas?// background('#ffffcc'); doodle.background('black');
background ('black');if( do_training ){// do some training per step for(let i =0; i < TRAINPERSTEP; i++){if(i ==0) trainit(true);// show only one per step - still flashes by else trainit(false);}// do some testing per step for(let i =0; i < TESTPERSTEP; i++)
testit();}// keep drawing demo and doodle images // and keep guessing - we will update our guess as time goes on if( demo_exists ){
drawDemo();
guessDemo();}if( doodle_exists ){
drawDoodle();
guessDoodle();}// detect doodle drawing // (restriction) the following assumes doodle starts at 0,0 if( mouseIsPressed )// gets called when we click buttons, as well as if in doodle corner {// console.log ( mouseX + " " + mouseY + " " + pmouseX + " " + pmouseY );var MAX = ZOOMPIXELS +20;// can draw up to this pixels in corner if((mouseX < MAX)&&(mouseY < MAX)&&(pmouseX < MAX)&&(pmouseY < MAX)){
mousedrag =true;// start a mouse drag
doodle_exists =true;
doodle.stroke('white');
doodle.strokeWeight( DOODLE_THICK );
doodle.line(mouseX, mouseY, pmouseX, pmouseY);}}else{// are we exiting a drawingif( mousedrag ){
mousedrag =false;// console.log ("Exiting draw. Now blurring.");
doodle.filter (BLUR, DOODLE_BLUR);// just blur once// console.log (doodle);}}}//--- demo -------------------------------------------------------------// demo some test image and predict it// get it from test set so have not used it in trainingfunction makeDemo(){
demo_exists =true;var i = AB.randomIntAtoB (0, NOTEST -1);
demo = mnist.test_images[i];var label = mnist.test_labels[i];
thehtml ="Test image no: "+ i +"<br>"+"Classification: "+ label +"<br>";
AB.msg ( thehtml,8);// type "demo" in console to see raw data }function drawDemo(){var theimage = getImage ( demo );// console.log (theimage);
image ( theimage,0, canvasheight - ZOOMPIXELS, ZOOMPIXELS, ZOOMPIXELS );// magnified
image ( theimage, ZOOMPIXELS+50, canvasheight - ZOOMPIXELS, PIXELS, PIXELS );// original}function guessDemo(){
let inputs = getInputs ( demo );
demo_inputs = inputs;// can inspect in console
let prediction = nn.predict(inputs);// array of outputs
let guess = findMax(prediction);// the top output
thehtml =" We classify it as: "+ greenspan + guess +"</span>";
AB.msg ( thehtml,9);}//--- doodle -------------------------------------------------------------function drawDoodle(){// doodle is createGraphics not createImage
let theimage = doodle.get();// console.log (theimage);
image ( theimage,0,0, ZOOMPIXELS, ZOOMPIXELS );// original
image ( theimage, ZOOMPIXELS+50,0, PIXELS, PIXELS );// shrunk}function guessDoodle(){// doodle is createGraphics not createImage
let img = doodle.get();
img.resize ( PIXELS, PIXELS );
img.loadPixels();// set up inputs
let inputs =[];for(let i =0; i < PIXELSSQUARED ; i++){
inputs[i]= processPixel(img.pixels[i *4]);}
inputs = centerImage(inputs);
doodle_inputs = inputs;// can inspect in console // console.log('doodle inputs set');// console.log(doodle_inputs);
digitPrediction(inputs).then((arr)=>{
let [first, second]= findPredictionLabels(arr);
thehtml =" We classify it as: "+ first +"</span> <br>"+" No.2 guess is: "+ second +"</span>";
AB.msg ( thehtml,2);});}function wipeDoodle(){
doodle_exists =false;
doodle.background('black');}// --- debugging --------------------------------------------------// in console// showInputs(demo_inputs);// showInputs(doodle_inputs);function showInputs ( inputs )// display inputs row by row, corresponding to square of pixels {var str ="";for(let i =0; i < inputs.length; i++){if( i % PIXELS ==0) str = str +"\n";// new line for each row of pixels var value = inputs[i];
str = str +" "+ value.toFixed(2);}
console.log (str);}