Code viewer for World:
CCNN
// Cloned by AC on 10 Dec 2020 from World "Character recognition neural network (clone by AC)" by AC
// Please leave this clone trail here.
// Cloned by AC on 2 Dec 2020 from World "Character recognition neural network" by "Coding Train" project
// Please leave this clone trail here.
// Port of Character recognition neural network from here:
// https://github.com/CodingTrain/Toy-Neural-Network-JS/tree/master/examples/mnist
// with many modifications
// --- defined by MNIST - do not change these ---------------------------------------
const PIXELS = 28; // images in data set are tiny
const PIXELSSQUARED = PIXELS * PIXELS;
// number of training and test exemplars in the data set:
const NOTRAIN = 60000;
const NOTEST = 10000;
//--- can modify all these --------------------------------------------------
// no of nodes in network
const noinput = PIXELSSQUARED;
const nohidden = 64; //64 initially
const nooutput = 10;
const learningrate = 0.1; // default 0.1
// should we train every timestep or not
let do_training = false;
// how many to train and test per timestep
const TRAINPERSTEP = 30;
const TESTPERSTEP = 5;
// multiply it by this to magnify for display
const ZOOMFACTOR = 7;
const ZOOMPIXELS = ZOOMFACTOR * PIXELS;
// 3 rows of
// large image + 50 gap + small image
// 50 gap between rows
const canvaswidth = ( PIXELS + ZOOMPIXELS ) + 50;
const canvasheight = ( ZOOMPIXELS * 3 ) + 100;
const DOODLE_THICK = 12; // thickness of doodle lines
const DOODLE_BLUR = 3; // blur factor applied to doodles
let mnist;
// all data is loaded into this
// mnist.train_images
// mnist.train_labels
// mnist.test_images
// mnist.test_labels
let nn;
let trainrun = 1;
let train_index = 0;
let testrun = 1;
let test_index = 0;
let total_tests = 0;
let total_correct = 0;
// images in LHS:
let doodle, demo;
let doodle_exists = false;
let demo_exists = false;
let mousedrag = false; // are we in the middle of a mouse drag drawing?
// save inputs to global var to inspect
// type these names in console
var train_inputs, test_inputs, demo_inputs, doodle_inputs, aug_doodle_inputs;
// ==================================================================================================
// === AC: additional global variable starts here ===========================================================
// ==================================================================================================
//data augmentation
const digitOrder = [1, 0,6,4,9,8,5,7,2,3];
var currentDigits = [];
const AUGMENTPERBATCH = 500;
let last_augmented_index = -1;
const roundDownLevels = 5;
var bDigitOrdering = false;
// neural network count
let nnNO = 0;
var train_log =[];
//CNN
var CNNTrained =false;
let cnn_doodle_result;
//canvas
let canvas;
// Boolean for Modification to doodles
var bBlur = false;
var bCenter = true;
// Boolean for Data Augmentation
var bAugment = true;
var bAugmented = false;
var bIShift=false;
var bICrop=true;
var bIRotate=false;
var bIRoundDown = false;
let bDoodleUpdated = false;
var bTesting = true;
// Matrix.randomize() is changed to point to this. Must be defined by user of Matrix.
function randomWeight()
{
return ( AB.randomFloatAtoB ( -0.5, 0.5 ) );
// Coding Train default is -1 to 1
}
// CSS trick
// make run header bigger
$("#runheaderbox").css ( { "max-width": "50vw", "max-height": "95vh" } );
//--- start of AB.msgs structure: ---------------------------------------------------------
// We output a serious of AB.msgs to put data at various places in the run header
var thehtml=[];
let strActionText = [];
strActionText.blur = bBlur?"Blur:ON":"Blur:OFF";
strActionText.center = bCenter?"Center:ON":"Center:OFF"
// 1 Doodle header
thehtml[1] = "<hr> <h1> 1. Doodle " +
"<button onclick='toggleBlurAction();' class='normbutton' >" + strActionText.blur + " </button>" +
"<button onclick='toggleCenterAction();' class='normbutton' >" + strActionText.center + " </button>" +
"</h1>" +
" Draw your doodle in top LHS. <button onclick='wipeDoodle();' class='normbutton' >Clear doodle</button> <br> ";
AB.msg ( thehtml[1], 1 );
// 2 Doodle variable data (guess)
strActionText.training = do_training?"Stop":"Start";
strActionText.ordering = bDigitOrdering?"Order:ON":"Order:OFF";
strActionText.augmenting = bAugment?"Augment:ON":"Augment:OFF";
strActionText.shift = bIShift?"Shift:ON":"Shift:OFF";
strActionText.rotate = bIRotate?"Rotate:ON":"Rotate:OFF";
strActionText.rounddown = bIRoundDown?"RoundDown:ON":"RoundDown:OFF";
strActionText.crop = bICrop?"Crop:ON":"Crop:OFF";
let strAugment = bAugment?"<button onclick='toggleShiftAction();' class='normbutton' >" + strActionText.shift + " </button>" +
"<button onclick='toggleRotateAction();' class='normbutton' >" + strActionText.rotate + " </button>" +
"<button onclick='toggleRoundDownAction();' class='normbutton' >" + strActionText.rounddown + " </button>" +
"<button onclick='toggleCropAction();' class='normbutton' >" + strActionText.crop + " </button>":""
// 3 Training header
thehtml[3] = "<hr> <h1> 2. Training "+
"<button onclick='toggleTrainingAction();' class='normbutton' >" + strActionText.training + " </button>" +
"<button onclick='setupNetwork();' class='normbutton' >Refresh NN</button>"+ " </h1> " +
"Data Preprocessing: " +
"<button onclick='toggleOrderAction();' class='normbutton' >" + strActionText.ordering + " </button>" +
"<button onclick='toggleAugmentAction();' class='normbutton' >" + strActionText.augmenting + " </button>" +
"<br> " + strAugment;
AB.msg ( thehtml[3], 3 );
// 4 variable training data
// 5 Testing header
thehtml[5] = "<h3> Hidden tests </h3> " ;
AB.msg ( thehtml[5], 5 );
// 6 variable testing data
// 7 Demo header
thehtml[7] = "<hr> <h1> 3. Demo </h1> Bottom row: Test image magnified (left) and original (right). <br>" +
" The network is <i>not</i> trained on any of these images. <br> " +
" <button onclick='makeDemo();' class='normbutton' >Demo test image</button> <br> ";
AB.msg ( thehtml[7], 7 );
// 8 Demo variable data (random demo ID)
// 9 Demo variable data (changing guess)
const greenspan = "<span style='font-weight:bold; font-size:x-large; color:darkgreen'> " ;
//--- end of AB.msgs structure: ---------------------------------------------------------
// AC: image Classifier
let digitClassifier;
let featureExtractor;
function setup()
{
canvas = createCanvas ( canvaswidth, canvasheight );
doodle = createGraphics ( ZOOMPIXELS, ZOOMPIXELS ); // doodle on larger canvas
doodle.pixelDensity(1);
// JS load other JS
// maybe have a loading screen while loading the JS and the data set
AB.loadingScreen();
$.getScript ( "/uploads/codingtrain/matrix.js", function()
{
$.getScript ( "/uploads/ac2021/nn.js", function()
{
$.getScript ( "/uploads/codingtrain/mnist.js", function()
{
console.log ("All JS loaded");
setupNetwork();
//pnn = new NeuralNetwork(nn);
//nnn = new NeuralNetwork(nn);
loadData();
});
});
});
$.getScript ( "https://unpkg.com/ml5@latest/dist/ml5.min.js", function(){
let options = { inputs: [PIXELS, PIXELS, 4],
task: "imageClassification",
debug:true
};
digitClassifier = ml5.neuralNetwork(options);
featureExtractor = ml5.featureExtractor("MobileNet", { numLabels: 10 }, modelLoaded);
});
}
// ==================================================================================================
// === AC: additional coding starts here ===========================================================
// ==================================================================================================
function setupNetwork()
{
nn = new NeuralNetwork( noinput, nohidden, nooutput );
nn.setLearningRate ( learningrate );
trainrun = 1;
train_index = 0;
resetTestData();
if (nnNO>0)
{
train_log.push(["train", nnNO, trainrun, train_index, total_tests, total_correct, (total_correct/total_tests)])
}
nnNO++;
}
function resetTestData()
{
console.log
testrun = 1;
test_index = 0;
total_tests = 0;
total_correct = 0;
AB.msg ( "", 6 );
}
function centerInputs(inputs)
{
let m = checkMargins(inputs);
let left = m[0];
let right = m[1];
let top = m[2];
let bottom = m[3];
let hShift =0;
let vShift = 0;
hShift = left - right;
vShift = top - bottom;
console.log(m);
console.log(hShift + "," + vShift);
return shiftInputs(inputs, Math.floor(hShift/2), Math.floor(vShift/2));
}
function shiftInputs(inputs, r=0, c=0) //can be positive or negative
{
let temp = [];
if (r == 0 && c == 0)
return inputs;
if ((r >= 0) && (c >= 0)) // left and top can be performed together
{
console.log("both forward " + r + "," + c)
return performShift(inputs, r, c);
}
else if (( r <= 0) && (c <= 0))
{
console.log(" both backward " + r + "," + c)
return performShift(inputs, r, c, true);
}
else
{
if (r > 0)
{
console.log("forward -> backward " + r + "," + c)
temp = performShift(inputs, r, 0);
return performShift(temp, 0, c, true);
}
else
{
console.log("backward -> forward "+ r + "," + c )
temp = performShift(inputs, r, 0, true);
return performShift(temp, 0, c);
}
}
}
function performShift(inputs, r, c, bBackward = false)
{
let temp = [];
let counter = 0;
if (!bBackward)
{
for (let i = 0; i < inputs.length ; i++)
{
if (( i + c * PIXELS + r) < inputs.length)
temp[i] = inputs[i + c*PIXELS + r];
else
temp[i] = 0;
counter++;
}
}
else
{
for (let i = inputs.length - 1; i >= 0 ; i--)
{
if (( i + c*PIXELS + r) >= 0)
temp[i] = inputs[i + c*PIXELS + r];
else
temp[i] = 0;
counter++;
}
}
console.log("counter " + counter)
return temp;
}
function modelLoaded()
{
console.log("Model Loaded");
}
let classifier;
async function train_features()
{
classifier = featureExtractor.classification();
await prepareTrainingData();
/*
console.log('training model')
await classifier.train(function(lossValue) {
console.log("Loss is", lossValue);
});*/
}
function checkAllMargins(n = null)
{
let m, r = [0,0,0,0], counter = 0;
for (let i = 0; i < mnist.train_images.length; i++)
{
if ((n == null) || (mnist.train_labels[i] == (n % 10)))
{
m = checkMargins(getInputs(mnist.train_images[i]));
r[0] += m[0];
r[1] += m[1];
r[2] += m[2];
r[3] += m[3];
counter++;
}
}
r[5] = counter;
console.log(r)
}
let img_base;
function prepareTrainingData()
{
const promises = [];
let s = 0
for(var i = s; i < s + 2; i++)
{
//console.log(window.btoa(mnist.train_images[i]));
//createImg("data:image/ping;base64," + window.btoa(mnist.train_images[i]));
//console.log(window.btoa(getImage(mnist.train_images[i],true)));
//promises.push(classifier.addImage(createImg(getImage(mnist.train_images[i]).canvas.toDataURL()), mnist.train_labels[i].toString()));
}
return Promise.all(promises);
}
function trainCNN()
{
CNNTrained = false;
//for (var i = 0; i < mnist.train_images.length; i++)
for(var i = 0; i < 10000; i++)
{
digitClassifier.addData(getImage(mnist.train_images[i],true), {label: mnist.train_labels[i].toString()});
}
console.log("CNN-finished adding data " + i);
digitClassifier.normalizeData();
console.log("CNN-finished normalizing data");
digitClassifier.train({epochs:50, batchSize: 32}, finishedTraining);
console.log("CNN-finished training ");
}
function finishedTraining()
{
CNNTrained = true;
}
function augmentData()
{
let batchEnds = (train_index + AUGMENTPERBATCH) > (NOTRAIN)? (train_index + AUGMENTPERBATCH):(NOTRAIN);
for (let i = train_index; i < batchEnds ; i++)
{
mnist.train_images[i] = augment(mnist.train_images[i])[0];
}
if(batchEnds >= (NOTRAIN))
{
bAugmented = true;
}
return batchEnds;
}
function checkMargins(inputs)
{
let bFoundDigit = false;
let leftEmpty = 0;
let rightEmpty = 0;
let topEmpty = 0;
let bottomEmpty = 0;
let blankCounter = 0;
let minTopFilled = 0;
let maxBottomFilled = 0;
//console.log("checkMargins : " + inputs.length)
//check columns with no values
for (let c = 0; c < PIXELS; c++)
{
for (let r = 0; r <PIXELS; r++)
{
if (!bFoundDigit)
{
// console.log( "check @ :" + (r*PIXELS + c) + " and " + inputs[r*PIXELS + c])
if( inputs[r*PIXELS + c] > 0)
{
leftEmpty = c;
minTopFilled = r;
maxBottomFilled = r;
bFoundDigit = true;
}
}
else
{
if(r==0)
blankCounter = 0; //reset counter on first row
if( inputs[r*PIXELS + c] > 0 && r < minTopFilled)
minTopFilled = r
else if(inputs[r*PIXELS + c] > 0 && r > maxBottomFilled)
maxBottomFilled = r;
if (inputs[r*PIXELS + c] <= 0 )
blankCounter++;
//console.log("B" + blankCounter )
if((r==PIXELS-1) && (blankCounter == PIXELS)) // check whether the entire column is blank.
{
if (rightEmpty < (PIXELS-c-1))
rightEmpty = PIXELS-c-1;
}
}
}
}
topEmpty = minTopFilled;
bottomEmpty = PIXELS - (maxBottomFilled +1);
return [leftEmpty, rightEmpty, topEmpty, bottomEmpty];
}
function getRotatedPixelIndex( x, bClockwise = true)
{
// x is the first index position of the requested pixel and img is the photo.
let r, b, g, i, j, rx, ry;
let rindex = 0;
i = x % PIXELS
j = Math.floor(x / PIXELS)
if (bClockwise)
{
ry = (PIXELS-1-i);
rx = j;
rindex = ry * PIXELS + rx
//console.log (i +"," + j+"," + rx +"," + ry);
/* 0 -> 2 0, 0 -> 1, 0
1 -> 0 0, 1 -> 0, 0
2 -> 3 1, 0 -> 1, 1
3 -> 1 1, 1 -> 0, 1
0 -> 6 0, 0 -> 2, 0 ry = (pixel-1 - x) rx = y
1 -> 3 0, 1 -> 1, 0
2 -> 0 0, 2 -> 0, 0
3 -> 7 1, 0 -> 2, 1
4 -> 4 1, 1 -> 1, 1
5 -> 1 1, 2 -> 0, 1
6 -> 8 2, 0 -> 2, 2 2, 0 -> 2, 2
7 -> 5 2, 1 -> 1, 2
8 -> 2 2, 2 -> 0, 2 */
}
else
{
ry = i;
rx = (PIXELS-1-j);
rindex = Math.floor(ry * PIXELS + rx ) * 4
//console.log (i +"," + j+"," + rx +"," + ry);
/*0 -> 1 0, 0 -> 0, 1
1 -> 3 0, 1 -> 1, 1
2 -> 0 1, 0 -> 0, 0
3 -> 2 1, 1 -> 1, 0
0 -> 2 0, 0 -> 0, 2
1 -> 5 0, 1 -> 1, 2
2 -> 8 0, 2 -> 2, 2
3 -> 1 1, 0 -> 0, 1
4 -> 4 1, 1 -> 1, 1
5 -> 7 1, 2 -> 2, 1
6 -> 0 2, 0 -> 0, 0
7 -> 3 2, 1 -> 1, 0
8 -> 6 2, 2 -> 2, 0 */
}
return rindex;
}
// load data set from local file (on this server)
function loadData()
{
loadMNIST ( function(data)
{
mnist = data;
console.log ("All data loaded into mnist object:")
console.log(mnist);
AB.removeLoading(); // if no loading screen exists, this does nothing
});
}
function getImage ( img, bGetPixelArray = false ) // make a P5 image object from a raw data array
{
let theimage = createImage (PIXELS, PIXELS); // make blank image, then populate it
theimage.loadPixels();
for (let i = 0; i < PIXELSSQUARED ; i++)
{
let bright = img[i];
let index = i * 4;
theimage.pixels[index + 0] = bright;
theimage.pixels[index + 1] = bright;
theimage.pixels[index + 2] = bright;
theimage.pixels[index + 3] = 255;
}
theimage.updatePixels();
if (bGetPixelArray)
return theimage.pixels;
else
return theimage;
}
function getInputs ( img ) // convert img array into normalised input array
{
let inputs = [];
for (let i = 0; i < PIXELSSQUARED ; i++)
{
let bright = img[i];
inputs[i] = bright / 255; // normalise to 0 to 1
}
return ( inputs );
}
function trainit (show) // train the network with a single exemplar, from global var "train_index", show visual on or off
{
let img = mnist.train_images[train_index];
let label = mnist.train_labels[train_index];
if(bDigitOrdering && ((train_index % 3000) == 0 && digitOrder.length > 0))
currentDigits.push(digitOrder.shift())
// optional - show visual of the image
if (show)
{
var theimage = getImage ( img ); // get image from data array
noFill();
stroke(255,255,255);
rect(0, ZOOMPIXELS+50-1, ZOOMPIXELS+1, ZOOMPIXELS+1);
image ( theimage, 1, ZOOMPIXELS+50, ZOOMPIXELS, ZOOMPIXELS ); // magnified
image ( theimage, ZOOMPIXELS+50, ZOOMPIXELS+50, PIXELS, PIXELS ); // original
}
// set up the inputs
let inputs = getInputs ( img ); // get inputs from data array
// set up the outputs
let targets = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
targets[label] = 1; // change one output location to 1, the rest stay at 0
// console.log(train_index);
// console.log(inputs);
// console.log(targets);
// inputs = augment(inputs);
train_inputs = inputs; // can inspect in console
if (!bDigitOrdering || (currentDigits.includes(label)))
nn.train ( inputs, targets );
if (bDigitOrdering)
{
thehtml[4] = "Current Training numbers :" + currentDigits.toString() + " <br>" +
" trainrun: " + trainrun + "<br> no: " + train_index ;
}
else
thehtml[4] = " trainrun: " + trainrun + "<br> no: " + train_index ;
AB.msg ( thehtml[4], 4 );
train_index++;
if ( train_index == NOTRAIN )
{
train_index = 0;
console.log( "finished trainrun: " + trainrun );
trainrun++;
}
}
function augment(inputs, bShift = false, bRoundDown = false, bRotate = false, bCrop = false, shiftMargins = null){
if (!(bShift || bRoundDown || bRotate || bCrop))
{
bShift = Math.floor(Math.random() * 10) <= 3;
bRoundDown = Math.floor(Math.random() * 10) <= 3;
bRotate = Math.floor(Math.random() * 10) <= 3;
bCrop = Math.floor(Math.random() * 10) <= 3;
}
//console.log(nRandom + "," + bShift + "," + bResize + "," + bRotate)
let temp = inputs;
if (bShift && bIShift)
{
let margins = checkMargins(inputs);
let nColumns = Math.floor(Math.random()*margins[0]);
let nRows = Math.floor(Math.random()*margins[2]);
temp = performShift(temp, nRows, nColumns)
/* if (margins[0] > 0 && margins[0] < margins[1])
{
console.log("left");
let nColumns = Math.floor(Math.random()*margins[0]);
let nRows = Math.floor(Math.random()*3);
for (let i = 0; i < PIXELSSQUARED ; i++)
{
if (( i +nColumns*PIXELS + nRows) < PIXELSSQUARED)
temp[i] = inputs[i + nColumns*PIXELS + nRows];
else
temp[i] = 0;
}
//console.log("shifted")
}
else if (margins[1] > 0 && margins[0] >= margins[1])
{
console.log("right");
let nColumns = Math.floor(Math.random()*margins[1]);
let nRows = Math.floor(Math.random()*3);
for (let i = PIXELSSQUARED -1; i >= 0 ; i--)
{
if (( i - nColumns*PIXELS - nRows) >= 0)
temp[i] = inputs[i - nColumns*PIXELS - nRows];
else
temp[i] = 0;
}
}*/
}
if (bRoundDown && bIRoundDown)
{
//console.log("rounddown");
for (let i = 0; i < PIXELSSQUARED ; i++)
{
temp[i] = Math.floor(inputs[i] * 4) / 4;
}
}
if (bRotate && bIRotate)
{
//let bClockwise = nRandom % 2 >= 1;
let tempR = [];
for (let j = 0; j < PIXELSSQUARED ; j++)
{
tempR[j] = temp[getRotatedPixelIndex(j)];
}
//console.log("rotated")
temp = tempR;
}
//console.log ("crop is " + bICrop);
if (bCrop && bICrop)
{
/* let nCropRows = Math.floor(Math.random() * 4) + 1;
let nCropStart = (PIXELSSQUARED/2 - nCropRows*PIXELS);
let nCropEnd = (PIXELSSQUARED/2 + nCropRows*PIXELS);
for (let k = nCropStart ; k < nCropEnd; k++)
{
temp[k] = 0;
}*/
let nRowStart = Math.floor(Math.random() * 2) * PIXELS/2;
let nColStart = Math.floor(Math.random() * 2) * PIXELS/2;
console.log( "Crop " + nRowStart + "," + nColStart)
for (let r = nRowStart; r < nRowStart + (PIXELS/2); r++ )
{
for (let c = nColStart; c < nColStart + (PIXELS/2); c++)
temp[c * PIXELS + r] = 0;
}
}
return [temp, bShift, bRoundDown, bRotate, bCrop];
}
function testit() // test the network with a single exemplar, from global var "test_index"
{
let img = mnist.test_images[test_index];
let label = mnist.test_labels[test_index];
if (bDigitOrdering && (!currentDigits.includes(label)))
return;
// set up the inputs
let inputs = getInputs ( img );
test_inputs = inputs; // can inspect in console
let prediction = nn.predict(inputs); // array of outputs
let guess = findMax(prediction); // the top output
total_tests++;
if (guess == label) total_correct++;
let percent = (total_correct / total_tests) * 100 ;
thehtml[6] = " testrun: " + testrun + "<br> no: " + total_tests + " <br> " +
" correct: " + total_correct + "<br>" +
" score: " + greenspan + percent.toFixed(2) + "</span>";
AB.msg ( thehtml[6], 6 );
test_index++;
if ( test_index == NOTEST )
{
console.log( "finished testrun: " + testrun + " score: " + percent.toFixed(2) );
testrun++;
test_index = 0;
total_tests = 0;
total_correct = 0;
}
}
//--- find no.1 (and maybe no.2) output nodes ---------------------------------------
// (restriction) assumes array values start at 0 (which is true for output nodes)
function find12 (a) // return array showing indexes of no.1 and no.2 values in array
{
let no1 = 0;
let no2 = 0;
let no1value = 0;
let no2value = 0;
for (let i = 0; i < a.length; i++)
{
if (a[i] > no1value)
{
no1 = i;
no1value = a[i];
}
else if (a[i] > no2value)
{
no2 = i;
no2value = a[i];
}
}
var b = [ no1, no2 ];
return b;
}
// just get the maximum - separate function for speed - done many times
// find our guess - the max of the output nodes array
function findMax (a)
{
let no1 = 0;
let no1value = 0;
for (let i = 0; i < a.length; i++)
{
if (a[i] > no1value)
{
no1 = i;
no1value = a[i];
}
}
return no1;
}
// --- the draw function -------------------------------------------------------------
// every step:
function draw()
{
// check if libraries and data loaded yet:
if ( typeof mnist == 'undefined' ) return;
if (do_training && bAugment && train_index > last_augmented_index && !bAugmented ) {
console.log( "augment");
last_augmented_index = augmentData();
}
// how can we get white doodle on black background on yellow canvas?
// background('#ffffcc'); doodle.background('black');
background ('black');
// AC: Added a white square bracket.
noFill();
stroke(255,255,255);
rect(0, 0, ZOOMPIXELS+1, ZOOMPIXELS+1);
if ( do_training )
{
// do some training per step
for (let i = 0; i < TRAINPERSTEP; i++)
{
if (i == 0) trainit(true); // show only one per step - still flashes by
else trainit(false);
}
// do some testing per step
if(train_index > 10000 && bTesting)
for (let i = 0; i < TESTPERSTEP; i++)
testit();
}
// keep drawing demo and doodle images
// and keep guessing - we will update our guess as time goes on
if ( demo_exists )
{
drawDemo();
guessDemo();
}
if ( doodle_exists )
{
drawDoodle();
guessDoodle();
}
// detect doodle drawing
// (restriction) the following assumes doodle starts at 0,0
if ( mouseIsPressed ) // gets called when we click buttons, as well as if in doodle corner
{
// console.log ( mouseX + " " + mouseY + " " + pmouseX + " " + pmouseY );
var MAX = ZOOMPIXELS + 20; // can draw up to this pixels in corner
if ( (mouseX < MAX) && (mouseY < MAX) && (pmouseX < MAX) && (pmouseY < MAX) )
{
mousedrag = true; // start a mouse drag
doodle_exists = true;
doodle.stroke('white');
doodle.strokeWeight( DOODLE_THICK );
doodle.line(mouseX, mouseY, pmouseX, pmouseY);
}
}
else
{
// are we exiting a drawing
if ( mousedrag )
{
mousedrag = false;
bDoodleUpdated = true;
// console.log ("Exiting draw. Now blurring.");
if(bBlur)
doodle.filter (BLUR, DOODLE_BLUR); // just blur once
if (bCenter)
{
// prepare an augment version of input.
let img = doodle.get();
img.resize ( PIXELS, PIXELS );
img.loadPixels();
// set up inputs
let inputs = [];
for (let i = 0; i < PIXELSSQUARED ; i++)
{
inputs[i] = img.pixels[i * 4] / 255;
}
aug_doodle_inputs = centerInputs(inputs);
//console.log(find12(nn.predict(aug_doodle_inputs)));
//console.log(checkMargins(inputs));
}
}
}
/*if(keyIsPressed & CNNTrained)
{
let d = createGraphics(PIXELS, PIXELS);
d.copy(canvas, 0, 0, ZOOMPIXELS, ZOOMPIXELS, 0, 0, PIXELS, PIXELS)
console.log("w, h" + d.width + "," + d.height)
image ( d, ZOOMPIXELS+50, PIXELS, PIXELS, PIXELS ); // shrunk
digitClassifier.classify({image:d}, classifyFinished);
}*/
}
//--- demo -------------------------------------------------------------
// demo some test image and predict it
// get it from test set so have not used it in training
function makeDemo()
{
demo_exists = true;
var i = AB.randomIntAtoB ( 0, NOTEST - 1 );
demo = mnist.test_images[i];
var label = mnist.test_labels[i];
thehtml[8] = "Test image no: " + i + "<br>" +
"Classification: " + label + "<br>" ;
AB.msg ( thehtml[8], 8 );
// type "demo" in console to see raw data
}
function drawDemo()
{
var theimage = getImage ( demo );
// console.log (theimage);
noFill();
stroke(255,255,255);
rect(0, canvasheight - ZOOMPIXELS-1, ZOOMPIXELS+1, ZOOMPIXELS+1);
image ( theimage, 1, canvasheight - ZOOMPIXELS, ZOOMPIXELS, ZOOMPIXELS ); // magnified
image ( theimage, ZOOMPIXELS+50, canvasheight - ZOOMPIXELS, PIXELS, PIXELS ); // original
}
function guessDemo()
{
let inputs = getInputs ( demo );
demo_inputs = inputs; // can inspect in console
let prediction = nn.predict(inputs); // array of outputs
let guess = findMax(prediction); // the top output
thehtml[9] = " We classify it as: " + greenspan + guess + "</span>" ;
AB.msg ( thehtml[9], 9 );
}
//--- doodle -------------------------------------------------------------
function drawDoodle()
{
// doodle is createGraphics not createImage
let theimage = doodle.get();
// console.log (theimage);
image ( theimage, 0, 0, ZOOMPIXELS, ZOOMPIXELS ); // original
image ( theimage, ZOOMPIXELS+50, 0, PIXELS, PIXELS ); // shrunk
}
function guessDoodle()
{
let prediction, b;
if (mousedrag || (aug_doodle_inputs == undefined) || !bCenter)
{
// doodle is createGraphics not createImage
let img = doodle.get();
img.resize ( PIXELS, PIXELS );
img.loadPixels();
// set up inputs
let inputs = [];
for (let i = 0; i < PIXELSSQUARED ; i++)
{
inputs[i] = img.pixels[i * 4] / 255;
}
doodle_inputs = inputs; // can inspect in console
// feed forward to make prediction
prediction = nn.predict(inputs); // array of outputs
b = find12(prediction); // get no.1 and no.2 guesses
thehtml[2] = " We classify it as: " + greenspan + b[0] + "</span> <br>" +
" No.2 guess is: " + greenspan + b[1] + "</span> <br>" + cnn_doodle_result;
AB.msg ( thehtml[2], 2 );
}
else if (bDoodleUpdated)
{
bDoodleUpdated = false;
// feed forward to make prediction
prediction = nn.predict(aug_doodle_inputs); // array of outputs
b = find12(prediction); // get no.1 and no.2 guesses
console.log("centered")
thehtml[2] = " We classify it as: " + greenspan + b[0] + "</span> <br>" +
" No.2 guess is: " + greenspan + b[1] + "</span> <br>" + cnn_doodle_result;
AB.msg ( thehtml[2], 2 );
}
if(CNNTrained)
{
let d = createGraphics(PIXELS, PIXELS);
d.copy(canvas, 0, 0, ZOOMPIXELS, ZOOMPIXELS, 0, 0, PIXELS, PIXELS)
image ( d, ZOOMPIXELS+50, PIXELS, PIXELS, PIXELS ); // shrunk
digitClassifier.classify({image:d}, classifyFinished);
}
//let cnn_result = digitClassifier.classify(img, classifyFinished);
}
function wipeDoodle()
{
doodle_exists = false;
doodle.background('black');
}
function classifyFinished(error, result)
{
if(error)
{
console.log("classify error:" + error);
}
console.log("cnn predicted: " + result[0].label);
cnn_doodle_result = "(CNN) We classify it as: " + greenspan + result[0].label + "</span> <br>" +
"(CNN) Confidence is: " + greenspan + result[0].confidence + "</span> <br>";
}
// --- debugging --------------------------------------------------
// in console
// showInputs(demo_inputs);
// showInputs(doodle_inputs);
function showInputs ( inputs )
// display inputs row by row, corresponding to square of pixels
{
var str = "";
for (let i = 0; i < inputs.length; i++)
{
if ( i % PIXELS == 0 ) str = str + "\n"; // new line for each row of pixels
var value = inputs[i];
str = str + " " + value.toFixed(2) ;
}
console.log (str);
}
function toggleBlurAction()
{
bBlur = !bBlur;
strActionText.blur = bBlur?"Blur:ON":"Blur:OFF";
thehtml[1] = "<hr> <h1> 1. Doodle " +
"<button onclick='toggleBlurAction();' class='normbutton' >" + strActionText.blur + " </button>" +
"<button onclick='toggleCenterAction();' class='normbutton' >" + strActionText.center + " </button>" +
"</h1>" +
" Draw your doodle in top LHS. <button onclick='wipeDoodle();' class='normbutton' >Clear doodle</button> <br> ";
AB.msg ( thehtml[1], 1 );
}
function toggleCenterAction()
{
bCenter = !bCenter;
strActionText.center = bCenter?"Center:ON":"Center:OFF";
thehtml[1] = "<hr> <h1> 1. Doodle " +
"<button onclick='toggleBlurAction();' class='normbutton' >" + strActionText.blur + " </button>" +
"<button onclick='toggleCenterAction();' class='normbutton' >" + strActionText.center + " </button>" +
"</h1>" +
" Draw your doodle in top LHS. <button onclick='wipeDoodle();' class='normbutton' >Clear doodle</button> <br> ";
AB.msg ( thehtml[1], 1 );
}
function toggleTrainingAction()
{
do_training = !do_training;
strActionText.training = do_training?"Stop":"Start";
thehtml[3] = "<hr> <h1> 2. Training "+
"<button onclick='toggleTrainingAction();' class='normbutton' >" + strActionText.training + " </button>" +
"<button onclick='setupNetwork();' class='normbutton' >Refresh NN</button>"+ " </h1> " +
"Data Preprocessing: " +
"<button onclick='toggleOrderAction();' class='normbutton' >" + strActionText.ordering + " </button>" +
"<button onclick='toggleAugmentAction();' class='normbutton' >" + strActionText.augmenting + " </button>" +
"<br> " + strAugment+ "<br>";
AB.msg ( thehtml[3], 3 );
}
function toggleOrderAction()
{
bDigitOrdering = !bDigitOrdering
strActionText.ordering = bDigitOrdering?"Order:ON":"Order:OFF";
thehtml[3] = "<hr> <h1> 2. Training "+
"<button onclick='toggleTrainingAction();' class='normbutton' >" + strActionText.training + " </button>" +
"<button onclick='setupNetwork();' class='normbutton' >Refresh NN</button>"+ " </h1> " +
"Data Preprocessing: " +
"<button onclick='toggleOrderAction();' class='normbutton' >" + strActionText.ordering + " </button>" +
"<button onclick='toggleAugmentAction();' class='normbutton' >" + strActionText.augmenting + " </button>" +
"<br> " + strAugment+ "<br>";
AB.msg ( thehtml[3], 3 );
}
function toggleAugmentAction()
{
bAugment = !bAugment
strActionText.augmenting = bAugment?"Augment:ON":"Augment:OFF";
strAugment = bAugment?"<button onclick='toggleShiftAction();' class='normbutton' >" + strActionText.shift + " </button>" +
"<button onclick='toggleRotateAction();' class='normbutton' >" + strActionText.rotate + " </button>" +
"<button onclick='toggleRoundDownAction();' class='normbutton' >" + strActionText.rounddown + " </button>" +
"<button onclick='toggleCropAction();' class='normbutton' >" + strActionText.crop + " </button>":""
thehtml[3] = "<hr> <h1> 2. Training "+
"<button onclick='toggleTrainingAction();' class='normbutton' >" + strActionText.training + " </button>" +
"<button onclick='setupNetwork();' class='normbutton' >Refresh NN</button>"+ " </h1> " +
"Data Preprocessing: " +
"<button onclick='toggleOrderAction();' class='normbutton' >" + strActionText.ordering + " </button>" +
"<button onclick='toggleAugmentAction();' class='normbutton' >" + strActionText.augmenting + " </button>" +
"<br> " + strAugment+ "<br>";
AB.msg ( thehtml[3], 3 );
}
function toggleShiftAction()
{
bIShift = !bIShift
strActionText.shift = bIShift?"Shift:ON":"Shift:OFF";
strAugment = bAugment?"<button onclick='toggleShiftAction();' class='normbutton' >" + strActionText.shift + " </button>" +
"<button onclick='toggleRotateAction();' class='normbutton' >" + strActionText.rotate + " </button>" +
"<button onclick='toggleRoundDownAction();' class='normbutton' >" + strActionText.rounddown + " </button>" +
"<button onclick='toggleCropAction();' class='normbutton' >" + strActionText.crop + " </button>":""
thehtml[3] = "<hr> <h1> 2. Training "+
"<button onclick='toggleTrainingAction();' class='normbutton' >" + strActionText.training + " </button>" +
"<button onclick='setupNetwork();' class='normbutton' >Refresh NN</button>"+ " </h1> " +
"Data Preprocessing: " +
"<button onclick='toggleOrderAction();' class='normbutton' >" + strActionText.ordering + " </button>" +
"<button onclick='toggleAugmentAction();' class='normbutton' >" + strActionText.augmenting + " </button>" +
"<br> " + strAugment + "<br>";
AB.msg ( thehtml[3], 3 );
}
function toggleRotateAction()
{
bIRotate = !bIRotate
strActionText.rotate = bIRotate?"Rotate:ON":"Rotate:OFF";
strAugment = bAugment?"<button onclick='toggleShiftAction();' class='normbutton' >" + strActionText.shift + " </button>" +
"<button onclick='toggleRotateAction();' class='normbutton' >" + strActionText.rotate + " </button>" +
"<button onclick='toggleRoundDownAction();' class='normbutton' >" + strActionText.rounddown + " </button>" +
"<button onclick='toggleCropAction();' class='normbutton' >" + strActionText.crop + " </button>":""
thehtml[3] = "<hr> <h1> 2. Training "+
"<button onclick='toggleTrainingAction();' class='normbutton' >" + strActionText.training + " </button>" +
"<button onclick='setupNetwork();' class='normbutton' >Refresh NN</button>"+ " </h1> " +
"Data Preprocessing: " +
"<button onclick='toggleOrderAction();' class='normbutton' >" + strActionText.ordering + " </button>" +
"<button onclick='toggleAugmentAction();' class='normbutton' >" + strActionText.augmenting + " </button>" +
"<br> " + strAugment + "<br>";
AB.msg ( thehtml[3], 3 );
}
function toggleRoundDownAction()
{
bIRoundDown = !bIRoundDown
strActionText.rounddown = bIRoundDown?"RoundDown:ON":"RoundDown:OFF";
strAugment = bAugment?"<button onclick='toggleShiftAction();' class='normbutton' >" + strActionText.shift + " </button>" +
"<button onclick='toggleRotateAction();' class='normbutton' >" + strActionText.rotate + " </button>" +
"<button onclick='toggleRoundDownAction();' class='normbutton' >" + strActionText.rounddown + " </button>" +
"<button onclick='toggleCropAction();' class='normbutton' >" + strActionText.crop + " </button>":""
thehtml[3] = "<hr> <h1> 2. Training "+
"<button onclick='toggleTrainingAction();' class='normbutton' >" + strActionText.training + " </button>" +
"<button onclick='setupNetwork();' class='normbutton' >Refresh NN</button>"+ " </h1> " +
"Data Preprocessing: " +
"<button onclick='toggleOrderAction();' class='normbutton' >" + strActionText.ordering + " </button>" +
"<button onclick='toggleAugmentAction();' class='normbutton' >" + strActionText.augmenting + " </button>" +
"<br> " + strAugment + "<br>";
AB.msg ( thehtml[3], 3 );
}
function toggleCropAction()
{
bICrop = !bICrop;
console.log("@toggle " + bICrop)
strActionText.crop = bICrop?"Crop:ON":"Crop:OFF";
strAugment = bAugment?"<button onclick='toggleShiftAction();' class='normbutton' >" + strActionText.shift + " </button>" +
"<button onclick='toggleRotateAction();' class='normbutton' >" + strActionText.rotate + " </button>" +
"<button onclick='toggleRoundDownAction();' class='normbutton' >" + strActionText.rounddown + " </button>" +
"<button onclick='toggleCropAction();' class='normbutton' >" + strActionText.crop + " </button>":""
thehtml[3] = "<hr> <h1> 2. Training "+
"<button onclick='toggleTrainingAction();' class='normbutton' >" + strActionText.training + " </button>" +
"<button onclick='setupNetwork();' class='normbutton' >Refresh NN</button>"+ " </h1> " +
"Data Preprocessing: " +
"<button onclick='toggleOrderAction();' class='normbutton' >" + strActionText.ordering + " </button>" +
"<button onclick='toggleAugmentAction();' class='normbutton' >" + strActionText.augmenting + " </button>" +
"<br> " + strAugment + "<br>";
AB.msg ( thehtml[3], 3 );
}