Code viewer for World: Character recognition neur...

// Cloned by Abdelshafa Abdala on 28 Nov 2021 from World "Character recognition neural network (2020-C686I: Assignment 2)" by Andrey Totev 
// Please leave this clone trail here.
 
const CROP_PIXELS=24,
PIXELS=28,

PIXELSSQUARED=PIXELS*PIXELS,

NOTRAIN=6e4,
NOTEST=1e4,


noinput=PIXELSSQUARED,
nohidden=64,
nooutput=10,

learningrate=.1;

let do_training=!0;


const BATCH_SIZE=50,
TRAINPERSTEP=50,

TESTPERSTEP=5,

ZOOMFACTOR=7,
ZOOMPIXELS=7*PIXELS,

canvaswidth=PIXELS+ZOOMPIXELS+50,
canvasheight=3*ZOOMPIXELS+100,

DOODLE_THICK=18,
DOODLE_BLUR=0;

let mnist;
const theNN=3;

let nn,

doodle,demo,
trainrun=1,
train_index=0,

testrun=1,
test_index=0,
total_tests=0,
total_correct=0,

doodle_exists=!1,
demo_exists=!1,

mousedrag=!1;

var train_inputs,test_inputs,demo_inputs,doodle_inputs,thehtml;

function randomWeight()
{
    return AB.randomFloatAtoB(-.5,.5)}$("#runheaderbox").css({"max-height":"95vh"}),
    
    thehtml="<hr> <h1> 1. Doodle </h1> Top row: Doodle (left) and shrunk (right). <br>  Draw your doodle in top LHS. <button onclick='wipeDoodle();' class='normbutton' >Clear doodle</button> <br> ",
    AB.msg(thehtml,1),
    
    thehtml="<hr> <h1> 2. Training </h1> Middle row: Training image magnified (left) and original (right). <br>   <button onclick='do_training = false;' class='normbutton' >Stop training</button> <br> ",
    AB.msg(thehtml,3),
    
    thehtml="<h3> Hidden tests </h3> ",
    AB.msg(thehtml,5),
    
    thehtml="<hr> <h1> 3. Demo </h1> Bottom row: Test image magnified (left) and  original (right). <br> The network is <i>not</i> trained on any of these images. <br>  <button onclick='makeDemo();' class='normbutton' >Demo test image</button> <br> ",
    AB.msg(thehtml,7);
    
    const greenspan="<span style='font-weight:bold; font-size:x-large; color:darkgreen'> ";
    
    function setup()
    
    {
        createCanvas(canvaswidth,canvasheight),
        (
            doodle=createGraphics(ZOOMPIXELS,ZOOMPIXELS)).pixelDensity(1),
            AB.loadingScreen(),$.getScript("/uploads/codingtrain/mnist.js",function()
            {
                $.getScript("/uploads/atotev/mathutils.js",function()
                
                {
                    $.getScript("/uploads/atotev/webcnn.js",function()
                    
                    {
                        $.getJSON("/uploads/atotev/cnn_mnist_10_20_98accuracy.json",function(e)
                        
                        {
                            console.log("All JS loaded"),
                            0===theNN?nn=createShallowNetwork(ACTIVATION_RELU):
                            1===theNN?nn=createShallowNetwork(ACTIVATION_TANH):
                            2===theNN?nn=createDefaultNetwork():
                            3===theNN?nn=loadNetworkFromJSON(e):
                            console.log("Unknown NN type: "+theNN),loadData()})})})})}
                            
