Code viewer for World: Doodle Classifier
const len = 784;
const total_data = 1000;

const AIRPLANE = 0;
const APPLE = 1;
const BUS = 2;
const CAR = 3;
const CAT = 4;
const CRUISESHIP = 5;
const ELEPHANT = 6;
const HELICOPTER = 7;
const SUBMARINE = 8;
const TRUCK = 9;

let airplane_data = {};
let apple_data = {};
let bus_data = {};
let car_data = {};
let cat_data = {};
let cruiseship_data = {};
let elephant_data = {};
let helicopter_data = {};
let submarine_data = {};
let truck_data = {};

let airplane = {};
let apple = {};
let bus = {};
let car = {};
let cat = {};
let cruiseship = {};
let elephant = {};
let helicopter = {};
let submarine = {};
let truck = {};



let training_runs = 0;
let guessObject = "";

let nn;
let training = [];
let test = [];
let initiated = false;

let options;
let doodleClassifier;

//--- start of AB.msgs structure: ---------------------------------------------------------
// We output a serious of AB.msgs to put data at various places in the run header 
var thehtml;

  // 1 Doodle header 
  thehtml = "<hr> <h1> 1. Doodle </h1>" +
        "<button onclick='wipeDoodle();' class='normbutton' >Clear doodle</button> <br> ";
   AB.msg ( thehtml, 1 );

  // 2 Doodle variable data (guess)
  
  // 3 Training header
  thehtml = "<hr> <h1> 2. Training </h1> Accuracy: -%<br> " +
        " <button onclick='trainSetup(training);' class='normbutton' >Train for 1 Epoch</button> <br> ";
  AB.msg ( thehtml, 2 );
  
  // 3 Guess header
  thehtml = "<hr> <h1> 3. Guess </h1> Guess: <br> " +
        " <button onclick='guessSetup();' class='normbutton' >Guess Drawing</button> <br> ";
  AB.msg ( thehtml, 3 );

  // 4 Save header
  thehtml = "<hr> <h1> 4. Save </h1>" +
        " <button onclick='saveData();' class='normbutton' >Save Model</button> <br> ";
  AB.msg ( thehtml, 4 );

  // 5 Restore header
  thehtml = "<hr> <h1> 5. Restore </h1>" +
        " <button onclick='restoreData();' class='normbutton' >Restore Model</button> <br> ";
  AB.msg ( thehtml, 5 );
  
const greenspan = "<span style='font-weight:bold; font-size:x-large; color:darkgreen'> "  ;

//--- end of AB.msgs structure: ---------------------------------------------------------



AB.world.newRun = function()
{
 	// Code to execute once at the start.  
	// Should call:
	//		ABWorld.init (  COLOR  );
	AB.runReady = false;            // prevent screenshot while waiting on splash screen 
    ABWorld.init (  'black'  ); 
    AB.headerRHS();   
    
    //AB.loadingScreen();
 
    $.getScript ( "/uploads/codingtrain/matrix.js", function()
    {
      $.getScript ( "/uploads/codingtrain/nn.js", function()
      {
           $.getScript ( "/uploads/codingtrain/mnist.js", function()
           {
                nn = new NeuralNetwork(784, 128, 10);
                netSetup();
                
                console.log ("All JS loaded");
           });
      });
    });

};


AB.world.nextStep = function()		 
{
	// Code to execute every step. 
	// Can put P5 instructions to be executed every step here, or in draw()  
 };


AB.world.endRun = function()
{
};




//---- setup -------------------------------------------------------
// Do NOT make a setup function.
// This is done for you in the API. The API setup just creates a canvas.
// Anything else you want to run at the start should go into the following two functions.


function beforesetup()      // Optional 
{
	
	airplane_data = loadBytes("/uploads/gfar97/airplaneK.bin")
    apple_data = loadBytes("/uploads/gfar97/appleK.bin")
    bus_data = loadBytes("/uploads/gfar97/busK.bin")
    car_data = loadBytes("/uploads/gfar97/carK.bin")
    cat_data = loadBytes("/uploads/gfar97/catK.bin")
    cruiseship_data = loadBytes("/uploads/gfar97/cruise_shipK.bin")
    elephant_data = loadBytes("/uploads/gfar97/elephantK.bin")
    helicopter_data = loadBytes("/uploads/gfar97/helicopterK.bin")
    submarine_data = loadBytes("/uploads/gfar97/submarineK.bin")
    truck_data = loadBytes("/uploads/gfar97/truckK.bin")
    
    options = {
	    inputs: [28, 28, 1],
	    task: 'imageClassification',
	    debug: "true",
	};

    //doodleClassifier = ml5.neuralNetwork(options);
    
    
}


function aftersetup()       // Optional
{
	// Anything you want to run at the start AFTER the canvas is created
    resizeCanvas(400, 400);
    background(0, 0, 0);
    
}

function netSetup() {
    prepareData(airplane, airplane_data, AIRPLANE);
    prepareData(apple, apple_data, APPLE);
    prepareData(bus, bus_data, BUS);
    prepareData(car, car_data, CAR);
    prepareData(cat, cat_data, CAT);
    prepareData(cruiseship, cruiseship_data, CRUISESHIP);
    prepareData(elephant, elephant_data, ELEPHANT);
    prepareData(helicopter, helicopter_data, HELICOPTER);
    prepareData(submarine, submarine_data, SUBMARINE);
    prepareData(truck, truck_data, TRUCK);


    
    training = training.concat(airplane.training);
    training = training.concat(apple.training);
    training = training.concat(bus.training);
    training = training.concat(car.training);
    training = training.concat(cat.training);
    training = training.concat(cruiseship.training);
    training = training.concat(elephant.training);
    training = training.concat(helicopter.training);
    training = training.concat(submarine.training);
    training = training.concat(truck.training);

    test = test.concat(airplane.test);
    test = test.concat(apple.test);
    test = test.concat(bus.test);
    test = test.concat(car.test);
    test = test.concat(cat.test);
    test = test.concat(cruiseship.test);
    test = test.concat(elephant.test);
    test = test.concat(helicopter.test);
    test = test.concat(submarine.test);
    test = test.concat(truck.test);

    updateHtml();
    // for (let i = 1; i < 6; i++) {
    //     trainSetup(training);
    // console.log("Epoch: " + i);
    // testSetup(test);
    // }
    
}

