Code viewer for World: Doodle Play

// Cloned by Itisha on 18 Nov 2021 from World "Character recognition neural network" by "Coding Train" project 
// Please leave this clone trail here.
 

// Port of Character recognition neural network from here:
// https://github.com/CodingTrain/Toy-Neural-Network-JS/tree/master/examples/mnist
// with many modifications 


// --- defined by MNIST - do not change these ---------------------------------------

const PIXELS        = 28;                       // images in data set are tiny 
const PIXELSSQUARED = PIXELS * PIXELS;

// number of training and test exemplars in the data set:
const NOTRAIN = 60000;
const NOTEST  = 10000;



//--- can modify all these --------------------------------------------------

// no of nodes in network 
const noinput  = PIXELSSQUARED;
const nohidden = 200 ; //5 ;//64;
const nooutput = 3;

const learningrate = 0.2;   // default 0.1  

// should we train every timestep or not 
let do_training = false ; //true;
let do_testing = false ; //true;

// how many to train and test per timestep 
const TRAINPERSTEP = 1 ; //30;
const TESTPERSTEP  = 3000; //5;

// multiply it by this to magnify for display 
const ZOOMFACTOR    = 10;                        
const ZOOMPIXELS    = ZOOMFACTOR * PIXELS + 50; 

// 3 rows of
// large image + 50 gap + small image    
// 50 gap between rows 

const canvaswidth = ( PIXELS + ZOOMPIXELS) + 100;
const canvasheight = ( ZOOMPIXELS) + 100;


const DOODLE_THICK = 5;    // thickness of doodle lines 
const DOODLE_BLUR = 3;      // blur factor applied to doodles 

//New variables added   -- Itisha

const len = 784;
const totalData = 1000;

const CAT = 0;
const RAINBOW = 1;
const TRAIN = 2;

let catsData;
let trainsData;
let rainbowsData;

let cats = {};
let trains = {};
let rainbows = {};
let save_flag = false;

// end of new variables added   -- Itisha

let nn;

let trainrun = 1;
let train_index = 0;

let testrun = 1;
let test_index = 0;
let total_tests = 0;
let total_correct = 0;

// images in LHS:
let doodle, demo;
let doodle_exists = false;
let demo_exists = false;

let mousedrag = false;      // are we in the middle of a mouse drag drawing?  


// save inputs to global var to inspect
// type these names in console 
var train_inputs, test_inputs, demo_inputs, doodle_inputs;


// Matrix.randomize() is changed to point to this. Must be defined by user of Matrix. 

function randomWeight()
{
    //return 0 ;
    return ( AB.randomFloatAtoB ( -0.5, 0.5 ) );
            // Coding Train default is -1 to 1
}    


// make run header bigger
AB.headerCSS ( { "max-height": "95vh" } );




//--- start of AB.msgs structure: ---------------------------------------------------------
// We output a serious of AB.msgs to put data at various places in the run header 
var thehtml;

  // 1 Doodle header 
  thehtml = "<hr> <h1> Doodle </h1>" +
        " Draw doodle of a Train, Rainbow or a Cat in the top LHS. <br><br> <button onclick='wipeDoodle();' class='normbutton' > Clear</button>" +"  "+
        "<button onclick='guessDoodle();' class='normbutton' > Guess</button>" +"  "+
        "<button onclick='do_training = true;' class='normbutton' >Train</button>" + "  "+
        "<button onclick='do_testing = true;' class='normbutton' >Test</button>" + "  "+
        "<button onclick='do_training = false; do_testing = false;' class='normbutton' >Stop</button>" + "<br><br>"+
        "<button onclick='save_data();' class='normbutton' >Save</button>" + "  "+
        "<button onclick='restore_data();' class='normbutton' >Restore</button>" ;
   AB.msg ( thehtml, 1 );

const greenspan = "<span style='font-weight:bold; font-size:x-large; color:darkgreen'> "  ;

//--- end of AB.msgs structure: ---------------------------------------------------------


// load data set from local file (on this server)

