Code viewer for World: World: CA686: Practical 2 ...
/*
    Author: Kieron Drumm.
    Student Number: 13314446.
    Module: CA686 (Foundations of Artifical Intelligence).
    Assignment: Practical 2.
    Description: The following world is heavily inspired by the "Character Recognition Neural Network" world
                 (https://ancientbrain.com/world.php?world=2337620282). It is an extension, using Tensorflow.js
                 in order to attempt to identify any input (doodles) from the user as a character from A-Z.
*/


/*
    The following program will use Tensorflow.js to carry out any data normalisation and training. Any knowledge
    or sample code gathered about Tensorflow has been pulled from the following places:
        - The official website for Tensorflow.js:
          https://js.tensorflow.org/api/latest/
        
        - Tensorflow.js's official tutorial on constructing CNN's:
          https://codelabs.developers.google.com/codelabs/tfjs-training-classfication/index.html#0
        
        - Daniel Schiffman's Coding Train series, giving a general introduction to Tensorflow.js:
          https://www.youtube.com/playlist?list=PLRqwX-V7Uu6YIeVA3dNxbR9PYj4wV31oQ
        
        - Daniel Schiffman's Coding Train series on training a colour classifier with Tensorflow.js:
          https://www.youtube.com/playlist?list=PLRqwX-V7Uu6bmMRCIoTi72aNWHo7epX4L
*/

/* Constants */
/* The number of epochs, for which we should train our model. */
const numEpochs = 3;

/* Some values related to the mnist image dataset. */
/* Taken from the "Character Recognition Neural Network" world. */
const pixels = 28;
const pixelsSquared = pixels * pixels;

/* The zoom factor to use when displaying each digit. */
/* Taken from the "Character Recognition Neural Network" world. */
const zoomFactor = 10;                        
const zoomPixels = zoomFactor * pixels;

/* The dimensions of our canvas. */
/* Taken from the "Character Recognition Neural Network" world. */
const canvasHeight = (zoomPixels * 2) + 50;
const canvasWidth = zoomPixels;

/* Some additional values to describe our doodle. */
/* Taken from the "Character Recognition Neural Network" world. */
const doodleBlur = 3;
const doodleThickness = 18;

/* Some HTML that stylises an element in the colour of dark green. */
const greenSpan = "<span style='font-weight:bold; font-size:x-large; color:darkgreen'>";

/* The default splash message, to be added to. */
const splashMessage = "<h1>Training in Progress...</h1></br>";

/* Variables */
/* The current epoch that we are at, during the training phase. */
let currentEpoch = 0;

/* To store our demo. */
/* Taken from the "Character Recognition Neural Network" world. */
let demo;

/* Whether or not a demo is ready to be displayed. */
/* Taken from the "Character Recognition Neural Network" world. */
let demoExists = false;

/* Whether or not we have made a prediction about the current demo image. */
let demoGuessed = false;

/* To store our doodle. */
/* Taken from the "Character Recognition Neural Network" world. */
let doodle;

/* Whether or not a doodle is ready to be displayed. */
/* Taken from the "Character Recognition Neural Network" world. */
let doodleExists = false;

/* Whether or not we have made a prediction about this doodle yet. */
let doodleGuessed = false;

/* Whether or not the training process has complete. */
let modelTrained = false;

/* Denotes whether or not the mouse has been dragged. */
/* Taken from the "Character Recognition Neural Network" world. */
let mouseDrag = false;

/* The HTML that will be used to display any messages. */
/* Taken from the "Character Recognition Neural Network" world. */
let theHTML;

/* The model that we are going to train. */
let networkModel;

/* Some tensors, containing the images and labels with which we will train and test our model. */
let encodedTestLabels;
let encodedTrainingLabels;
let testImages;
let testLabels;
let trainingImages;
let trainingLabels;

