Code viewer for World: Doodle recogniser for a-z
// Cloned by Tristan Everitt on 25 Nov 2022 from World "Character recognition neural network" by "Coding Train" project
// Please leave this clone trail here.


// Port of Character recognition neural network from here:
// https://github.com/CodingTrain/Toy-Neural-Network-JS/tree/master/examples/mnist
// with many modifications


// Options: balanced, byclass, bymerge, digits, letters, mnist
const mnistType = 'balanced';

// Options: MultilayerPerceptron, Convolutional, TensorConvolutional
const nnType = 'Convolutional';
const showCharts = true;

let persistedDoodles;


// --- defined by MNIST - do not change these ---------------------------------------

const PIXELS = 28;                       // images in data set are tiny
const PIXELSSQUARED = PIXELS * PIXELS;


//--- can modify all these --------------------------------------------------

// should we train every timestep or not
let do_training = true;

// how many to train and test per timestep
const TRAINPERSTEP = 50;
const TESTPERSTEP = 10;

// multiply it by this to magnify for display
const ZOOMFACTOR = 7;
const ZOOMPIXELS = ZOOMFACTOR * PIXELS;

// 3 rows of
// large image + 50 gap + small image
// 50 gap between rows

const canvaswidth = (PIXELS + ZOOMPIXELS) + 50;
const canvasheight = (ZOOMPIXELS * 3) + 100;


const DOODLE_THICK = 18;    // thickness of doodle lines
const DOODLE_BLUR = 3;      // blur factor applied to doodles

let mnist;
let mapping;
let uiRefreshSteps = 100;
let transposeOnLoad = false;

let empiricalChart;
let trainAccuracyChart;
let lossChart;
const persistedDataUrlPrefix = '/uploads/tristan/';
const emnistDataUrlPrefix = 'https://d32y778xrl4hxt.cloudfront.net/emnist/'; // 'uploads/emnist/'

let empiricalDatasets = {};
empiricalDatasets['train'] = {label: 'Train', data: [], dataMap: {}, borderWidth: 1};
empiricalDatasets['test'] = {label: 'Test', data: [], dataMap: {}, borderWidth: 1};

let nn;

let trainrun = 1;
let train_index = 0;

let testrun = 1;
let test_index = 0;

// images in LHS:
let doodle, demo;
let doodle_exists = false;
let demo_exists = false;

let mousedrag = false;      // are we in the middle of a mouse drag drawing?


// make run header bigger
AB.headerCSS({"max-height": "95vh"});


//--- start of AB.msgs structure: ---------------------------------------------------------
// We output a serious of AB.msgs to put data at various places in the run header
var thehtml;

// 1 Doodle header
thehtml = "<hr> <h1> 1. Doodle </h1> Top row: Doodle (left) and shrunk (right). <br> " +
    " Draw your doodle in top LHS. <button onclick='wipeDoodle();' class='normbutton' >Clear doodle</button> <br> " +
    " Print your doodle to the console. <button onclick='printDoodleToConsole();' class='normbutton' >Print doodle</button> <br> ";
AB.msg(thehtml, 1);

// 2 Doodle variable data (guess)

// 3 Training header
thehtml = "<hr> <h1> 2. Training: " + nnType + "</h1><br/>Middle row: Training image magnified (left) and original (right). <br>  " +
    " <button onclick='do_training = !do_training;' class='normbutton' >Toggle Model Training</button> <br> ";
if ('Convolutional' === nnType) {
    thehtml += " <button onclick='loadBestModel();' class='normbutton' >Load Trained Model</button> <br> ";
}
AB.msg(thehtml, 3);

// 4 variable training data

// 5 Testing header
thehtml = "<h3> Hidden tests </h3> ";
AB.msg(thehtml, 5);

// 6 variable testing data

// 7 Demo header
thehtml = "<hr> <h1> 3. Demo </h1>Bottom row: Test image magnified (left) and  original (right). <br>" +
    " The network is <i>not</i> trained on any of these images. <br> " +
    " <button onclick='makeDemo();' class='normbutton' >Demo mnist image</button><button onclick='loadRandomDoodle();' class='normbutton' >Demo Random doodle</button> <br> ";
AB.msg(thehtml, 7);

// 8 Empirical Chart
//thehtml = "<hr><h1> 4. Empirical Data</h1> <canvas id=\"empiricalChart\"/> <br> ";
//AB.msg(thehtml, 7);

// 8 Demo variable data (random demo ID)
// 9 Demo variable data (changing guess)

const greenspan = "<span style='font-weight:bold; font-size:x-large; color:darkgreen'> ";

//--- end of AB.msgs structure: ---------------------------------------------------------


