var layer_defs, net, trainer;
let trainrun = 1;
let train_index = 0;
let total_trains = 0;
let autoTestsAt = -1; //5000; //After this number of trains have been run, output some stats and auto run the doodle tests.
var xLossWindow = new ErrorStore(100);
var wLossWindow = new ErrorStore(100);
let testrun = 1;
let test_index = 0;
let total_tests = 0;
let total_correct = 0;
const PIXELS = 28; // images in data set are tiny
const PIXELSSQUARED = PIXELS * PIXELS;
const ZOOMFACTOR = 7;
const ZOOMPIXELS = ZOOMFACTOR * PIXELS;
const canvaswidth = ( PIXELS + ZOOMPIXELS ) + 50;
const canvasheight = ( ZOOMPIXELS * 4 ) + 150;
const DOODLE_THICK = 18; // thickness of doodle lines
const DOODLE_BLUR = 3; // blur factor applied to doodles
// should we train every timestep or not
let do_training = true;
// how many to train and test per timestep
const TRAINPERSTEP = 30;
var jsLoaded = false;
// images in LHS:
let doodle;
let doodle_exists = false;
let mousedrag = false; // are we in the middle of a mouse drag drawing?
var doodle_inputs;
var lossChart;
var testChart;
var showGraphs = true; //Turn them off for performance reasons.
const letters = ["A","B","C","D","E","F","G","H","I","J","K","L","M","N","O","P","Q","R","S","T","U","V","W","X","Y","Z"];
// make run header bigger
AB.headerCSS ( { "max-height": "95vh" } );
//--- start of AB.msgs structure: ---------------------------------------------------------
// We output a serious of AB.msgs to put data at various places in the run header
var thehtml;
// 1 Doodle header
thehtml = "<hr> <h2> 1. Doodle </h2> " +
" Draw your doodle in top LHS. <button onclick='wipeDoodle();' class='normbutton' >Clear doodle</button><br> ";
AB.msg ( thehtml, 1 );
thehtml = "<table> <tr><th></th><th>Letter</th><th>Confidence</th><th></th><th>Letter</th><th>Confidence</th></tr>";
thehtml = thehtml + "<tr><td>First Guess</td><td id='guess_letter_1'></td><td id='guess_percent_1'></td><td>Fourth Guess</td><td id='guess_letter_4'></td><td id='guess_percent_4'></td></tr>";
thehtml = thehtml + "<tr><td>Second Guess</td><td id='guess_letter_2'></td><td id='guess_percent_2'></td><td>Fifth Guess</td><td id='guess_letter_5'></td><td id='guess_percent_5'></td></tr>";
thehtml = thehtml + "<tr><td>Third Guess</td><td id='guess_letter_3'></td><td id='guess_percent_3'></td><td>Sixth Guess</td><td id='guess_letter_6'></td><td id='guess_percent_6'></td></tr></table>";
thehtml = thehtml + "<hr>Draw a doodle. Type the correct answer in the box and then click 'Save doodle'<input type='text' id='doodlePred' size='1' value='A'>"+
"<button onclick='saveDoodle();' class='normbutton' >Save doodle</button> <button onclick='loadDoodlesAndTest();' class='normbutton' >Load doodles and test</button>"+
" Results: <div id='doodles_Results' style='color:darkgreen'>results will be displayed here</div>";
AB.msg ( thehtml, 2 );
thehtml = "<hr> <h2> 2. Training & Testing </h2> " +
" <button id='pauseBtn' onclick='pauseTraining()' class='normbutton' >Pause training</button> <br> ";
thehtml = thehtml + "<table> <tr><th><div style='width: 150px;'>Training</div></th><th>Testing (Last 100)</th><th>Testing (Best 100)</th></tr>";
thehtml = thehtml + "<tr><td id = 'training_run'></td><td id='testing_last'></td><td id='testing_best'></td></tr>";
thehtml = thehtml + "<tr><td id = 'training_num'></td><td></td><td></td></tr></table>";
AB.msg ( thehtml, 3 );
thehtml = "<hr> A set from the test data. Useful for comparing results.<br> " +
" <button onclick='controlTest();' class='normbutton' >Run control test</button> <br> ";
thehtml = thehtml + "<table> <tr>";
var rowHtml = "";
for (var l=0; l<26; l++){
thehtml = thehtml + "<th>" + letters[l] + "</th>";
rowHtml = rowHtml + "<td id='ctrlTest_"+ l + "'>0</td>"
}
thehtml = thehtml + "<th>Result</th>"
rowHtml = rowHtml + "<td id='ctrlTest_Result'>0%</td>"
thehtml = thehtml + "</tr><tr>" + rowHtml + "</tr></table>";
AB.msg ( thehtml, 4 );
// 4 variable training data
thehtml = "<hr> <h2> 3. Save/Load Snapshot </h2> Save current state to a snapshot. Load a previously saved snapshot <br> " +
" <button onclick='saveSnapshot();' class='normbutton' >Save snapshot</button> <button onclick='loadSnapshot();' class='normbutton' >Load snapshot</button><br> "+
"<br> Download a snapshot - 60k trainings <br>" +
" <button onclick='downloadSnapshot();' class='normbutton' >Download snapshot</button>"
AB.msg ( thehtml, 5 );
//--- end of AB.msgs structure: ---------------------------------------------------------
//---- normal P5 code -------------------------------------------------------
//Training data was split so that it could be uploaded to Ancient Brain.
var trainingData_1;
var trainingData_2;
var testingData;
var controlTestData;
function preload()
{
//Load the JSON - I am using the 3rd batch as testing data.
loadJSON('/uploads/paul79/training_images_letters_1.json', function(data1){
console.log("JSON Data 1 finished reading.");
console.log("JSON Data 1 size: " + data1.size);
trainingData_1 = data1;
loadJSON('/uploads/paul79/training_images_letters_2.json', function(data2){
console.log("JSON Data 2 finished reading.");
console.log("JSON Data 2 size: " + data2.size);
trainingData_2 = data2;
loadJSON('/uploads/paul79/training_images_letters_3.json', function(data3){
console.log("JSON Data 3 finished reading.");
console.log("JSON Data 3 size: " + data3.size);
testingData = data3;
loadJSON('/uploads/paul79/control_set_letters_26.json', function(data4){
console.log("JSON Data 4 finished reading.");
console.log("JSON Data 4 size: " + data4.size);
controlTestData = data4;
AB.removeLoading(); // if no loading screen exists, this does nothing
});
});
});
});
}
//Download the necessary JS and setup the CNN
function setup()
{
createCanvas ( canvaswidth, canvasheight );
doodle = createGraphics ( ZOOMPIXELS, ZOOMPIXELS ); // doodle on larger canvas
doodle.pixelDensity(1);
$.getScript ( "/uploads/paul79/convnet.js", function() //See https://cs.stanford.edu/people/karpathy/convnetjs/docs.html
{
$.getScript ( "/uploads/paul79/LineChart.js", function()
{
console.log ("All JS loaded");
jsLoaded = true;
layer_defs = [];
//Input layer - images have been cropped from 28x28 to 24x24
// layer_defs.push({type:'input', out_sx:24, out_sy:24, out_depth:1});
layer_defs.push({type:'input', out_sx:28, out_sy:28, out_depth:1});
//First convolutional layer - 5x5, 8 filters
// layer_defs.push({type:'conv', sx:5, filters:4, stride:1, pad:2, activation:'relu'});
layer_defs.push({type:'conv', sx:5, filters:8, stride:1, pad:2, activation:'relu'});
// layer_defs.push({type:'conv', sx:5, filters:16, stride:1, pad:2, activation:'relu'});
//First pool layer
layer_defs.push({type:'pool', sx:2, stride:2});
//Second convolutional layer - 5x5, 16 filters
// layer_defs.push({type:'conv', sx:5, filters:8, stride:1, pad:2, activation:'relu'});
layer_defs.push({type:'conv', sx:5, filters:16, stride:1, pad:2, activation:'relu'});
// layer_defs.push({type:'conv', sx:5, filters:32, stride:1, pad:2, activation:'relu'});
//Second pool layer
layer_defs.push({type:'pool', sx:3, stride:3});
//Output layer - one per character
layer_defs.push({type:'softmax', num_classes:26});
net = new convnetjs.Net();
net.makeLayers(layer_defs);
trainer = new convnetjs.SGDTrainer(net, {method:'adadelta', batch_size:20, l2_decay:0.001});
});
});
}
function getImage ( img ) // make a P5 image object from a raw data array
{
let theimage = createImage (PIXELS, PIXELS); // make blank image, then populate it
theimage.loadPixels();
for (let i = 0; i < PIXELSSQUARED ; i++)
{
let bright = img[i];
let index = i * 4;
theimage.pixels[index + 0] = bright;
theimage.pixels[index + 1] = bright;
theimage.pixels[index + 2] = bright;
theimage.pixels[index + 3] = 255;
}
theimage.updatePixels();
return theimage;
}
function getInputs ( img ) // convert img array into normalised input array
{
var x = new convnetjs.Vol(28,28,1,0.0);
for (let i = 0; i < PIXELSSQUARED ; i++)
{
let bright = img[i];
x.w[i] = bright / 255; // normalise to 0 to 1
}
// x = convnetjs.augment(x, 24);
return ( x );
}
function train (show) // train the network with a single exemplar, from global var "train_index", show visual on or off
{
var trainingData = trainrun % 2 == 1 ? trainingData_1 : trainingData_2;
let img = trainingData.images[train_index].img;
let label = trainingData.images[train_index].label;
// optional - show visual of the image
if (show)
{
var theimage = getImage ( img ); // get image from data array
image ( theimage, 0, ZOOMPIXELS+50, ZOOMPIXELS, ZOOMPIXELS ); // magnified
image ( theimage, ZOOMPIXELS+50, ZOOMPIXELS+50, PIXELS, PIXELS ); // original
}
// set up the inputs
let x = getInputs ( img ); // get inputs from data array
// train on it with network
var stats = trainer.train(x, label);
var lossx = stats.cost_loss;
var lossw = stats.l2_decay_loss;
xLossWindow.add(lossx);
wLossWindow.add(lossw);
document.getElementById("training_run").innerHTML = "Train Run: " + trainrun;
document.getElementById("training_num").innerHTML = "Train Tests: " + train_index;
train_index++;
if ( train_index >= trainingData.images.length )
{
train_index = 0;
console.log( "finished trainrun: " + trainrun );
trainrun++;
}
// log progress to graph, (full loss)
if(showGraphs && total_trains % 200 === 0) {
var xa = xLossWindow.get_average();
var xw = wLossWindow.get_average();
if(xa >= 0 && xw >= 0) { // if they are -1 it means not enough data was accumulated yet for estimates
createLossChart(xa + xw, total_trains);
}
}
total_trains++;
}
function pauseTraining(){
if (do_training){
do_training = false;
document.getElementById("pauseBtn").innerHTML = "Resume Training";
} else {
do_training = true;
document.getElementById("pauseBtn").innerHTML = "Pause Training";
}
}
var testResults = new TestResultsSquash(100, 10);
var testingBest = 0;
function test(){
let img = testingData.images[test_index].img;
let label = testingData.images[test_index].label;
// set up the inputs
let x = getInputs ( img ); // get inputs from data array
// feed forward and get prediction
var a = net.forward(x);
var y = net.getPrediction();
test_index++;
total_tests++;
if (y == label){
total_correct++;
}
if (showGraphs){
testResults.add (y == label ? 1 : 0);
// log progress to graph, (full loss)
if(total_tests % 10 === 0) {
createTestChart();
}
}
if (test_index % 100 === 0){
res = testResults.getLast();
if (res != -1){
res = Math.round(res);
document.getElementById("testing_last").innerHTML = res + "%";
if (res > testingBest){
testingBest = res;
document.getElementById("testing_best").innerHTML = res + "%";
}
}
}
if (test_index >= testingData.images.length){
test_index = 0;
}
}
function controlTest(){
var ctrlTestRes = 0;
for (var i=0; i<26; i++){
let img = controlTestData.images[i].img;
let label = controlTestData.images[i].label;
// set up the inputs
let x = getInputs ( img ); // get inputs from data array
// feed forward to make prediction
var a = net.forward(x);
var y = net.getPrediction();
var res = 0;
if (y == label){
res = 1;
ctrlTestRes++;
}
document.getElementById("ctrlTest_"+i).innerHTML = res;
document.getElementById("ctrlTest_Result").innerHTML = Math.round((ctrlTestRes/26) * 100) + "%";
}
}
function draw()
{
// check if libraries and data loaded yet:
if ( jsLoaded === false || typeof testingData == 'undefined' ) return;
background ('black');
if ( do_training )
{
// do some training per step
for (let i = 0; i < TRAINPERSTEP; i++)
{
if (i == 0)
train(true); // show only one per step - still flashes by
else
train(false);
test();
if (autoTestsAt !== -1 && total_trains % autoTestsAt === 0){
//Do 10 doodle tests and average them (will only work if there are saved doodles in localStorage!)
let doodleScores = [];
for (let i=0; i<10; i++){
doodleScores.push (loadDoodlesAndTest());
}
let average = doodleScores.reduce((a, b) => a + b, 0) / doodleScores.length;
console.log("Auto tests after " + total_trains + " trains:");
console.log("Average Doodle percentage: " + average);
let last500Tests = testResults.getLastN(5);
average = last500Tests.reduce((a, b) => a + b, 0) / last500Tests.length;
console.log("Average Test percentage: " + average);
}
}
}
if (typeof lossChart !== 'undefined'){
lossChart.show();
}
if (typeof testChart !== 'undefined'){
testChart.show();
}
if ( doodle_exists )
{
drawDoodle();
guessDoodle();
}
// detect doodle drawing
// (restriction) the following assumes doodle starts at 0,0
if ( mouseIsPressed ) // gets called when we click buttons, as well as if in doodle corner
{
var MAX = ZOOMPIXELS + 20; // can draw up to this pixels in corner
if ( (mouseX < MAX) && (mouseY < MAX) && (pmouseX < MAX) && (pmouseY < MAX) )
{
doodle_exists = true;
mousedrag = true; // start a mouse drag
doodle.stroke('white');
doodle.strokeWeight( DOODLE_THICK );
doodle.line(mouseX, mouseY, pmouseX, pmouseY);
}
}
else
{
// are we exiting a drawing
if ( mousedrag )
{
mousedrag = false;
// console.log ("Exiting draw. Now blurring.");
doodle.filter (BLUR, DOODLE_BLUR); // just blur once
// console.log (doodle);
}
}
}
//--- doodle -------------------------------------------------------------
function drawDoodle()
{
let theimage;
// doodle is createGraphics not createImage
theimage = doodle.get();
image ( theimage, 0, 0, ZOOMPIXELS, ZOOMPIXELS ); // original
image ( theimage, ZOOMPIXELS+50, 0, PIXELS, PIXELS ); // shrunk
}
function guessDoodle()
{
let img;
// doodle is createGraphics not createImage
img = doodle.get();
img.resize ( PIXELS, PIXELS );
img.loadPixels();
// set up inputs
var x = new convnetjs.Vol(28,28,1,0.0);
for (let i = 0; i < PIXELSSQUARED ; i++)
{
let bright = img[i];
x.w[i] = img.pixels[i * 4] / 255;
}
// x = convnetjs.augment(x, 24);
// feed forward
var a = net.forward(x); //probability of each output is given in a.w
var preds = [];
for(var k=0;k<a.w.length;k++) {
preds.push({k:k,p:a.w[k]});
}
//Sort the predictions
preds.sort(function(a,b){return a.p<b.p ? 1:-1;});
for (var wi=0; wi<6; wi++){
var roundedPC = Math.round(preds[wi].p * 100);
document.getElementById("guess_percent_" + (wi+1)).innerHTML = roundedPC + "%";
document.getElementById("guess_letter_" + (wi+1)).innerHTML = letters[preds[wi].k];
}
}
function wipeDoodle()
{
doodle_exists = false;
doodle.background('black');
}
//Save a doodle to local storage
//Extract the pixels and store in JSON object along with the inputted label
function saveDoodle(){
// doodle is createGraphics not createImage
let img = doodle.get();
img.resize ( PIXELS, PIXELS );
img.loadPixels();
let label = document.getElementById("doodlePred").value;
var jsonimage = {'pixels': [...img.pixels], 'label':label};
var doodleSaveList;
if (localStorage.getItem('doodles') === null){
doodleSaveList = {'doodles':[]};
} else {
doodleSaveList = JSON.parse(localStorage.getItem('doodles'));
}
doodleSaveList.doodles.push(jsonimage);
localStorage.setItem('doodles', JSON.stringify(doodleSaveList));
}
//Load all the saved doodles and test each of them
function loadDoodlesAndTest(){
var doodleSaveList;
if (localStorage.getItem('doodles') === null){
alert("Save doodle(s), then you can load.");
} else {
doodleSaveList = JSON.parse(localStorage.getItem('doodles'));
}
numCorrect = 0;
for (let i=0; i<doodleSaveList.doodles.length; i++){
var img = doodleSaveList.doodles[i];
var x = new convnetjs.Vol(28,28,1,0.0);
for (let i = 0; i < PIXELSSQUARED ; i++)
{
x.w[i] = img.pixels[i * 4] / 255;
}
// feed forward to make prediction
var a = net.forward(x);
var y = net.getPrediction();
if (y == letters.indexOf(img.label)){
numCorrect++;
}
}
let pc = Math.round(numCorrect/doodleSaveList.doodles.length*100);
document.getElementById("doodles_Results").innerHTML = "Doodles tested: " + doodleSaveList.doodles.length + " Correct: " + pc + "%";
return pc;
}
//Show a chart of training error
var graphData = {xData:[], yData:[]};
function createLossChart(loss, trainNumber){
graphData.xData.push(trainNumber);
graphData.yData.push(loss);
data = [];
colors = ['#ff0000']
lineLabels = ["Loss (training)"]
data.push([]);
for(let j = 0; j < graphData.xData.length; j++) {
data[0].push(createVector(graphData.xData[j], graphData.yData[j]));
}
lossChart = new LineChart(data, colors, lineLabels, 250, 250, 5, canvasheight-250, [min(graphData.xData.flat()), max(graphData.xData.flat())], [0, max(graphData.yData.flat())]);
}
//Show a chart of testing success
function createTestChart(){
var results = testResults.getData();
data = [];
colors = ['#0000ff']
lineLabels = ["Hidden Tests"]
data.push([]);
for(let j = 0; j < results.xData.length; j++) {
data[0].push(createVector(results.xData[j], results.yData[j]));
}
testChart = new LineChart(data, colors, lineLabels, 250, 250, 5, canvasheight-500, [min(results.xData.flat()), max(results.xData.flat())], [0, 100]);
}
function ErrorStore (size, minsize){
this.v = [];
this.size = typeof(size)==='undefined' ? 100 : size;
this.minsize = typeof(minsize)==='undefined' ? 20 : minsize;
this.sum = 0;
this.add = function(x){
this.v.push(x);
this.sum += x;
if(this.v.length>this.size) {
var xold = this.v.shift();
this.sum -= xold;
}
}
this.get_average = function(){
if(this.v.length < this.minsize) return -1;
else return this.sum/this.v.length;
}
this.reset = function(){
this.v = [];
this.sum = 0;
}
}
function TestResultsSquash (size, minsize){
this.vals = []; //will hold 0/1 for the last 100 tests
this.size = typeof(size)==='undefined' ? 100 : size;
this.minsize = typeof(minsize)==='undefined' ? 10 : minsize;
this.hundredTest = [];
this.sum = 0;
this.add = function(res){
this.vals.push(res);
this.sum += res;
if (this.vals.length >= this.size){
this.hundredTest.push((this.sum/this.size)*100);
this.vals = [];
this.sum = 0;
}
}
this.getData = function (){
var yData = [...this.hundredTest];
var xData = [];
for (var i=0; i<yData.length; i++){
xData.push(this.size * (i+1));
}
yData.push ((this.sum/this.vals.length)*100);
xData.push (xData[xData.length-1] + this.vals.length)
return {"xData": xData, "yData": yData};
}
this.getLast = function (){
if (this.hundredTest.length > 0){
return this.hundredTest[this.hundredTest.length-1];
} else {
return -1;
}
}
this.getLastN = function (n){
return this.hundredTest.slice(-n, this.hundredTest.length);
}
this.reset = function(){
this.vals = [];
this.hundredTest = [];
this.sum = 0;
}
}
function saveSnapshot(){
var snap = JSON.stringify(net.toJSON());
localStorage.setItem('cnnSnapshot', snap);
console.log("Save snaphot " + snap);
}
function loadSnapshot(){
var snap = JSON.parse(localStorage.getItem('cnnSnapshot'));
console.log("Loading snaphot " + snap);
resetNetFromSnapshot(snap);
}
function downloadSnapshot(){
loadJSON('/uploads/paul79/snapshot.json', function(snapshot){
var snap = snapshot;
console.log("Downloaded snapshot " + snap);
resetNetFromSnapshot(snap);
});
}
//Reset the network state from a snapshot
function resetNetFromSnapshot(snap){
net = new convnetjs.Net();
net.fromJSON(snap);
xLossWindow.reset();
wLossWindow.reset();
trainrun = 1;
train_index = 0;
testrun = 1;
test_index = 0;
total_tests = 0;
total_correct = 0;
testResults.reset();
graphData = {xData:[], yData:[]};
}