/*
    This function will do the following:
        - Read in our dataset, separating it out into training and testing data, and carrying out any required normalisation.
        - Construct our neural network using Tensorflow.js.
        - Train our neural network and produce a model to be used to predict which character from A to Z that a user-provided doodle is representing.
*/
function setup() {
    /* Initialise the canvas. */
    createCanvas(canvasWidth, canvasHeight);
    
    /* Initialise the doodle. */
    doodle = createGraphics(zoomPixels, zoomPixels);
    doodle.pixelDensity(1);
    
    /* Initialise some elements on the UI. */
    /* The header for the doodle. */
    theHTML = "<hr><h1>1. Doodle</h1>Please draw your doodle in the top-left.</br></br>" +
              "To clear the doodle, click here: <button onclick='wipeDoodle();' class='normbutton'>Clear</button></br>";
    AB.msg(theHTML, 1);
    
    /* The header for the demo. */
    theHTML = "<hr><h1>2. Demo</h1>Demonstrations of the trained network can be viewed here.</br></br>" +
              "The network is <i>not</i> trained on any of these images.</br></br>" +
              "To test out an image, click here: <button onclick='makeDemo();' class='normbutton'>Test Image</button></br>";
    AB.msg(theHTML, 3);
    
    /* The header for the demo. */
    theHTML = "<hr><h1>3. Model Settings</h1>" +
              "To delete the currently trained model, click here: <button onclick='localStorage.clear();'" +
              " class='normbutton'>Delete Trained Model</button></br>";
    AB.msg(theHTML, 6);
    
    /* Display a splash screen while the neural network is being trained. */
    AB.newSplash();
    AB.splashHtml("<h1>Training in Progress...</h1>");
    
    /* Load Tensorflow.js at the very beginning. */
    $.getScript("https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js", () => {
        /*
            In order to avoid replication, I will be using pre-uploaded binary files containing the mnist dataset, belonging to Vrushali Golde:
                - https://ancientbrain.com/uploads/vrushali/emnist-letters-train-images-idx3-ubyte.bin
                - https://ancientbrain.com/uploads/vrushali/emnist-letters-train-labels-idx1-ubyte.bin
                - https://ancientbrain.com/uploads/vrushali/emnist-letters-test-images-idx3-ubyte.bin
                - https://ancientbrain.com/uploads/vrushali/emnist-letters-test-labels-idx1-ubyte.bin
            
            The following code is a modified form of something that was written by Vrushali Golde to read in the training and
            test datasets that we will be using: https://ancientbrain.com/uploads/vrushali/vru_mnist.js.
        */
        /* Determine whether or not a model has been trained before by checking the local storage for this browser. */
        existingModel = localStorage.getItem("tensorflowjs_models/drummk2-tensorflow-network-model/info");

        /* If a model has already been trained, then skip the training phase,
           otherwise, train a model from scratch. */
        if (existingModel) {
            tf.loadLayersModel("localstorage://drummk2-tensorflow-network-model").then((loadedModel) => {
                console.log("Loading test data");
                loadTestData();
                
                console.log("Pre-existing model loaded from local storage.");
                networkModel = loadedModel;
                modelTrained = true;
            });
        } else {
            console.log("No model was found in local storage; commencing with training.");
            
            /* Load all necessary files. */
            let mnistData = {};
            
            /* Intialise all of the URLs for our files. */
            let files = {
                trainingImages: "/uploads/vrushali/emnist-letters-train-images-idx3-ubyte.bin",
                trainingLabels: "/uploads/vrushali/emnist-letters-train-labels-idx1-ubyte.bin",
                testImages: "/uploads/vrushali/emnist-letters-test-images-idx3-ubyte.bin",
                testLabels: "/uploads/vrushali/emnist-letters-test-labels-idx1-ubyte.bin"
            };
            
            /* Retrieve all of the files in question, then store them in a centralised object. */
            Promise.all(Object.keys(files).map(async (file) => {
                mnistData[file] = await loadBinaryFile(files[file]);
            })).then(() => {
                /* Store all of the MNIST data that was retrieved. */
                /* Convert all of our individual training images from a 1D array to a 28x28x1 tensor. */
                trainingImages = tf.reshape(tf.tensor(mnistData.trainingImages, [124800, 784], "float32"), [124800, 28, 28, 1]);
    
                /* Convert our training labels to a tensor. */
                trainingLabels = tf.tensor1d(mnistData.trainingLabels, "int32");
                
                /* Store our test images. */
                testImages = mnistData.testImages;
                
                /* Store our test labels. */
                testLabels = mnistData.testLabels;
                
                /* Encode the training and test labels using one-hot encoding. */
                encodedTestLabels = tf.oneHot(testLabels, 26);
                encodedTrainingLabels = tf.oneHot(trainingLabels, 26);
                
                /* Initialise and configure our model, before we begin training. */
                networkModel = tf.sequential();
                
                /* Configure our first hidden layer. */
                networkModel.add(tf.layers.conv2d({
                    activation: "tanh",
                    filters: 8,
                    inputShape: trainingImages.shape.slice(1),
                    kernelInitializer: "glorotUniform",
                    kernelSize: 5,
                    strides: 1
                }));
                
                /* Carry out a max pooling operation on the output of the first hidden layer. */
                networkModel.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));
                
                /* Configure our second hidden layer, this time with additional filters. */
                networkModel.add(tf.layers.conv2d({
                    activation: "tanh",
                    filters: 16,
                    inputShape: trainingImages.shape.slice(1),
                    kernelInitializer: "glorotUniform",
                    kernelSize: 5,
                    strides: 1
                }));
                
                /* Carry out a max pooling operation on the output of the second hidden layer. */
                networkModel.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));
                
                /* Flatten the output of our second hidden layer, so that it can be sent into the output layer. */
                networkModel.add(tf.layers.flatten());
                
                /* Configure our output layer. It's worth noting that:
                       - I have chosen to use the "Softmax" function for my activation, as it is ideal for multi-class classification problems such as this,
                         and it works well with categorical cross entropy, due to the fact that it is continuously differentiable, which I intend to use
                         for my loss function for the model.
                       - This layer will have 26 units to represents each letter of the alphabet.
                */
                networkModel.add(tf.layers.dense({
                    activation: "softmax",
                    kernelInitializer: 'glorotUniform',
                    units: 26
                }));
                
                /* Initialise an optimiser and a loss function for our network. I will be using stochastic gradient descent
                   for optimisation and categorical cross entropy for my loss function. */
                networkModel.compile({
                    loss: "categoricalCrossentropy",
                    optimizer: tf.train.sgd(0.0001),
                });
                
                /* Train the model for 'n' epochs, ensuring that we use 10% of our training set for validation,
                   and reshuffle our dataset between epochs. */
                let options = {
                    callbacks: {
                        onBatchEnd: (batchNum, logs) => {
                            AB.splashHtml(splashMessage + "<p><b>Epoch:</b> " + (currentEpoch) + "/" + options.epochs + 
                                          "</br><b>Batch:</b> " + batchNum + "</br><b>Loss:</b> " + logs.loss);
                        },
                        onEpochBegin: (epochNum, logs) => {
                            currentEpoch = epochNum + 1;
                        }
                    },
                    batchSize: 512,
                    epochs: numEpochs,
                    shuffle: true,
                    validationSplit: 0.1
                };
                
                networkModel.fit(trainingImages, encodedTrainingLabels, options).then(() => {
                    networkModel.save("localstorage://drummk2-tensorflow-network-model").then(() => {
                        console.log("Model saved to local storage.");
                        modelTrained = true;
                    });
                });
            });
        }
    });
}

