Code viewer for World: CA686_Assignment_2
'use strict';
const PIXELS = 28;
const PIXELSSQUARED = PIXELS * PIXELS;
const NOTRAIN = 124800;
const NOTEST = 20800;
const noinput = PIXELSSQUARED;
const nohidden = 64;
const nooutput = 10;
const learningrate = 0.001;
let do_training = false;
const TRAINPERSTEP = 12;
const TESTPERSTEP = 2;
const ZOOMFACTOR = 10;
const ZOOMPIXELS = 8 * PIXELS;
const canvaswidth = PIXELS + ZOOMPIXELS + 150;
const canvasheight = 3 * ZOOMPIXELS + 200;
const DOODLE_THICK = 15;
const DOODLE_BLUR = 1;
let mnist;
let cnn;
let cnn_model;
let cnn_trainer;
let doodle;
let demo;
let DOODLE_TOTAL_GUESS = 1;
let DOODLE_TOTAL_WRONG = 0;
let trainrun = 1;
let train_index = 0;
let testrun = 1;
let test_index = 0;
let total_tests = 0;
let total_correct = 0;
let doodle_exists = false;
let demo_exists = false;
let mousedrag = false;
var train_inputs;
var test_inputs;
var demo_inputs;
var doodle_inputs;

function randomWeight() {
    return AB.randomFloatAtoB(-.5, .5);
}
AB.headerCSS({
    "max-height": "95vh"
});
let letters = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"];
var thehtml;
thehtml = "Draw your doodle in top LHS. <button onclick='wipeDoodle();' class='normbutton' >Clear</button>  <button class='normbutton' onclick='guessDoodle();'>Guess</button> <br>", AB.msg(thehtml, 1), thehtml = "<hr> <h1> 2. Training </h1> Middle row: Training image magnified (left) and original (right). <br>   <button onclick='do_training = !do_training;' class='normbutton' >start or end train</button> <br> ", AB.msg(thehtml, 3), thehtml = "<h3> Hidden tests </h3> ", AB.msg(thehtml, 5), thehtml =
    "<hr> <h1> 3. Demo </h1> Bottom row: Test image magnified (left) and  original (right). <br> The network is <i>not</i> trained on any of these images. <br>  <button onclick='makeDemo();' class='normbutton' >Demo test image</button> <br> ", AB.msg(thehtml, 7);
const greenspan = "<span style='font-weight:bold; font-size:x-large; color:darkgreen'> ";


function setup() {
    createCanvas(canvaswidth, canvasheight);
    (doodle = createGraphics(ZOOMPIXELS, ZOOMPIXELS)).pixelDensity(1);
    AB.loadingScreen();
    $.getScript("/uploads/codingtrain/matrix.js", function () {
        $.getScript("/uploads/chinya07/convnet-min.js", function () {
            $.getScript("/uploads/chinya07/mymnist.js", function () {
                console.log("All JS loaded");

                let layer_defs = [];
                layer_defs.push({
                    type: "input",
                    out_sx: 28,
                    out_sy: 28,
                    out_depth: 1
                });
                layer_defs.push({
                    type: "conv",
                    sx: 5,
                    filters: 8,
                    stride: 1,
                    pad: 2,
                    activation: "relu"
                });
                layer_defs.push({
                    type: "pool",
                    sx: 2,
                    stride: 2
                });
                layer_defs.push({
                    type: "conv",
                    sx: 5,
                    filters: 16,
                    stride: 1,
                    pad: 2,
                    activation: "relu"
                });
                layer_defs.push({
                    type: "pool",
                    sx: 3,
                    stride: 3
                });
                layer_defs.push({
                    type: "softmax",
                    num_classes: 27
                });
                cnn_model = new convnetjs.Net;
                AB.restoreData(function (snapshot) {
                    if (console.log(snapshot), void 0 !== snapshot) {
                        cnn_model.fromJSON(snapshot.cnn);
                        DOODLE_TOTAL_GUESS = snapshot.doodle_total_guess;
                        DOODLE_TOTAL_WRONG = snapshot.doodle_total_wrong;
                        let data = "Doodle score:" + ((DOODLE_TOTAL_GUESS - DOODLE_TOTAL_WRONG) / DOODLE_TOTAL_GUESS * 100).toFixed(2);
                        AB.msg(data, 2);
                    }
                });
                cnn_model.makeLayers(layer_defs);
                cnn_trainer = new convnetjs.SGDTrainer(cnn_model, {
                    method: "adadelta",
                    batch_size: 32,
                    l2_decay: .001
                });
                loadData();
            });
        });
    });
}


