Code viewer for World: CCNN

// Cloned by AC on 10 Dec 2020 from World "Character recognition neural network (clone by AC)" by AC 
// Please leave this clone trail here.
 


// Cloned by AC on 2 Dec 2020 from World "Character recognition neural network" by "Coding Train" project 
// Please leave this clone trail here.
 

// Port of Character recognition neural network from here:
// https://github.com/CodingTrain/Toy-Neural-Network-JS/tree/master/examples/mnist
// with many modifications 


// --- defined by MNIST - do not change these ---------------------------------------

const PIXELS        = 28;                       // images in data set are tiny 
const PIXELSSQUARED = PIXELS * PIXELS;

// number of training and test exemplars in the data set:
const NOTRAIN = 60000;
const NOTEST  = 10000;



//--- can modify all these --------------------------------------------------

// no of nodes in network 
const noinput  = PIXELSSQUARED;
const nohidden = 64;  //64 initially
const nooutput = 10;

const learningrate = 0.1;   // default 0.1  

// should we train every timestep or not 
let do_training = false;

// how many to train and test per timestep 
const TRAINPERSTEP = 30;
const TESTPERSTEP  = 5;

// multiply it by this to magnify for display 
const ZOOMFACTOR    = 7;                        
const ZOOMPIXELS    = ZOOMFACTOR * PIXELS; 

// 3 rows of
// large image + 50 gap + small image    
// 50 gap between rows 

const canvaswidth = ( PIXELS + ZOOMPIXELS ) + 50;
const canvasheight = ( ZOOMPIXELS * 3 ) + 100;


const DOODLE_THICK = 12;    // thickness of doodle lines 
const DOODLE_BLUR = 3;      // blur factor applied to doodles 


let mnist;      
// all data is loaded into this 
// mnist.train_images
// mnist.train_labels
// mnist.test_images
// mnist.test_labels


let nn;

let trainrun = 1;
let train_index = 0;

let testrun = 1;
let test_index = 0;
let total_tests = 0;
let total_correct = 0;

// images in LHS:
let doodle, demo;
let doodle_exists = false;
let demo_exists = false;

let mousedrag = false;      // are we in the middle of a mouse drag drawing?  


// save inputs to global var to inspect
// type these names in console 
var train_inputs, test_inputs, demo_inputs, doodle_inputs, aug_doodle_inputs;

// ==================================================================================================
// === AC: additional global variable starts here  ===========================================================
// ==================================================================================================

//data augmentation 
const digitOrder = [1, 0,6,4,9,8,5,7,2,3];
var currentDigits = [];
const AUGMENTPERBATCH = 500;
let last_augmented_index = -1;
const roundDownLevels = 5;
var bDigitOrdering = false;

// neural network count
let nnNO = 0;

var train_log =[];

//CNN
var CNNTrained =false;
let cnn_doodle_result;

//canvas
let canvas;

// Boolean for Modification to doodles 
var bBlur = false;
var bCenter = true;

// Boolean for Data Augmentation
var bAugment = true;
var bAugmented = false;

var bIShift=false;
var bICrop=true;
var bIRotate=false;
var bIRoundDown = false;

let bDoodleUpdated = false;
var bTesting = true;

// Matrix.randomize() is changed to point to this. Must be defined by user of Matrix. 

function randomWeight()
{
    return ( AB.randomFloatAtoB ( -0.5, 0.5 ) );
            // Coding Train default is -1 to 1
}    



// CSS trick 
// make run header bigger 
 $("#runheaderbox").css ( { "max-width": "50vw", "max-height": "95vh" } );



//--- start of AB.msgs structure: ---------------------------------------------------------
// We output a serious of AB.msgs to put data at various places in the run header 
var thehtml=[];
let strActionText = [];
strActionText.blur = bBlur?"Blur:ON":"Blur:OFF";
strActionText.center = bCenter?"Center:ON":"Center:OFF"
  // 1 Doodle header 
    thehtml[1] = "<hr> <h1> 1. Doodle " +
                "<button onclick='toggleBlurAction();' class='normbutton' >" + strActionText.blur + " </button>" + 
                "<button onclick='toggleCenterAction();' class='normbutton' >" + strActionText.center + " </button>" + 
                "</h1>" +
            " Draw your doodle in top LHS. <button onclick='wipeDoodle();' class='normbutton' >Clear doodle</button> <br> ";
   AB.msg ( thehtml[1], 1 );

  // 2 Doodle variable data (guess)
    
 strActionText.training = do_training?"Stop":"Start";
 strActionText.ordering = bDigitOrdering?"Order:ON":"Order:OFF";
 strActionText.augmenting = bAugment?"Augment:ON":"Augment:OFF";
 
 strActionText.shift = bIShift?"Shift:ON":"Shift:OFF";
 strActionText.rotate = bIRotate?"Rotate:ON":"Rotate:OFF";
 strActionText.rounddown = bIRoundDown?"RoundDown:ON":"RoundDown:OFF";
 strActionText.crop = bICrop?"Crop:ON":"Crop:OFF";
 
 let strAugment = bAugment?"<button onclick='toggleShiftAction();' class='normbutton' >" + strActionText.shift + " </button>" +
        "<button onclick='toggleRotateAction();' class='normbutton' >" + strActionText.rotate + " </button>" + 
        "<button onclick='toggleRoundDownAction();' class='normbutton' >" + strActionText.rounddown + " </button>" +
        "<button onclick='toggleCropAction();' class='normbutton' >" + strActionText.crop + " </button>":""
  // 3 Training header
  thehtml[3] = "<hr> <h1> 2. Training "+ 
            "<button onclick='toggleTrainingAction();' class='normbutton' >" + strActionText.training + " </button>" + 
            "<button onclick='setupNetwork();' class='normbutton' >Refresh NN</button>"+ " </h1> " + 
            "Data Preprocessing: " +
            "<button onclick='toggleOrderAction();' class='normbutton' >" + strActionText.ordering + " </button>" + 
            "<button onclick='toggleAugmentAction();' class='normbutton' >" + strActionText.augmenting + " </button>" +
            "<br> " + strAugment; 
   
  AB.msg ( thehtml[3], 3 );
     
  // 4 variable training data 
  
  // 5 Testing header
  thehtml[5] = "<h3> Hidden tests </h3> " ;
  AB.msg ( thehtml[5], 5 );
           
  // 6 variable testing data 
  
  // 7 Demo header 
  thehtml[7] = "<hr> <h1> 3. Demo </h1> Bottom row: Test image magnified (left) and  original (right). <br>" +
        " The network is <i>not</i> trained on any of these images. <br> " +
        " <button onclick='makeDemo();' class='normbutton' >Demo test image</button> <br> ";
   AB.msg ( thehtml[7], 7 );
   
  // 8 Demo variable data (random demo ID)
  // 9 Demo variable data (changing guess)
  