/*
    This function will do the following:
        - Graphically illustrate the training process.
        - Process any input (doodles) from the user.
*/
function draw() {
    /* Remove the splash screen, once the training process has complete. */
    if (modelTrained) {
        AB.removeSplash();
        
        /* Initialise a black background. */
        background('black');
        
        /* Has the user requested for a test to be run? */
        /* Taken from the "Character Recognition Neural Network" world. */
        if (demoExists) {
            drawDemo();

            if (!demoGuessed) {
                guessDemo();
            }
        }
        
        /* Has a doodle been drawn on the canvas? */
        /* Taken from the "Character Recognition Neural Network" world. */
        if (doodleExists) {
            /* Draw the doodle that has been done. */
            drawDoodle();
            
            /* Use our neural network to determine the character that the doodle is displaying. */
            if (!doodleGuessed) {
                guessDoodle();
                doodleGuessed = true;
            }
        }
        
        /* Detect any mouse press actions, and allow the user to doodle a character. */
        /* Taken from the "Character Recognition Neural Network" world. */
        if (mouseIsPressed) {
            var maxBoundary = zoomPixels + 50;
            if ((mouseX < maxBoundary) && (mouseY < maxBoundary) && 
                (pmouseX < maxBoundary) && (pmouseY < maxBoundary)) {
                mouseDrag = true;
                doodleExists = true;
                doodleGuessed = false;
                doodle.stroke('white');
                doodle.strokeWeight(doodleThickness);
                doodle.line(mouseX, mouseY, pmouseX, pmouseY);      
            }
        } else {
            if (mouseDrag) {
                mouseDrag = false;
                doodle.filter(BLUR, doodleBlur);
            }
        }
    }
}

