Code viewer for World: A-Z Character Recognition
 
var layer_defs, net, trainer;
let trainrun = 1;
let train_index = 0;
let total_trains = 0;
let autoTestsAt = -1; //5000; //After this number of trains have been run, output some stats and auto run the doodle tests.

var xLossWindow = new ErrorStore(100);
var wLossWindow = new ErrorStore(100);

let testrun = 1;
let test_index = 0;
let total_tests = 0;
let total_correct = 0;

const PIXELS        = 28;                       // images in data set are tiny 
const PIXELSSQUARED = PIXELS * PIXELS;

const ZOOMFACTOR    = 7;                        
const ZOOMPIXELS    = ZOOMFACTOR * PIXELS; 

const canvaswidth = ( PIXELS + ZOOMPIXELS ) + 50;
const canvasheight = ( ZOOMPIXELS * 4 ) + 150;

const DOODLE_THICK = 18;    // thickness of doodle lines 
const DOODLE_BLUR = 3;      // blur factor applied to doodles 

// should we train every timestep or not 
let do_training = true;

// how many to train and test per timestep 
const TRAINPERSTEP = 30;
var jsLoaded = false;

// images in LHS:
let doodle;
let doodle_exists = false;

let mousedrag = false;      // are we in the middle of a mouse drag drawing?  

var doodle_inputs;
var lossChart;
var testChart;
var showGraphs = true; //Turn them off for performance reasons.

const letters = ["A","B","C","D","E","F","G","H","I","J","K","L","M","N","O","P","Q","R","S","T","U","V","W","X","Y","Z"];      


// make run header bigger
AB.headerCSS ( { "max-height": "95vh" } );




//--- start of AB.msgs structure: ---------------------------------------------------------
// We output a serious of AB.msgs to put data at various places in the run header 
var thehtml;

  // 1 Doodle header 
  thehtml = "<hr> <h2> 1. Doodle </h2> " +
        " Draw your doodle in top LHS. <button onclick='wipeDoodle();' class='normbutton' >Clear doodle</button><br> ";
   AB.msg ( thehtml, 1 );

    thehtml = "<table> <tr><th></th><th>Letter</th><th>Confidence</th><th></th><th>Letter</th><th>Confidence</th></tr>";
    thehtml = thehtml + "<tr><td>First Guess</td><td id='guess_letter_1'></td><td id='guess_percent_1'></td><td>Fourth Guess</td><td id='guess_letter_4'></td><td id='guess_percent_4'></td></tr>";
    thehtml = thehtml + "<tr><td>Second Guess</td><td id='guess_letter_2'></td><td id='guess_percent_2'></td><td>Fifth Guess</td><td id='guess_letter_5'></td><td id='guess_percent_5'></td></tr>";
    thehtml = thehtml + "<tr><td>Third Guess</td><td id='guess_letter_3'></td><td id='guess_percent_3'></td><td>Sixth Guess</td><td id='guess_letter_6'></td><td id='guess_percent_6'></td></tr></table>";
    thehtml = thehtml + "<hr>Draw a doodle. Type the correct answer in the box and then click 'Save doodle'<input type='text' id='doodlePred' size='1' value='A'>"+
    "<button onclick='saveDoodle();' class='normbutton' >Save doodle</button> &nbsp; &nbsp; &nbsp; &nbsp;       <button onclick='loadDoodlesAndTest();' class='normbutton' >Load doodles and test</button>"+
    " Results: <div id='doodles_Results' style='color:darkgreen'>results will be displayed here</div>";
   AB.msg ( thehtml, 2 );
   
  thehtml = "<hr> <h2> 2. Training & Testing </h2> " +
        " <button id='pauseBtn' onclick='pauseTraining()' class='normbutton' >Pause training</button> <br> ";
    thehtml = thehtml + "<table> <tr><th><div style='width: 150px;'>Training</div></th><th>Testing (Last 100)</th><th>Testing (Best 100)</th></tr>";    
    thehtml = thehtml + "<tr><td id = 'training_run'></td><td id='testing_last'></td><td id='testing_best'></td></tr>";
    thehtml = thehtml + "<tr><td id = 'training_num'></td><td></td><td></td></tr></table>";
  AB.msg ( thehtml, 3 );
  thehtml = "<hr> A set from the test data. Useful for comparing results.<br> " +
        " <button onclick='controlTest();' class='normbutton' >Run control test</button> <br> ";
    thehtml = thehtml + "<table> <tr>";
    var rowHtml = "";
    for (var l=0; l<26; l++){
        thehtml = thehtml + "<th>" + letters[l] + "</th>";
        rowHtml = rowHtml + "<td id='ctrlTest_"+ l + "'>0</td>"
    }
    thehtml = thehtml + "<th>Result</th>"
    rowHtml = rowHtml + "<td id='ctrlTest_Result'>0%</td>"
    thehtml = thehtml + "</tr><tr>" + rowHtml + "</tr></table>";
   AB.msg ( thehtml, 4 );
     
  // 4 variable training data 
  thehtml = "<hr> <h2> 3. Save/Load Snapshot </h2> Save current state to a snapshot. Load a previously saved snapshot <br>  " +
        " <button onclick='saveSnapshot();' class='normbutton' >Save snapshot</button> <button onclick='loadSnapshot();' class='normbutton' >Load snapshot</button><br> "+
        "<br> Download a snapshot - 60k trainings <br>" +
        " <button onclick='downloadSnapshot();' class='normbutton' >Download snapshot</button>"
  AB.msg ( thehtml, 5 );
      