const greenspan = "<span style='font-weight:bold; font-size:x-large; color:darkgreen'> "  ;

//--- end of AB.msgs structure: ---------------------------------------------------------


// AC: image Classifier 
let digitClassifier;
let featureExtractor;

function setup() 
{
  canvas = createCanvas ( canvaswidth, canvasheight );

  doodle = createGraphics ( ZOOMPIXELS, ZOOMPIXELS );       // doodle on larger canvas 
  doodle.pixelDensity(1);
  
// JS load other JS 
// maybe have a loading screen while loading the JS and the data set 

      AB.loadingScreen();
 
 $.getScript ( "/uploads/codingtrain/matrix.js", function()
 {
   $.getScript ( "/uploads/ac2021/nn.js", function()
   {
       
            $.getScript ( "/uploads/codingtrain/mnist.js", function()
            {
                console.log ("All JS loaded");
                setupNetwork();

                
                //pnn = new NeuralNetwork(nn);
                //nnn = new NeuralNetwork(nn);
                
                loadData();
            });
   });
 });
 
 $.getScript ( "https://unpkg.com/ml5@latest/dist/ml5.min.js", function(){
                     
                let options = { inputs: [PIXELS, PIXELS, 4],
                                task: "imageClassification",
                                debug:true
                };
                
                digitClassifier = ml5.neuralNetwork(options);
                
                featureExtractor = ml5.featureExtractor("MobileNet", { numLabels: 10 }, modelLoaded);
 });
 
}

// ==================================================================================================
// === AC: additional coding starts here  ===========================================================
// ==================================================================================================

function setupNetwork()
{
    
    nn = new NeuralNetwork(  noinput, nohidden, nooutput );
    nn.setLearningRate ( learningrate );
    trainrun = 1;
    train_index = 0;
    resetTestData();
    if (nnNO>0)
    {
        train_log.push(["train", nnNO, trainrun, train_index, total_tests, total_correct, (total_correct/total_tests)])
    }
    
    nnNO++;
    
}

function resetTestData()
{
    console.log
    testrun = 1;
    test_index = 0;
    total_tests = 0;
    total_correct = 0;
    
        
    AB.msg ( "", 6 );
    
    
}

function centerInputs(inputs)
{
    let m = checkMargins(inputs);
    let left = m[0];
    let right = m[1];
    let top = m[2];
    let bottom = m[3];
    let hShift =0;
    
    let vShift = 0;
    
    
    hShift = left - right;
    vShift = top - bottom;
    
    console.log(m);
    console.log(hShift + "," + vShift);
    
    
    return shiftInputs(inputs,  Math.floor(hShift/2), Math.floor(vShift/2));
}

function shiftInputs(inputs, r=0, c=0)  //can be positive or negative
{

    let temp = [];
    
    if (r == 0 && c == 0)
        return inputs;
    
    if ((r >= 0) && (c >= 0))  // left and top can be performed together
    {
        console.log("both forward " + r + "," + c)
        return performShift(inputs, r, c);
    }
    else if (( r <= 0) && (c <= 0))
    {
        console.log(" both backward " + r + "," + c)
        return performShift(inputs, r, c, true);
    }
    else
    {
        if (r > 0)
        {
            console.log("forward -> backward " + r + "," + c)
            temp = performShift(inputs, r, 0);
            return performShift(temp, 0, c, true);
        }
        else
        {
            console.log("backward -> forward "+ r + "," + c )
            temp = performShift(inputs, r, 0, true);
            return performShift(temp, 0, c);
        }

    }   

    
}

function performShift(inputs, r, c, bBackward = false)
{
    let temp = [];
    let counter = 0;
    
    if (!bBackward)
    {
        for (let i = 0; i < inputs.length ; i++)
            {
                if (( i + c * PIXELS + r) < inputs.length)
                    temp[i] = inputs[i + c*PIXELS + r];
                else
                    temp[i] = 0;
                
            counter++;
            }
    }
    else
    {
        for (let i = inputs.length - 1; i >= 0  ; i--)
        {
            if (( i + c*PIXELS + r) >= 0)
                temp[i] = inputs[i + c*PIXELS + r];
            else
                temp[i] = 0;
            counter++;
        }
    }
    
    console.log("counter " + counter)
    return temp;
}
function modelLoaded()
{
    console.log("Model Loaded");
    
}

