function trainEpoch(training) {
shuffle(training, true);
for (let i = 0; i < training.length; i++) {
let data = training[i];
let inputs = Array.from(data).map(x => x / 255);
let label = training[i].label;
let targets = [0, 0, 0, 0];
targets[label] = 1;
nn.train(inputs, targets);
}
}
function testAll(testing) {
let correct = 0;
// Train for one epoch
for (let i = 0; i < testing.length; i++) {
// for (let i = 0; i < 1; i++) {
let data = testing[i];
let inputs = Array.from(data).map(x => x / 255);
let label = testing[i].label;
let guess = nn.predict(inputs);
let m = max(guess);
let classification = guess.indexOf(m);
if (classification === label) {
correct++;
}
}
let percent = 100 * correct / testing.length;
return percent;
}
function prepareData(category, data, label) {
category.training = [];
category.testing = [];
for (let i = 0; i < totalData; i++) {
let offset = i * len;
let threshold = floor(0.8 * totalData);
if (i < threshold) {
category.training[i] = data.bytes.subarray(offset, offset + len);
category.training[i].label = label;
} else {
category.testing[i - threshold] = data.bytes.subarray(offset, offset + len);
category.testing[i - threshold].label = label;
}
}
}
const len = 784;
const totalData = 1000;
const APPLE = 0;
const BUS = 1;
const DIAMOND = 2;
const HAND = 3;
let appleData;
let busData;
let diamondData;
let handData;
let apple = {};
let bus = {};
let diamond = {};
let hand = {};
let nn;
function preload() {
appleData = loadBytes('/uploads/harshakulkarni/apple.bin');
busData = loadBytes('/uploads/harshakulkarni/bus.bin');
diamondData = loadBytes('/uploads/harshakulkarni/diamond.bin');
handData = loadBytes('/uploads/harshakulkarni/hand.bin');
}
function setup() {
let canvas = createCanvas(600, 600);
background(255);
// Preparing the data
prepareData(apple, appleData, APPLE);
prepareData(bus, busData, BUS);
prepareData(diamond, diamondData, DIAMOND);
prepareData(hand, handData, HAND);
// Making the neural network
nn = new NeuralNetwork(784, 64, 4);
// Randomizing the data
let training = [];
training = training.concat(apple.training);
training = training.concat(bus.training);
training = training.concat(diamond.training);
training = training.concat(hand.training);
let testing = [];
testing = testing.concat(apple.testing);
testing = testing.concat(bus.testing);
testing = testing.concat(diamond.testing);
testing = testing.concat(hand.testing);
let trainButton = select('.train');
let epochCounter = 0;
trainButton.mousePressed(function() {
trainEpoch(training);
epochCounter++;
console.log("Epoch: " + epochCounter);
thehtml = "Trained " + epochCounter + " time/s";
AB.msg ( thehtml, 4 );
});
let testButton = select('.test');
testButton.mousePressed(function()
{
let percent = testAll(testing);
console.log("Percent: " + nf(percent, 2, 2) + "%");
thehtml = "Percentage " + nf(percent, 2, 2) + "%"
AB.msg ( thehtml, 7 );
});
let guessButton = select('.guess');
guessButton.mousePressed(function() {
let inputs = [];
let img = get();
img.resize(28, 28);
img.loadPixels();
for (let i = 0; i < len; i++) {
let bright = img.pixels[i * 4];
inputs[i] = (255 - bright) / 255.0;
}
let guess = nn.predict(inputs);
console.log(guess);
let m = max(guess);
let classification = guess.indexOf(m);
console.log(classification)
let pred
if (classification === APPLE) {
console.log("APPLE");
pred = "APPLE"
} else if (classification === BUS) {
console.log("bus");
pred = "bus"
} else if (classification === DIAMOND) {
console.log("diamond");
pred = "diamond"
} else if (classification === HAND) {
console.log("hand");
pred = "hand"
}
thehtml = "Prediction: " + pred
AB.msg ( thehtml, 12 );
});
let clearButton = select('.clear');
clearButton.mousePressed(function() {
background(255);
thehtml = "Guess: ";
AB.msg ( thehtml, 12 );
});
}
function draw() {
strokeWeight(10);
stroke(0);
if (mouseIsPressed) {
line(pmouseX, pmouseY, mouseX, mouseY);
}
}
// Other techniques for learning
$.getScript ( "/uploads/ayasu/nn.js")
$.getScript ( "/uploads/ayasu/matrix.js")
$.getScript ( "/uploads/harshakulkarni/html.js")