//--- end of AB.msgs structure: ---------------------------------------------------------





//---- normal P5 code -------------------------------------------------------
//Training data was split so that it could be uploaded to Ancient Brain.
var trainingData_1;
var trainingData_2;
var testingData;
var controlTestData;
function preload() 
{
//Load the JSON - I am using the 3rd batch as testing data.    
loadJSON('/uploads/paul79/training_images_letters_1.json', function(data1){
  console.log("JSON Data 1 finished reading.");
  console.log("JSON Data 1 size: " + data1.size);
  trainingData_1 = data1;
  

  loadJSON('/uploads/paul79/training_images_letters_2.json', function(data2){
      console.log("JSON Data 2 finished reading.");
      console.log("JSON Data 2 size: " + data2.size);
      trainingData_2 = data2;
      

      loadJSON('/uploads/paul79/training_images_letters_3.json', function(data3){
          console.log("JSON Data 3 finished reading.");
          console.log("JSON Data 3 size: " + data3.size);
          testingData = data3;


          loadJSON('/uploads/paul79/control_set_letters_26.json', function(data4){
              console.log("JSON Data 4 finished reading.");
              console.log("JSON Data 4 size: " + data4.size);
              controlTestData = data4;
    
              AB.removeLoading();     // if no loading screen exists, this does nothing 
          });
      });
  });
});
}