let classifier;

async function train_features()
{
    classifier = featureExtractor.classification();
    
    await prepareTrainingData();
    /*
    console.log('training model')
    await classifier.train(function(lossValue) {
    console.log("Loss is", lossValue);
    
  });*/
    
    
}

function checkAllMargins(n = null)
{
    let m, r = [0,0,0,0], counter = 0;
    for (let i = 0; i < mnist.train_images.length; i++)
    {
        if ((n == null) ||  (mnist.train_labels[i] == (n % 10)))  
        {
            m = checkMargins(getInputs(mnist.train_images[i]));
            r[0] += m[0];
            r[1] += m[1];
            r[2] += m[2];
            r[3] += m[3];
            counter++;
        }
    }
    
    r[5] = counter;
    console.log(r)
}

let img_base;
function prepareTrainingData()
{
    const promises = [];
    let s = 0
    for(var i = s; i < s + 2; i++)
    {
        //console.log(window.btoa(mnist.train_images[i]));
        //createImg("data:image/ping;base64," + window.btoa(mnist.train_images[i]));
        //console.log(window.btoa(getImage(mnist.train_images[i],true)));
         //promises.push(classifier.addImage(createImg(getImage(mnist.train_images[i]).canvas.toDataURL()), mnist.train_labels[i].toString()));  
    } 

  return Promise.all(promises);
}

function trainCNN()
{
    CNNTrained = false;
    //for (var i = 0; i < mnist.train_images.length; i++)
    for(var i = 0; i < 10000; i++)
    {
        digitClassifier.addData(getImage(mnist.train_images[i],true), {label: mnist.train_labels[i].toString()});
 
    }
    
    console.log("CNN-finished adding data " + i);
    digitClassifier.normalizeData();
    console.log("CNN-finished normalizing data");
    digitClassifier.train({epochs:50, batchSize: 32}, finishedTraining);
    console.log("CNN-finished training ");
}

function finishedTraining()
{
    CNNTrained = true;
}

function augmentData()
{
    let batchEnds = (train_index + AUGMENTPERBATCH) >  (NOTRAIN)? (train_index + AUGMENTPERBATCH):(NOTRAIN);
    for (let i = train_index; i < batchEnds ; i++)
    {
        
        mnist.train_images[i] = augment(mnist.train_images[i])[0];
    }
    
    if(batchEnds >= (NOTRAIN))
    {
        bAugmented = true;
    }
    return batchEnds;
}

function checkMargins(inputs)
{
    let bFoundDigit = false;
    let leftEmpty = 0;
    let rightEmpty = 0;
    let topEmpty = 0;
    let bottomEmpty = 0;
    let blankCounter = 0;
    let minTopFilled = 0;
    let maxBottomFilled = 0;
    
    //console.log("checkMargins : " + inputs.length)
    
    //check columns with no values

    for (let c = 0; c < PIXELS; c++)
    {
    
        for (let r = 0; r <PIXELS; r++)
        {
            if (!bFoundDigit)
            {
                // console.log( "check @ :"  + (r*PIXELS + c) + " and " + inputs[r*PIXELS + c])
                
                if( inputs[r*PIXELS + c] > 0)
                {
                    leftEmpty = c;
                    minTopFilled = r;
                    maxBottomFilled = r;
                    bFoundDigit = true;
                }
                
            }
            else
            {   
                if(r==0)
                    blankCounter = 0;  //reset counter on first row
                
                if( inputs[r*PIXELS + c] > 0 && r < minTopFilled)
                    minTopFilled = r
                else if(inputs[r*PIXELS + c] > 0 && r > maxBottomFilled)
                    maxBottomFilled = r;
                
                if (inputs[r*PIXELS + c] <= 0 )
                    blankCounter++;
                //console.log("B" + blankCounter )
                if((r==PIXELS-1) && (blankCounter == PIXELS))  // check whether the entire column is blank.
                {
                    if (rightEmpty < (PIXELS-c-1))
                        rightEmpty = PIXELS-c-1;
                }
                
            }
        }
    }
 
    topEmpty = minTopFilled;
    bottomEmpty = PIXELS - (maxBottomFilled +1);
    
    return [leftEmpty, rightEmpty, topEmpty, bottomEmpty];   
}

function getRotatedPixelIndex( x, bClockwise = true)
{
    // x is the first index position of the requested pixel and img is the photo.
    let r, b, g, i, j, rx, ry;
    let rindex = 0;
    i = x % PIXELS
    j = Math.floor(x / PIXELS)
    
    if (bClockwise)
    {
        ry = (PIXELS-1-i);
        rx = j;
        rindex = ry * PIXELS + rx 
        //console.log (i +"," + j+"," + rx +"," + ry);
        
        /* 0 -> 2    0, 0  -> 1, 0 
        1 -> 0    0, 1  -> 0, 0
        2 -> 3    1, 0  -> 1, 1
        3 -> 1    1, 1  -> 0, 1
        
        0 -> 6    0, 0  -> 2, 0   ry = (pixel-1 - x)   rx = y
        1 -> 3    0, 1  -> 1, 0
        2 -> 0    0, 2  -> 0, 0
        3 -> 7    1, 0  -> 2, 1
        4 -> 4    1, 1  -> 1, 1
        5 -> 1    1, 2  -> 0, 1
        6 -> 8    2, 0  -> 2, 2   2, 0 ->   2, 2
        7 -> 5    2, 1  -> 1, 2
        8 -> 2    2, 2  -> 0, 2 */
        
    }
    else
    {
        ry = i;
        rx = (PIXELS-1-j);
        rindex = Math.floor(ry * PIXELS + rx ) * 4
        //console.log (i +"," + j+"," + rx +"," + ry);
        
        /*0 -> 1    0, 0  -> 0, 1 
        1 -> 3    0, 1  -> 1, 1
        2 -> 0    1, 0  -> 0, 0
        3 -> 2    1, 1  -> 1, 0
        
        0 -> 2    0, 0  -> 0, 2   
        1 -> 5    0, 1  -> 1, 2
        2 -> 8    0, 2  -> 2, 2
        3 -> 1    1, 0  -> 0, 1
        4 -> 4    1, 1  -> 1, 1
        5 -> 7    1, 2  -> 2, 1
        6 -> 0    2, 0  -> 0, 0   
        7 -> 3    2, 1  -> 1, 0
        8 -> 6    2, 2  -> 2, 0 */
    }
    
    return rindex;
}