function loadData() {
    loadMNIST(function (canCreateDiscussions) {

        mnist = canCreateDiscussions;

        var indexLookupKey = 0;
        for (; indexLookupKey < NOTRAIN; indexLookupKey++) {
            imageRotate(mnist.train_images[indexLookupKey]);
        }

        indexLookupKey = 0;
        for (; indexLookupKey < NOTEST; indexLookupKey++) {
            imageRotate(mnist.test_images[indexLookupKey]);
        }
        console.log("All data loaded into mnist object:");
        AB.removeLoading();
    });
}

function imageRotate(ctx) {
    for (let i = 0; i < PIXELS; i++) {
        for (let j = i; j < PIXELS; j++) {
            let name = i * PIXELS + j;
            let id = j * PIXELS + i;
            let canvas = ctx[name];
            ctx[name] = ctx[id];
            ctx[id] = canvas;
        }
    }
}

function getImage(id) {
    let img = createImage(PIXELS, PIXELS);
    img.loadPixels();
    for (let i = 0; i < PIXELSSQUARED; i++) {
        let o = id[i];
        let index = 4 * i;
        img.pixels[index + 0] = o;
        img.pixels[index + 1] = o;
        img.pixels[index + 2] = o;

        img.pixels[index + 3] = 255;
    }
    return img.updatePixels(), img;
}

function getInputs(id) {
    let inp = [];
    for (let i = 0; i < PIXELSSQUARED; i++) {
        let val = id[i];

        inp[i] = val / 255;
    }
    return inp;
}

function trainData(addedRenderer) {
    let id = mnist.train_images[train_index];
    let ystruct = mnist.train_labels[train_index];
    if (addedRenderer) {
        var img = getImage(id);
        image(img, 0, ZOOMPIXELS + 50, ZOOMPIXELS, ZOOMPIXELS);
        image(img, ZOOMPIXELS + 50, ZOOMPIXELS + 50, PIXELS, PIXELS);
    }
    let nameArgs = getInputs(id);
    train_inputs = nameArgs; {
        let set = getCnnInputs(nameArgs);
        cnn_trainer.train(set, ystruct);
    }
    thehtml = " trainrun: " + trainrun + "<br> no: " + train_index;
    AB.msg(thehtml, 4);
    if (++train_index == NOTRAIN) {

        train_index = 0;
        dataSavedOnAB();
        console.log("finished trainrun: " + trainrun);
        trainrun++;
    }
}

function getCnnInputs(t) {
    var anim = new convnetjs.Vol(PIXELS, PIXELS, 1, 0);

    var j = 0;
    for (; j < PIXELSSQUARED; j++) {
        anim.w[j] = t[j];
    }
    return anim;
}

function testData() {
    let id = mnist.test_images[test_index];
    let e = mnist.test_labels[test_index];
    let nameArgs = getInputs(id);
    let encoding = getCnnInputs(nameArgs);
    test_inputs = nameArgs;
    let place = findMax(cnn_model.forward(encoding).w);
    var img = getImage(id);
    image(img, 0, ZOOMPIXELS + 50, ZOOMPIXELS, ZOOMPIXELS);
    image(img, ZOOMPIXELS + 50, ZOOMPIXELS + 50, PIXELS, PIXELS);
    total_tests++;
    if (place == e) {
        total_correct++;
    }
    let e_total = total_correct / total_tests * 100;

    thehtml = " testrun: " + testrun + "<br> no: " + total_tests + " <br>  correct: " + total_correct + "<br>  score: " + greenspan + e_total.toFixed(2) + "</span>";
    AB.msg(thehtml, 6);
    if (++test_index == NOTEST) {
        console.log("finished testrun: " + testrun + " score: " + e_total.toFixed(2));
        testrun++;

        test_index = 0;

        total_tests = 0;

        total_correct = 0;
    }
}

function find12(diffs) {
    let v = 0;
    let validationVM = 0;
    let max = 0;
    let d = 0;
    for (let j = 0; j < diffs.length; j++) {
        if (diffs[j] > max) {
            validationVM = v;
            d = max;
            v = j;
            max = diffs[j];
        } else {
            if (diffs[j] > d) {
                validationVM = j;
                d = diffs[j];
            }
        }
    }
    return [v, validationVM];
}