//Download the necessary JS and setup the CNN
function setup()       
{
    createCanvas ( canvaswidth, canvasheight );

    doodle = createGraphics ( ZOOMPIXELS, ZOOMPIXELS );       // doodle on larger canvas 
    doodle.pixelDensity(1);

    $.getScript ( "/uploads/paul79/convnet.js", function()  //See https://cs.stanford.edu/people/karpathy/convnetjs/docs.html
    {
    $.getScript ( "/uploads/paul79/LineChart.js", function()
    {
        console.log ("All JS loaded");
        jsLoaded = true;
        
        layer_defs = [];
        //Input layer - images have been cropped from 28x28 to 24x24
//        layer_defs.push({type:'input', out_sx:24, out_sy:24, out_depth:1});
        layer_defs.push({type:'input', out_sx:28, out_sy:28, out_depth:1});

        //First convolutional layer - 5x5, 8 filters 
//        layer_defs.push({type:'conv', sx:5, filters:4, stride:1, pad:2, activation:'relu'});
        layer_defs.push({type:'conv', sx:5, filters:8, stride:1, pad:2, activation:'relu'});
//        layer_defs.push({type:'conv', sx:5, filters:16, stride:1, pad:2, activation:'relu'});

        //First pool layer
        layer_defs.push({type:'pool', sx:2, stride:2});

        //Second convolutional layer - 5x5, 16 filters 
//        layer_defs.push({type:'conv', sx:5, filters:8, stride:1, pad:2, activation:'relu'});
        layer_defs.push({type:'conv', sx:5, filters:16, stride:1, pad:2, activation:'relu'});
//        layer_defs.push({type:'conv', sx:5, filters:32, stride:1, pad:2, activation:'relu'});

        //Second pool layer
        layer_defs.push({type:'pool', sx:3, stride:3});

        //Output layer - one per character
        layer_defs.push({type:'softmax', num_classes:26});
        
        net = new convnetjs.Net();
        net.makeLayers(layer_defs);
        
        trainer = new convnetjs.SGDTrainer(net, {method:'adadelta', batch_size:20, l2_decay:0.001});
    });
    });
    
}

function getImage ( img )      // make a P5 image object from a raw data array   
{
    let theimage  = createImage (PIXELS, PIXELS);    // make blank image, then populate it 
    theimage.loadPixels();        
    
    for (let i = 0; i < PIXELSSQUARED ; i++) 
    {
        let bright = img[i];
        let index = i * 4;
        theimage.pixels[index + 0] = bright;
        theimage.pixels[index + 1] = bright;
        theimage.pixels[index + 2] = bright;
        theimage.pixels[index + 3] = 255;
    }
    
    theimage.updatePixels();
    return theimage;
}


function getInputs ( img )      // convert img array into normalised input array 
{
    var x = new convnetjs.Vol(28,28,1,0.0);
    for (let i = 0; i < PIXELSSQUARED ; i++)          
    {
        let bright = img[i];
        x.w[i] = bright / 255;       // normalise to 0 to 1
    } 
//    x = convnetjs.augment(x, 24);    
    return ( x );
}

function train (show)        // train the network with a single exemplar, from global var "train_index", show visual on or off 
{
    var trainingData = trainrun % 2 == 1 ? trainingData_1 : trainingData_2;
    let img = trainingData.images[train_index].img;
    let label = trainingData.images[train_index].label;
  
    // optional - show visual of the image 
    if (show)                
    {
        var theimage = getImage ( img );    // get image from data array 
        image ( theimage,   0,                ZOOMPIXELS+50,    ZOOMPIXELS,     ZOOMPIXELS  );      // magnified 
        image ( theimage,   ZOOMPIXELS+50,    ZOOMPIXELS+50,    PIXELS,         PIXELS      );      // original
    }

    // set up the inputs
    let x = getInputs ( img );       // get inputs from data array 

    // train on it with network
    var stats = trainer.train(x, label);
    var lossx = stats.cost_loss;
    var lossw = stats.l2_decay_loss;
    xLossWindow.add(lossx);
    wLossWindow.add(lossw);
  

    document.getElementById("training_run").innerHTML = "Train Run: " + trainrun;
    document.getElementById("training_num").innerHTML = "Train Tests: " + train_index;
    
    train_index++;
    if ( train_index >= trainingData.images.length ) 
    {
        train_index = 0;
        console.log( "finished trainrun: " + trainrun );
        trainrun++;
    }
    // log progress to graph, (full loss)
    if(showGraphs && total_trains % 200 === 0) {
        var xa = xLossWindow.get_average();
        var xw = wLossWindow.get_average();
        if(xa >= 0 && xw >= 0) { // if they are -1 it means not enough data was accumulated yet for estimates
            createLossChart(xa + xw, total_trains);
        }
    }
    total_trains++;
}

function pauseTraining(){
    if (do_training){
        do_training = false;
        document.getElementById("pauseBtn").innerHTML = "Resume Training";
    } else {
        do_training = true;
        document.getElementById("pauseBtn").innerHTML = "Pause Training";
    }
}