// load data set from local file (on this server)

function loadData()    
{
  loadMNIST ( function(data)    
  {
    mnist = data;
    console.log ("All data loaded into mnist object:")
    console.log(mnist);
    AB.removeLoading();     // if no loading screen exists, this does nothing 
  });
  
}



function getImage ( img, bGetPixelArray = false )      // make a P5 image object from a raw data array   
{
    let theimage  = createImage (PIXELS, PIXELS);    // make blank image, then populate it 
    theimage.loadPixels();        
    
    for (let i = 0; i < PIXELSSQUARED ; i++) 
    {
        let bright = img[i];
        let index = i * 4;
        theimage.pixels[index + 0] = bright;
        theimage.pixels[index + 1] = bright;
        theimage.pixels[index + 2] = bright;
        theimage.pixels[index + 3] = 255;
    }
    
    theimage.updatePixels();
    
    if (bGetPixelArray)
        return theimage.pixels;
    else
        return theimage;
}


function getInputs ( img )      // convert img array into normalised input array 
{
    let inputs = [];
    for (let i = 0; i < PIXELSSQUARED ; i++)          
    {
        let bright = img[i];
        inputs[i] = bright / 255;       // normalise to 0 to 1
    } 
    return ( inputs );
}

 

function trainit (show)        // train the network with a single exemplar, from global var "train_index", show visual on or off 
{

  let img   = mnist.train_images[train_index];
  let label = mnist.train_labels[train_index];
  
  if(bDigitOrdering && ((train_index % 3000) == 0 && digitOrder.length > 0))
     currentDigits.push(digitOrder.shift())
    
  // optional - show visual of the image 
  if (show)                
  {
    var theimage = getImage ( img );    // get image from data array 
    noFill();
    stroke(255,255,255);
    rect(0, ZOOMPIXELS+50-1, ZOOMPIXELS+1, ZOOMPIXELS+1);
    image ( theimage,   1,                ZOOMPIXELS+50,    ZOOMPIXELS,     ZOOMPIXELS  );      // magnified 
    image ( theimage,   ZOOMPIXELS+50,    ZOOMPIXELS+50,    PIXELS,         PIXELS      );      // original
  }

  // set up the inputs
  let inputs = getInputs ( img );       // get inputs from data array 

  // set up the outputs
  let targets = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
  targets[label] = 1;       // change one output location to 1, the rest stay at 0 

  // console.log(train_index);
  // console.log(inputs);
  // console.log(targets);

 // inputs = augment(inputs);

  train_inputs = inputs;        // can inspect in console 
  if (!bDigitOrdering || (currentDigits.includes(label)))
    nn.train ( inputs, targets );
  
 
 if (bDigitOrdering)
 {
     thehtml[4] = "Current Training numbers :" + currentDigits.toString() + " <br>" + 
    " trainrun: " + trainrun + "<br> no: " + train_index ;
 }
 else
    thehtml[4] = " trainrun: " + trainrun + "<br> no: " + train_index ;
      
  AB.msg ( thehtml[4], 4 );
 
  
  
  train_index++;
  if ( train_index == NOTRAIN ) 
  {
    train_index = 0;
    console.log( "finished trainrun: " + trainrun );
    trainrun++;
  }
}