function setup() {
    createCanvas(canvaswidth, canvasheight);
    wipeDoodle();

// JS load other JS
// maybe have a loading screen while loading the JS and the data set

    AB.loadingScreen();

    $.when(
        // Source: /uploads/codingtrain/matrix.js
        $.getScript(persistedDataUrlPrefix + 'matrix.js'),

        // Source: /uploads/codingtrain/nn.js
        $.getScript(persistedDataUrlPrefix + 'nn.js'),

        // Source /uploads/codingtrain/mnist.js
        $.getScript(persistedDataUrlPrefix + 'mnist.js'),

        // Source: http://cs.stanford.edu/people/karpathy/convnetjs/build/convnet-min.js
        $.getScript(persistedDataUrlPrefix + 'convnet-min.js'),

        // Source: https://cs.stanford.edu/people/karpathy/convnetjs/build/util.js
        $.getScript(persistedDataUrlPrefix + 'convnet_util.js'),

        // Source:  https://cdn.jsdelivr.net/npm/chart.js@4.0.1/dist/chart.umd.js
        $.getScript('https://cdn.jsdelivr.net/npm/chart.js@4.0.1/dist/chart.umd.js'), //persistedDataUrlPrefix + "chart.js"),

        // Source: https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.min.js
        $.getScript('https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.min.js'), //persistedDataUrlPrefix + "tf.min.js"),

        // Source: https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis@1.0.2/dist/tfjs-vis.umd.min.js
        $.getScript('https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis@1.0.2/dist/tfjs-vis.umd.min.js'), //persistedDataUrlPrefix + "tfjs-vis.umd.min.js"),
        $.Deferred(function (deferred) {
            $(deferred.resolve);
        })
    ).done(function () {
        console.log("All JS loaded");
        loadData();
    });
}


// load data set from local file (on this server)

function loadData() {
    loadMnistMapping(function (data) {
        mapping = data;
    });
}

function loadMnistMapping(callback) {
    fetch(`${emnistDataUrlPrefix}${mnistType}/emnist-${mnistType}-mapping.txt`)
        .then(r => r.text().then(text => {
                const maps = {};
                text.split('\n').forEach(line => {
                    const split = line.split(' ');
                    maps[split[0]] = split[1];
                })
                return maps;
            })
        )
        .then((maps) => callback(maps))
        .then(() => createNetwork())
        .then(() => loadPersistedDoodles())
        .then(() => loadMnistData(function (data) {
            mnist = data;

            // Download mappings
            console.log("All data loaded into mnist object:")
            //console.log(mnist);

            if (showCharts) {
                setupCharts();
            }

            AB.removeLoading();     // if no loading screen exists, this does nothing
        }));
}

function createNetwork() {
    switch (nnType) {
        case 'MultilayerPerceptron':
            nn = new MultilayerPerceptron(PIXELS, PIXELS, Object.keys(mapping).length);
            break;
        case 'Convolutional':
            nn = new Convolutional(PIXELS, PIXELS, Object.keys(mapping).length);
            break;
        case 'TensorConvolutional':
            nn = new TensorConvolutional(PIXELS, PIXELS, Object.keys(mapping).length);
            break;
        default:
            throw new Error("Unsupported NN type: " + nnType);
    }
}

function loadMnistData(callback) {
    let mnist = {};
    let files = {
        train_images: `${emnistDataUrlPrefix}${mnistType}/emnist-${mnistType}-train-images-idx3-ubyte`,
        train_labels: `${emnistDataUrlPrefix}${mnistType}/emnist-${mnistType}-train-labels-idx1-ubyte`,
        test_images: `${emnistDataUrlPrefix}${mnistType}/emnist-${mnistType}-test-images-idx3-ubyte`,
        test_labels: `${emnistDataUrlPrefix}${mnistType}/emnist-${mnistType}-test-labels-idx1-ubyte`,
    };

    // Download files
    return Promise.all(Object.keys(files).map(async file => {
        mnist[file] = await loadMnistFile(files[file])
    }))
        .then(() => callback(mnist));
}

async function loadMnistFile(file) {
    console.log('Downloading ' + file);
    let buffer = await fetch(file).then(r => r.arrayBuffer());
    let headerCount = 4;
    let headerView = new DataView(buffer, 0, 4 * headerCount);
    let headers = new Array(headerCount).fill().map((_, i) => headerView.getUint32(4 * i, false));

    // Get file type from the magic number
    let type, dataLength;
    if (headers[0] === 2049) {
        type = 'label';
        dataLength = 1;
        headerCount = 2;
    } else if (headers[0] === 2051) {
        type = 'image';
        dataLength = headers[2] * headers[3];
    } else {
        throw new Error("Unknown file type " + headers[0])
    }

    let data = new Uint8Array(buffer, headerCount * 4);
    if (type === 'image') {
        let dataArr = [];
        for (let i = 0; i < headers[1]; i++) {
            let oneDArray = data.subarray(dataLength * i, dataLength * (i + 1))
            if (transposeOnLoad) {
                oneDArray = transposeImage(oneDArray);
            }
            dataArr.push(oneDArray);
        }

        return dataArr;
    }
    console.log('Processed ' + file);
    return data;
}