function findMax(array) {
    let maxI = 0;
    let maxValue = 0;
    for (let i = 0; i < array.length; i++) {
        if (array[i] > maxValue) {
            maxI = i;
            maxValue = array[i];
        }
    }
    return maxI;
}

function draw() {
    if (void 0 !== mnist) {
        if (background("black"), do_training) {
            for (let t = 0; t < TRAINPERSTEP; t++) {
                trainData(false);
            }
            for (let t = 0; t < TESTPERSTEP; t++) {
                testData();
            }
        }
        if (demo_exists && (drawDemo(), guessDemo()), doodle_exists && drawDoodle(), mouseIsPressed) {
            var left = ZOOMPIXELS + 20;
            if (mouseX < left && mouseY < left && pmouseX < left && pmouseY < left) {

                mousedrag = true;

                doodle_exists = true;
                doodle.stroke("white");
                doodle.strokeWeight(DOODLE_THICK);
                doodle.line(mouseX, mouseY, pmouseX, pmouseY);
            }
        } else {
            if (mousedrag) {

                mousedrag = false;
                doodle.filter(BLUR, DOODLE_BLUR);
            }
        }
    }
}

function makeDemo() {

    demo_exists = true;
    var i = AB.randomIntAtoB(0, NOTEST - 1);
    demo = mnist.test_images[i];
    var j = mnist.test_labels[i];

    thehtml = "Test image no: " + i + "<br>Classification: " + letters[j - 1] + "<br>";
    AB.msg(thehtml, 8);
}

function drawDemo() {
    var sal = getImage(demo);
    image(sal, 0, canvasheight - ZOOMPIXELS, ZOOMPIXELS, ZOOMPIXELS);
    image(sal, ZOOMPIXELS + 50, canvasheight - ZOOMPIXELS, PIXELS, PIXELS);
}

function guessDemo() {
    let nameArgs = getInputs(demo);
    demo_inputs = nameArgs;
    let encoding = getCnnInputs(nameArgs);
    let j = findMax(cnn_model.forward(encoding).w);

    thehtml = " We classify it as: " + greenspan + letters[j - 1] + "</span>";
    AB.msg(thehtml, 9);
}


function dataSavedOnAB() {
    let data = {};
    data.doodle_total_guess = DOODLE_TOTAL_GUESS;
    data.doodle_total_wrong = DOODLE_TOTAL_WRONG;
    data.cnn = cnn_model.toJSON();
    AB.saveData(data);
}

function drawDoodle() {
    let sal = doodle.get();
    image(sal, 0, 0, ZOOMPIXELS, ZOOMPIXELS);
    image(sal, ZOOMPIXELS + 50, 0, PIXELS, PIXELS);
}

function guessWrong() {
    let data = "Doodle score:" + ((DOODLE_TOTAL_GUESS - ++DOODLE_TOTAL_WRONG) / DOODLE_TOTAL_GUESS * 100).toFixed(2);
    AB.msg(data, 2);
    dataSavedOnAB();
}
// *****************************************************************************************************************************************
function guessDoodle() {
    let dst = doodle.get();
    DOODLE_TOTAL_GUESS++;
    dst.resize(PIXELS, PIXELS);
    dst.loadPixels();
    let n = [];
    for (let i = 0; i < PIXELSSQUARED; i++) {

        n[i] = dst.pixels[4 * i] / 255;
    }
    doodle_inputs = n;
    let encoding = getCnnInputs(n);
    let o = findMax(cnn_model.forward(encoding).w);

    thehtml = "Doodle score:" + ((DOODLE_TOTAL_GUESS - DOODLE_TOTAL_WRONG) / DOODLE_TOTAL_GUESS * 100).toFixed(2) + "<br>We classify it as: " + greenspan + letters[o - 1] + "</span><button class='normbutton' onclick='guessWrong()'>wrong</button>";
    AB.msg(thehtml, 2);
    dataSavedOnAB();
}
// *****************************************************************************************************************************************

function wipeDoodle() {

    doodle_exists = false;
    doodle.background("black");
    let data = "Doodle score:" + ((DOODLE_TOTAL_GUESS - DOODLE_TOTAL_WRONG) / DOODLE_TOTAL_GUESS * 100).toFixed(2);
    AB.msg(data, 2);
}

function showInputs(groups) {

    var stderr = "";
    for (let i = 0; i < groups.length; i++) {
        if (i % PIXELS == 0) {

            stderr = stderr + "\n";
        }
        stderr = stderr + " " + groups[i].toFixed(2);
    }
    console.log(stderr);
};