function augment(inputs, bShift = false, bRoundDown = false, bRotate = false, bCrop = false, shiftMargins = null){
    
    if (!(bShift || bRoundDown || bRotate || bCrop))
    {
        bShift = Math.floor(Math.random() * 10) <= 3;
        bRoundDown = Math.floor(Math.random() * 10) <= 3;
        bRotate = Math.floor(Math.random() * 10) <= 3;
        bCrop = Math.floor(Math.random() * 10) <= 3;
    }
    //console.log(nRandom + "," + bShift + "," + bResize  +  "," + bRotate)
    
    let temp = inputs;
    
    if (bShift && bIShift)
    {
        let margins = checkMargins(inputs);
        let nColumns = Math.floor(Math.random()*margins[0]);
        let nRows = Math.floor(Math.random()*margins[2]);
        
        temp = performShift(temp, nRows, nColumns)

        /* if (margins[0] > 0 && margins[0] < margins[1])
        { 
            console.log("left");
            let nColumns = Math.floor(Math.random()*margins[0]);
            let nRows = Math.floor(Math.random()*3);
        
            for (let i = 0; i < PIXELSSQUARED ; i++)
            {
                if (( i +nColumns*PIXELS + nRows) < PIXELSSQUARED)
                    temp[i] = inputs[i + nColumns*PIXELS + nRows];
                else
                    temp[i] = 0;
            }
            //console.log("shifted")
        }
        else if (margins[1] > 0 && margins[0] >= margins[1])
        {
            console.log("right");
            let nColumns = Math.floor(Math.random()*margins[1]);
            let nRows = Math.floor(Math.random()*3);
        
            for (let i = PIXELSSQUARED -1; i >= 0  ; i--)
            {
                if (( i - nColumns*PIXELS - nRows) >= 0)
                    temp[i] = inputs[i - nColumns*PIXELS - nRows];
                else
                    temp[i] = 0;
            }
        }*/
    }
    
    if (bRoundDown && bIRoundDown)
    {
        //console.log("rounddown");
        for (let i = 0; i < PIXELSSQUARED ; i++)
        {
                temp[i] = Math.floor(inputs[i] * 4) / 4;
        }   
    }
    
    if (bRotate && bIRotate)
    {
        //let bClockwise = nRandom % 2 >= 1;
        let tempR = [];
        for (let j = 0; j < PIXELSSQUARED ; j++)
        {
            tempR[j] = temp[getRotatedPixelIndex(j)];    
        }
        //console.log("rotated")
        temp = tempR;
    }
    
    //console.log ("crop is " + bICrop);
    if (bCrop && bICrop)
    {
    
        /* let nCropRows = Math.floor(Math.random() * 4) + 1;
        let nCropStart = (PIXELSSQUARED/2 - nCropRows*PIXELS);
        let nCropEnd = (PIXELSSQUARED/2 + nCropRows*PIXELS);
        for (let k = nCropStart ; k < nCropEnd; k++)
        {
            temp[k] = 0;
        }*/
        
        let nRowStart = Math.floor(Math.random() * 2) * PIXELS/2;
        let nColStart = Math.floor(Math.random() * 2) * PIXELS/2;
        console.log( "Crop " + nRowStart + "," + nColStart)
        for (let r = nRowStart; r < nRowStart + (PIXELS/2); r++ )
        {
            for (let c = nColStart; c < nColStart + (PIXELS/2); c++)
                temp[c * PIXELS + r] = 0;
        }
    }
    
    return [temp, bShift, bRoundDown, bRotate, bCrop];
}



function testit()    // test the network with a single exemplar, from global var "test_index"
{ 
  let img   = mnist.test_images[test_index];
  let label = mnist.test_labels[test_index];

  if (bDigitOrdering && (!currentDigits.includes(label)))
    return;
  // set up the inputs
  let inputs = getInputs ( img ); 
  
  test_inputs = inputs;        // can inspect in console 
  let prediction    = nn.predict(inputs);       // array of outputs 
  let guess         = findMax(prediction);      // the top output 

  total_tests++;
  if (guess == label)  total_correct++;

  let percent = (total_correct / total_tests) * 100 ;
  
  thehtml[6] =  " testrun: " + testrun + "<br> no: " + total_tests + " <br> " +
        " correct: " + total_correct + "<br>" +
        "  score: " + greenspan + percent.toFixed(2) + "</span>";
  AB.msg ( thehtml[6], 6 );

  test_index++;
  if ( test_index == NOTEST ) 
  {
    console.log( "finished testrun: " + testrun + " score: " + percent.toFixed(2) );
    testrun++;
    test_index = 0;
    total_tests = 0;
    total_correct = 0;
  }
}




//--- find no.1 (and maybe no.2) output nodes ---------------------------------------
// (restriction) assumes array values start at 0 (which is true for output nodes) 


function find12 (a)         // return array showing indexes of no.1 and no.2 values in array 
{
  let no1 = 0;
  let no2 = 0;
  let no1value = 0;     
  let no2value = 0;
  
  for (let i = 0; i < a.length; i++) 
  {
    if (a[i] > no1value) 
    {
      no1 = i;
      no1value = a[i];
    }
    else if (a[i] > no2value) 
    {
      no2 = i;
      no2value = a[i];
    }
  }
  
  var b = [ no1, no2 ];
  return b;
}


// just get the maximum - separate function for speed - done many times 
// find our guess - the max of the output nodes array

function findMax (a)        
{
  let no1 = 0;
  let no1value = 0;     
  
  for (let i = 0; i < a.length; i++) 
  {
    if (a[i] > no1value) 
    {
      no1 = i;
      no1value = a[i];
    }
  }
  
  return no1;
}




// --- the draw function -------------------------------------------------------------
// every step:
 