function loadPersistedDoodles() {
    const url = persistedDataUrlPrefix + 'a-z_doodles.json';
    console.log('Downloading ' + url);
    fetch(url)
        .then(res => res.json())
        .then(out => persistedDoodles = out);
}

function loadBestModel() {
    if (nnType !== 'Convolutional') {
        throw new Error("Loading model unsupported for NN type: " + nnType);
    }
    let prevTrainingState = do_training;
    do_training = false;
    let url = `${persistedDataUrlPrefix}a-z_best_model.json`;
    console.log('Fetching: ' + url);
    fetch(url)
        .then(r => r.text()
            .then(json => {
                createNetwork();
                nn.network.fromJSON(JSON.parse(json));
                console.log('Loaded ' + url);
            })
            .then(() => do_training = prevTrainingState)
        );
}

function normaliseImage(img) {
    const normalised = new Array(img.length);
    for (let i = 0; i < img.length; i++) {
        const v = img[i];
        // if over 1 then it's already been normalised
        normalised[i] = v > 1 ? v / 255 : v;  // normalise to be between 0 and 1
    }
    return normalised;
}

// Convert 1D array into a 2D, then transpose, then flatten back to 1D
function transposeImage(img) {

    /*
    Alternative way using tensorflow
    const oneD = tf.tensor1d(img);
    const twoD = tf.reshape(oneD,[PIXELS,PIXELS]);
    const transposed = twoD.transpose();
    const oneDAgain = tf.reshape(transposed,[-1]).arraySync();

    return oneDAgain;
    */

    // Convert 1D array into a 2D array
    const rows = PIXELS;
    const cols = PIXELS;
    const twoDArray = new Array(rows);
    for (let row = 0; row < rows; row++) {
        twoDArray[row] = new Array(cols);
    }
    for (let row = 0; row < rows; row++) {
        for (let col = 0; col < cols; col++) {
            twoDArray[row][col] = img[row * cols + col];
        }
    }

    // Transpose 2D array (flip it)
    const transposed = twoDArray[0].map((_, colIndex) => twoDArray.map(row => row[colIndex]));

    // Flatten the 2D array back into a 1D array

    img = [];
    for (row of transposed) {
        for (col of row) {
            img.push(col);
        }
    }
    return img;
}

function randomWeight() {
    return (AB.randomFloatAtoB(-0.5, 0.5));
    // Coding Train default is -1 to 1
}

function getImage(img)      // make a P5 image object from a raw data array
{
    if (!transposeOnLoad) {
        img = transposeImage(img);
    }
    let theimage = createImage(PIXELS, PIXELS);    // make blank image, then populate it
    theimage.loadPixels();

    for (let i = 0; i < PIXELSSQUARED; i++) {
        let bright = img[i];
        let index = i * 4;
        theimage.pixels[index + 0] = bright;
        theimage.pixels[index + 1] = bright;
        theimage.pixels[index + 2] = bright;
        theimage.pixels[index + 3] = 255;
    }

    theimage.updatePixels();
    return theimage;
}

function trainit(batch)        // train the network with a single exemplar, from global var "train_index", show visual on or off
{
    const images = [];
    const labels = [];
    batch.forEach(v => {
        const label = v[0];
        const img = v[1];

        images.push(img);
        labels.push(label);
        const char = fromCharCode(label);
        if (empiricalDatasets['train']['dataMap'][char]) {
            empiricalDatasets['train']['dataMap'][char] = empiricalDatasets['train']['dataMap'][char] + 1;
        } else {
            empiricalDatasets['train']['dataMap'][char] = 1;
        }
    });

    const showImg = images[images.length - 1];

    // Display the image and the magnified version of it
    let theimage = getImage(showImg);    // get image from data array
    image(theimage, 0, ZOOMPIXELS + 50, ZOOMPIXELS, ZOOMPIXELS);      // magnified
    image(theimage, ZOOMPIXELS + 50, ZOOMPIXELS + 50, PIXELS, PIXELS);      // original

    nn.train(images, labels);

    const progress = ((train_index / mnist.train_images.length) * 100).toFixed(3);

    thehtml = " trainrun: " + trainrun + "<br> no: " + train_index + " of " + mnist.train_images.length + " (" + progress + "%)";
    AB.msg(thehtml, 4);

    if (train_index === mnist.train_images.length) {
        train_index = 0;
        console.log("finished trainrun: " + trainrun);
        trainrun++;
    }
}


