// Cloned by Vyoma Patel on 18 Nov 2021 from World "Character recognition neural network" by "Coding Train" project
// Please leave this clone trail here.
// Port of Character recognition neural network from here:
// https://github.com/CodingTrain/Toy-Neural-Network-JS/tree/master/examples/mnist
// with many modifications
// --- defined by MNIST - do not change these ---------------------------------------
var PIXELS = 28; //images in dataset are tiny
var PIXELSSQUARED = PIXELS * PIXELS;
// number of training and test exemplars in the data set:
var NOTRAIN = 124800;
var NOTEST = 20800;
//--- can modify all these --------------------------------------------------
// no of nodes in network
var noinput = PIXELSSQUARED;
var nohidden = 64;
var nooutput = 10;
var learningrate = 0.1; // default 0.1
// should we train every timestep or not
let do_training = true;
// how many to train and test per timestep
const TRAINPERSTEP = 20;//30;
const TESTPERSTEP = 5;
// multiply it by this to magnify for display
const ZOOMFACTOR = 7;
const ZOOMPIXELS = ZOOMFACTOR * PIXELS;
// 3 rows of
// large image + 50 gap + small image
// 50 gap between rows
const canvaswidth = ( PIXELS + ZOOMPIXELS ) + 50;
const canvasheight = ( ZOOMPIXELS * 3 ) + 100;
const DOODLE_THICK = 18; // thickness of doodle lines
const DOODLE_BLUR = 3; // blur factor applied to doodles
let mnist;
// all data is loaded into this
// mnist.train_images
// mnist.train_labels
// mnist.test_images
// mnist.test_labels
let nn;
alphabets = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"];
let trainrun = 1;
let train_index = 0;
let testrun = 1;
let test_index = 0;
let total_tests = 0;
let total_correct = 0;
// images in LHS:
let doodle, demo;
let doodle_exists = false;
let demo_exists = false;
let mousedrag = false; // are we in the middle of a mouse drag drawing?
// save inputs to global var to inspect
// type these names in console
var train_inputs, test_inputs, demo_inputs, doodle_inputs;
var thehtml;
// Matrix.randomize() is changed to point to this. Must be defined by user of Matrix.
function randomWeight() {
return AB.randomFloatAtoB(-.5, .5);
}
// make run header bigger
AB.headerCSS({
"max-height" : "95vh"
}),
//--- start of AB.msgs structure: ---------------------------------------------------------
// We output a serious of AB.msgs to put data at various places in the run header
// 1 Doodle header
thehtml = "<hr> <h1> 1. Doodle </h1> Top row: Doodle (left) and shrunk (right). <br> Draw your doodle in top LHS. <button onclick='wipeDoodle();' class='normbutton' >Clear and Re-draw</button> <br> ", AB.msg(thehtml, 1), thehtml = "<hr> <h1> 2. Training </h1> Middle row: Training image magnified (left) and original (right). <br> <button onclick='do_training = false;' class='normbutton' >Pause training</button> <button onclick='do_training = true;' class='normbutton' >Resume training</button> <br> ",
AB.msg(thehtml, 3),
thehtml = "<h3> Hidden tests </h3> ", AB.msg(thehtml, 5),
// 3 Training header
thehtml = "<hr> <h1> 3. Demo </h1> Bottom row: Test image magnified (left) and original (right). <br> The network is <i>not</i> trained on any of these images. <br> <button onclick='makeDemo();' class='normbutton' >Demo test image</button> <br> ", AB.msg(thehtml, 7);
var bluespan = "<span style='font-weight:bold; font-size:x-large; color:darkblue'> ";
function setup() {
createCanvas(canvaswidth, canvasheight);
(doodle = createGraphics(ZOOMPIXELS, ZOOMPIXELS)).pixelDensity(1);
wipeDoodle();
// JS load other JS
// maybe have a loading screen while loading the JS and the data set
AB.loadingScreen();
$.getScript("/uploads/codingtrain/matrix.js", function() {
$.getScript("/uploads/vyoma136/convnet.js", function() {
$.getScript("/uploads/vyoma136/Vyoma_MNIST.js", function() {
console.log("All JS Files loaded");
// Convolutional Layer featuring images and taking all values from 0 to 255 in each pixel defined with better quality of inout provided
//Pooling Layer where only most salient features of images are considered one which gives the max. information
//refer this link for better understanding: https://cs.stanford.edu/people/karpathy/convnetjs/demo/mnist.html
var layerDefs = [];
layerDefs.push({
type : "input",
out_sx : 28,
out_sy : 28,
out_depth : 1
});
layerDefs.push({
type : "conv",
sx : 5,
filters : 8,
stride : 1,
pad : 2,
activation : "relu"
});
layerDefs.push({
type : "pool",
sx : 2,
stride : 2
});
layerDefs.push({
type : "conv",
sx : 5,
filters : 16,
stride : 1,
pad : 2,
activation : "relu"
});
layerDefs.push({
type : "pool",
sx : 3,
stride : 3
});
layerDefs.push({
type : "softmax",
num_classes : 27
});
(mycnnModel = new convnetjs.Net).makeLayers(layerDefs);
mycnnTrain = new convnetjs.SGDTrainer(mycnnModel, {
method : "adadelta",
momentum : .9,
batch_size : 10,
l2_decay : .001
});
loadData();
});
});
});
}
// load data set from local file (on this server)
function loadData() {
loadMNIST(function(canCreateDiscussions) {
mnist = canCreateDiscussions;
var indexLookupKey = 0;
for (; indexLookupKey < NOTRAIN; indexLookupKey++) {
rotateImage(mnist.train_images[indexLookupKey]);
}
indexLookupKey = 0;
for (; indexLookupKey < NOTEST; indexLookupKey++) {
rotateImage(mnist.test_images[indexLookupKey]);
}
console.log("All data loaded into Emnist object.");
console.log(mnist);
AB.removeLoading(); // if no loading screen exists, this does nothing
});
}
function getImage(id) {
var img = createImage(PIXELS, PIXELS); // make a P5 image object from a raw data array
img.loadPixels();
var i = 0;
for (; i < PIXELSSQUARED; i++) {
var ident = id[i];
var offset = 4 * i;
img.pixels[offset + 0] = ident;
img.pixels[offset + 1] = ident;
img.pixels[offset + 2] = ident;
img.pixels[offset + 3] = 255;
}
return img.updatePixels(), img;
}
function getInputs(id) { // convert img array into normalised input array
var resolutions = [];
var i = 0;
for (; i < PIXELSSQUARED; i++) {
var d = id[i];
resolutions[i] = d / 255; // normalise to 0 to 1
}
return resolutions;
}
function rotateImage(ctx) {
var a = 0;
for (; a < PIXELS; a++) {
var b = a;
for (; b < PIXELS; b++) {
var key = a * PIXELS + b;
var s = b * PIXELS + a;
var val = ctx[key];
ctx[key] = ctx[s];
ctx[s] = val;
}
}
}
function trainit(canCreateDiscussions) { // train the network with a single exemplar, from global var "train_index", show visual on or off
var id = mnist.train_images[train_index];
var ystruct = mnist.train_labels[train_index];
// optional - show visual of the image
if (canCreateDiscussions) {
var img = getImage(id);
image(img, 0, ZOOMPIXELS + 50, ZOOMPIXELS, ZOOMPIXELS); // magnified
image(img, ZOOMPIXELS + 50, ZOOMPIXELS + 50, PIXELS, PIXELS); // original
}
var level = getInputs(id); // get inputs from data array
train_inputs = level;
{
var label = getmycnnInputs(level);
mycnnTrain.train(label, ystruct);
}
thehtml = " trainrun: " + trainrun + "<br> no: " + train_index;
AB.msg(thehtml, 4);
if (++train_index == NOTRAIN) {
train_index = 0;
console.log("finished trainrun: " + trainrun);
trainrun++;
}
}
function getmycnnInputs(obj) {
var e = new convnetjs.Vol(28, 28, 1, 0);
var i = 0;
for (; i < PIXELSSQUARED; i++) {
e.w[i] = obj[i];
}
return e;
}
function testit() { // test the network with a single exemplar
var id = mnist.test_images[test_index];
var test = mnist.test_labels[test_index];
var data = getInputs(id);
var encoding = getmycnnInputs(data);
test_inputs = data;
var dirName = findMax(mycnnModel.forward(encoding).w);
var img = getImage(id);
image(img, 0, ZOOMPIXELS + 50, ZOOMPIXELS, ZOOMPIXELS);
image(img, ZOOMPIXELS + 50, ZOOMPIXELS + 50, PIXELS, PIXELS);
total_tests++;
if (dirName == test) {
total_correct++;
}
var e_total = total_correct / total_tests * 100;
thehtml = " testrun: " + testrun + "<br> no: " + total_tests + " <br> correct: " + total_correct + "<br> score: " + bluespan + e_total.toFixed(2) + "</span>";
AB.msg(thehtml, 6);
if (++test_index == NOTEST) {
console.log("finished testrun: " + testrun + " score: " + e_total.toFixed(2));
testrun++;
test_index = 0;
total_tests = 0;
total_correct = 0;
}
}
//--- find no.1 (and maybe no.2) output nodes ---------------------------------------
// (restriction) assumes array values start at 0 (which is true for output nodes)
function find12(prices) {
var common = 0;
var current = 0;
var maxSell = 0;
var minBuy = 0;
var i = 0;
for (; i < prices.length; i++) {
if (prices[i] > maxSell) {
current = common;
minBuy = maxSell;
common = i;
maxSell = prices[i];
} else {
if (prices[i] > minBuy) {
current = i;
minBuy = prices[i];
}
}
}
return [common, current];
}
function findMax(array) {
var maxI = 0;
var maxValue = 0;
var i = 0;
for (; i < array.length; i++) {
if (array[i] > maxValue) {
maxI = i;
maxValue = array[i];
}
}
return maxI;
}
function draw() {
if (void 0 !== mnist) {
if (background("black"), strokeWeight(1), stroke("green"), rect(0, 0, ZOOMPIXELS, ZOOMPIXELS), textSize(10), textAlign(CENTER), text("DOODLE HERE", ZOOMPIXELS / 2, ZOOMPIXELS / 2.2), do_training) {
var _t2 = 0;
for (; _t2 < TRAINPERSTEP; _t2++) {
trainit(0 === _t2);
}
var _t3 = 0;
for (; _t3 < TESTPERSTEP; _t3++) {
testit();
}
}
if (demo_exists && (drawDemo(), guessDemo()), doodle_exists && (drawDoodle(), guessDoodle()), mouseIsPressed) {
var left = ZOOMPIXELS + 20;
if (mouseX < left && mouseY < left && pmouseX < left && pmouseY < left) {
mousedrag = true;
doodle_exists = true;
doodle.stroke("red");
strokeJoin(ROUND);
doodle.strokeWeight(DOODLE_THICK);
doodle.line(mouseX, mouseY, pmouseX, pmouseY);
}
} else {
if (mousedrag) {
mousedrag = false;
doodle.filter(BLUR, DOODLE_BLUR);
}
}
}
}
function makeDemo() {
demo_exists = true;
var i = AB.randomIntAtoB(0, NOTEST - 1);
demo = mnist.test_images[i];
var beforeTab = mnist.test_labels[i];
thehtml = "Test image no: " + i + "<br>Classification: " + alphabets[beforeTab - 1] + "<br>";
AB.msg(thehtml, 8);
}
function drawDemo() {
var sal = getImage(demo);
image(sal, 0, canvasheight - ZOOMPIXELS, ZOOMPIXELS, ZOOMPIXELS);
image(sal, ZOOMPIXELS + 50, canvasheight - ZOOMPIXELS, PIXELS, PIXELS);
}
function guessDemo() {
var data = getInputs(demo);
demo_inputs = data;
var encoding = getmycnnInputs(data);
var guess = findMax(mycnnModel.forward(encoding).w);
thehtml = " We classify it as: " + bluespan + alphabets[guess - 1] + "</span>";
AB.msg(thehtml, 9);
}
function drawDoodle() {
var sal = doodle.get();
image(sal, 0, 0, ZOOMPIXELS, ZOOMPIXELS);
image(sal, ZOOMPIXELS + 20, 0, PIXELS, PIXELS);
}
function guessDoodle() {
var dst = doodle.get();
dst.resize(PIXELS, PIXELS);
dst.loadPixels();
var structuredData = [];
var newTypeName = 0;
for (; newTypeName < PIXELSSQUARED; newTypeName++) {
structuredData[newTypeName] = dst.pixels[4 * newTypeName] / 255;
}
doodle_inputs = structuredData;
var encoding = getmycnnInputs(structuredData);
var o = find12(mycnnModel.forward(encoding).w);
thehtml = " Our 1st Guess is: " + bluespan + alphabets[o[0] - 1] + "</span> <br> Our 2nd Guess is: " + bluespan + alphabets[o[1] - 1] + "</span>";
AB.msg(thehtml, 2);
}
function wipeDoodle() {
doodle_exists = false;
doodle.background("black");
}
// --- debugging --------------------------------------------------
// in console
// showInputs(demo_inputs);
// showInputs(doodle_inputs);
function showInputs(groups) {
var stderr = "";
var i = 0;
for (; i < groups.length; i++) {
if (i % PIXELS == 0) {
stderr = stderr + "\n";
}
stderr = stderr + " " + groups[i].toFixed(2);
}
console.log(stderr);
}
;