'use strict';
const PIXELS = 28;
const PIXELSSQUARED = PIXELS * PIXELS;
const NOTRAIN = 124800;
const NOTEST = 20800;
const noinput = PIXELSSQUARED;
const nohidden = 64;
const nooutput = 10;
const learningrate = 0.001;
let do_training = false;
const TRAINPERSTEP = 12;
const TESTPERSTEP = 2;
const ZOOMFACTOR = 10;
const ZOOMPIXELS = 8 * PIXELS;
const canvaswidth = PIXELS + ZOOMPIXELS + 150;
const canvasheight = 3 * ZOOMPIXELS + 200;
const DOODLE_THICK = 15;
const DOODLE_BLUR = 1;
let mnist;
let cnn;
let cnn_model;
let cnn_trainer;
let doodle;
let demo;
let DOODLE_TOTAL_GUESS = 1;
let DOODLE_TOTAL_WRONG = 0;
let trainrun = 1;
let train_index = 0;
let testrun = 1;
let test_index = 0;
let total_tests = 0;
let total_correct = 0;
let doodle_exists = false;
let demo_exists = false;
let mousedrag = false;
var train_inputs;
var test_inputs;
var demo_inputs;
var doodle_inputs;
function randomWeight() {
return AB.randomFloatAtoB(-.5, .5);
}
AB.headerCSS({
"max-height": "95vh"
});
let letters = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"];
var thehtml;
thehtml = "Draw your doodle in top LHS. <button onclick='wipeDoodle();' class='normbutton' >Clear</button> <button class='normbutton' onclick='guessDoodle();'>Guess</button> <br>", AB.msg(thehtml, 1), thehtml = "<hr> <h1> 2. Training </h1> Middle row: Training image magnified (left) and original (right). <br> <button onclick='do_training = !do_training;' class='normbutton' >start or end train</button> <br> ", AB.msg(thehtml, 3), thehtml = "<h3> Hidden tests </h3> ", AB.msg(thehtml, 5), thehtml =
"<hr> <h1> 3. Demo </h1> Bottom row: Test image magnified (left) and original (right). <br> The network is <i>not</i> trained on any of these images. <br> <button onclick='makeDemo();' class='normbutton' >Demo test image</button> <br> ", AB.msg(thehtml, 7);
const greenspan = "<span style='font-weight:bold; font-size:x-large; color:darkgreen'> ";
function setup() {
createCanvas(canvaswidth, canvasheight);
(doodle = createGraphics(ZOOMPIXELS, ZOOMPIXELS)).pixelDensity(1);
AB.loadingScreen();
$.getScript("/uploads/codingtrain/matrix.js", function () {
$.getScript("/uploads/chinya07/convnet-min.js", function () {
$.getScript("/uploads/chinya07/mymnist.js", function () {
console.log("All JS loaded");
let layer_defs = [];
layer_defs.push({
type: "input",
out_sx: 28,
out_sy: 28,
out_depth: 1
});
layer_defs.push({
type: "conv",
sx: 5,
filters: 8,
stride: 1,
pad: 2,
activation: "relu"
});
layer_defs.push({
type: "pool",
sx: 2,
stride: 2
});
layer_defs.push({
type: "conv",
sx: 5,
filters: 16,
stride: 1,
pad: 2,
activation: "relu"
});
layer_defs.push({
type: "pool",
sx: 3,
stride: 3
});
layer_defs.push({
type: "softmax",
num_classes: 27
});
cnn_model = new convnetjs.Net;
AB.restoreData(function (snapshot) {
if (console.log(snapshot), void 0 !== snapshot) {
cnn_model.fromJSON(snapshot.cnn);
DOODLE_TOTAL_GUESS = snapshot.doodle_total_guess;
DOODLE_TOTAL_WRONG = snapshot.doodle_total_wrong;
let data = "Doodle score:" + ((DOODLE_TOTAL_GUESS - DOODLE_TOTAL_WRONG) / DOODLE_TOTAL_GUESS * 100).toFixed(2);
AB.msg(data, 2);
}
});
cnn_model.makeLayers(layer_defs);
cnn_trainer = new convnetjs.SGDTrainer(cnn_model, {
method: "adadelta",
batch_size: 32,
l2_decay: .001
});
loadData();
});
});
});
}
function loadData() {
loadMNIST(function (canCreateDiscussions) {
mnist = canCreateDiscussions;
var indexLookupKey = 0;
for (; indexLookupKey < NOTRAIN; indexLookupKey++) {
imageRotate(mnist.train_images[indexLookupKey]);
}
indexLookupKey = 0;
for (; indexLookupKey < NOTEST; indexLookupKey++) {
imageRotate(mnist.test_images[indexLookupKey]);
}
console.log("All data loaded into mnist object:");
AB.removeLoading();
});
}
function imageRotate(ctx) {
for (let i = 0; i < PIXELS; i++) {
for (let j = i; j < PIXELS; j++) {
let name = i * PIXELS + j;
let id = j * PIXELS + i;
let canvas = ctx[name];
ctx[name] = ctx[id];
ctx[id] = canvas;
}
}
}
function getImage(id) {
let img = createImage(PIXELS, PIXELS);
img.loadPixels();
for (let i = 0; i < PIXELSSQUARED; i++) {
let o = id[i];
let index = 4 * i;
img.pixels[index + 0] = o;
img.pixels[index + 1] = o;
img.pixels[index + 2] = o;
img.pixels[index + 3] = 255;
}
return img.updatePixels(), img;
}
function getInputs(id) {
let inp = [];
for (let i = 0; i < PIXELSSQUARED; i++) {
let val = id[i];
inp[i] = val / 255;
}
return inp;
}
function trainData(addedRenderer) {
let id = mnist.train_images[train_index];
let ystruct = mnist.train_labels[train_index];
if (addedRenderer) {
var img = getImage(id);
image(img, 0, ZOOMPIXELS + 50, ZOOMPIXELS, ZOOMPIXELS);
image(img, ZOOMPIXELS + 50, ZOOMPIXELS + 50, PIXELS, PIXELS);
}
let nameArgs = getInputs(id);
train_inputs = nameArgs; {
let set = getCnnInputs(nameArgs);
cnn_trainer.train(set, ystruct);
}
thehtml = " trainrun: " + trainrun + "<br> no: " + train_index;
AB.msg(thehtml, 4);
if (++train_index == NOTRAIN) {
train_index = 0;
dataSavedOnAB();
console.log("finished trainrun: " + trainrun);
trainrun++;
}
}
function getCnnInputs(t) {
var anim = new convnetjs.Vol(PIXELS, PIXELS, 1, 0);
var j = 0;
for (; j < PIXELSSQUARED; j++) {
anim.w[j] = t[j];
}
return anim;
}
function testData() {
let id = mnist.test_images[test_index];
let e = mnist.test_labels[test_index];
let nameArgs = getInputs(id);
let encoding = getCnnInputs(nameArgs);
test_inputs = nameArgs;
let place = findMax(cnn_model.forward(encoding).w);
var img = getImage(id);
image(img, 0, ZOOMPIXELS + 50, ZOOMPIXELS, ZOOMPIXELS);
image(img, ZOOMPIXELS + 50, ZOOMPIXELS + 50, PIXELS, PIXELS);
total_tests++;
if (place == e) {
total_correct++;
}
let e_total = total_correct / total_tests * 100;
thehtml = " testrun: " + testrun + "<br> no: " + total_tests + " <br> correct: " + total_correct + "<br> score: " + greenspan + e_total.toFixed(2) + "</span>";
AB.msg(thehtml, 6);
if (++test_index == NOTEST) {
console.log("finished testrun: " + testrun + " score: " + e_total.toFixed(2));
testrun++;
test_index = 0;
total_tests = 0;
total_correct = 0;
}
}
function find12(diffs) {
let v = 0;
let validationVM = 0;
let max = 0;
let d = 0;
for (let j = 0; j < diffs.length; j++) {
if (diffs[j] > max) {
validationVM = v;
d = max;
v = j;
max = diffs[j];
} else {
if (diffs[j] > d) {
validationVM = j;
d = diffs[j];
}
}
}
return [v, validationVM];
}
function findMax(array) {
let maxI = 0;
let maxValue = 0;
for (let i = 0; i < array.length; i++) {
if (array[i] > maxValue) {
maxI = i;
maxValue = array[i];
}
}
return maxI;
}
function draw() {
if (void 0 !== mnist) {
if (background("black"), do_training) {
for (let t = 0; t < TRAINPERSTEP; t++) {
trainData(false);
}
for (let t = 0; t < TESTPERSTEP; t++) {
testData();
}
}
if (demo_exists && (drawDemo(), guessDemo()), doodle_exists && drawDoodle(), mouseIsPressed) {
var left = ZOOMPIXELS + 20;
if (mouseX < left && mouseY < left && pmouseX < left && pmouseY < left) {
mousedrag = true;
doodle_exists = true;
doodle.stroke("white");
doodle.strokeWeight(DOODLE_THICK);
doodle.line(mouseX, mouseY, pmouseX, pmouseY);
}
} else {
if (mousedrag) {
mousedrag = false;
doodle.filter(BLUR, DOODLE_BLUR);
}
}
}
}
function makeDemo() {
demo_exists = true;
var i = AB.randomIntAtoB(0, NOTEST - 1);
demo = mnist.test_images[i];
var j = mnist.test_labels[i];
thehtml = "Test image no: " + i + "<br>Classification: " + letters[j - 1] + "<br>";
AB.msg(thehtml, 8);
}
function drawDemo() {
var sal = getImage(demo);
image(sal, 0, canvasheight - ZOOMPIXELS, ZOOMPIXELS, ZOOMPIXELS);
image(sal, ZOOMPIXELS + 50, canvasheight - ZOOMPIXELS, PIXELS, PIXELS);
}
function guessDemo() {
let nameArgs = getInputs(demo);
demo_inputs = nameArgs;
let encoding = getCnnInputs(nameArgs);
let j = findMax(cnn_model.forward(encoding).w);
thehtml = " We classify it as: " + greenspan + letters[j - 1] + "</span>";
AB.msg(thehtml, 9);
}
function dataSavedOnAB() {
let data = {};
data.doodle_total_guess = DOODLE_TOTAL_GUESS;
data.doodle_total_wrong = DOODLE_TOTAL_WRONG;
data.cnn = cnn_model.toJSON();
AB.saveData(data);
}
function drawDoodle() {
let sal = doodle.get();
image(sal, 0, 0, ZOOMPIXELS, ZOOMPIXELS);
image(sal, ZOOMPIXELS + 50, 0, PIXELS, PIXELS);
}
function guessWrong() {
let data = "Doodle score:" + ((DOODLE_TOTAL_GUESS - ++DOODLE_TOTAL_WRONG) / DOODLE_TOTAL_GUESS * 100).toFixed(2);
AB.msg(data, 2);
dataSavedOnAB();
}
// *****************************************************************************************************************************************
function guessDoodle() {
let dst = doodle.get();
DOODLE_TOTAL_GUESS++;
dst.resize(PIXELS, PIXELS);
dst.loadPixels();
let n = [];
for (let i = 0; i < PIXELSSQUARED; i++) {
n[i] = dst.pixels[4 * i] / 255;
}
doodle_inputs = n;
let encoding = getCnnInputs(n);
let o = findMax(cnn_model.forward(encoding).w);
thehtml = "Doodle score:" + ((DOODLE_TOTAL_GUESS - DOODLE_TOTAL_WRONG) / DOODLE_TOTAL_GUESS * 100).toFixed(2) + "<br>We classify it as: " + greenspan + letters[o - 1] + "</span><button class='normbutton' onclick='guessWrong()'>wrong</button>";
AB.msg(thehtml, 2);
dataSavedOnAB();
}
// *****************************************************************************************************************************************
function wipeDoodle() {
doodle_exists = false;
doodle.background("black");
let data = "Doodle score:" + ((DOODLE_TOTAL_GUESS - DOODLE_TOTAL_WRONG) / DOODLE_TOTAL_GUESS * 100).toFixed(2);
AB.msg(data, 2);
}
function showInputs(groups) {
var stderr = "";
for (let i = 0; i < groups.length; i++) {
if (i % PIXELS == 0) {
stderr = stderr + "\n";
}
stderr = stderr + " " + groups[i].toFixed(2);
}
console.log(stderr);
};