function draw() 
{
  // check if libraries and data loaded yet:
  if ( typeof mnist == 'undefined' ) return;
  
  if (do_training && bAugment && train_index > last_augmented_index && !bAugmented ) {
      console.log( "augment");
      last_augmented_index = augmentData();
  }


// how can we get white doodle on black background on yellow canvas?
//        background('#ffffcc');    doodle.background('black');

    background ('black');

// AC: Added a white square bracket.
    noFill();
    stroke(255,255,255);
    rect(0, 0, ZOOMPIXELS+1, ZOOMPIXELS+1);
    
if ( do_training )    
{
  // do some training per step 
    for (let i = 0; i < TRAINPERSTEP; i++) 
    {
      if (i == 0)    trainit(true);    // show only one per step - still flashes by  
      else           trainit(false);
    }
    
  // do some testing per step 
    if(train_index > 10000 && bTesting) 
        for (let i = 0; i < TESTPERSTEP; i++) 
            testit();
}

  // keep drawing demo and doodle images 
  // and keep guessing - we will update our guess as time goes on 
  
  if ( demo_exists )
  {
    drawDemo();
    guessDemo();
  }
  if ( doodle_exists ) 
  {
    drawDoodle();
    guessDoodle();
    
  }


// detect doodle drawing 
// (restriction) the following assumes doodle starts at 0,0 

  if ( mouseIsPressed )         // gets called when we click buttons, as well as if in doodle corner  
  {
     // console.log ( mouseX + " " + mouseY + " " + pmouseX + " " + pmouseY );
     var MAX = ZOOMPIXELS + 20;     // can draw up to this pixels in corner 
     if ( (mouseX < MAX) && (mouseY < MAX) && (pmouseX < MAX) && (pmouseY < MAX) )
     {
        mousedrag = true;       // start a mouse drag 
        doodle_exists = true;
        doodle.stroke('white');
        doodle.strokeWeight( DOODLE_THICK );
        doodle.line(mouseX, mouseY, pmouseX, pmouseY);      
     }
  }
  else 
  {
      // are we exiting a drawing
      if ( mousedrag )
      {
            mousedrag = false;
            bDoodleUpdated = true;
            // console.log ("Exiting draw. Now blurring.");
            if(bBlur)
                doodle.filter (BLUR, DOODLE_BLUR);    // just blur once
            
            if (bCenter)
            {
                // prepare an augment version of input.
                let img = doodle.get();
                img.resize ( PIXELS, PIXELS );     
                img.loadPixels();
            
                // set up inputs   
                let inputs = [];
                for (let i = 0; i < PIXELSSQUARED ; i++) 
                {
                    inputs[i] = img.pixels[i * 4] / 255;
                }
              
                aug_doodle_inputs = centerInputs(inputs); 
            
            
                //console.log(find12(nn.predict(aug_doodle_inputs)));
                //console.log(checkMargins(inputs));
            }
            
            
      }
  }
  
  /*if(keyIsPressed & CNNTrained)
  {
      let d = createGraphics(PIXELS, PIXELS);
      d.copy(canvas, 0, 0, ZOOMPIXELS, ZOOMPIXELS, 0, 0, PIXELS, PIXELS)
      console.log("w, h" + d.width + "," + d.height)
      image ( d,   ZOOMPIXELS+50,    PIXELS,    PIXELS,         PIXELS      );      // shrunk
      digitClassifier.classify({image:d}, classifyFinished);
  }*/
}




//--- demo -------------------------------------------------------------
// demo some test image and predict it
// get it from test set so have not used it in training


function makeDemo()
{
    demo_exists = true;
    var  i = AB.randomIntAtoB ( 0, NOTEST - 1 );  
    
    demo        = mnist.test_images[i];     
    var label   = mnist.test_labels[i];
    
   thehtml[8] =  "Test image no: " + i + "<br>" + 
            "Classification: " + label + "<br>" ;
   AB.msg ( thehtml[8], 8 );
   
   // type "demo" in console to see raw data 
}


function drawDemo()
{
    var theimage = getImage ( demo );
     //  console.log (theimage);
    noFill();
    stroke(255,255,255);
    rect(0, canvasheight - ZOOMPIXELS-1, ZOOMPIXELS+1, ZOOMPIXELS+1);
    image ( theimage,   1,                canvasheight - ZOOMPIXELS,    ZOOMPIXELS,     ZOOMPIXELS  );      // magnified 
    image ( theimage,   ZOOMPIXELS+50,    canvasheight - ZOOMPIXELS,    PIXELS,         PIXELS      );      // original
}


function guessDemo()
{
   let inputs = getInputs ( demo ); 
   
  demo_inputs = inputs;  // can inspect in console 
  
  let prediction    = nn.predict(inputs);       // array of outputs 
  let guess         = findMax(prediction);      // the top output 

   thehtml[9] =   " We classify it as: " + greenspan + guess + "</span>" ;
   AB.msg ( thehtml[9], 9 );
}




//--- doodle -------------------------------------------------------------

function drawDoodle()
{
    // doodle is createGraphics not createImage
    let theimage = doodle.get();
    // console.log (theimage);
    
    image ( theimage,   0,                0,    ZOOMPIXELS,     ZOOMPIXELS  );      // original 
    image ( theimage,   ZOOMPIXELS+50,    0,    PIXELS,         PIXELS      );      // shrunk
}
      


function guessDoodle() 
{
    let prediction, b;
    
    if (mousedrag || (aug_doodle_inputs == undefined) || !bCenter)
    {
    // doodle is createGraphics not createImage
      let img = doodle.get();
      
      img.resize ( PIXELS, PIXELS );     
      img.loadPixels();
    
      // set up inputs   
      let inputs = [];
      for (let i = 0; i < PIXELSSQUARED ; i++) 
      {
         inputs[i] = img.pixels[i * 4] / 255;
      }

      doodle_inputs = inputs;     // can inspect in console 
         // feed forward to make prediction 
      prediction    = nn.predict(inputs);       // array of outputs 
      b     = find12(prediction);       // get no.1 and no.2 guesses 
      
      thehtml[2] =   " We classify it as: " + greenspan + b[0] + "</span> <br>" +
            " No.2 guess is: " + greenspan + b[1] + "</span> <br>" + cnn_doodle_result;
            
      AB.msg ( thehtml[2], 2 );
    
    } 
    else if (bDoodleUpdated)
    {
      bDoodleUpdated = false;
      // feed forward to make prediction 
      prediction    = nn.predict(aug_doodle_inputs);       // array of outputs 
      b     = find12(prediction);       // get no.1 and no.2 guesses  
      console.log("centered")
      
      thehtml[2] =   " We classify it as: " + greenspan + b[0] + "</span> <br>" +
            " No.2 guess is: " + greenspan + b[1] + "</span> <br>" + cnn_doodle_result;
            
      AB.msg ( thehtml[2], 2 );
    }
  
  
  if(CNNTrained)
  {
    let d = createGraphics(PIXELS, PIXELS); 
    d.copy(canvas, 0, 0, ZOOMPIXELS, ZOOMPIXELS, 0, 0, PIXELS, PIXELS)
    image ( d,   ZOOMPIXELS+50,    PIXELS,    PIXELS,         PIXELS      );      // shrunk
    digitClassifier.classify({image:d}, classifyFinished);
  }
  //let cnn_result = digitClassifier.classify(img, classifyFinished);


}


