/* Defined by MNIST. */
const PIXELS = 28;
const PIXELSSQUARED = PIXELS * PIXELS;
/* The number of training and test exemplars in the data set. */
const NOTRAIN = 60000;
const NOTEST = 10000;
/* The number of nodes in the network. */
const noinput = PIXELSSQUARED;
const nohidden = 64;
const nooutput = 10;
/* The rate at which the network should learn. */
const learningrate = 0.1;
/* Whether or not we should train every timestep. */
let do_training = true;
/* How many exemplars to train on, and test, per timestep. */
const TRAINPERSTEP = 30;
const TESTPERSTEP = 30;
/* The zoom factor to use when displaying each digit. */
const ZOOMFACTOR = 7;
const ZOOMPIXELS = ZOOMFACTOR * PIXELS;
const canvaswidth = (PIXELS + ZOOMPIXELS) + 50;
const canvasheight = (ZOOMPIXELS * 3) + 100;
const DOODLE_THICK = 18;
const DOODLE_BLUR = 3;
let mnist;
let nn;
let trainrun = 1;
let train_index = 0;
let testrun = 1;
let test_index = 0;
let total_tests = 0;
let total_correct = 0;
let doodle, demo;
let doodle_exists = false;
let demo_exists = false;
let mousedrag = false;
var train_inputs, test_inputs, demo_inputs, doodle_inputs;
/* Generate a random weight for each of our nodes. */
function randomWeight(){
return (AB.randomFloatAtoB(-0.5, 0.5));
}
/* Set the size of the header. */
AB.headerCSS ( { "max-height": "95vh" } );
/* Output various messages, using AB's built-in library. */
var thehtml;
/* The header for the doodle. */
thehtml = "<hr> <h1> 1. Doodle </h1> Top row: Doodle (left) and shrunk (right). <br>" +
"Draw your doodle in top LHS. <button onclick='wipeDoodle();' class='normbutton' >Clear doodle</button> <br>";
AB.msg(thehtml, 1);
/* The guess being made about the doodle. */
/* The header for the training section. */
thehtml = "<hr> <h1> 2. Training </h1> Middle row: Training image magnified (left) and original (right). <br>" +
"<button onclick='do_training = false;' class='normbutton' >Stop training</button> <br>";
AB.msg(thehtml, 3);
/* Further information about the training data. */
/* The header for the testing section. */
thehtml = "<h3> Hidden tests </h3>";
AB.msg(thehtml, 5);
/* Further information about the testing data. */
/* The header for the demo. */
thehtml = "<hr> <h1> 3. Demo </h1> Bottom row: Test image magnified (left) and original (right). <br>" +
"The network is <i>not</i> trained on any of these images. <br>" +
"<button onclick='makeDemo();' class='normbutton' >Demo test image</button> <br>";
AB.msg(thehtml, 7);
const greenspan = "<span style='font-weight:bold; font-size:x-large; color:darkgreen'>";
/* Initialise our neural network. */
function setup() {
createCanvas(canvaswidth, canvasheight);
doodle = createGraphics(ZOOMPIXELS, ZOOMPIXELS);
doodle.pixelDensity(1);
AB.loadingScreen();
$.getScript("/uploads/codingtrain/matrix.js", function() {
$.getScript("/uploads/codingtrain/nn.js", function() {
$.getScript("/uploads/codingtrain/mnist.js", function() {
nn = new NeuralNetwork(noinput, nohidden, nooutput);
nn.setLearningRate(learningrate);
loadData();
});
});
});
}
/* Load our training data. */
function loadData() {
loadMNIST(function(data) {
mnist = data;
AB.removeLoading();
});
}
/* Construct a P5 image object for a given image. */
function getImage(img) {
let theimage = createImage(PIXELS, PIXELS);
theimage.loadPixels();
for (let i = 0; i < PIXELSSQUARED; i++) {
let bright = img[i];
let index = i * 4;
theimage.pixels[index + 0] = bright;
theimage.pixels[index + 1] = bright;
theimage.pixels[index + 2] = bright;
theimage.pixels[index + 3] = 255;
}
theimage.updatePixels();
return theimage;
}
/* Convert an image into a normalised set of inputs. */
function getInputs(img) {
let inputs = [];
for (let i = 0; i < PIXELSSQUARED; i++) {
let bright = img[i];
inputs[i] = bright / 255;
}
return inputs;
}
/* Train our network with a single exemplar. */
function trainit(show) {
let img = mnist.train_images[train_index];
let label = mnist.train_labels[train_index];
if (show) {
var theimage = getImage(img);
image(theimage, 0, ZOOMPIXELS + 50, ZOOMPIXELS, ZOOMPIXELS);
image(theimage, ZOOMPIXELS + 50, ZOOMPIXELS + 50, PIXELS, PIXELS);
}
let inputs = getInputs(img);
let targets = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
/* Randomly set some of the outputs to a confidence level between 0 and 0.5 */
for (let i = 0; i < 3; i++) {
targets[Math.floor(Math.random() * targets.length)] = Math.round(AB.randomFloatAtoB(0.0, 0.5));
}
targets[label] = 1;
train_inputs = inputs;
nn.train(inputs, targets);
thehtml = "trainrun: " + trainrun + "<br> no: " + train_index;
AB.msg(thehtml, 4);
train_index++;
if (train_index == NOTRAIN) {
train_index = 0;
trainrun++;
}
}
/* Test the network with a single exemplar. */
function testit() {
let img = mnist.test_images[test_index];
let label = mnist.test_labels[test_index];
let inputs = getInputs(img);
test_inputs = inputs;
let prediction = nn.predict(inputs);
let guess = findMax(prediction);
total_tests++;
if (guess == label) total_correct++;
let percent = (total_correct / total_tests) * 100;
thehtml = "testrun: " + testrun + "<br> no: " + total_tests + "<br>" +
"correct: " + total_correct + "<br>" +
"score: " + greenspan + percent.toFixed(2) + "</span>";
AB.msg(thehtml, 6);
test_index++;
if (test_index == NOTEST) {
testrun++;
test_index = 0;
total_tests = 0;
total_correct = 0;
}
}
/* Find the no. 1 and no. 2 nodes in a given array. */
function find12(a) {
let no1 = 0;
let no2 = 0;
let no1value = 0;
let no2value = 0;
for (let i = 0; i < a.length; i++) {
console.log(i + ": " + a[i]);
if (a[i] > no1value) {
no2 = no1;
no2value = no1value;
no1 = i;
no1value = a[i];
} else if (a[i] > no2value) {
no2 = i;
no2value = a[i];
}
}
var b = [no1, no2];
return b;
}
/* Find the maximum of our outputs. */
function findMax(a) {
let no1 = 0;
let no1value = 0;
for (let i = 0; i < a.length; i++) {
if (a[i] > no1value) {
no1 = i;
no1value = a[i];
}
}
return no1;
}
function draw() {
if (typeof mnist == 'undefined') return;
background('black');
if (do_training) {
for (let i = 0; i < TRAINPERSTEP; i++) {
if (i === 0) {
trainit(true);
} else {
trainit(false);
}
}
for (let i = 0; i < TESTPERSTEP; i++) {
testit();
}
}
if (demo_exists) {
drawDemo();
guessDemo();
}
if (doodle_exists) {
drawDoodle();
guessDoodle();
}
if (mouseIsPressed) {
var MAX = ZOOMPIXELS + 20;
if ((mouseX < MAX) && (mouseY < MAX) && (pmouseX < MAX) && (pmouseY < MAX)) {
mousedrag = true;
doodle_exists = true;
doodle.stroke('white');
doodle.strokeWeight(DOODLE_THICK);
doodle.line(mouseX, mouseY, pmouseX, pmouseY);
}
} else {
if (mousedrag) {
mousedrag = false;
doodle.filter(BLUR, DOODLE_BLUR);
}
}
}
function makeDemo() {
demo_exists = true;
var i = AB.randomIntAtoB(0, NOTEST - 1);
demo = mnist.test_images[i];
var label = mnist.test_labels[i];
thehtml = "Test image no: " + i + "<br>" +
"Classification: " + label + "<br>";
AB.msg(thehtml, 8);
}
function drawDemo() {
var theimage = getImage(demo);
image(theimage, 0, canvasheight - ZOOMPIXELS, ZOOMPIXELS, ZOOMPIXELS);
image(theimage, ZOOMPIXELS + 50, canvasheight - ZOOMPIXELS, PIXELS, PIXELS);
}
function guessDemo() {
let inputs = getInputs(demo);
demo_inputs = inputs;
let prediction = nn.predict(inputs);
let guess = findMax(prediction);
thehtml = "We classify it as: " + greenspan + guess + "</span>";
AB.msg(thehtml, 9);
}
function drawDoodle() {
let theimage = doodle.get();
image(theimage, 0, 0, ZOOMPIXELS, ZOOMPIXELS);
image(theimage, ZOOMPIXELS + 50, 0, PIXELS, PIXELS);
}
function guessDoodle() {
let img = doodle.get();
img.resize(PIXELS, PIXELS);
img.loadPixels();
let inputs = [];
for (let i = 0; i < PIXELSSQUARED; i++) {
inputs[i] = img.pixels[i * 4] / 255;
}
doodle_inputs = inputs;
let prediction = nn.predict(inputs);
let b = find12(prediction);
thehtml = "We classify it as: " + greenspan + b[0] + "</span> <br>" +
"No.2 guess is: " + greenspan + b[1] + "</span>";
AB.msg(thehtml, 2);
}
function wipeDoodle() {
doodle_exists = false;
doodle.background('black');
}
function showInputs(inputs) {
var str = "";
for (let i = 0; i < inputs.length; i++) {
if (i % PIXELS === 0) str = str + "\n";
var value = inputs[i];
str = str + " " + value.toFixed(2);
}
}