// Cloned by Andrew Merrigan on 17 Dec 2019 from World "Character recognition neural network" by "Coding Train" project
// Please leave this clone trail here.
// Port of Character recognition neural network from here:
// https://github.com/CodingTrain/Toy-Neural-Network-JS/tree/master/examples/mnist
// with many modifications
// Name: Andrew Merrigan
// Student No: 19214610
// Based on:
// https://github.com/tensorflow/tfjs-vis/tree/master/tfjs-vis/demos/mnist
// https://storage.googleapis.com/tfjs-vis/mnist/dist/index.html
// https://www.tensorflow.org/js/guide
// --- defined by MNIST - do not change these ---------------------------------------
const PIXELS = 28; // images in data set are tiny
const PIXELSSQUARED = PIXELS * PIXELS;
// number of training and test exemplars in the data set:Z
//--- can modify all these --------------------------------------------------
// no of nodes in network
const LEARNING_RATE = 0.26;
const LEARNING_RATE_DECAY = 0.004;
const EPOCHS = 2;
const BATCH_SIZE = 300;
// Model constants
const CLASSES = 10;
const NUM_CLASSES = 10;
const TRAIN_DATA_SIZE = 60000;
const TEST_DATA_SIZE = 10000;
const IMAGE_SIZE = 28;
// should we train
let stop_training = false;
// multiply it by this to magnify for display
const ZOOMFACTOR = 7;
const ZOOMPIXELS = ZOOMFACTOR * PIXELS;
// 3 rows of
// large image + 50 gap + small image
// 50 gap between rows
const canvaswidth = (PIXELS + ZOOMPIXELS) + 50;
const canvasheight = (ZOOMPIXELS * 3) + 100;
const DOODLE_THICK = 25; // thickness of doodle lines
const DOODLE_BLUR = 8; // blur factor applied to doodles
// an object that provides functions to load batchs of data
let mnist;
let nn;
// images in LHS:
let doodle, demo;
let doodle_exists = false;
let demo_exists = false;
let mousedrag = false; // are we in the middle of a mouse drag drawing?
// save inputs to global var to inspect
// type these names in console
var train_inputs, test_inputs, demo_inputs, doodle_inputs;
// CSS trick
// make run header bigger
$("#runheaderbox").css({ "max-height": "90vh", "max-width": "70vW" });
//--- start of AB.msgs structure: ---------------------------------------------------------
// We output a serious of AB.msgs to put data at various places in the run header
var thehtml;
// 1 Doodle header
thehtml = "
1. Doodle
Top row: Doodle (left) and shrunk (right). " +
" Draw your doodle in top LHS. ";
AB.msg(thehtml, 1);
// 2 Doodle variable data (guess)
// 3 Training header
thehtml = "
2. Training
Middle row: Training image magnified (left) and original (right). " +
" " +
" " +
" ";
AB.msg(thehtml, 3);
// 4 variable training data
thehtml = "
Bottom row: Test image magnified (left) and original (right). " +
" The network is not trained on any of these images. " +
" ";
AB.msg(thehtml, 7);
// 8 Demo variable data (random demo ID)
// 9 Demo variable data (changing guess)
const greenspan = " ";
//--- end of AB.msgs structure: ---------------------------------------------------------
let trainingPromise;
function setup() {
createCanvas(canvaswidth, canvasheight);
doodle = createGraphics(ZOOMPIXELS, ZOOMPIXELS); // doodle on larger canvas
doodle.pixelDensity(1);
// JS load other JS
// maybe have a loading screen while loading the JS and the data set
AB.loadingScreen();
$.getScript("https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis@1.0.2/dist/tfjs-vis.umd.min.js", () => {
$.getScript("https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.min.js", () => {
$.getScript("/uploads/andrewmerrigan/tsModel.js", () => {
$.getScript("/uploads/andrewmerrigan/mnist.js", () => {
console.log("All JS loaded");
nn = getModel();
loadMNIST().then((data) => {
mnist = data;
console.log("All data loaded into mnist object:")
console.log(mnist);
AB.removeLoading(); // if no loading screen exists, this does nothing
}).then(() => {
startTraining();
});
});
});
});
});
}
function watchTraining() {
document.querySelector('#show-visor').addEventListener('click', () => {
const visorInstance = tfvis.visor();
if (!visorInstance.isOpen()) {
visorInstance.toggle();
}
});
const metrics = ['acc', 'val_acc'];
// const container = document.getElementById("metrics"); // puts the graph inline
const container = {
tab: 'Training',
name: 'Training Metrics',
styles: {
height: '90vh'
}
};
return tfvis.show.fitCallbacks(container, metrics, {
height: 200
});
}
// make a P5 image object from a raw data array
function getImage(img) {
let theimage = createImage(PIXELS, PIXELS); // make blank image, then populate it
theimage.loadPixels();
for (let i = 0; i < PIXELSSQUARED; i++) {
let bright = img[i];
let index = i * 4;
theimage.pixels[index + 0] = bright;
theimage.pixels[index + 1] = bright;
theimage.pixels[index + 2] = bright;
theimage.pixels[index + 3] = 255;
}
theimage.updatePixels();
return theimage;
}
// convert img array into normalised input array
function getInputs(img) {
let inputs = [];
for (let i = 0; i < PIXELSSQUARED; i++) {
let bright = img[i];
inputs[i] = bright / 255; // normalise to 0 to 1
}
return (inputs);
}
function restartTraining() {
if (!stop_training) {
return;
}
stop_training = false;
startTraining();
}
let trainImage;
// train the network with a single exemplar, from global var "train_index", show visual on or off
function startTraining() {
const [trainXs, trainYs] = mnist.getTrainTensor(TRAIN_DATA_SIZE);
const [testXs, testYs] = mnist.getTestTensor(TEST_DATA_SIZE);
// tf.enableDebugMode()
fitCallbacks = watchTraining()
// Override the callbacks provided by ts viz to allow training to be stopped and for images to cycle.
let customCallbacks = {
onEpochEnd: (epoch, logs) => {
fitCallbacks.onEpochEnd(epoch, logs);
status.value = `epoch ${epoch + 1}`;
if (stop_training) {
this.model.stopTraining = true;
}
},
onBatchEnd: (batch, logs) => {
fitCallbacks.onBatchEnd(batch, logs);
if (batch % 8 == 0) {
const trainImageTs = tf.slice(trainXs, batch * BATCH_SIZE + Math.floor(BATCH_SIZE / 2), 1).reshape([784]);
trainImage = getImage(trainImageTs.arraySync()); // get image from data array
trainImageTs.dispose();
AB.msg("
", 5);
}
}
}
return nn.fit(trainXs, trainYs, {
batchSize: BATCH_SIZE,
validationData: [testXs, testYs],
epochs: EPOCHS,
shuffle: true,
callbacks: customCallbacks
});
}
// --- the draw function -------------------------------------------------------------
// every step:
function draw() {
// check if libraries and data loaded yet:
if (typeof mnist == 'undefined') return;
if (typeof nn == 'undefined') return;
// how can we get white doodle on black background on yellow canvas?
// background('#ffffcc'); doodle.background('black');
background('black');
// keep drawing demo and doodle images
// and keep guessing - we will update our guess as time goes on
if (trainImage) {
image(trainImage, 0, ZOOMPIXELS + 50, ZOOMPIXELS, ZOOMPIXELS); // magnified
image(trainImage, ZOOMPIXELS + 50, ZOOMPIXELS + 50, PIXELS, PIXELS);
}
if (demo_exists) {
drawDemo();
guessDemo();
}
if (doodle_exists) {
drawDoodle();
guessDoodle();
}
// detect doodle drawing
// (restriction) the following assumes doodle starts at 0,0
// gets called when we click buttons, as well as if in doodle corner
if (mouseIsPressed) {
// console.log ( mouseX + " " + mouseY + " " + pmouseX + " " + pmouseY );
var MAX = ZOOMPIXELS + 20; // can draw up to this pixels in corner
if ((mouseX < MAX) && (mouseY < MAX) && (pmouseX < MAX) && (pmouseY < MAX)) {
mousedrag = true; // start a mouse drag
doodle_exists = true;
doodle.stroke('#FDFDFD');
doodle.strokeWeight(DOODLE_THICK);
doodle.line(mouseX, mouseY, pmouseX, pmouseY);
}
} else {
// are we exiting a drawing
if (mousedrag) {
mousedrag = false;
}
}
}
//--- demo -------------------------------------------------------------
// demo some test image and predict it
// get it from test set so have not used it in training
function makeDemo() {
demo_exists = true;
var i = AB.randomIntAtoB(0, TEST_DATA_SIZE - 1);
const [testXs, testYs] = mnist.getTestTensor(TEST_DATA_SIZE);
demo = tf.slice(testXs, i, 1).reshape([784]).arraySync();
var label = tf.slice(testYs, i, 1).argMax(-1).dataSync();
thehtml = "Test image no: " + i + " " +
"Classification: " + label + " ";
AB.msg(thehtml, 8);
// type "demo" in console to see raw data
}
function drawDemo() {
var theimage = getImage(demo);
image(theimage, 0, canvasheight - ZOOMPIXELS, ZOOMPIXELS, ZOOMPIXELS); // magnified
image(theimage, ZOOMPIXELS + 50, canvasheight - ZOOMPIXELS, PIXELS, PIXELS); // original
}
function guessDemo() {
let inputs = getInputs(demo);
let demo_inputs = tf.tensor2d(inputs, [1, PIXELSSQUARED]).reshape([1, 28, 28, 1]);
let prediction = nn.predict(demo_inputs); // array of outputs
predicted_values = prediction.dataSync()
var values = predicted_values.slice();
predicted_values.sort(function (a, b) { return b - a });
thehtml = " We classify it as: " + greenspan + values.indexOf(predicted_values[0]) + "" + " my second guess is: " + values.indexOf(predicted_values[1]);
AB.msg(thehtml, 9);
demo_inputs.dispose();
prediction.dispose();
}
//--- doodle -------------------------------------------------------------
function drawDoodle() {
// doodle is createGraphics not createImage
let theimage = doodle.get();
// console.log (theimage);
image(theimage, 0, 0, ZOOMPIXELS, ZOOMPIXELS); // original
image(theimage, ZOOMPIXELS + 50, 0, PIXELS, PIXELS); // shrunk
}
function guessDoodle() {
// doodle is createGraphics not createImage
let img = doodle.get();
img.resize(PIXELS, PIXELS);
img.filter(BLUR, DOODLE_BLUR);
img.loadPixels();
// set up inputs
let inputs = [];
for (let i = 0; i < PIXELSSQUARED; i++) {
inputs[i] = img.pixels[i * 4] / 255;
}
doodle_inputs = tf.tensor2d(inputs, [1, PIXELSSQUARED]).reshape([1, 28, 28, 1]);
// feed forward to make prediction
let prediction = nn.predict(doodle_inputs);
predicted_values = prediction.dataSync()
thehtml = " We classify it as: "
var values = predicted_values.slice();
predicted_values.sort(function (a, b) { return b - a });
thehtml = thehtml + "Guess one:" + greenspan + values.indexOf(predicted_values[0]) + " "
+ "Guess two:" + values.indexOf(predicted_values[1]) + " "
+ "Guess three:" + values.indexOf(predicted_values[2]);
AB.msg(thehtml, 2);
prediction.dispose();
doodle_inputs.dispose();
}
function wipeDoodle() {
doodle_exists = false;
doodle.background('black');
thehtml = " We classify it as: Guess one: - "
+ "Guess two: - "
+ "Guess three: -";
AB.msg(thehtml, 2);
}
// display inputs row by row, corresponding to square of pixels
function showInputs(inputs) {
var str = "";
for (let i = 0; i < inputs.length; i++) {
if (i % PIXELS == 0) str = str + "\n"; // new line for each row of pixels
var value = inputs[i];
str = str + " " + value.toFixed(2);
}
console.log(str);
}