function wipeDoodle()    
{
    doodle_exists = false;
    doodle.background('black');
}

function classifyFinished(error, result)
{
    if(error)
    {
        console.log("classify error:" + error);
    }
    console.log("cnn predicted: " + result[0].label);
    cnn_doodle_result =   "(CNN) We classify it as: " + greenspan + result[0].label + "</span> <br>" +
            "(CNN) Confidence is: " + greenspan + result[0].confidence + "</span> <br>";
            

}



// --- debugging --------------------------------------------------
// in console
// showInputs(demo_inputs);
// showInputs(doodle_inputs);


function showInputs ( inputs )
// display inputs row by row, corresponding to square of pixels 
{
    var str = "";
    for (let i = 0; i < inputs.length; i++) 
    {
      if ( i % PIXELS == 0 )    str = str + "\n";                                   // new line for each row of pixels 
      var value = inputs[i];
      str = str + " " + value.toFixed(2) ; 
    }
    console.log (str);
}


function toggleBlurAction()
{
    bBlur = !bBlur;
    strActionText.blur = bBlur?"Blur:ON":"Blur:OFF";
    thehtml[1] = "<hr> <h1> 1. Doodle " +
                "<button onclick='toggleBlurAction();' class='normbutton' >" + strActionText.blur + " </button>" + 
                "<button onclick='toggleCenterAction();' class='normbutton' >" + strActionText.center + " </button>" + 
                "</h1>" +
        " Draw your doodle in top LHS. <button onclick='wipeDoodle();' class='normbutton' >Clear doodle</button> <br> ";
   AB.msg ( thehtml[1], 1 );

}

function toggleCenterAction()
{
    bCenter = !bCenter;
    strActionText.center = bCenter?"Center:ON":"Center:OFF";
    thehtml[1] = "<hr> <h1> 1. Doodle " +
                "<button onclick='toggleBlurAction();' class='normbutton' >" + strActionText.blur + " </button>" + 
                "<button onclick='toggleCenterAction();' class='normbutton' >" + strActionText.center + " </button>" + 
                "</h1>" +
            " Draw your doodle in top LHS. <button onclick='wipeDoodle();' class='normbutton' >Clear doodle</button> <br> ";
   AB.msg ( thehtml[1], 1 );
}
function toggleTrainingAction()
{
    do_training = !do_training;
    strActionText.training = do_training?"Stop":"Start";
thehtml[3] = "<hr> <h1> 2. Training "+ 
            "<button onclick='toggleTrainingAction();' class='normbutton' >" + strActionText.training + " </button>" + 
            "<button onclick='setupNetwork();' class='normbutton' >Refresh NN</button>"+ " </h1> " + 
            "Data Preprocessing: " +
            "<button onclick='toggleOrderAction();' class='normbutton' >" + strActionText.ordering + " </button>" + 
            "<button onclick='toggleAugmentAction();' class='normbutton' >" + strActionText.augmenting + " </button>" +
            "<br> " + strAugment+ "<br>"; 
   
   
  AB.msg ( thehtml[3], 3 );
}


function toggleOrderAction()
{
    bDigitOrdering = !bDigitOrdering
     strActionText.ordering = bDigitOrdering?"Order:ON":"Order:OFF";
thehtml[3] = "<hr> <h1> 2. Training "+ 
            "<button onclick='toggleTrainingAction();' class='normbutton' >" + strActionText.training + " </button>" + 
            "<button onclick='setupNetwork();' class='normbutton' >Refresh NN</button>"+ " </h1> " + 
            "Data Preprocessing: " +
            "<button onclick='toggleOrderAction();' class='normbutton' >" + strActionText.ordering + " </button>" + 
            "<button onclick='toggleAugmentAction();' class='normbutton' >" + strActionText.augmenting + " </button>" +
            "<br> " + strAugment+ "<br>"; 
   
AB.msg ( thehtml[3], 3 );
    
}

function toggleAugmentAction()
{
    bAugment = !bAugment
    strActionText.augmenting = bAugment?"Augment:ON":"Augment:OFF";
strAugment = bAugment?"<button onclick='toggleShiftAction();' class='normbutton' >" + strActionText.shift + " </button>" +
        "<button onclick='toggleRotateAction();' class='normbutton' >" + strActionText.rotate + " </button>" + 
        "<button onclick='toggleRoundDownAction();' class='normbutton' >" + strActionText.rounddown + " </button>" +
        "<button onclick='toggleCropAction();' class='normbutton' >" + strActionText.crop + " </button>":""
    thehtml[3] = "<hr> <h1> 2. Training "+ 
            "<button onclick='toggleTrainingAction();' class='normbutton' >" + strActionText.training + " </button>" + 
            "<button onclick='setupNetwork();' class='normbutton' >Refresh NN</button>"+ " </h1> " + 
            "Data Preprocessing: " +
            "<button onclick='toggleOrderAction();' class='normbutton' >" + strActionText.ordering + " </button>" + 
            "<button onclick='toggleAugmentAction();' class='normbutton' >" + strActionText.augmenting + " </button>" +
            "<br> " + strAugment+ "<br>"; 
   
AB.msg ( thehtml[3], 3 );
    
}