function testit(doodlesCorrect)    // test the network with a single exemplar, from global var "test_index"
{
    let img = mnist.test_images[test_index];
    let label = mnist.test_labels[test_index];
    const char = fromCharCode(label);
    if (empiricalDatasets['test']['dataMap'][char]) {
        empiricalDatasets['test']['dataMap'][char] = empiricalDatasets['test']['dataMap'][char] + 1;
    } else {
        empiricalDatasets['test']['dataMap'][char] = 1;
    }

    const guess = nn.predict(img);
    nn.recordResult(guess === label);

    const percent = nn.accuracy * 100;

    const progress = ((nn.total_tests / mnist.test_images.length) * 100).toFixed(3);
    const doodlePercent = (doodlesCorrect[0] / persistedDoodles.doodles.length) * 100.0;
    const doodlePercentIgnoreCase = (doodlesCorrect[1] / persistedDoodles.doodles.length) * 100.0;

    thehtml = " testrun: " + testrun + "<br> no: " + nn.total_tests + " of " + mnist.test_images.length + " (" + progress + "%) <br>" +
        " correct: " + nn.total_correct + "<br>" +
        "  mnist score: " + greenspan + percent.toFixed(2) + "</span>" +
        " <hr/>persisted doodles: " + persistedDoodles.doodles.length + "<br>" +
        " doodle score: " + greenspan + doodlePercent.toFixed(2) + "</span>" +
        " Case Insensitive: " + greenspan + doodlePercentIgnoreCase.toFixed(2) + "</span>";

    AB.msg(thehtml, 6);

    test_index++;
    if (test_index === mnist.test_images.length) {
        console.log("finished testrun: " + testrun + " score: " + percent.toFixed(2));
        testrun++;
        test_index = 0;
        nn.total_tests = 0;
        nn.total_correct = 0;
    }
}

function calculateDoodleCorrect() {
    // The doodle stats
    let doodlesCorrect = 0;
    let doodlesCorrectIgnoreCase = 0;
    persistedDoodles.doodles.forEach(entry => {
        const doodleChar = Object.keys(entry)[0];
        const doodleImg = new Uint8Array(Object.values(entry)[0].split(','));

        const doodleGuess = fromCharCode(nn.predict(doodleImg));
        if (doodleGuess === doodleChar) {
            doodlesCorrect++;
        }
        if (doodleGuess.toLowerCase() === doodleChar.toLowerCase()) {
            doodlesCorrectIgnoreCase++;
        }
    });

    return [doodlesCorrect, doodlesCorrectIgnoreCase];
}


// --- the draw function -------------------------------------------------------------
// every step:
let trainingStep = 0;
let testingStep = 0;

function draw() {
    // check if libraries and data loaded yet:
    if (typeof mnist == 'undefined') return;


// how can we get white doodle on black background on yellow canvas?
//        background('#ffffcc');    doodle.background('black');

    background('black');

    if (do_training) {

        let batch = [];
        for (let i = 0; i < TRAINPERSTEP; i++) {
            const img = mnist.train_images[train_index];
            const label = mnist.train_labels[train_index];
            batch.push([label, img]);
            train_index++;
            trainingStep++;
        }

        trainit(batch);

        const doodlesCorrect = calculateDoodleCorrect();
        // do some testing per step
        for (let i = 0; i < TESTPERSTEP; i++) {
            testit(doodlesCorrect);
            testingStep++;
        }
    }

    // keep drawing demo and doodle images
    // and keep guessing - we will update our guess as time goes on

    if (demo_exists) {
        drawDemo();
        guessDemo();
    }
    if (doodle_exists) {
        drawDoodle();
        guessDoodle();
    }


// detect doodle drawing
// (restriction) the following assumes doodle starts at 0,0

    if (mouseIsPressed)         // gets called when we click buttons, as well as if in doodle corner
    {
        // console.log ( mouseX + " " + mouseY + " " + pmouseX + " " + pmouseY );
        var MAX = ZOOMPIXELS + 20;     // can draw up to this pixels in corner
        if ((mouseX < MAX) && (mouseY < MAX) && (pmouseX < MAX) && (pmouseY < MAX)) {
            mousedrag = true;       // start a mouse drag
            doodle_exists = true;
            doodle.stroke('white');
            doodle.strokeWeight(DOODLE_THICK);
            doodle.line(mouseX, mouseY, pmouseX, pmouseY);
        }
    } else {
        // are we exiting a drawing
        if (mousedrag) {
            mousedrag = false;
            // console.log ("Exiting draw. Now blurring.");
            doodle.filter(BLUR, DOODLE_BLUR);    // just blur once
            //   console.log (doodle);
        }
    }

    if (showCharts && train_index % uiRefreshSteps === 0) {
        updateCharts();
    }
}


//--- demo -------------------------------------------------------------
// demo some test image and predict it
// get it from test set so have not used it in training


function makeDemo() {
    demo_exists = true;
    var i = AB.randomIntAtoB(0, mnist.test_images.length - 1);

    demo = mnist.test_images[i];
    var label = mnist.test_labels[i];

    thehtml = "Test image no: " + i + "<br>";//+
    // "Classification: " + label + "<br>";
    AB.msg(thehtml, 8);

}

