Code viewer for World: Assignment 2 - doodle reco...
function trainEpoch(training) {
  shuffle(training, true);
  for (let i = 0; i < training.length; i++) {
    let data = training[i];
    let inputs = Array.from(data).map(x => x / 255);
    let label = training[i].label;
    let targets = [0, 0, 0, 0];
    targets[label] = 1;
    nn.train(inputs, targets);
  }
}

function testAll(testing) {

  let correct = 0;
  // Train for one epoch
  for (let i = 0; i < testing.length; i++) {
    // for (let i = 0; i < 1; i++) {
    let data = testing[i];
    let inputs = Array.from(data).map(x => x / 255);
    let label = testing[i].label;
    let guess = nn.predict(inputs);

    let m = max(guess);
    let classification = guess.indexOf(m);

    if (classification === label) {
      correct++;
    }
  }
  let percent = 100 * correct / testing.length;
  return percent;

}


function prepareData(category, data, label) {
  category.training = [];
  category.testing = [];
  for (let i = 0; i < totalData; i++) {
    let offset = i * len;
    let threshold = floor(0.8 * totalData);
    if (i < threshold) {
      category.training[i] = data.bytes.subarray(offset, offset + len);
      category.training[i].label = label;
    } else {
      category.testing[i - threshold] = data.bytes.subarray(offset, offset + len);
      category.testing[i - threshold].label = label;
    }
  }
}


const len = 784;
const totalData = 1000;

const APPLE = 0;
const BUS = 1;
const DIAMOND = 2;
const HAND = 3;

let appleData;
let busData;
let diamondData;
let handData;

let apple = {};
let bus = {};
let diamond = {};
let hand = {};

let nn;

function preload() {
    
  appleData = loadBytes('/uploads/harshakulkarni/apple.bin');
  busData = loadBytes('/uploads/harshakulkarni/bus.bin');
  diamondData = loadBytes('/uploads/harshakulkarni/diamond.bin');
  handData = loadBytes('/uploads/harshakulkarni/hand.bin');
}
function setup() {
  let canvas = createCanvas(600, 600);
  background(255);
  // Preparing the data
  prepareData(apple, appleData, APPLE);
  prepareData(bus, busData, BUS);
  prepareData(diamond, diamondData, DIAMOND);
  prepareData(hand, handData, HAND);

  // Making the neural network
  nn = new NeuralNetwork(784, 64, 4);

  // Randomizing the data
  let training = [];
  
  training = training.concat(apple.training);
  training = training.concat(bus.training);
  training = training.concat(diamond.training);
  training = training.concat(hand.training);

  let testing = [];
  
  testing = testing.concat(apple.testing);
  testing = testing.concat(bus.testing);
  testing = testing.concat(diamond.testing);
  testing = testing.concat(hand.testing);

  let trainButton = select('.train');
  let epochCounter = 0;
  trainButton.mousePressed(function() {
    trainEpoch(training);
    epochCounter++;
    console.log("Epoch: " + epochCounter);
    thehtml = "Trained " + epochCounter + " time/s";
    AB.msg ( thehtml, 4 );
  });

  let testButton = select('.test');
  testButton.mousePressed(function() 
  {
    let percent = testAll(testing);
    console.log("Percent: " + nf(percent, 2, 2) + "%");
    thehtml = "Percentage " + nf(percent, 2, 2) + "%"
    AB.msg ( thehtml, 7 );
  });

  let guessButton = select('.guess');
  guessButton.mousePressed(function() {
    let inputs = [];
    let img = get();
    img.resize(28, 28);
    img.loadPixels();
    for (let i = 0; i < len; i++) {
      let bright = img.pixels[i * 4];
      inputs[i] = (255 - bright) / 255.0;
    }

    let guess = nn.predict(inputs);
    console.log(guess);
    let m = max(guess);
    let classification = guess.indexOf(m);
    console.log(classification)
    let pred
    if (classification === APPLE) {
      console.log("APPLE");
      pred = "APPLE"
    } else if (classification === BUS) {
      console.log("bus");
      pred = "bus"
    } else if (classification === DIAMOND) {
      console.log("diamond");
      pred = "diamond"      
    } else if (classification === HAND) {
      console.log("hand");
      pred = "hand" 
    }
    thehtml = "Prediction: " + pred
    AB.msg ( thehtml, 12 );

  });

  let clearButton = select('.clear');
  clearButton.mousePressed(function() {
    background(255);
    thehtml = "Guess: ";
    AB.msg ( thehtml, 12 );
  });

}


function draw() {
  strokeWeight(10);
  stroke(0);
  if (mouseIsPressed) {
    line(pmouseX, pmouseY, mouseX, mouseY);
  }
}


// Other techniques for learning

$.getScript ( "/uploads/ayasu/nn.js")
$.getScript ( "/uploads/ayasu/matrix.js")
$.getScript ( "/uploads/harshakulkarni/html.js")