// ml5.js: Classifying Drawings with DoodleNet (Mouse)
// The Coding Train / Daniel Shiffman
// https://thecodingtrain.com/learning/ml5/9.1-doodlenet.html
// https://youtu.be/ABN_DWnM5GQ
// Template: https://editor.p5js.org/codingtrain/sketches/AHgkwgPdc
// Mouse: https://editor.p5js.org/codingtrain/sketches/6LLnGY1VY
// Video: https://editor.p5js.org/codingtrain/sketches/fxFKOn3il
let clearButton;
let canvas;
let doodleClassifier;
let resultsDiv;
// Our training dataset is stored in this array
let classes = [];
//label and data for a given classification
function Classification(name) {
this.name = name;
this.data = [];
}
function setup() {
//create a canvas of 80 by 80 pixels (scaled up to 400 for drawing)
pixelDensity(0.2);
canvas = createCanvas(400, 400);
$("canvas").css({"border":"3px solid #4682B4", "border-radius": "5px",
"margin":"10px","box-shadow" : "2px 5px #999999"});
background(255);
AB.msg("<button onclick='clearCanvas()'>Clear Canvas</button>",1);
AB.msg("<h3>Create Dataset</h3><p>Instructions:<br>"+
"1. Create Classications <br> 2. Draw Doodles on canvas <br> 3. Add doodle to selected classification (Tip: Press 'x' to quickly add current doodle)<br>"+
"3. Train Model on dataset</p>"+
"New classification: <input type='text' id='newclass'><button onclick='newClass()'>Create</button><br/>"+
"Selected: <select id='classes'></select> <button onclick='addData()'>Add doodle</button><br/>" +
"<button onclick='trainModel()'>Train Model</button>", 2);
AB.msg("<h3>Dataset import/export</h3><button onclick='exportJson()'>Export Dataset</button><br/>" +
"<input type='file' id='selectFiles' value='Import'/><button onclick='importJson()'>Import Dataset</button><br/>" +
"Use example dataset: <button onclick='loadInJson(\"/uploads/theonogo/shapes.json\")'>Shapes</button>"+
"<button onclick='loadInJson(\"/uploads/theonogo/math.json\")'>Math</button>",3);
let options = {
// 80x80 & 4-channels
inputs: [80, 80, 4],
task: 'imageClassification',
debug: true
};
doodleClassifier = ml5.neuralNetwork(options);
//doodleClassifier = ml5.imageClassifier('DoodleNet', modelReady);
//resultsDiv = createDiv('model loading');
}
// Runs when the neural network is done with its asynchronous classification
function gotResults(error, results) {
if (error) {
console.error(error);
return;
}
AB.msg("<p> 1st Guess: "+results[0].label+ " - with confidence of " +nf(100 * results[0].confidence, 2, 1)+"%<p><br>"+
"<p> 2nd Guess: "+results[1].label+ " = with confidence of " +nf(100 * results[1].confidence, 2, 1)+"%<p>", 3);
//Continuously classify canvas
loadPixels();
doodleClassifier.classify(Array.prototype.slice.call(pixels), gotResults);
}
//Create a new classification based on user input
function newClass() {
var input = document.getElementById('newclass').value;
classes.push(new Classification(input));
document.getElementById('newclass').value = "";
updateDropdown()
}
function clearCanvas() {
background(255);
}
// Makes it faster to save the current doodle to dataset by pressing X key
function keyPressed() {
if (keyCode === 88) {
addData();
}
}
//Add image pixel data to corresponding set
function addData() {
loadPixels();
let i = parseInt($('#classes').val());
classes[i].data.push(Array.prototype.slice.call(pixels));
//doodleClassifier.addData({ image: Array.prototype.slice.call(pixels) }, { label: classes[i].name });
//doodleClassifier.addData({ image: canvas }, { label: classes[i].name });
clearCanvas();
updateDropdown();
}
//Update the dropdown classifications options to reflect options and number of entries
function updateDropdown() {
let selected = parseInt($('#classes').val());
$("#classes").empty();
let dropdown = document.getElementById('classes').options;
for(let i=0; i<classes.length; i++){
dropdown.add( new Option(classes[i].name + " ("+classes[i].data.length + " doodles)",
i.toString()));
}
dropdown.selectedIndex = selected;
}
function draw() {
if (mouseIsPressed) {
strokeWeight(16);
line(mouseX, mouseY, pmouseX, pmouseY);
}
}
// Train our classifier
function trainModel() {
AB.msg("<h3>Training....</h3>", 2)
for (let i = 0; i < classes.length; i++) {
let lbl = classes[i].name;
for(let j=0; j < classes[i].data.length; j++) {
console.log(classes[i].data[j]);
doodleClassifier.addData({ image:classes[i].data[j] }, { label: lbl });
}
}
doodleClassifier.normalizeData();
doodleClassifier.train({ epochs: 50 }, finishedTraining);
}
//Once training is finished, start classifying canvas
function finishedTraining() {
AB.msg("<h3>Training Done</h3>", 2);
loadPixels();
doodleClassifier.classify(Array.prototype.slice.call(pixels),gotResults);
}
//create and download the dataset in json format
function exportJson() {
var a = document.createElement("a");
var file = new Blob([JSON.stringify(classes)], {type: 'text/plain'});
a.href = URL.createObjectURL(file);
a.download = 'dataset.json';
a.click();
}
//upload and parse a json file as dataset
function importJson() {
var files = document.getElementById('selectFiles').files;
console.log(files);
if (files.length <= 0) {
return false;
}
var fr = new FileReader();
fr.onload = function(e) {
console.log(e);
var result = JSON.parse(e.target.result);
classes = result;
updateDropdown();
}
fr.readAsText(files.item(0));
}
//load and parse the json file at the given address
function loadInJson(choice) {
$.getJSON(choice, function(data) {
classes = data;
updateDropdown();
});
}