function getSanitisedDoodlePixels() {
    // doodle is createGraphics not createImage
    let img = doodle.get();
    img.resize(PIXELS, PIXELS);
    img.loadPixels();

    let pixels = [];
    for (let y = 0; y < img.height; y++) {
        for (let x = 0; x < img.width; x++) {
            let px = img.get(x, y);
            //let r = px[0];
            //let g = px[1];
            // let b = px[2];
            let a = px[3];
            pixels.push(a);
        }
    }
    return new Array(transposeImage(pixels));
}

function drawDemo() {
    let theimage = getImage(demo);    // get image from data array
    image(theimage, 0, canvasheight - ZOOMPIXELS, ZOOMPIXELS, ZOOMPIXELS);      // magnified
    image(theimage, ZOOMPIXELS + 50, canvasheight - ZOOMPIXELS, PIXELS, PIXELS);      // original
}

function drawDoodle() {
    // doodle is createGraphics not createImage
    let theimage = doodle.get();
    image(theimage, 0, 0, ZOOMPIXELS, ZOOMPIXELS);      // original
    image(theimage, ZOOMPIXELS + 50, 0, PIXELS, PIXELS);      // shrunk
}


function guessDoodle() {
    // set up inputs
    // Workaround when the data produced doesn't map properly to the int8Array
    // In other words, produce a string with an array of 0-255 and then parse it back.
    let inputs = new Uint8Array(getSanitisedDoodlePixels().toString().split(','));

    // feed forward to make prediction
    let predictions = nn.predictTwo(inputs);       // get no.1 and no.2 guesses
    let predictionsStr = [];
    predictions.forEach(prediction => {
        predictionsStr.push(fromCharCode(prediction));
    });

    thehtml = " We classify it as: " + greenspan + predictionsStr[0] + "</span> <br>" +
        " No.2 guess is: " + greenspan + predictionsStr[1] + "</span>";
    AB.msg(thehtml, 2);
}

function guessDemo() {
    let predictions = nn.predictTwo(demo);       // get no.1 and no.2 guesses
    let predictionsStr = [];
    predictions.forEach(prediction => {
        predictionsStr.push(fromCharCode(prediction));
    });

    thehtml = " We classify it as: " + greenspan + predictionsStr[0] + "</span> <br>" +
        " No.2 guess is: " + greenspan + predictionsStr[1] + "</span>";
    AB.msg(thehtml, 9);
}

function fromCharCode(label) {
    return String.fromCharCode(mapping[label]);
}

function wipeDoodle() {
    doodle_exists = false;
    doodle = createGraphics(ZOOMPIXELS, ZOOMPIXELS);//,'webgl');       // doodle on larger canvas
    doodle.pixelDensity(1);
    //doodle.background('black');
}

function loadRandomDoodle() {
    const entry = AB.randomElementOfArray(persistedDoodles.doodles);
    const k = Object.keys(entry)[0];
    const v = Object.values(entry)[0];
    console.log('Loaded random doodle: ' + k);
    demo = new Uint8Array(v.split(','));
    demo_exists = true;
}

function printDoodleToConsole() {
    console.log(JSON.stringify(getSanitisedDoodlePixels()[0]));
}

function setupCharts() {
    const div = document.getElementById('ab-threepage');
    const container = document.createElement("div");
    container.setAttribute('id', 'chart-container');
    container.setAttribute('style', '' +
        // '    display: flex;' +
        '    flex-wrap: wrap;' +
        '    align-items: center;' +
        '    justify-content: center;' +
        '    font-family: Georgia, Verdana, "Times New Roman", Sans-serif;' +
        '    font-size: 16px;' +
        '    background: rgba(255, 255, 255, 0.8);' +
        '    color: black;' +
        '    border-radius: 10px;' +
        '    border: 1px solid black;' +
        '    padding: 10px;' +
        '    text-align: left;' +
        '    word-wrap: break-word;' +
        '    overflow: auto;' +
        '    z-index: 20;' +
        '    display: inline-block;' +
        '    min-width: 600px;' +
        '    max-width: 800px;' +
        '    max-height: 100vh;' +
        '    position: absolute;' +
        '    top: 20px;' +
        '    left: 850px;');
    div.append(container);


    const chart1 = document.createElement("div");
    chart1.setAttribute('id', 'chart-1');
    chart1.setAttribute('style', '/*flex: 1 1 30%;*/height: 300px;');
    chart1.append(createEmpiricalChart());
    container.append(chart1);

    const chart2 = document.createElement("div");
    chart2.setAttribute('id', 'chart-2');
    chart2.setAttribute('style', '/*flex: 1 1 30%;*/height: 300px;');
    chart2.append(createTrainAccuracyChart());
    container.append(chart2);

    if (nn.hasOwnProperty('cost_loss')) {
        const chart3 = document.createElement("div");
        chart3.setAttribute('id', 'chart-3');
        chart3.setAttribute('style', '/*flex: 1 1 30%;*/height: 300px;');
        chart3.append(createLossChart());
        container.append(chart3);
    }
}