var testResults = new TestResultsSquash(100, 10);
var testingBest = 0;
function test(){
    let img = testingData.images[test_index].img;
    let label = testingData.images[test_index].label;
    // set up the inputs
    let x = getInputs ( img );       // get inputs from data array 
    
    // feed forward and get prediction 
    var a = net.forward(x);
    var y = net.getPrediction();
    
    test_index++;
    total_tests++;
    if (y == label){
        total_correct++;
    }
    
    if (showGraphs){
        testResults.add (y == label ? 1 : 0);
        
            // log progress to graph, (full loss)
        if(total_tests % 10 === 0) {
            createTestChart();
        }
    }
    if (test_index % 100 === 0){
        res = testResults.getLast();
        if (res != -1){
            res = Math.round(res);
            document.getElementById("testing_last").innerHTML = res + "%";
            if (res > testingBest){
                testingBest = res;
                document.getElementById("testing_best").innerHTML = res + "%";
            }
        }
    }

    
    if (test_index >= testingData.images.length){
        test_index = 0;
    }
}

function controlTest(){
    var ctrlTestRes = 0;
    for (var i=0; i<26; i++){
        let img = controlTestData.images[i].img;
        let label = controlTestData.images[i].label;
        // set up the inputs
        let x = getInputs ( img );       // get inputs from data array 
        // feed forward to make prediction 
        var a = net.forward(x);
        var y = net.getPrediction();
        
        var res = 0;
        if (y == label){
            res = 1;
            ctrlTestRes++;
        }
        document.getElementById("ctrlTest_"+i).innerHTML = res;
        document.getElementById("ctrlTest_Result").innerHTML = Math.round((ctrlTestRes/26) * 100) + "%";
    }
}

function draw()             
{
    // check if libraries and data loaded yet:
    if ( jsLoaded === false || typeof testingData == 'undefined' ) return;

    background ('black');
    if ( do_training )    
    {
        // do some training per step 
        for (let i = 0; i < TRAINPERSTEP; i++) 
        {
            if (i == 0)    
                train(true);    // show only one per step - still flashes by  
            else           
                train(false);

            test();
            
            if (autoTestsAt !== -1 && total_trains % autoTestsAt === 0){
                //Do 10 doodle tests and average them (will only work if there are saved doodles in localStorage!)
                let doodleScores = [];
                for (let i=0; i<10; i++){
                    doodleScores.push (loadDoodlesAndTest());
                }
                let average = doodleScores.reduce((a, b) => a + b, 0) / doodleScores.length;
                console.log("Auto tests after " + total_trains + " trains:");
                console.log("Average Doodle percentage: " + average);
                
                let last500Tests = testResults.getLastN(5);
                average = last500Tests.reduce((a, b) => a + b, 0) / last500Tests.length;
                console.log("Average Test percentage: " + average);
            }
        }
    }
    if (typeof lossChart !== 'undefined'){
        lossChart.show();        
    }
    if (typeof testChart !== 'undefined'){
        testChart.show();
    }
    
    if ( doodle_exists ) 
    {
        drawDoodle();
        guessDoodle();
    }
    

// detect doodle drawing 
// (restriction) the following assumes doodle starts at 0,0 

    if ( mouseIsPressed )         // gets called when we click buttons, as well as if in doodle corner  
    {
        var MAX = ZOOMPIXELS + 20;     // can draw up to this pixels in corner 
        if ( (mouseX < MAX) && (mouseY < MAX) && (pmouseX < MAX) && (pmouseY < MAX) )
        {
            doodle_exists = true;
            mousedrag = true;       // start a mouse drag 
            doodle.stroke('white');
            doodle.strokeWeight( DOODLE_THICK );
            doodle.line(mouseX, mouseY, pmouseX, pmouseY);      
        }
    }
    else 
    {
        // are we exiting a drawing
        if ( mousedrag )
        {
            mousedrag = false;
            // console.log ("Exiting draw. Now blurring.");
            doodle.filter (BLUR, DOODLE_BLUR);    // just blur once 
            //   console.log (doodle);
        }
    }
}


