// Cloned by Tom McAllister on 9 Dec 2020 from World "Character recognition neural network" by "Coding Train" project
// Please leave this clone trail here.
const len = 784;
const totalData = 1000;
const learningRate = 0.1;
const PIXELS = 28;
const PIXELSSQUARED = PIXELS * PIXELS;
const ZOOMFACTOR = 7;
const ZOOMPIXELS = ZOOMFACTOR * PIXELS;
const DOODLE_THICK = 10;
// These are for labelling, used for the supervised learning of the model
const BREAD = 0;
const BROCCOLI = 1;
const CLOCK = 2;
const DOLPHIN = 3;
const LADDER = 4;
const MONALISA = 5;
const ZEBRA = 6;
var labelMap = new Map(); // This is here so that we can map our best guess for a doodle to the String name
labelMap.set(0, 'Bread');
labelMap.set(1, 'Broccoli');
labelMap.set(2, 'Clock');
labelMap.set(3, 'Dolphin');
labelMap.set(4, 'Ladder');
labelMap.set(5, 'Mona Lisa');
labelMap.set(6, 'Zebra');
// Here to hold the raw data that is collected by the preload function
let breadsData;
let broccolisData;
let clocksData;
let dolphinsData;
let laddersData;
let monaLisasData;
let zebrasData;
let breads = {};
let broccolis = {};
let clocks = {};
let dolphins = {};
let ladders = {};
let monaLisas = {};
let zebras = {};
let training = [];
let testing = [];
let doodle;
let doodle_exists = false;
let demo_exists = false;
let mousedrag = false;
let trained = false;
let canvas;
let drawings = [];
let currentPath = [];
let isDrawing = false;
let nn; // neural network
//--- start of AB.msgs structure: ---------------------------------------------------------
var thehtml;
// Testing Accuracy
thehtml = "<hr> <h1> Accuracy : </h1> " + " ";
AB.msg ( thehtml, 2 );
// Doodle Guess
thehtml = "<hr> <h1> Model Guess : </h1> " + " ";
AB.msg (thehtml, 4);
// Loads the binary doodle files into the program
function preload() {
breadsData = loadBytes('/uploads/tom/breads1000.bin');
broccolisData = loadBytes('uploads/tom/broccolis1000.bin');
clocksData = loadBytes('uploads/tom/clocks1000.bin');
dolphinsData = loadBytes('uploads/tom/dolphins1000.bin');
laddersData = loadBytes('uploads/tom/ladders1000.bin');
monaLisasData = loadBytes('uploads/tom/monalisas1000.bin');
zebrasData = loadBytes('uploads/tom/zebras1000.bin');
console.log("Loaded Raw Training Data");
}
function setup() {
canvas = createCanvas(200, 200);
doodle = createGraphics(200, 200); // doodle on larger canvas
doodle.pixelDensity(1);
background(0);
canvas.mousePressed(startPath);
canvas.mouseReleased(function() {
endPath();
});
AB.loadingScreen();
AB.msg(`
<button onclick= 'train();' class=largenormbutton > Train </button>
<button onclick= 'test();' class=largenormbutton > Test </button>
<button onclick= 'guess();' class=largenormbutton > Guess </button>
<button onclick= 'clearDoodle();' class=largenormbutton > Clear </button>
`);
$.getScript("/uploads/tom/matrix.js", function() {
$.getScript("/uploads/tom/nn.js", function() {
console.log("Matrix and Neural Network Libraries Loaded");
nn = new NeuralNetwork(784, 128, 7);
nn.setLearningRate(learningRate);
});
});
// Preparing the data for training and testing and applying labels for training
prepareData(breads, breadsData, BREAD);
prepareData(broccolis, broccolisData, BROCCOLI);
prepareData(clocks, clocksData, CLOCK);
prepareData(dolphins, dolphinsData, DOLPHIN);
prepareData(ladders, laddersData, LADDER);
prepareData(monaLisas, monaLisasData, MONALISA);
prepareData(zebras, zebrasData, ZEBRA);
AB.removeLoading();
// Adds the training portion (80%) of the raw data of each doodle type to the training array
training = training.concat(breads.training);
training = training.concat(broccolis.training);
training = training.concat(clocks.training);
training = training.concat(dolphins.training);
training = training.concat(ladders.training);
training = training.concat(monaLisas.training);
training = training.concat(zebras.training);
shuffle(training, true);
// Adds the testing portion (20%) of the raw data of each doodle type to the testing array
testing = testing.concat(breads.testing);
testing = testing.concat(broccolis.testing);
testing = testing.concat(clocks.testing);
testing = testing.concat(dolphins.testing);
testing = testing.concat(ladders.testing);
testing = testing.concat(monaLisas.testing);
testing = testing.concat(zebras.testing);
}
////////////////////////////////////////////////////////// TRAIN, TEST, GUESS, CLEAR (BUTTONS) //////////////////////////////////////////////////////////
function train() {
console.log("Starting training")
shuffle(training, true);
for (let i = 0; i < training.length; i++) {
let data = training[i];
let inputs = Array.from(data).map(x => x / 255);
let label = training[i].label;
let targets = [0, 0, 0, 0, 0, 0, 0];
targets[label] = 1;
nn.train(inputs, targets);
}
console.log("Finished training");
trained = true;
}
function test() {
console.log("Testing");
let correct = 0;
for (let i = 0; i < testing.length; i++) {
let data = testing[i];
let inputs = Array.from(data).map(x => x / 255);
let label = testing[i].label;
let guess = nn.predict(inputs);
let m = max(guess);
let classification = guess.indexOf(m);
console.log(guess);
console.log(classification);
console.log(label);
if (classification == label) {
correct++;
}
}
let percent = 100 * correct / testing.length;
console.log("Percent: " + nf(percent, 2, 2) + "%");
thehtml = "<h2>" + percent + " % </h2>"
AB.msg ( thehtml, 3 );
}
function guess() {
console.log("Guessing...");
let img = doodle.get();
img.resize(28, 28);
img.loadPixels();
// set up inputs
let inputs = [];
for (let i = 0; i < 28 * 28; i++) {
inputs[i] = img.pixels[i * 4] / 255;
}
doodle_inputs = inputs;
let prediction = nn.predict(inputs);
let maxIndex = indexOfMax(prediction);
let guess = labelMap.get(maxIndex);
console.log("Best guess: " + guess);
thehtml = "<h2>" + guess + " </h2>"
AB.msg ( thehtml, 5 );
}
////////////////////////////////////////////////////////////END OF BUTTON CODE////////////////////////////////////////////////////////////////////
///////////////////////DRAW CODE - CONSTANT LOOP THAT UPDATES CANVAS/////////////////
function draw() {
if(!trained) {
background('black');
} else {
background('green');
}
if (doodle_exists) {
drawDoodle();
}
if (mouseIsPressed) {
var MAX = ZOOMPIXELS + 20;
if ((mouseX < MAX) && (mouseY < MAX) && (pmouseX < MAX) && (pmouseY < MAX)) {
mousedrag = true;
doodle_exists = true;
doodle.stroke('white');
doodle.strokeWeight(DOODLE_THICK);
doodle.line(mouseX, mouseY, pmouseX, pmouseY);
}
}
}
///////////////////////////////////////////////////////////////HELPER CODE////////////////////////////////////////////////////////////////////////////////////
function prepareData(category, data, label) {
category.training = [];
category.testing = [];
for (let i = 0; i < totalData; i++) {
let offset = i * len;
let threshold = floor(0.8 * totalData);
if (i < threshold) {
category.training[i] = data.bytes.subarray(offset, offset + len);
category.training[i].label = label;
} else {
category.testing[i - threshold] = data.bytes.subarray(offset, offset + len);
category.testing[i - threshold].label = label;
}
}
}
// Helper function to get the index of the Max Value in an array
function indexOfMax(arr) {
var max = arr[0];
var maxIndex = 0;
for (var i = 1; i < arr.length; i++) {
if (arr[i] > max) {
maxIndex = i;
max = arr[i];
}
}
return maxIndex;
}
function startPath() {
px = mouseX;
py = mouseY;
}
function drawDoodle() {
let theimage = doodle.get();
image(theimage, 0, 0, ZOOMPIXELS, ZOOMPIXELS);
image(theimage, ZOOMPIXELS + 50, 0, PIXELS, PIXELS);
}
// Helper function to clear doodle and reset canvas for next doodle
function clearDoodle() {
doodle_exists = false;
if(!trained) {
doodle.background('black');
} else {
doodle.background('green');
}
}