function createEmpiricalChart() {
    const canvas = document.createElement("canvas")
    canvas.setAttribute('id', 'empiricalChart');

    empiricalChart = new Chart(canvas, {
        type: 'bar',
        data: {
            labels: ['A', 'B', 'C', '1', '2', '3'],
            datasets: [{
                label: 'Train',
                data: [0, 0, 0, 0, 0, 0],
                borderWidth: 1
            }, {
                label: 'Test',
                data: [0, 0, 0, 0, 0, 0],
                borderWidth: 1
            }]
        },
        options: {
            responsive: true,
            plugins: {
                title: {
                    display: true,
                    text: 'Dataset Distribution'
                },
            },
            scales: {
                x: {
                    stacked: true,
                },
                y: {
                    beginAtZero: true,
                    stacked: true
                }
            }
        }
    });

    return canvas;
}

function createTrainAccuracyChart() {
    const canvas = document.createElement("canvas")
    canvas.setAttribute('id', 'trainAccuracyChart');

    trainAccuracyChart = new Chart(canvas, {
        type: 'line',
        data: {
            labels: ['0'],
            datasets: [{
                label: 'Train',
                data: [0],
                borderWidth: 1
            }, {
                label: 'Test',
                data: [0],
                borderWidth: 1
            }]
        },
        options: {
            responsive: true,
            animation: {
                duration: 0
            },
            plugins: {
                title: {
                    display: true,
                    text: 'Model Accuracy'
                },
            },
            scales: {
                y: {
                    beginAtZero: true,
                }
            }
        }
    });

    return canvas;
}

function createLossChart() {
    const canvas = document.createElement("canvas")
    canvas.setAttribute('id', 'lossChart');

    lossChart = new Chart(canvas, {
        type: 'line',
        data: {
            labels: ['0'],
            datasets: [{
                label: 'Cost',
                data: [0],
                borderWidth: 1
            }, {
                label: 'L2 Decay',
                data: [0],
                borderWidth: 1
            }]
        },
        options: {
            responsive: true,
            animation: {
                duration: 0
            },
            plugins: {
                title: {
                    display: true,
                    text: 'Model Loss'
                },
            },
            scales: {
                y: {
                    beginAtZero: true,
                }
            }
        }
    });

    return canvas;
}

function updateCharts() {
    if (!nn.tensor) {
        updateEmpiricalChart();
        updateTrainAccuracyChart();
        if (nn.hasOwnProperty('cost_loss')) {
            updateLossChart();
        }
    }
}

function updateEmpiricalChart() {
    const trainDataMap = empiricalDatasets['train']['dataMap'];
    const testDataMap = empiricalDatasets['test']['dataMap'];

    const labels = new Set([...Object.keys(trainDataMap), ...Object.keys(testDataMap)])
    empiricalChart.data.labels = Array.from(labels).sort();

    empiricalDatasets['train']['data'] = [];
    empiricalDatasets['test']['data'] = [];

    empiricalChart.data.labels.forEach(label => {
        if (trainDataMap[label]) {
            empiricalDatasets['train']['data'].push(trainDataMap[label]);
        } else {
            empiricalDatasets['train']['data'].push(0);
        }
    });

    empiricalChart.data.labels.forEach(label => {
        if (testDataMap[label]) {
            empiricalDatasets['test']['data'].push(testDataMap[label]);
        } else {
            empiricalDatasets['test']['data'].push(0);
        }
    });

    empiricalChart.data.datasets = [];
    empiricalChart.data.datasets.push(empiricalDatasets['train']);
    empiricalChart.data.datasets.push(empiricalDatasets['test']);

    empiricalChart.update();
}

function updateTrainAccuracyChart() {
    trainAccuracyChart.data.labels = Object.keys(nn.metrics['train']['accuracy']);

    const trainDS = {
        label: 'Training',
        data: Object.values(nn.metrics['train']['accuracy']),
        borderWidth: 1,
    };

    trainAccuracyChart.data.datasets = [];
    trainAccuracyChart.data.datasets.push(trainDS);

    trainAccuracyChart.update();
}

function updateLossChart() {
    lossChart.data.labels = Object.keys(nn.metrics['train']['cost_loss_l2_decay_loss']);

    const ds = {
        label: 'Loss',
        data: Object.values(nn.metrics['train']['cost_loss_l2_decay_loss']),
        borderWidth: 1,
    };
    lossChart.data.datasets = [];
    lossChart.data.datasets.push(ds);

    lossChart.update();
}