/* Utility Functions */
/* Draw a demo image. This code has been taken from the "Character Recognition Neural Network" world. */
function drawDemo() {
    /* Retrieve the demo image. */
    let theImage = getImage(demo);
    
    /* Draw the demo. */ 
    image(theImage, 0, canvasHeight - zoomPixels, zoomPixels, zoomPixels);
    image(theImage, zoomPixels + 50, canvasHeight - zoomPixels, pixels, pixels);
}

/* Draw a doodle image. This code has been taken from the "Character Recognition Neural Network" world. */
function drawDoodle() {
    /* Retrieve the doodle image. */
    let theImage = doodle.get();
    
    /* Draw the doodle. */
    image(theImage, 0, 0, zoomPixels, zoomPixels);
    image(theImage, zoomPixels + 50, 0, pixels, pixels);
}

/* Construct a P5 image from an array of pixels. This code has been taken from
   the "Character Recognition Neural Network" world.*/
function getImage(image) {
    /* Construct a blank image, to be populated. */
    let theImage = createImage(pixels, pixels); 
    theImage.loadPixels();        
    
    /* Populate the blank image. */
    for (let i = 0; i < pixelsSquared ; i++) {
        let bright = image[i];
        let index = i * 4;
        theImage.pixels[index + 0] = bright;
        theImage.pixels[index + 1] = bright;
        theImage.pixels[index + 2] = bright;
        theImage.pixels[index + 3] = 255;
    }
    
    /* Return the newly constructed image. */
    theImage.updatePixels();
    return theImage;
}

/* Retrieve the pixels for a given image. This code has been taken from 
   the "Character Recognition Neural Network" world. */
function getInputs(image) {
    let inputs = [];
    for (let i = 0; i < pixelsSquared ; i++) {
        let bright = image[i];
        inputs[i] = bright / 255;
    }
    
    return inputs;
}

/* Predict the character that is present in the demo image. This code is a modification of something
   taken from the "Character Recognition Neural Network" world. */
function guessDemo() {
    /* Retrieve the pixels for our demo image. */
    let inputs = getInputs(demo); 
  
    /* Convert these pixels into a tensor, to be used by Tensorflow.js. */
    let inputTensor = tf.reshape(tf.tensor(inputs, [1, 784], "int32"), [1, 28, 28, 1]);
    
    /* Predict the character that is present in the image using Tensorflow.js. Since
       Tensorflow.js will return a 1D array with probabilities for every characters,
       we must find the character in the prediction with the highest level of confidence.
    */
    let prediction = networkModel.predict(inputTensor).arraySync()[0];
    
    /* Retrieve the index of the character with the highest confidence level. */
    let predictedIndex = prediction.indexOf(Math.max(...prediction));
    
    /* Convert this index to the character to which it corresponds. */
    let predictedCharacter = String.fromCharCode(97 + predictedIndex)

    /* Output our prediction. */
    theHTML = "The network classified it as: " + greenSpan + predictedCharacter + "</span>";
    AB.msg(theHTML, 5);
    
    demoGuessed = true;
}

/* Predict the character that is present in the doodle image. This code has been taken
   from the "Character Recognition Neural Network" world. */
function guessDoodle() {
    /* Retrieve the image in question. */
    let image = doodle.get();
    
    /* Resize the doodle to fit the dimensions required for our model, and load the image. */
    image.resize(pixels, pixels);     
    image.loadPixels();
   
    /* Retrieve the pixel values for the image. */
    let inputs = [];
    for (let i = 0; i < pixelsSquared; i++) {
        inputs[i] = image.pixels[i * 4] / 255;
    }
    
    /* Convert these pixels into a tensor, to be used by Tensorflow.js. */
    let inputTensor = tf.reshape(tf.tensor(inputs, [1, 784], "int32"), [1, 28, 28, 1]);
    
    /* Predict the character that is present in the image using Tensorflow.js. Since
       Tensorflow.js will return a 1D array with probabilities for every characters,
       we must find the character in the prediction with the highest level of confidence.
    */
    let prediction = networkModel.predict(inputTensor).arraySync()[0];
    
    /* Retrieve the index of the character with the highest confidence level. */
    let predictedIndex = prediction.indexOf(Math.max(...prediction));
    
    /* Convert this index to the character to which it corresponds. */
    let predictedCharacter = String.fromCharCode(97 + predictedIndex)
    
    /* Clear the previous classification. */
    theHTML = "Classification:</span>";
    AB.msg(theHTML, 2);
    
    /* Display our prediction to the user. */
    theHTML = "Classification: " + greenSpan + predictedCharacter + "</span>";
    AB.msg(theHTML, 2);
}