function preload()    
{
      catsData = loadBytes('/uploads/itisha312/cats1000.bin');
      trainsData = loadBytes('/uploads/itisha312/trains1000.bin');
      rainbowsData = loadBytes('/uploads/itisha312/rainbows1000.bin');
}


function prepareData(category, data, label) {
  category.training = [];
  category.testing = [];
  for (let i = 0; i < totalData; i++) {
    let offset = i * len;
    let threshold = floor(0.8 * totalData);
    if (i < threshold) {
      category.training[i] = data.bytes.subarray(offset, offset + len);
      category.training[i].label = label;
    } else {
      category.testing[i - threshold] = data.bytes.subarray(offset, offset + len);
      category.testing[i - threshold].label = label;
    }
  }
}

function save_data(){
    AB.saveData ( nn );
}

function restore_data(){
    if ( AB.runloggedin ){
                    		    // Check if any data exists, if so make restore button
                        			AB.queryDataExists ( function ( exists )		// asynchronous - need callback function 
                        			{
                        			    if ( exists ){
                        			        AB.restoreData( function (nn){
                        			            console.log('Restoring data from server');
                        			            console.log(nn);
                        			           // nn.setLearningRate ( learningrate );
                                                console.log('calling saved nn obj');
                                                nn = new NeuralNetwork(nn);
                                                nn.setLearningRate(learningrate);
                                                redraw();
                        			        });
                        			    }
                        			});
                    		    }
    }

function setup() 
{
  createCanvas ( canvaswidth, canvasheight );
  background(0);
  doodle = createGraphics ( ZOOMPIXELS, ZOOMPIXELS );       // doodle on larger canvas 
  doodle.pixelDensity(1);
  
  // JS load other JS 
   $.getScript ( "/uploads/itisha312/matrix.js", function(){
             $.getScript ( "/uploads/itisha312/neural_nw.js", function(){
                    console.log ("All JS loaded");   
                    AB.removeLoading();
                      // Making the neural network
                         console.log('creating a new nn obj');
                         nn = new NeuralNetwork(  noinput, nohidden, nooutput );
                         nn.setLearningRate ( learningrate );
                        
                        });
                    });
         
// maybe have a loading screen while loading the JS and the data set 
    AB.loadingScreen();
}


function trainEpoch(training,nn) {
  shuffle(training, true);
  console.log('Begin Training');
  //console.log(training);
  // Train for one epoch
  let train_nbr = 0;
  //console.log('training.length: '+training.length);
  for (let i = 0; i < training.length; i++) {
    let data = training[i];
    let inputs = Array.from(data).map(x => x / 255);
    let label = training[i].label;
    let targets = [0, 0, 0];
    targets[label] = 1;
    train_nbr = i+1;
   // console.log(data);
    thehtml = " Trained dataset: " + train_nbr ;
    AB.msg ( thehtml, 4 );
    nn.train(inputs, targets);
  }
}

function testAllDoodles(testing,nn) {
    console.log('TestAllDoodles');
  let correct = 0;
  // Train for one epoch
  //console.log('testing.length: '+testing.length);
  for (let i = 0; i < testing.length; i++) {
    let data = testing[i];
    let inputs = Array.from(data).map(x => x / 255);
    let label = testing[i].label;
    let guess = nn.predict(inputs);

    let m = max(guess);
    let classification = guess.indexOf(m);
   
    if (classification === label) {
      correct++;
    }
  }
  let percent = 100 * correct / testing.length;
  return percent;
}

function find_doodle_labels (guess)         
{
  var doodle_guess = [];
  let m = max(guess);
    let classification = guess.indexOf(m);
    if (classification === CAT) {
      doodle_guess[0] = 'CAT';
      //console.log("cat");
    } else if (classification === RAINBOW) {
      doodle_guess[0] = 'RAINBOW';
      //console.log("rainbow");
    } else if (classification === TRAIN) {
        doodle_guess[0] = 'TRAIN';
     // console.log("train");
    }
  
  return doodle_guess;
}


// --- the draw function -------------------------------------------------------------
// every step:
 
