Code viewer for World: Character recognition neur...
// Cloned by Tom McAllister on 9 Dec 2020 from World "Character recognition neural network" by "Coding Train" project 
// Please leave this clone trail here.
const len = 784;
const totalData = 1000;
const learningRate = 0.1;
const PIXELS = 28;
const PIXELSSQUARED = PIXELS * PIXELS;
const ZOOMFACTOR = 7;
const ZOOMPIXELS = ZOOMFACTOR * PIXELS;
const DOODLE_THICK = 10;

// These are for labelling, used for the supervised learning of the model
const BREAD = 0;
const BROCCOLI = 1;
const CLOCK = 2;
const DOLPHIN = 3;
const LADDER = 4;
const MONALISA = 5;
const ZEBRA = 6;

var labelMap = new Map(); // This is here so that we can map our best guess for a doodle to the String name

labelMap.set(0, 'Bread');
labelMap.set(1, 'Broccoli');
labelMap.set(2, 'Clock');
labelMap.set(3, 'Dolphin');
labelMap.set(4, 'Ladder');
labelMap.set(5, 'Mona Lisa');
labelMap.set(6, 'Zebra');

// Here to hold the raw data that is collected by the preload function
let breadsData;
let broccolisData;
let clocksData;
let dolphinsData;
let laddersData;
let monaLisasData;
let zebrasData;

let breads = {};
let broccolis = {};
let clocks = {};
let dolphins = {};
let ladders = {};
let monaLisas = {};
let zebras = {};

let training = [];
let testing = [];

let doodle;
let doodle_exists = false;
let demo_exists = false;

let mousedrag = false;
let trained = false;

let canvas;
let drawings = [];
let currentPath = [];
let isDrawing = false;

let nn; // neural network

//--- start of AB.msgs structure: ---------------------------------------------------------
  
  var thehtml;

  // Testing Accuracy
  thehtml = "<hr> <h1> Accuracy : </h1> " + "  ";
  AB.msg ( thehtml, 2 );
  
  
  // Doodle Guess
  thehtml = "<hr> <h1> Model Guess : </h1> " + "  ";
  AB.msg (thehtml, 4);
  
   
// Loads the binary doodle files into the program
function preload() {
    breadsData = loadBytes('/uploads/tom/breads1000.bin');
    broccolisData = loadBytes('uploads/tom/broccolis1000.bin');
    clocksData = loadBytes('uploads/tom/clocks1000.bin');
    dolphinsData = loadBytes('uploads/tom/dolphins1000.bin');
    laddersData = loadBytes('uploads/tom/ladders1000.bin');
    monaLisasData = loadBytes('uploads/tom/monalisas1000.bin');
    zebrasData = loadBytes('uploads/tom/zebras1000.bin');
    console.log("Loaded Raw Training Data");
}


function setup() {
    canvas = createCanvas(200, 200);
    doodle = createGraphics(200, 200); // doodle on larger canvas 
    doodle.pixelDensity(1);
    background(0);
    canvas.mousePressed(startPath);
    canvas.mouseReleased(function() {
        endPath();
    });

    AB.loadingScreen();
    AB.msg(`
            <button onclick= 'train();' class=largenormbutton > Train </button>  
            <button onclick= 'test();'  class=largenormbutton > Test  </button>
            <button onclick= 'guess();' class=largenormbutton > Guess </button>  
            <button onclick= 'clearDoodle();' class=largenormbutton > Clear </button> 
        `);

    $.getScript("/uploads/tom/matrix.js", function() {
        $.getScript("/uploads/tom/nn.js", function() {
            console.log("Matrix and Neural Network Libraries Loaded");
            nn = new NeuralNetwork(784, 128, 7);
            nn.setLearningRate(learningRate);
        });
    });

    // Preparing the data for training and testing and applying labels for training
    prepareData(breads, breadsData, BREAD);
    prepareData(broccolis, broccolisData, BROCCOLI);
    prepareData(clocks, clocksData, CLOCK);
    prepareData(dolphins, dolphinsData, DOLPHIN);
    prepareData(ladders, laddersData, LADDER);
    prepareData(monaLisas, monaLisasData, MONALISA);
    prepareData(zebras, zebrasData, ZEBRA);
    AB.removeLoading();

    // Adds the training portion (80%) of the raw data of each doodle type to the training array
    training = training.concat(breads.training);
    training = training.concat(broccolis.training);
    training = training.concat(clocks.training);
    training = training.concat(dolphins.training);
    training = training.concat(ladders.training);
    training = training.concat(monaLisas.training);
    training = training.concat(zebras.training);

    shuffle(training, true);

    // Adds the testing portion (20%) of the raw data of each doodle type to the testing array
    testing = testing.concat(breads.testing);
    testing = testing.concat(broccolis.testing);
    testing = testing.concat(clocks.testing);
    testing = testing.concat(dolphins.testing);
    testing = testing.concat(ladders.testing);
    testing = testing.concat(monaLisas.testing);
    testing = testing.concat(zebras.testing);

}