/* Read a binary file from a specified URL. The following code is a modified form of something that was
   written by Vrushali Golde to read in the training and test datasets that we will be using:
   https://ancientbrain.com/uploads/vrushali/vru_mnist.js.
*/
async function loadBinaryFile(fileURL) {
    let buffer = await fetch(fileURL).then(r => r.arrayBuffer());
    let headerCount = 4;
    let headerView = new DataView(buffer, 0, 4 * headerCount);
    let headers = new Array(headerCount).fill().map((_, i) => headerView.getUint32(4 * i, false));

    /* Ascertain the type of file that we are dealing with, using the magic number. */
    let type, dataLength;
    
    /* Does the file contain training labels? */
    if (headers[0] === 2049) {
        type = "label";
        dataLength = 1;
        headerCount = 2;
    /* Does the file contain training images. */
    } else if (headers[0] === 2051) {
        type = "image";
        dataLength = headers[2] * headers[3];
    }
    
    /* If the file contains images, then load all of the data from this file into an array,
       otherwise, if it contains labels, just return the binary data. */
    let bufferData = Int32Array.from(new Uint8Array(buffer, headerCount * 4));
    if (type === "image") {
        let bufferArray = [];
        for (let i = 0; i < headers[1]; i++) {
            bufferArray.push(bufferData.subarray(dataLength * i, dataLength * (i + 1)));
        }
        
        return bufferArray;
    }
    
    /* Otherwise, just return the labels. */
    return bufferData;
}

/* Load our test data. To be used when a pre-existing model is loaded,
   but the test data must still be retrieved. */
function loadTestData() {
    /* Load all necessary files. */
    let mnistData = {};
    
    /* Intialise all of the URLs for our test files. */
    let testFiles = {
        testImages: "/uploads/vrushali/emnist-letters-test-images-idx3-ubyte.bin",
        testLabels: "/uploads/vrushali/emnist-letters-test-labels-idx1-ubyte.bin"
    };
    
    /* Retrieve all of the files in question, then store them in a centralised object. */
    Promise.all(Object.keys(testFiles).map(async (testFile) => {
        mnistData[testFile] = await loadBinaryFile(testFiles[testFile]);
    })).then(() => {
        /* Store all of the MNIST data that was retrieved. */
        /* Store our test images. */
        testImages = mnistData.testImages;
        
        /* Store our test labels. */
        testLabels = mnistData.testLabels;
        
        /* Encode the test labels using one-hot encoding. */
        encodedTestLabels = tf.oneHot(testLabels, 26);
    });
}

/* Retrieve an image from the test data to demo. This code has been taken from the "Character Recognition Neural Network" world. */
function makeDemo() {
    /* Retrieve a random image from our test data. */
    demoExists = true;
    let randomIndex = AB.randomIntAtoB(0, testImages.length - 1);  
    
    demo = testImages[randomIndex];
    let testLabel = testLabels[randomIndex];
    
    /* Construct a formatted message about our test sample. */
    theHTML = "Test image no: " + randomIndex + "<br>" + 
              "Classification: " + String.fromCharCode(97 + testLabel) + "<br>";
    
    /* Display the details of the image that we have randomly chosen. */
    AB.msg(theHTML, 4);
    
    demoGuessed = false;
}

/* Wipe our canvas. This code has been taken from the "Character Recognition Neural Network" world. */
function wipeDoodle() {
    /* Clear the canvas. */
    doodleExists = false;
    doodle.background("black");
    
    /* Clear the previous classification. */
    theHTML = "Classification:</span>";
    AB.msg(theHTML, 2);
}