function Convolutional(width, height, out_nodes) {
    this.out_nodes = out_nodes;
    this.width = width;
    this.height = height;
    this.tensor = false;

    layer_defs = [];
    layer_defs.push({type: 'input', out_sx: this.width - 4, out_sy: this.height - 4, out_depth: 1});
    layer_defs.push({type: 'conv', sx: 5, filters: 8, stride: 1, pad: 2, activation: 'relu'});
    layer_defs.push({type: 'pool', sx: 2, stride: 2});
    layer_defs.push({type: 'conv', sx: 5, filters: 16, stride: 1, pad: 2, activation: 'relu'});
    layer_defs.push({type: 'pool', sx: 2, stride: 2});
    layer_defs.push({type: 'conv', sx: 5, filters: 32, stride: 1, pad: 2, activation: 'relu'});
    layer_defs.push({type: 'pool', sx: 2, stride: 2});
    // // layer_defs.push({type: 'fc', num_neurons: 128, activation: 'relu', drop_prob: 0.25});
    layer_defs.push({type: 'softmax', num_classes: this.out_nodes});

    this.network = new convnetjs.Net();
    this.network.makeLayers(layer_defs);

    this.trainer = new convnetjs.SGDTrainer(this.network, {method: 'adadelta', batch_size: 20, l2_decay: 0.001});

    this.accuracy = 0.0;
    this.total_tests = 0;
    this.total_correct = 0;

    if (showCharts) {
        this.cost_loss = new cnnutil.Window(100, 10);
        this.l2_decay_loss = new cnnutil.Window(100, 10);
        this.l1_decay_loss = new cnnutil.Window(100, 10);

        this.metrics = {};
        this.metrics['train'] = {};
        this.metrics['test'] = {};
        this.metrics['train']['accuracy'] = {};
        this.metrics['train']['cost_loss'] = {};
        this.metrics['train']['cost_loss_l2_decay_loss'] = {};
        this.metrics['train']['l1_decay_loss'] = {};
        this.metrics['train']['l2_decay_loss'] = {};
    }
}

Convolutional.prototype.recordResult = function (pass) {
    this.total_tests++;
    if (pass) {
        this.total_correct++;
    }
    this.accuracy = this.total_correct / this.total_tests;

    if (showCharts && this.total_tests % 50 === 0) {
        this.metrics['train']['accuracy'][this.total_tests] = this.accuracy * 100.0;
        this.metrics['train']['cost_loss'][this.total_tests] = this.cost_loss.get_average();
        this.metrics['train']['l1_decay_loss'][this.total_tests] = this.l1_decay_loss.get_average();
        this.metrics['train']['l2_decay_loss'][this.total_tests] = this.l2_decay_loss.get_average();
        this.metrics['train']['cost_loss_l2_decay_loss'][this.total_tests] = this.cost_loss.get_average() + this.l2_decay_loss.get_average();
    }
}

Convolutional.prototype.train = function (images, labels) {
    for (let i = 0; i < images.length; i++) {
        const img = images[i];
        const label = labels[i];

        const vol = new convnetjs.Vol(this.width, this.height, 1, 0);
        for (let i = 0; i < this.width * this.height; i++) {
            vol.w[i] = img[i];
        }
        const t = this.trainer.train(vol, label);
        if (showCharts) {
            this.cost_loss.add(t.cost_loss);
            this.l2_decay_loss.add(t.l2_decay_loss);
            this.l1_decay_loss.add(t.l1_decay_loss);
        }
    }
};

Convolutional.prototype.predict = function (img) {
    // img = normaliseImage(img);
    const vol = new convnetjs.Vol(this.width, this.height, 1, 0);
    for (let i = 0; i < this.width * this.height; i++) {
        vol.w[i] = img[i];
    }
    this.network.forward(vol);
    return this.network.getPrediction();
};

//--- find no.1 (and maybe no.2) output nodes ---------------------------------------
// (restriction) assumes array values start at 0 (which is true for output nodes)
// return array showing indexes of no.1 and no.2 values in array
Convolutional.prototype.predictTwo = function (img) {

    //img = normaliseImage(img);
    const vol = new convnetjs.Vol(this.width, this.height, 1, 0);
    for (let i = 0; i < this.width * this.height; i++) {
        vol.w[i] = img[i];
    }

    this.network.forward(vol);

    // Pulled from convnet getPrediction() function, but modified to give the two best guesses
    const S = this.network.layers[this.network.layers.length - 1];

    const p = S.out_act.w;
    let maxv = p[0];
    let maxi = 0;
    let maxi2 = 0;
    for (let i = 1; i < p.length; i++) {
        if (p[i] > maxv) {
            maxv = p[i];
            maxi2 = maxi;
            maxi = i; // index of the class with the highest class probability
        }
    }

    return [maxi, maxi2];
}