function loadNetworkFromJSON(e)
{
    let t=new WebCNN;void 0!==e.momentum&&t.setMomentum(e.momentum),void 0!==e.lambda&&t.setLambda(e.lambda),void 0!==e.learningRate&&t.setLearningRate(e.learningRate);for(let n=0;n<e.layers.length;++n){let o=e.layers[n];console.log(o),t.newLayer(o)}for(let n=0;n<e.layers.length;++n){let o=e.layers[n];switch(e.layers[n].type){case LAYER_TYPE_CONV:case LAYER_TYPE_FULLY_CONNECTED:void 0!==o.weights&&void 0!==o.biases&&t.layers[n].setWeightsAndBiases(o.weights,o.biases)}}return t.initialize(),t}function createDefaultNetwork(){let e=new WebCNN;return e.newLayer({name:"image",type:LAYER_TYPE_INPUT_IMAGE,width:CROP_PIXELS,height:CROP_PIXELS,depth:1}),e.newLayer({name:"conv1",type:LAYER_TYPE_CONV,units:10,kernelWidth:5,kernelHeight:5,strideX:1,strideY:1,padding:!1}),e.newLayer({name:"pool1",type:LAYER_TYPE_MAX_POOL,poolWidth:2,poolHeight:2,strideX:2,strideY:2}),e.newLayer({name:"conv2",type:LAYER_TYPE_CONV,units:20,kernelWidth:5,kernelHeight:5,strideX:1,strideY:1,padding:!1}),e.newLayer({name:"pool2",type:LAYER_TYPE_MAX_POOL,poolWidth:2,poolHeight:2,strideX:2,strideY:2}),e.newLayer({name:"out",type:LAYER_TYPE_FULLY_CONNECTED,units:10,activation:ACTIVATION_SOFTMAX}),e.initialize(),e.setLearningRate(.01),e.setMomentum(.9),e.setLambda(0),e}function createShallowNetwork(e){let t=new WebCNN;return t.newLayer({name:"image",type:LAYER_TYPE_INPUT_IMAGE,width:CROP_PIXELS,height:CROP_PIXELS,depth:1}),t.newLayer({name:"hidden",type:LAYER_TYPE_FULLY_CONNECTED,units:nohidden,activation:e}),t.newLayer({name:"out",type:LAYER_TYPE_FULLY_CONNECTED,units:nooutput,activation:ACTIVATION_SOFTMAX}),t.initialize(),t.setLearningRate(.01),t.setMomentum(.9),t.setLambda(0),t}function loadData(){loadMNIST(function(e){mnist=e,console.log("All data loaded into mnist object:"),console.log(mnist),AB.removeLoading()})}function centerImage(e,t){let n=[];for(let o=0;o<t;o++){n[o]=[];for(let i=0;i<t;i++)n[o][i]=e[4*(o*t+i)]}for(var o=Number.MAX_VALUE,a=Number.MAX_VALUE,r=-1,s=-1,l=0;l<n.length;l++){var d=n[l].indexOf(255),m=n[l].lastIndexOf(255);d>=0&&d<a&&(a=d),m>=0&&m>s&&(s=m),d>=0&&l<o&&(o=l),d>=0&&l>r&&(r=l)}let u=Math.floor((t-r-o)/2),g=Math.floor((t-s-a)/2),h=Array(t).fill().map(()=>Array(t).fill(0));for(i=o;i<=r;i++)for(j=a;j<=s;j++)h[i+u][j+g]=n[i][j];let c=[];for(let e=0;e<t;e++)for(let n=0;n<t;n++)c[e*t+n]=h[e][n];return c}function getImage(e,t=PIXELS){let n=createImage(t,t);n.loadPixels();for(let o=0;o<t*t;o++){let t=e[o],i=4*o;n.pixels[i+0]=t,n.pixels[i+1]=t,n.pixels[i+2]=t,n.pixels[i+3]=255}return n.updatePixels(),n}function randomCrop(e,t){const n=PIXELS-t;return crop(e,t,Math.floor(Math.random()*n),Math.floor(Math.random()*n))}function crop(e,t,n=2,o=2){const i=PIXELS;let a=n+t,r=o+t,s=[];for(let t=n;t<a;t++)for(let n=o;n<r;n++)s.push(e[t*i+n]);return s}function getInputs(e){let t=[];for(let n=0;n<PIXELSSQUARED;n++){let o=e[n];t[n]=o/255}return t}function toModelFormat(e,t){return{width:t,height:t,data:getImage(randomCrop(e,t),t).pixels}}function trainit(e){if(train_index%TRAINPERSTEP!=0)return void train_index++;let t=mnist.train_images[train_index],n=(mnist.train_labels[train_index],[]),o=[];for(i=0;i<TRAINPERSTEP;i++)n.push(toModelFormat(mnist.train_images[train_index+i],CROP_PIXELS)),o.push(mnist.train_labels[train_index+i]);let a=t;var r=getImage(t);image(r,0,ZOOMPIXELS+50,ZOOMPIXELS,ZOOMPIXELS),image(r,ZOOMPIXELS+50,ZOOMPIXELS+50,PIXELS,PIXELS),train_inputs=a,nn.trainCNNClassifier(n,o),thehtml=" trainrun: "+trainrun+"<br> no: "+train_index,AB.msg(thehtml,4),++train_index+TRAINPERSTEP>=NOTRAIN&&(train_index=0,console.log("finished trainrun: "+trainrun),trainrun++)}function testit(){let e=mnist.test_images[test_index],t=mnist.test_labels[test_index],n=getInputs(e);test_inputs=n;let o=findMax(nn.classifyImages([toModelFormat(mnist.test_images[test_index],CROP_PIXELS)]));total_tests++,o==t&&total_correct++;let i=total_correct/total_tests*100;thehtml=" testrun: "+testrun+"<br> no: "+total_tests+" <br>  correct: "+total_correct+"<br>  score: "+greenspan+i.toFixed(2)+"</span>",AB.msg(thehtml,6),++test_index==NOTEST&&(console.log("finished testrun: "+testrun+" score: "+i.toFixed(2)),testrun++,test_index=0,total_tests=0,total_correct=0)}function find12(e){let t=0,n=0,o=0,i=0;for(var a=0;a<10;++a)e[0].getValue(0,0,a)>n?(n=e[0].getValue(0,0,a),t=a):e[0].getValue(0,0,a)>i&&(i=e[0].getValue(0,0,a),o=a);return[t,o]}function findMax(e){let t=0,n=0;for(var o=0;o<10;++o)e[0].getValue(0,0,o)>n&&(n=e[0].getValue(0,0,o),t=o);return t}function draw(){if(void 0!==mnist)if(background("black"),demo_exists&&(drawDemo(),guessDemo()),doodle_exists&&(drawDoodle(),guessDoodle()),mouseIsPressed){var e=ZOOMPIXELS+20;mouseX<e&&mouseY<e&&pmouseX<e&&pmouseY<e&&(mousedrag=!0,doodle_exists=!0,doodle.stroke("white"),doodle.strokeWeight(DOODLE_THICK),doodle.line(mouseX,mouseY,pmouseX,pmouseY))}else if(mousedrag)mousedrag=!1,console.log("Exiting draw. Now blurring."),doodle.filter(BLUR,DOODLE_BLUR);else if(do_training){for(let e=0;e<TRAINPERSTEP;e++)trainit(0===e);for(let e=0;e<TESTPERSTEP;e++)testit()}}function makeDemo(){demo_exists=!0;var e=AB.randomIntAtoB(0,NOTEST-1);demo=mnist.test_images[e];var t=mnist.test_labels[e];thehtml="Test image no: "+e+"<br>Classification: "+t+"<br>",AB.msg(thehtml,8)}function drawDemo(){var e=getImage(demo);image(e,0,canvasheight-ZOOMPIXELS,ZOOMPIXELS,ZOOMPIXELS),image(e,ZOOMPIXELS+50,canvasheight-ZOOMPIXELS,PIXELS,PIXELS)}function guessDemo(){let e=getInputs(demo);demo_inputs=e;let t=findMax(nn.classifyImages([toModelFormat(demo,CROP_PIXELS)]));thehtml=" We classify it as: "+greenspan+t+"</span>",AB.msg(thehtml,9)}function drawDoodle(){let e=doodle.get();image(e,0,0,ZOOMPIXELS,ZOOMPIXELS),image(e,ZOOMPIXELS+50,0,PIXELS,PIXELS)}function guessDoodle(){let e=doodle.get();e.resize(PIXELS,PIXELS),e.loadPixels();let t=find12(nn.classifyImages([toModelFormat(centerImage(e.pixels,PIXELS),CROP_PIXELS)]));thehtml=" We classify it as: "+greenspan+t[0]+"</span> <br> No.2 guess is: "+greenspan+t[1]+"</span>",AB.msg(thehtml,2)}function wipeDoodle(){doodle_exists=!1,doodle.background("black")}function showInputs(e){var t="";for(let n=0;n<e.length;n++){n%PIXELS==0&&(t+="\n"),t=t+" "+e[n].toFixed(2)}console.log(t)}