function trainSetup(training) {
    shuffle(training, true);
    //console.log(training);
    for (let i = 0; i < training.length; i++) {
        let data = training[i];
        let inputs =  Array.from(data).map(x => x / 255);
        let label = training[i].label;
        let targets = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
        targets[label] = 1;

        nn.train(inputs, targets);
    }
    updateHtml();
    console.log("Epoch training finished");
}

function testSetup(testing) {
    let correct = 0;

    for (let i = 0; i < testing.length; i++) {
        let data = testing[i];
        let inputs = Array.from(data).map(x => x / 255);
        let label = testing[i].label;

        let prediction = nn.predict(inputs);
        let classification = prediction.indexOf(max(prediction));

        if (classification === label) {
            correct ++;
        }
        
        
        //nn.train(inputs, targets);
    }
    let percent = correct / testing.length;
    return percent * 100;
    
}

function prepareData(category, data, label) {
    category.training = []
    category.test = [];
    for (let i = 0; i < total_data; i++) {
        let offset = i * len;
        let threshold = floor(0.8 * total_data)
        if (i < threshold) {
            category.training[i] = data.bytes.subarray(offset, offset + len);
            category.training[i].label = label;
        }
        else {
            category.test[i - threshold] = data.bytes.subarray(offset, offset + len);
            category.test[i - threshold].label = label
        }
    }
    
    // let total = 100;
    // for (let n = 0; n < total; n++) {
    //     let img = createImage(28, 28);
    //     img.loadPixels();
    //     let offset = n * 784;
    //     for (let i = 0; i < 784; i++) {
    //         let val = cars.bytes[i + offset];
    //         img.pixels[i * 4 + 0] = val;
    //         img.pixels[i * 4 + 1] = val;
    //         img.pixels[i * 4 + 2] = val;
    //         img.pixels[i * 4 + 3] = 255;
    //     }
    //     img.updatePixels();
    //     let x = (n % 10) * 28;
    //     let y = floor(n/10) * 28;
    //     image(img, x, y);
    // }
}


function wipeDoodle() {
    background(0, 0, 0);
    console.log("Cleared doodle");
}

function guessSetup() {
    let inputs = [];
    let img = get();
    img.resize(28, 28);
    //console.log(img);
    img.loadPixels();
    // Only get the 4 channel of each pixel (the brightness)
    for (let i = 0; i < len; i ++) {
        let bright = img.pixels[i * 4];
        inputs[i] = bright / 255.0;
    }
    //console.log(inputs);

    let prediction = nn.predict(inputs);
    let classification = prediction.indexOf(max(prediction));
    if (classification === AIRPLANE) { guessObject = "Airplane"; }
    if (classification === APPLE) { guessObject = "Apple"; }
    if (classification === BUS) { guessObject = "Bus"; }
    if (classification === CAR) { guessObject = "Car"; }
    if (classification === CAT) { guessObject = "Cat"; }
    if (classification === CRUISESHIP) { guessObject = "Cruise Ship"; }
    if (classification === ELEPHANT) { guessObject = "Elephant"; }
    if (classification === HELICOPTER) { guessObject = "Helicopter"; }
    if (classification === SUBMARINE) { guessObject = "Submarine"; }
    if (classification === TRUCK) { guessObject = "Truck"; }

    updateHtml();
}

function randomWeight()
{
    return ( AB.randomFloatAtoB ( -0.5, 0.5 ) );
            // Coding Train default is -1 to 1
}    

function updateHtml() {
    // 3 Training header
    thehtml = "<hr> <h1> 2. Training </h1> Accuracy: " +  testSetup(test) + "%<br> " +
    " <button onclick='trainSetup(training);' class='normbutton' >Train for 1 Epoch</button> <br> ";
    AB.msg ( thehtml, 2 );

    // 3 Guess header
    thehtml = "<hr> <h1> 3. Guess </h1> Guess: " + guessObject + "<br> " +
    " <button onclick='guessSetup();' class='normbutton' >Guess Drawing</button> <br> ";
    AB.msg ( thehtml, 3 );
}

function saveData() {
    AB.saveData(nn);
}

function restoreData()
{
    AB.restoreData ( function ( a )            
    {
        nn.weights_ho = a.weights_ho;
        nn.weights_ih = a.weights_ih;
    });
} 

//---- draw -------------------------------------------------------

function draw()             // Optional
{
	// Can put P5 instructions to be executed every step here, or in AB.world.nextStep()
	
// 	strokeWeight(4);
// 	var r = random(24, 64);
// 	var x = random(ABWorld.p5canvas.width);
// 	var y = random(ABWorld.p5canvas.height);
// 	stroke(random(255), random(255), random(255));
// 	circle(x, y, r * 2);

    strokeWeight(8);
    stroke(255);
    if (mouseIsPressed) {
        line(pmouseX, pmouseY, mouseX, mouseY);
    }
    
    
    
    
}