//--- doodle -------------------------------------------------------------

function drawDoodle()
{
    let theimage;
    // doodle is createGraphics not createImage
    theimage = doodle.get();

    image ( theimage,   0,                0,    ZOOMPIXELS,     ZOOMPIXELS  );      // original 
    image ( theimage,   ZOOMPIXELS+50,    0,    PIXELS,         PIXELS      );      // shrunk
}
      
function guessDoodle() 
{
    let img;
    // doodle is createGraphics not createImage
    img = doodle.get();

    img.resize ( PIXELS, PIXELS );     
    img.loadPixels();

  // set up inputs   
    var x = new convnetjs.Vol(28,28,1,0.0);
    for (let i = 0; i < PIXELSSQUARED ; i++)          
    {
        let bright = img[i];
        x.w[i] = img.pixels[i * 4] / 255;
    } 
//    x = convnetjs.augment(x, 24);    
  

    // feed forward
    var a = net.forward(x); //probability of each output is given in a.w
  
    var preds = [];
    for(var k=0;k<a.w.length;k++) { 
        preds.push({k:k,p:a.w[k]}); 
    }
    //Sort the predictions
    preds.sort(function(a,b){return a.p<b.p ? 1:-1;});

    for (var wi=0; wi<6; wi++){
      var roundedPC = Math.round(preds[wi].p * 100);
      document.getElementById("guess_percent_" + (wi+1)).innerHTML = roundedPC + "%";
      document.getElementById("guess_letter_" + (wi+1)).innerHTML = letters[preds[wi].k];
  }

}

function wipeDoodle()    
{
    doodle_exists = false;
    doodle.background('black');
}

//Save a doodle to local storage
//Extract the pixels and store in JSON object along with the inputted label
function saveDoodle(){
    // doodle is createGraphics not createImage
    let img = doodle.get();
  
    img.resize ( PIXELS, PIXELS );     
    img.loadPixels();
    
    let label = document.getElementById("doodlePred").value;
    var jsonimage = {'pixels': [...img.pixels], 'label':label};
    
    var doodleSaveList;
    if (localStorage.getItem('doodles') === null){
        doodleSaveList = {'doodles':[]};
    } else {
        doodleSaveList = JSON.parse(localStorage.getItem('doodles'));
    }
    doodleSaveList.doodles.push(jsonimage);
    localStorage.setItem('doodles', JSON.stringify(doodleSaveList));
}

//Load all the saved doodles and test each of them
function loadDoodlesAndTest(){
    var doodleSaveList;
    if (localStorage.getItem('doodles') === null){
        alert("Save doodle(s), then you can load.");
    } else {
        doodleSaveList = JSON.parse(localStorage.getItem('doodles'));
    }
    numCorrect = 0;
    for (let i=0; i<doodleSaveList.doodles.length; i++){
        var img = doodleSaveList.doodles[i];
    
        var x = new convnetjs.Vol(28,28,1,0.0);
        for (let i = 0; i < PIXELSSQUARED ; i++)          
        {
            x.w[i] = img.pixels[i * 4] / 255;
        } 

        // feed forward to make prediction 
        var a = net.forward(x);
        var y = net.getPrediction();
        if (y == letters.indexOf(img.label)){
            numCorrect++;
        }
    }
    let pc = Math.round(numCorrect/doodleSaveList.doodles.length*100);
    document.getElementById("doodles_Results").innerHTML = "Doodles tested: " + doodleSaveList.doodles.length + " Correct: " + pc + "%";
    return pc;
}