function toggleShiftAction()
{
    bIShift = !bIShift
    strActionText.shift = bIShift?"Shift:ON":"Shift:OFF";
    strAugment = bAugment?"<button onclick='toggleShiftAction();' class='normbutton' >" + strActionText.shift + " </button>" +
        "<button onclick='toggleRotateAction();' class='normbutton' >" + strActionText.rotate + " </button>" + 
        "<button onclick='toggleRoundDownAction();' class='normbutton' >" + strActionText.rounddown + " </button>" +
        "<button onclick='toggleCropAction();' class='normbutton' >" + strActionText.crop + " </button>":""
    thehtml[3] = "<hr> <h1> 2. Training "+ 
            "<button onclick='toggleTrainingAction();' class='normbutton' >" + strActionText.training + " </button>" + 
            "<button onclick='setupNetwork();' class='normbutton' >Refresh NN</button>"+ " </h1> " + 
            "Data Preprocessing: " +
            "<button onclick='toggleOrderAction();' class='normbutton' >" + strActionText.ordering + " </button>" + 
            "<button onclick='toggleAugmentAction();' class='normbutton' >" + strActionText.augmenting + " </button>" +
            "<br> " + strAugment + "<br>"; 
   
AB.msg ( thehtml[3], 3 );
    
}
function toggleRotateAction()
{
    bIRotate = !bIRotate

    strActionText.rotate = bIRotate?"Rotate:ON":"Rotate:OFF";
    strAugment = bAugment?"<button onclick='toggleShiftAction();' class='normbutton' >" + strActionText.shift + " </button>" +
        "<button onclick='toggleRotateAction();' class='normbutton' >" + strActionText.rotate + " </button>" + 
        "<button onclick='toggleRoundDownAction();' class='normbutton' >" + strActionText.rounddown + " </button>" +
        "<button onclick='toggleCropAction();' class='normbutton' >" + strActionText.crop + " </button>":""
    thehtml[3] = "<hr> <h1> 2. Training "+ 
            "<button onclick='toggleTrainingAction();' class='normbutton' >" + strActionText.training + " </button>" + 
            "<button onclick='setupNetwork();' class='normbutton' >Refresh NN</button>"+ " </h1> " + 
            "Data Preprocessing: " +
            "<button onclick='toggleOrderAction();' class='normbutton' >" + strActionText.ordering + " </button>" + 
            "<button onclick='toggleAugmentAction();' class='normbutton' >" + strActionText.augmenting + " </button>" +
            "<br> " + strAugment + "<br>"; 
   
AB.msg ( thehtml[3], 3 );
    
}
function toggleRoundDownAction()
{
    bIRoundDown = !bIRoundDown
    strActionText.rounddown = bIRoundDown?"RoundDown:ON":"RoundDown:OFF";
    strAugment = bAugment?"<button onclick='toggleShiftAction();' class='normbutton' >" + strActionText.shift + " </button>" +
        "<button onclick='toggleRotateAction();' class='normbutton' >" + strActionText.rotate + " </button>" + 
        "<button onclick='toggleRoundDownAction();' class='normbutton' >" + strActionText.rounddown + " </button>" +
        "<button onclick='toggleCropAction();' class='normbutton' >" + strActionText.crop + " </button>":""
thehtml[3] = "<hr> <h1> 2. Training "+ 
            "<button onclick='toggleTrainingAction();' class='normbutton' >" + strActionText.training + " </button>" + 
            "<button onclick='setupNetwork();' class='normbutton' >Refresh NN</button>"+ " </h1> " + 
            "Data Preprocessing: " +
            "<button onclick='toggleOrderAction();' class='normbutton' >" + strActionText.ordering + " </button>" + 
            "<button onclick='toggleAugmentAction();' class='normbutton' >" + strActionText.augmenting + " </button>" +
            "<br> " + strAugment + "<br>"; 
   
AB.msg ( thehtml[3], 3 );
    
}
function toggleCropAction()
{
    bICrop = !bICrop;
    console.log("@toggle " + bICrop)
    strActionText.crop = bICrop?"Crop:ON":"Crop:OFF";
    strAugment = bAugment?"<button onclick='toggleShiftAction();' class='normbutton' >" + strActionText.shift + " </button>" +
        "<button onclick='toggleRotateAction();' class='normbutton' >" + strActionText.rotate + " </button>" + 
        "<button onclick='toggleRoundDownAction();' class='normbutton' >" + strActionText.rounddown + " </button>" +
        "<button onclick='toggleCropAction();' class='normbutton' >" + strActionText.crop + " </button>":""
thehtml[3] = "<hr> <h1> 2. Training "+ 
            "<button onclick='toggleTrainingAction();' class='normbutton' >" + strActionText.training + " </button>" + 
            "<button onclick='setupNetwork();' class='normbutton' >Refresh NN</button>"+ " </h1> " + 
            "Data Preprocessing: " +
            "<button onclick='toggleOrderAction();' class='normbutton' >" + strActionText.ordering + " </button>" + 
            "<button onclick='toggleAugmentAction();' class='normbutton' >" + strActionText.augmenting + " </button>" +
            "<br> " + strAugment + "<br>"; 
   
AB.msg ( thehtml[3], 3 );
    
}