const len = 784;
const total_data = 1000;
const AIRPLANE = 0;
const APPLE = 1;
const BUS = 2;
const CAR = 3;
const CAT = 4;
const CRUISESHIP = 5;
const ELEPHANT = 6;
const HELICOPTER = 7;
const SUBMARINE = 8;
const TRUCK = 9;
let airplane_data = {};
let apple_data = {};
let bus_data = {};
let car_data = {};
let cat_data = {};
let cruiseship_data = {};
let elephant_data = {};
let helicopter_data = {};
let submarine_data = {};
let truck_data = {};
let airplane = {};
let apple = {};
let bus = {};
let car = {};
let cat = {};
let cruiseship = {};
let elephant = {};
let helicopter = {};
let submarine = {};
let truck = {};
let training_runs = 0;
let guessObject = "";
let nn;
let training = [];
let test = [];
let initiated = false;
let options;
let doodleClassifier;
//--- start of AB.msgs structure: ---------------------------------------------------------
// We output a serious of AB.msgs to put data at various places in the run header
var thehtml;
// 1 Doodle header
thehtml = "<hr> <h1> 1. Doodle </h1>" +
"<button onclick='wipeDoodle();' class='normbutton' >Clear doodle</button> <br> ";
AB.msg ( thehtml, 1 );
// 2 Doodle variable data (guess)
// 3 Training header
thehtml = "<hr> <h1> 2. Training </h1> Accuracy: -%<br> " +
" <button onclick='trainSetup(training);' class='normbutton' >Train for 1 Epoch</button> <br> ";
AB.msg ( thehtml, 2 );
// 3 Guess header
thehtml = "<hr> <h1> 3. Guess </h1> Guess: <br> " +
" <button onclick='guessSetup();' class='normbutton' >Guess Drawing</button> <br> ";
AB.msg ( thehtml, 3 );
// 4 Save header
thehtml = "<hr> <h1> 4. Save </h1>" +
" <button onclick='saveData();' class='normbutton' >Save Model</button> <br> ";
AB.msg ( thehtml, 4 );
// 5 Restore header
thehtml = "<hr> <h1> 5. Restore </h1>" +
" <button onclick='restoreData();' class='normbutton' >Restore Model</button> <br> ";
AB.msg ( thehtml, 5 );
const greenspan = "<span style='font-weight:bold; font-size:x-large; color:darkgreen'> " ;
//--- end of AB.msgs structure: ---------------------------------------------------------
AB.world.newRun = function()
{
// Code to execute once at the start.
// Should call:
// ABWorld.init ( COLOR );
AB.runReady = false; // prevent screenshot while waiting on splash screen
ABWorld.init ( 'black' );
AB.headerRHS();
//AB.loadingScreen();
$.getScript ( "/uploads/codingtrain/matrix.js", function()
{
$.getScript ( "/uploads/codingtrain/nn.js", function()
{
$.getScript ( "/uploads/codingtrain/mnist.js", function()
{
nn = new NeuralNetwork(784, 128, 10);
netSetup();
console.log ("All JS loaded");
});
});
});
};
AB.world.nextStep = function()
{
// Code to execute every step.
// Can put P5 instructions to be executed every step here, or in draw()
};
AB.world.endRun = function()
{
};
//---- setup -------------------------------------------------------
// Do NOT make a setup function.
// This is done for you in the API. The API setup just creates a canvas.
// Anything else you want to run at the start should go into the following two functions.
function beforesetup() // Optional
{
airplane_data = loadBytes("/uploads/gfar97/airplaneK.bin")
apple_data = loadBytes("/uploads/gfar97/appleK.bin")
bus_data = loadBytes("/uploads/gfar97/busK.bin")
car_data = loadBytes("/uploads/gfar97/carK.bin")
cat_data = loadBytes("/uploads/gfar97/catK.bin")
cruiseship_data = loadBytes("/uploads/gfar97/cruise_shipK.bin")
elephant_data = loadBytes("/uploads/gfar97/elephantK.bin")
helicopter_data = loadBytes("/uploads/gfar97/helicopterK.bin")
submarine_data = loadBytes("/uploads/gfar97/submarineK.bin")
truck_data = loadBytes("/uploads/gfar97/truckK.bin")
options = {
inputs: [28, 28, 1],
task: 'imageClassification',
debug: "true",
};
//doodleClassifier = ml5.neuralNetwork(options);
}
function aftersetup() // Optional
{
// Anything you want to run at the start AFTER the canvas is created
resizeCanvas(400, 400);
background(0, 0, 0);
}
function netSetup() {
prepareData(airplane, airplane_data, AIRPLANE);
prepareData(apple, apple_data, APPLE);
prepareData(bus, bus_data, BUS);
prepareData(car, car_data, CAR);
prepareData(cat, cat_data, CAT);
prepareData(cruiseship, cruiseship_data, CRUISESHIP);
prepareData(elephant, elephant_data, ELEPHANT);
prepareData(helicopter, helicopter_data, HELICOPTER);
prepareData(submarine, submarine_data, SUBMARINE);
prepareData(truck, truck_data, TRUCK);
training = training.concat(airplane.training);
training = training.concat(apple.training);
training = training.concat(bus.training);
training = training.concat(car.training);
training = training.concat(cat.training);
training = training.concat(cruiseship.training);
training = training.concat(elephant.training);
training = training.concat(helicopter.training);
training = training.concat(submarine.training);
training = training.concat(truck.training);
test = test.concat(airplane.test);
test = test.concat(apple.test);
test = test.concat(bus.test);
test = test.concat(car.test);
test = test.concat(cat.test);
test = test.concat(cruiseship.test);
test = test.concat(elephant.test);
test = test.concat(helicopter.test);
test = test.concat(submarine.test);
test = test.concat(truck.test);
updateHtml();
// for (let i = 1; i < 6; i++) {
// trainSetup(training);
// console.log("Epoch: " + i);
// testSetup(test);
// }
}
function trainSetup(training) {
shuffle(training, true);
//console.log(training);
for (let i = 0; i < training.length; i++) {
let data = training[i];
let inputs = Array.from(data).map(x => x / 255);
let label = training[i].label;
let targets = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
targets[label] = 1;
nn.train(inputs, targets);
}
updateHtml();
console.log("Epoch training finished");
}
function testSetup(testing) {
let correct = 0;
for (let i = 0; i < testing.length; i++) {
let data = testing[i];
let inputs = Array.from(data).map(x => x / 255);
let label = testing[i].label;
let prediction = nn.predict(inputs);
let classification = prediction.indexOf(max(prediction));
if (classification === label) {
correct ++;
}
//nn.train(inputs, targets);
}
let percent = correct / testing.length;
return percent * 100;
}
function prepareData(category, data, label) {
category.training = []
category.test = [];
for (let i = 0; i < total_data; i++) {
let offset = i * len;
let threshold = floor(0.8 * total_data)
if (i < threshold) {
category.training[i] = data.bytes.subarray(offset, offset + len);
category.training[i].label = label;
}
else {
category.test[i - threshold] = data.bytes.subarray(offset, offset + len);
category.test[i - threshold].label = label
}
}
// let total = 100;
// for (let n = 0; n < total; n++) {
// let img = createImage(28, 28);
// img.loadPixels();
// let offset = n * 784;
// for (let i = 0; i < 784; i++) {
// let val = cars.bytes[i + offset];
// img.pixels[i * 4 + 0] = val;
// img.pixels[i * 4 + 1] = val;
// img.pixels[i * 4 + 2] = val;
// img.pixels[i * 4 + 3] = 255;
// }
// img.updatePixels();
// let x = (n % 10) * 28;
// let y = floor(n/10) * 28;
// image(img, x, y);
// }
}
function wipeDoodle() {
background(0, 0, 0);
console.log("Cleared doodle");
}
function guessSetup() {
let inputs = [];
let img = get();
img.resize(28, 28);
//console.log(img);
img.loadPixels();
// Only get the 4 channel of each pixel (the brightness)
for (let i = 0; i < len; i ++) {
let bright = img.pixels[i * 4];
inputs[i] = bright / 255.0;
}
//console.log(inputs);
let prediction = nn.predict(inputs);
let classification = prediction.indexOf(max(prediction));
if (classification === AIRPLANE) { guessObject = "Airplane"; }
if (classification === APPLE) { guessObject = "Apple"; }
if (classification === BUS) { guessObject = "Bus"; }
if (classification === CAR) { guessObject = "Car"; }
if (classification === CAT) { guessObject = "Cat"; }
if (classification === CRUISESHIP) { guessObject = "Cruise Ship"; }
if (classification === ELEPHANT) { guessObject = "Elephant"; }
if (classification === HELICOPTER) { guessObject = "Helicopter"; }
if (classification === SUBMARINE) { guessObject = "Submarine"; }
if (classification === TRUCK) { guessObject = "Truck"; }
updateHtml();
}
function randomWeight()
{
return ( AB.randomFloatAtoB ( -0.5, 0.5 ) );
// Coding Train default is -1 to 1
}
function updateHtml() {
// 3 Training header
thehtml = "<hr> <h1> 2. Training </h1> Accuracy: " + testSetup(test) + "%<br> " +
" <button onclick='trainSetup(training);' class='normbutton' >Train for 1 Epoch</button> <br> ";
AB.msg ( thehtml, 2 );
// 3 Guess header
thehtml = "<hr> <h1> 3. Guess </h1> Guess: " + guessObject + "<br> " +
" <button onclick='guessSetup();' class='normbutton' >Guess Drawing</button> <br> ";
AB.msg ( thehtml, 3 );
}
function saveData() {
AB.saveData(nn);
}
function restoreData()
{
AB.restoreData ( function ( a )
{
nn.weights_ho = a.weights_ho;
nn.weights_ih = a.weights_ih;
});
}
//---- draw -------------------------------------------------------
function draw() // Optional
{
// Can put P5 instructions to be executed every step here, or in AB.world.nextStep()
// strokeWeight(4);
// var r = random(24, 64);
// var x = random(ABWorld.p5canvas.width);
// var y = random(ABWorld.p5canvas.height);
// stroke(random(255), random(255), random(255));
// circle(x, y, r * 2);
strokeWeight(8);
stroke(255);
if (mouseIsPressed) {
line(pmouseX, pmouseY, mouseX, mouseY);
}
}