function draw() 
{

  background ('black');
   
// Preparing the data
  prepareData(cats, catsData, CAT);
  prepareData(rainbows, rainbowsData, RAINBOW);
  prepareData(trains, trainsData, TRAIN);
  
    // Randomizing the data
  let training = [];
  training = training.concat(cats.training);
  training = training.concat(rainbows.training);
  training = training.concat(trains.training);

  let testing = [];
  testing = testing.concat(cats.testing);
  testing = testing.concat(rainbows.testing);
  testing = testing.concat(trains.testing);
  
  let epochCounter = 0;
  if ( do_training ){
        trainEpoch(training,nn);
        epochCounter++;
  } 
  else{
      thehtml = "";
      AB.msg ( thehtml, 4 );
  }
  
   if(do_testing){
       let percent = testAllDoodles(testing,nn);
       thehtml =   " <br> Percentage : " + greenspan + nf(percent, 2, 2) + "%" + "</span>" ;
        AB.msg ( thehtml, 3 );
   }
   else{
       thehtml =   " " ;
        AB.msg ( thehtml, 3 );
       // console.log('Testing stopped');
   }
  
  if ( doodle_exists ) 
  {
    drawDoodle();
  }


// detect doodle drawing 
// (restriction) the following assumes doodle starts at 0,0 

  if ( mouseIsPressed )         // gets called when we click buttons, as well as if in doodle corner  
  {
      //console.log ( mouseX + " " + mouseY + " " + pmouseX + " " + pmouseY );
     var MAX = ZOOMPIXELS + 50;     // can draw up to this pixels in corner 
     if ( (mouseX < MAX) && (mouseY < MAX) && (pmouseX < MAX) && (pmouseY < MAX) )
     {
        mousedrag = true;       // start a mouse drag 
        doodle_exists = true;
        doodle.stroke('white');
        doodle.strokeWeight( DOODLE_THICK );
        doodle.line(mouseX, mouseY, pmouseX, pmouseY);      
     }
  }
}


//--- doodle -------------------------------------------------------------

function drawDoodle()
{
    // doodle is createGraphics not createImage
    let theimage = doodle.get();
    // console.log (theimage);
    
    image ( theimage,   0,                0,    ZOOMPIXELS,     ZOOMPIXELS  );      // original 
    image ( theimage,   ZOOMPIXELS+70,    0,    PIXELS,         PIXELS      );      // shrunk
}
      
      
function guessDoodle() 
{
   // doodle is createGraphics not createImage
   let img = doodle.get();
  //console.log('img inside guessDoodle: ');
  //console.log(img);
  img.resize ( PIXELS, PIXELS );     
  img.loadPixels();

  // set up inputs   
  let inputs = [];
  for (let i = 0; i < PIXELSSQUARED ; i++) 
  {
     inputs[i] = img.pixels[i * 4] / 255;
  }
  
  doodle_inputs = inputs;     // can inspect in console 

  // feed forward to make prediction 
  if(doodle_exists){
      let prediction    = nn.predict(inputs);       // array of outputs 
      let doodle_guess  = find_doodle_labels(prediction);       // get no.1 and no.2 guesses  
    
      thehtml =   " <br> It looks like you drew a : " + greenspan + doodle_guess[0] + "</span>" ;
      AB.msg ( thehtml, 2 );
  }
  else{
      thehtml =   " <br> Please draw something !!" ;
      AB.msg ( thehtml, 2 );
  }
}


function wipeDoodle()    
{
    doodle_exists = false;
    doodle.background('black');
    thehtml =   "" ;
    AB.msg ( thehtml, 2 );
}




// --- debugging --------------------------------------------------
// in console
// showInputs(demo_inputs);
// showInputs(doodle_inputs);

/*
function showInputs ( inputs )
// display inputs row by row, corresponding to square of pixels 
{
    var str = "";
    for (let i = 0; i < inputs.length; i++) 
    {
      if ( i % PIXELS == 0 )    str = str + "\n";                                   // new line for each row of pixels 
      var value = inputs[i];
      str = str + " " + value.toFixed(2) ; 
    }
    console.log (str);
}
*/