function TensorConvolutional(width, height, out_nodes) {
    this.width = width;
    this.height = height;
    this.out_nodes = out_nodes;

    this.channels = 1;
    this.model = tf.sequential();
    this.tensor = true;

    this.model.add(tf.layers.conv2d({
        inputShape: [this.width, this.height, this.channels],
        kernelSize: 5,
        filters: 8,
        strides: 1,
        activation: 'relu',
        kernelInitializer: 'varianceScaling'
    }));

    this.model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));

    this.model.add(tf.layers.conv2d({
        kernelSize: 5,
        filters: 16,
        strides: 1,
        activation: 'relu',
        kernelInitializer: 'varianceScaling'
    }));
    this.model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));

    this.model.add(tf.layers.flatten());

    this.model.add(tf.layers.dense({
        units: this.out_nodes,
        kernelInitializer: 'varianceScaling',
        activation: 'softmax'
    }));

    
    const optimizer = tf.train.adam();
    this.model.compile({
        optimizer: optimizer,
        loss: 'categoricalCrossentropy',
        metrics: ['accuracy'],
    });

}

TensorConvolutional.prototype.train = function (images, labels) {
    for (let i = 0; i < images.length; i++) {
        let img = images[i];
        //img = normaliseImage(img);
        const tensorImage = tf.tensor1d(img).reshape([1, this.width, this.height, this.channels]);
        const tensorLabel = labels[i];//tf.tensor1d(new Uint8Array(labels));
        this.model.fit(tensorImage, tensorLabel);
    }
};

TensorConvolutional.prototype.predict = function (img) {
    //img = normaliseImage(img);
    const fourD = tf.tensor1d(img).reshape([1, this.width, this.height, this.channels]);
    const prediction = this.model.predict(fourD);//, 1);//.argMax(-1);
    if (prediction[0]) {
        console.log(prediction);
    } else {
        return 0;
    }
};

TensorConvolutional.prototype.recordResult = function (pass) {
    this.total_tests++;
    if (pass) {
        this.total_correct++;
    }
    this.accuracy = this.total_correct / this.total_tests;
}
//--- find no.1 (and maybe no.2) output nodes ---------------------------------------
// (restriction) assumes array values start at 0 (which is true for output nodes)
// return array showing indexes of no.1 and no.2 values in array
TensorConvolutional.prototype.predictTwo = function (img) {
    //img = normaliseImage(img);
    const p = this.predict(img);
    return [p, p];
}

function MultilayerPerceptron(width, height, out_nodes) {
    this.width = width;
    this.height = height;
    this.noinput = width * height;
    this.nohidden = 128;
    this.out_nodes = out_nodes;
    this.learningrate = 0.1;   // default 0.1
    this.tensor = false;

    this.accuracy = 0.0;
    this.total_tests = 0;
    this.total_correct = 0;

    this.network = new NeuralNetwork(this.noinput, this.nohidden, this.out_nodes);
    this.network.setLearningRate(this.learningrate);

    if (showCharts) {
        this.metrics = {};
        this.metrics['train'] = {};
        this.metrics['test'] = {};
        this.metrics['train']['accuracy'] = {};
    }
}

MultilayerPerceptron.prototype.recordResult = function (pass) {
    this.total_tests++;
    if (pass) {
        this.total_correct++;
    }
    this.accuracy = this.total_correct / this.total_tests;

    if (showCharts && this.total_tests % 50 === 0) {
        this.metrics['train']['accuracy'][this.total_tests] = this.accuracy * 100.0;
    }
}

MultilayerPerceptron.prototype.train = function (images, labels) {
    for (let i = 0; i < images.length; i++) {
        const img = images[i];
        const label = labels[i];

        // Setup inputs
        const inputs = normaliseImage(img);

        // Setup outputs
        let targets = Array(this.out_nodes).fill(0);
        targets[label] = 1;       // change one output location to 1, the rest stay at 0

        this.network.train(inputs, targets);
    }
};

MultilayerPerceptron.prototype.predict = function (img) {
    const inputs = normaliseImage(img);
    const a = this.network.predict(inputs);

    // just get the maximum - separate function for speed - done many times
    // find our guess - the max of the output nodes array

    let prediction = 0;
    let predictionValue = 0;

    for (let i = 0; i < a.length; i++) {
        if (a[i] > predictionValue) {
            prediction = i;
            predictionValue = a[i];
        }
    }

    return prediction;
};

//--- find no.1 (and maybe no.2) output nodes ---------------------------------------
// (restriction) assumes array values start at 0 (which is true for output nodes)
// return array showing indexes of no.1 and no.2 values in array
MultilayerPerceptron.prototype.predictTwo = function (img) {
    const inputs = normaliseImage(img);
    const a = this.network.predict(inputs);
    let no1 = 0;
    let no2 = 0;
    let no1value = 0;
    let no2value = 0;

    for (let i = 0; i < a.length; i++) {
        if (a[i] > no1value)   // new no1
        {
            // old no1 becomes no2
            no2 = no1;
            no2value = no1value;
            // now put in the new no1
            no1 = i;
            no1value = a[i];
        } else if (a[i] > no2value)  // new no2
        {
            no2 = i;
            no2value = a[i];
        }
    }

    return [no1, no2];
}