/*
Author: Kieron Drumm.
Student Number: 13314446.
Module: CA686 (Foundations of Artifical Intelligence).
Assignment: Practical 2.
Description: The following world is heavily inspired by the "Character Recognition Neural Network" world
(https://ancientbrain.com/world.php?world=2337620282). It is an extension, using Tensorflow.js
in order to attempt to identify any input (doodles) from the user as a character from A-Z.
*/
/*
The following program will use Tensorflow.js to carry out any data normalisation and training. Any knowledge
or sample code gathered about Tensorflow has been pulled from the following places:
- The official website for Tensorflow.js:
https://js.tensorflow.org/api/latest/
- Tensorflow.js's official tutorial on constructing CNN's:
https://codelabs.developers.google.com/codelabs/tfjs-training-classfication/index.html#0
- Daniel Schiffman's Coding Train series, giving a general introduction to Tensorflow.js:
https://www.youtube.com/playlist?list=PLRqwX-V7Uu6YIeVA3dNxbR9PYj4wV31oQ
- Daniel Schiffman's Coding Train series on training a colour classifier with Tensorflow.js:
https://www.youtube.com/playlist?list=PLRqwX-V7Uu6bmMRCIoTi72aNWHo7epX4L
*/
/* Constants */
/* The number of epochs, for which we should train our model. */
const numEpochs = 3;
/* Some values related to the mnist image dataset. */
/* Taken from the "Character Recognition Neural Network" world. */
const pixels = 28;
const pixelsSquared = pixels * pixels;
/* The zoom factor to use when displaying each digit. */
/* Taken from the "Character Recognition Neural Network" world. */
const zoomFactor = 10;
const zoomPixels = zoomFactor * pixels;
/* The dimensions of our canvas. */
/* Taken from the "Character Recognition Neural Network" world. */
const canvasHeight = (zoomPixels * 2) + 50;
const canvasWidth = zoomPixels;
/* Some additional values to describe our doodle. */
/* Taken from the "Character Recognition Neural Network" world. */
const doodleBlur = 3;
const doodleThickness = 18;
/* Some HTML that stylises an element in the colour of dark green. */
const greenSpan = "<span style='font-weight:bold; font-size:x-large; color:darkgreen'>";
/* The default splash message, to be added to. */
const splashMessage = "<h1>Training in Progress...</h1></br>";
/* Variables */
/* The current epoch that we are at, during the training phase. */
let currentEpoch = 0;
/* To store our demo. */
/* Taken from the "Character Recognition Neural Network" world. */
let demo;
/* Whether or not a demo is ready to be displayed. */
/* Taken from the "Character Recognition Neural Network" world. */
let demoExists = false;
/* Whether or not we have made a prediction about the current demo image. */
let demoGuessed = false;
/* To store our doodle. */
/* Taken from the "Character Recognition Neural Network" world. */
let doodle;
/* Whether or not a doodle is ready to be displayed. */
/* Taken from the "Character Recognition Neural Network" world. */
let doodleExists = false;
/* Whether or not we have made a prediction about this doodle yet. */
let doodleGuessed = false;
/* Whether or not the training process has complete. */
let modelTrained = false;
/* Denotes whether or not the mouse has been dragged. */
/* Taken from the "Character Recognition Neural Network" world. */
let mouseDrag = false;
/* The HTML that will be used to display any messages. */
/* Taken from the "Character Recognition Neural Network" world. */
let theHTML;
/* The model that we are going to train. */
let networkModel;
/* Some tensors, containing the images and labels with which we will train and test our model. */
let encodedTestLabels;
let encodedTrainingLabels;
let testImages;
let testLabels;
let trainingImages;
let trainingLabels;
/*
This function will do the following:
- Read in our dataset, separating it out into training and testing data, and carrying out any required normalisation.
- Construct our neural network using Tensorflow.js.
- Train our neural network and produce a model to be used to predict which character from A to Z that a user-provided doodle is representing.
*/
function setup() {
/* Initialise the canvas. */
createCanvas(canvasWidth, canvasHeight);
/* Initialise the doodle. */
doodle = createGraphics(zoomPixels, zoomPixels);
doodle.pixelDensity(1);
/* Initialise some elements on the UI. */
/* The header for the doodle. */
theHTML = "<hr><h1>1. Doodle</h1>Please draw your doodle in the top-left.</br></br>" +
"To clear the doodle, click here: <button onclick='wipeDoodle();' class='normbutton'>Clear</button></br>";
AB.msg(theHTML, 1);
/* The header for the demo. */
theHTML = "<hr><h1>2. Demo</h1>Demonstrations of the trained network can be viewed here.</br></br>" +
"The network is <i>not</i> trained on any of these images.</br></br>" +
"To test out an image, click here: <button onclick='makeDemo();' class='normbutton'>Test Image</button></br>";
AB.msg(theHTML, 3);
/* The header for the demo. */
theHTML = "<hr><h1>3. Model Settings</h1>" +
"To delete the currently trained model, click here: <button onclick='localStorage.clear();'" +
" class='normbutton'>Delete Trained Model</button></br>";
AB.msg(theHTML, 6);
/* Display a splash screen while the neural network is being trained. */
AB.newSplash();
AB.splashHtml("<h1>Training in Progress...</h1>");
/* Load Tensorflow.js at the very beginning. */
$.getScript("https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js", () => {
/*
In order to avoid replication, I will be using pre-uploaded binary files containing the mnist dataset, belonging to Vrushali Golde:
- https://ancientbrain.com/uploads/vrushali/emnist-letters-train-images-idx3-ubyte.bin
- https://ancientbrain.com/uploads/vrushali/emnist-letters-train-labels-idx1-ubyte.bin
- https://ancientbrain.com/uploads/vrushali/emnist-letters-test-images-idx3-ubyte.bin
- https://ancientbrain.com/uploads/vrushali/emnist-letters-test-labels-idx1-ubyte.bin
The following code is a modified form of something that was written by Vrushali Golde to read in the training and
test datasets that we will be using: https://ancientbrain.com/uploads/vrushali/vru_mnist.js.
*/
/* Determine whether or not a model has been trained before by checking the local storage for this browser. */
existingModel = localStorage.getItem("tensorflowjs_models/drummk2-tensorflow-network-model/info");
/* If a model has already been trained, then skip the training phase,
otherwise, train a model from scratch. */
if (existingModel) {
tf.loadLayersModel("localstorage://drummk2-tensorflow-network-model").then((loadedModel) => {
console.log("Loading test data");
loadTestData();
console.log("Pre-existing model loaded from local storage.");
networkModel = loadedModel;
modelTrained = true;
});
} else {
console.log("No model was found in local storage; commencing with training.");
/* Load all necessary files. */
let mnistData = {};
/* Intialise all of the URLs for our files. */
let files = {
trainingImages: "/uploads/vrushali/emnist-letters-train-images-idx3-ubyte.bin",
trainingLabels: "/uploads/vrushali/emnist-letters-train-labels-idx1-ubyte.bin",
testImages: "/uploads/vrushali/emnist-letters-test-images-idx3-ubyte.bin",
testLabels: "/uploads/vrushali/emnist-letters-test-labels-idx1-ubyte.bin"
};
/* Retrieve all of the files in question, then store them in a centralised object. */
Promise.all(Object.keys(files).map(async (file) => {
mnistData[file] = await loadBinaryFile(files[file]);
})).then(() => {
/* Store all of the MNIST data that was retrieved. */
/* Convert all of our individual training images from a 1D array to a 28x28x1 tensor. */
trainingImages = tf.reshape(tf.tensor(mnistData.trainingImages, [124800, 784], "float32"), [124800, 28, 28, 1]);
/* Convert our training labels to a tensor. */
trainingLabels = tf.tensor1d(mnistData.trainingLabels, "int32");
/* Store our test images. */
testImages = mnistData.testImages;
/* Store our test labels. */
testLabels = mnistData.testLabels;
/* Encode the training and test labels using one-hot encoding. */
encodedTestLabels = tf.oneHot(testLabels, 26);
encodedTrainingLabels = tf.oneHot(trainingLabels, 26);
/* Initialise and configure our model, before we begin training. */
networkModel = tf.sequential();
/* Configure our first hidden layer. */
networkModel.add(tf.layers.conv2d({
activation: "tanh",
filters: 8,
inputShape: trainingImages.shape.slice(1),
kernelInitializer: "glorotUniform",
kernelSize: 5,
strides: 1
}));
/* Carry out a max pooling operation on the output of the first hidden layer. */
networkModel.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));
/* Configure our second hidden layer, this time with additional filters. */
networkModel.add(tf.layers.conv2d({
activation: "tanh",
filters: 16,
inputShape: trainingImages.shape.slice(1),
kernelInitializer: "glorotUniform",
kernelSize: 5,
strides: 1
}));
/* Carry out a max pooling operation on the output of the second hidden layer. */
networkModel.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));
/* Flatten the output of our second hidden layer, so that it can be sent into the output layer. */
networkModel.add(tf.layers.flatten());
/* Configure our output layer. It's worth noting that:
- I have chosen to use the "Softmax" function for my activation, as it is ideal for multi-class classification problems such as this,
and it works well with categorical cross entropy, due to the fact that it is continuously differentiable, which I intend to use
for my loss function for the model.
- This layer will have 26 units to represents each letter of the alphabet.
*/
networkModel.add(tf.layers.dense({
activation: "softmax",
kernelInitializer: 'glorotUniform',
units: 26
}));
/* Initialise an optimiser and a loss function for our network. I will be using stochastic gradient descent
for optimisation and categorical cross entropy for my loss function. */
networkModel.compile({
loss: "categoricalCrossentropy",
optimizer: tf.train.sgd(0.0001),
});
/* Train the model for 'n' epochs, ensuring that we use 10% of our training set for validation,
and reshuffle our dataset between epochs. */
let options = {
callbacks: {
onBatchEnd: (batchNum, logs) => {
AB.splashHtml(splashMessage + "<p><b>Epoch:</b> " + (currentEpoch) + "/" + options.epochs +
"</br><b>Batch:</b> " + batchNum + "</br><b>Loss:</b> " + logs.loss);
},
onEpochBegin: (epochNum, logs) => {
currentEpoch = epochNum + 1;
}
},
batchSize: 512,
epochs: numEpochs,
shuffle: true,
validationSplit: 0.1
};
networkModel.fit(trainingImages, encodedTrainingLabels, options).then(() => {
networkModel.save("localstorage://drummk2-tensorflow-network-model").then(() => {
console.log("Model saved to local storage.");
modelTrained = true;
});
});
});
}
});
}
/*
This function will do the following:
- Graphically illustrate the training process.
- Process any input (doodles) from the user.
*/
function draw() {
/* Remove the splash screen, once the training process has complete. */
if (modelTrained) {
AB.removeSplash();
/* Initialise a black background. */
background('black');
/* Has the user requested for a test to be run? */
/* Taken from the "Character Recognition Neural Network" world. */
if (demoExists) {
drawDemo();
if (!demoGuessed) {
guessDemo();
}
}
/* Has a doodle been drawn on the canvas? */
/* Taken from the "Character Recognition Neural Network" world. */
if (doodleExists) {
/* Draw the doodle that has been done. */
drawDoodle();
/* Use our neural network to determine the character that the doodle is displaying. */
if (!doodleGuessed) {
guessDoodle();
doodleGuessed = true;
}
}
/* Detect any mouse press actions, and allow the user to doodle a character. */
/* Taken from the "Character Recognition Neural Network" world. */
if (mouseIsPressed) {
var maxBoundary = zoomPixels + 50;
if ((mouseX < maxBoundary) && (mouseY < maxBoundary) &&
(pmouseX < maxBoundary) && (pmouseY < maxBoundary)) {
mouseDrag = true;
doodleExists = true;
doodleGuessed = false;
doodle.stroke('white');
doodle.strokeWeight(doodleThickness);
doodle.line(mouseX, mouseY, pmouseX, pmouseY);
}
} else {
if (mouseDrag) {
mouseDrag = false;
doodle.filter(BLUR, doodleBlur);
}
}
}
}
/* Utility Functions */
/* Draw a demo image. This code has been taken from the "Character Recognition Neural Network" world. */
function drawDemo() {
/* Retrieve the demo image. */
let theImage = getImage(demo);
/* Draw the demo. */
image(theImage, 0, canvasHeight - zoomPixels, zoomPixels, zoomPixels);
image(theImage, zoomPixels + 50, canvasHeight - zoomPixels, pixels, pixels);
}
/* Draw a doodle image. This code has been taken from the "Character Recognition Neural Network" world. */
function drawDoodle() {
/* Retrieve the doodle image. */
let theImage = doodle.get();
/* Draw the doodle. */
image(theImage, 0, 0, zoomPixels, zoomPixels);
image(theImage, zoomPixels + 50, 0, pixels, pixels);
}
/* Construct a P5 image from an array of pixels. This code has been taken from
the "Character Recognition Neural Network" world.*/
function getImage(image) {
/* Construct a blank image, to be populated. */
let theImage = createImage(pixels, pixels);
theImage.loadPixels();
/* Populate the blank image. */
for (let i = 0; i < pixelsSquared ; i++) {
let bright = image[i];
let index = i * 4;
theImage.pixels[index + 0] = bright;
theImage.pixels[index + 1] = bright;
theImage.pixels[index + 2] = bright;
theImage.pixels[index + 3] = 255;
}
/* Return the newly constructed image. */
theImage.updatePixels();
return theImage;
}
/* Retrieve the pixels for a given image. This code has been taken from
the "Character Recognition Neural Network" world. */
function getInputs(image) {
let inputs = [];
for (let i = 0; i < pixelsSquared ; i++) {
let bright = image[i];
inputs[i] = bright / 255;
}
return inputs;
}
/* Predict the character that is present in the demo image. This code is a modification of something
taken from the "Character Recognition Neural Network" world. */
function guessDemo() {
/* Retrieve the pixels for our demo image. */
let inputs = getInputs(demo);
/* Convert these pixels into a tensor, to be used by Tensorflow.js. */
let inputTensor = tf.reshape(tf.tensor(inputs, [1, 784], "int32"), [1, 28, 28, 1]);
/* Predict the character that is present in the image using Tensorflow.js. Since
Tensorflow.js will return a 1D array with probabilities for every characters,
we must find the character in the prediction with the highest level of confidence.
*/
let prediction = networkModel.predict(inputTensor).arraySync()[0];
/* Retrieve the index of the character with the highest confidence level. */
let predictedIndex = prediction.indexOf(Math.max(...prediction));
/* Convert this index to the character to which it corresponds. */
let predictedCharacter = String.fromCharCode(97 + predictedIndex)
/* Output our prediction. */
theHTML = "The network classified it as: " + greenSpan + predictedCharacter + "</span>";
AB.msg(theHTML, 5);
demoGuessed = true;
}
/* Predict the character that is present in the doodle image. This code has been taken
from the "Character Recognition Neural Network" world. */
function guessDoodle() {
/* Retrieve the image in question. */
let image = doodle.get();
/* Resize the doodle to fit the dimensions required for our model, and load the image. */
image.resize(pixels, pixels);
image.loadPixels();
/* Retrieve the pixel values for the image. */
let inputs = [];
for (let i = 0; i < pixelsSquared; i++) {
inputs[i] = image.pixels[i * 4] / 255;
}
/* Convert these pixels into a tensor, to be used by Tensorflow.js. */
let inputTensor = tf.reshape(tf.tensor(inputs, [1, 784], "int32"), [1, 28, 28, 1]);
/* Predict the character that is present in the image using Tensorflow.js. Since
Tensorflow.js will return a 1D array with probabilities for every characters,
we must find the character in the prediction with the highest level of confidence.
*/
let prediction = networkModel.predict(inputTensor).arraySync()[0];
/* Retrieve the index of the character with the highest confidence level. */
let predictedIndex = prediction.indexOf(Math.max(...prediction));
/* Convert this index to the character to which it corresponds. */
let predictedCharacter = String.fromCharCode(97 + predictedIndex)
/* Clear the previous classification. */
theHTML = "Classification:</span>";
AB.msg(theHTML, 2);
/* Display our prediction to the user. */
theHTML = "Classification: " + greenSpan + predictedCharacter + "</span>";
AB.msg(theHTML, 2);
}
/* Read a binary file from a specified URL. The following code is a modified form of something that was
written by Vrushali Golde to read in the training and test datasets that we will be using:
https://ancientbrain.com/uploads/vrushali/vru_mnist.js.
*/
async function loadBinaryFile(fileURL) {
let buffer = await fetch(fileURL).then(r => r.arrayBuffer());
let headerCount = 4;
let headerView = new DataView(buffer, 0, 4 * headerCount);
let headers = new Array(headerCount).fill().map((_, i) => headerView.getUint32(4 * i, false));
/* Ascertain the type of file that we are dealing with, using the magic number. */
let type, dataLength;
/* Does the file contain training labels? */
if (headers[0] === 2049) {
type = "label";
dataLength = 1;
headerCount = 2;
/* Does the file contain training images. */
} else if (headers[0] === 2051) {
type = "image";
dataLength = headers[2] * headers[3];
}
/* If the file contains images, then load all of the data from this file into an array,
otherwise, if it contains labels, just return the binary data. */
let bufferData = Int32Array.from(new Uint8Array(buffer, headerCount * 4));
if (type === "image") {
let bufferArray = [];
for (let i = 0; i < headers[1]; i++) {
bufferArray.push(bufferData.subarray(dataLength * i, dataLength * (i + 1)));
}
return bufferArray;
}
/* Otherwise, just return the labels. */
return bufferData;
}
/* Load our test data. To be used when a pre-existing model is loaded,
but the test data must still be retrieved. */
function loadTestData() {
/* Load all necessary files. */
let mnistData = {};
/* Intialise all of the URLs for our test files. */
let testFiles = {
testImages: "/uploads/vrushali/emnist-letters-test-images-idx3-ubyte.bin",
testLabels: "/uploads/vrushali/emnist-letters-test-labels-idx1-ubyte.bin"
};
/* Retrieve all of the files in question, then store them in a centralised object. */
Promise.all(Object.keys(testFiles).map(async (testFile) => {
mnistData[testFile] = await loadBinaryFile(testFiles[testFile]);
})).then(() => {
/* Store all of the MNIST data that was retrieved. */
/* Store our test images. */
testImages = mnistData.testImages;
/* Store our test labels. */
testLabels = mnistData.testLabels;
/* Encode the test labels using one-hot encoding. */
encodedTestLabels = tf.oneHot(testLabels, 26);
});
}
/* Retrieve an image from the test data to demo. This code has been taken from the "Character Recognition Neural Network" world. */
function makeDemo() {
/* Retrieve a random image from our test data. */
demoExists = true;
let randomIndex = AB.randomIntAtoB(0, testImages.length - 1);
demo = testImages[randomIndex];
let testLabel = testLabels[randomIndex];
/* Construct a formatted message about our test sample. */
theHTML = "Test image no: " + randomIndex + "<br>" +
"Classification: " + String.fromCharCode(97 + testLabel) + "<br>";
/* Display the details of the image that we have randomly chosen. */
AB.msg(theHTML, 4);
demoGuessed = false;
}
/* Wipe our canvas. This code has been taken from the "Character Recognition Neural Network" world. */
function wipeDoodle() {
/* Clear the canvas. */
doodleExists = false;
doodle.background("black");
/* Clear the previous classification. */
theHTML = "Classification:</span>";
AB.msg(theHTML, 2);
}