////////////////////////////////////////////////////////// TRAIN, TEST, GUESS, CLEAR   (BUTTONS) //////////////////////////////////////////////////////////
function train() {
    console.log("Starting training")
    shuffle(training, true);
    for (let i = 0; i < training.length; i++) {
        let data = training[i];
        let inputs = Array.from(data).map(x => x / 255);
        let label = training[i].label;
        let targets = [0, 0, 0, 0, 0, 0, 0];
        targets[label] = 1;
        nn.train(inputs, targets);
    }
    console.log("Finished training");
    trained = true;
}

function test() {
    console.log("Testing");
    let correct = 0;

    for (let i = 0; i < testing.length; i++) {
        let data = testing[i];
        let inputs = Array.from(data).map(x => x / 255);
        let label = testing[i].label;
        let guess = nn.predict(inputs);

        let m = max(guess);
        let classification = guess.indexOf(m);

        console.log(guess);
        console.log(classification);
        console.log(label);

        if (classification == label) {
            correct++;
        }
    }

    let percent = 100 * correct / testing.length;
    console.log("Percent: " + nf(percent, 2, 2) + "%");
    
    thehtml = "<h2>" + percent + " % </h2>"
    AB.msg ( thehtml, 3 );

}

function guess() {
    console.log("Guessing...");

    let img = doodle.get();

    img.resize(28, 28);
    img.loadPixels();

    // set up inputs   
    let inputs = [];
    for (let i = 0; i < 28 * 28; i++) {
        inputs[i] = img.pixels[i * 4] / 255;
    }

    doodle_inputs = inputs;

    let prediction = nn.predict(inputs);
    let maxIndex = indexOfMax(prediction);
    let guess = labelMap.get(maxIndex);

    console.log("Best guess: " + guess);
    
    thehtml = "<h2>" + guess + " </h2>"
    AB.msg ( thehtml, 5 );

}

////////////////////////////////////////////////////////////END OF BUTTON CODE////////////////////////////////////////////////////////////////////


///////////////////////DRAW CODE - CONSTANT LOOP THAT UPDATES CANVAS/////////////////
function draw() {
    
    if(!trained) {
        background('black');
    } else {
        background('green');
    }
    
    if (doodle_exists) {
        drawDoodle();
    }
    
    if (mouseIsPressed) {
        var MAX = ZOOMPIXELS + 20;
        if ((mouseX < MAX) && (mouseY < MAX) && (pmouseX < MAX) && (pmouseY < MAX)) {
            mousedrag = true;
            doodle_exists = true;
            doodle.stroke('white');
            doodle.strokeWeight(DOODLE_THICK);
            doodle.line(mouseX, mouseY, pmouseX, pmouseY);
        }
    }
}



///////////////////////////////////////////////////////////////HELPER CODE////////////////////////////////////////////////////////////////////////////////////

function prepareData(category, data, label) {
    category.training = [];
    category.testing = [];
    for (let i = 0; i < totalData; i++) {
        let offset = i * len;
        let threshold = floor(0.8 * totalData);
        if (i < threshold) {
            category.training[i] = data.bytes.subarray(offset, offset + len);
            category.training[i].label = label;
        } else {
            category.testing[i - threshold] = data.bytes.subarray(offset, offset + len);
            category.testing[i - threshold].label = label;
        }
    }
}


// Helper function to get the index of the Max Value in an array
function indexOfMax(arr) {
    var max = arr[0];
    var maxIndex = 0;

    for (var i = 1; i < arr.length; i++) {
        if (arr[i] > max) {
            maxIndex = i;
            max = arr[i];
        }
    }

    return maxIndex;
}


function startPath() {
    px = mouseX;
    py = mouseY;
}


function drawDoodle() {
    let theimage = doodle.get();
    image(theimage, 0, 0, ZOOMPIXELS, ZOOMPIXELS); 
    image(theimage, ZOOMPIXELS + 50, 0, PIXELS, PIXELS);
}

// Helper function to clear doodle and reset canvas for next doodle
function clearDoodle() {
    doodle_exists = false;
    if(!trained) {
        doodle.background('black');
    } else {
        doodle.background('green');
    }
    
}