//Show a chart of training error
var graphData = {xData:[], yData:[]};
function createLossChart(loss, trainNumber){
    graphData.xData.push(trainNumber);
    graphData.yData.push(loss);
    
    data = [];
  
    colors = ['#ff0000']
    lineLabels = ["Loss (training)"]
  
    data.push([]);
    for(let j = 0; j < graphData.xData.length; j++) {
        data[0].push(createVector(graphData.xData[j], graphData.yData[j]));
    }

    lossChart = new LineChart(data, colors, lineLabels, 250, 250, 5, canvasheight-250, [min(graphData.xData.flat()), max(graphData.xData.flat())], [0, max(graphData.yData.flat())]);
}

//Show a chart of testing success
function createTestChart(){

    var results = testResults.getData();
    
    data = [];
  
    colors = ['#0000ff']
    lineLabels = ["Hidden Tests"]
  
    data.push([]);
    for(let j = 0; j < results.xData.length; j++) {
        data[0].push(createVector(results.xData[j], results.yData[j]));
    }

    testChart = new LineChart(data, colors, lineLabels, 250, 250, 5, canvasheight-500, [min(results.xData.flat()), max(results.xData.flat())], [0, 100]);
}


function ErrorStore (size, minsize){
    this.v = [];
    this.size = typeof(size)==='undefined' ? 100 : size;
    this.minsize = typeof(minsize)==='undefined' ? 20 : minsize;
    this.sum = 0;
    
    this.add = function(x){
        this.v.push(x);
        this.sum += x;
        if(this.v.length>this.size) {
            var xold = this.v.shift();
            this.sum -= xold;
        }
    }
    
    this.get_average = function(){
        if(this.v.length < this.minsize) return -1;
        else return this.sum/this.v.length;
    }
    
    this.reset = function(){
        this.v = [];
        this.sum = 0;
    }
}

function TestResultsSquash (size, minsize){
    this.vals = []; //will hold 0/1 for the last 100 tests
    this.size = typeof(size)==='undefined' ? 100 : size;
    this.minsize = typeof(minsize)==='undefined' ? 10 : minsize;
    this.hundredTest = [];
    this.sum = 0;
    
    this.add = function(res){
        this.vals.push(res);
        this.sum += res;
        if (this.vals.length >= this.size){
            this.hundredTest.push((this.sum/this.size)*100);
            this.vals = [];
            this.sum = 0;
        }
    }
    
    this.getData = function (){
        var yData = [...this.hundredTest];
        var xData = [];
        for (var i=0; i<yData.length; i++){
            xData.push(this.size * (i+1));
        }
        yData.push ((this.sum/this.vals.length)*100);
        xData.push (xData[xData.length-1] + this.vals.length)
        return {"xData": xData, "yData": yData};
    }
    
    this.getLast = function (){
        if (this.hundredTest.length > 0){
            return this.hundredTest[this.hundredTest.length-1];
        } else {
            return -1;
        }
    }
    
    this.getLastN = function (n){
        return this.hundredTest.slice(-n, this.hundredTest.length);
    }
    
    this.reset = function(){
        this.vals = [];
        this.hundredTest = [];
        this.sum = 0;
    }
}

function saveSnapshot(){
    var snap = JSON.stringify(net.toJSON());
    localStorage.setItem('cnnSnapshot', snap);
    console.log("Save snaphot " + snap);
}

function loadSnapshot(){
    var snap = JSON.parse(localStorage.getItem('cnnSnapshot'));
    console.log("Loading snaphot " + snap);
    resetNetFromSnapshot(snap);
}

function downloadSnapshot(){
    loadJSON('/uploads/paul79/snapshot.json', function(snapshot){
        var snap = snapshot;
        console.log("Downloaded snapshot " + snap);
        resetNetFromSnapshot(snap);
    });
}

//Reset the network state from a snapshot
function resetNetFromSnapshot(snap){
    net = new convnetjs.Net();
    net.fromJSON(snap);

    xLossWindow.reset();
    wLossWindow.reset();
    trainrun = 1;
    train_index = 0;

    testrun = 1;
    test_index = 0;
    total_tests = 0;
    total_correct = 0;
    testResults.reset();
    graphData = {xData:[], yData:[]};
}