const CLASSES = 3;
const IMAGE_SIZE = 784;
const CAT = 0;
const RAINBOW = 1;
const TRAIN = 2;
let catsData;
let trainsData;
let rainbowsData;
let data;
let model;
function preload() {
catsData = loadBytes('/uploads/cindy15/cats1000.bin');
trainsData = loadBytes('/uploads/cindy15/trains1000.bin');
rainbowsData = loadBytes('/uploads/cindy15/rainbows1000.bin');
}
function setup() {
createCanvas(280, 280);
$.getScript ( "https://cdnjs.cloudflare.com/ajax/libs/p5.js/0.6.0/p5.min.js", function()
{
$.getScript ( "https://cdnjs.cloudflare.com/ajax/libs/p5.js/0.6.0/addons/p5.dom.min.js", function()
{
$.getScript ( "https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.8.0", function()
{
$.getScript ( "/uploads/cindy15/loadbinary.js", function()
{
$.getScript ( "/uploads/cindy15/data.js", function()
{
$.getScript ( "/uploads/cindy15/classifier.js", function()
{
console.log ("All JS loaded");
preload();
let total = (catsData.bytes.length + rainbowsData.bytes.length + trainsData.bytes.length) / IMAGE_SIZE;
data = new DoodleData(total);
data.load(catsData.bytes, CAT);
data.load(rainbowsData.bytes, RAINBOW);
data.load(trainsData.bytes, TRAIN);
data.shuffle();
model = new Classifier();
});
});});});});
});
// model.train(data);
background(255);
let trainButton = createButton('#train');
trainButton.position(50,50)
// let epochCounter = 0;
trainButton.mousePressed(function() {
model.train(data);
});
let testButton = select('#test');
testButton.mousePressed(function() {});
let guessButton = select('#guess');
guessButton.mousePressed(function() {
let inputs = [];
let img = get();
img.resize(28, 28);
img.loadPixels();
for (let i = 0; i < len; i++) {
let bright = img.pixels[i * 4];
inputs[i] = (255 - bright) / 255.0;
}
let guess = nn.predict(inputs);
// console.log(guess);
let m = max(guess);
let classification = guess.indexOf(m);
if (classification === CAT) {
console.log("cat");
} else if (classification === RAINBOW) {
console.log("rainbow");
} else if (classification === TRAIN) {
console.log("train");
}
//image(img, 0, 0);
});
let clearButton = select('#clear');
clearButton.mousePressed(function() {
background(255);
});
var html =" <body> <button id='train'>train</button> <button id='test'>test</button> <button id='guess'>guess</button> <button id='clear'>clear</button></body>";
AB.msg(html)
}
function draw() {
strokeWeight(8);
stroke(0);
if (mouseIsPressed) {
line(pmouseX, pmouseY, mouseX, mouseY);
}
}