/* Identify handwritten doodles using an existing trained neural network generated from:
https://run.ancientbrain.com/run.php?world=7029631059
*/
// Display confidence with a colour
const colourMap = {
GREEN: "#00FF00",
YELLOW: "#FFFF00",
ORANGE: "#FFA500",
RED: "#B20000",
WHITE: "#FFFFFF"
}
// Clear the canvas
function clearCanvas()
{
// user has pressed space
var context = document.getElementById('doodleCanvas').getContext("2d");
context.clearRect(0, 0, context.canvas.width, context.canvas.height); // Clears the canvas
clickX = new Array();
clickY = new Array();
clickDrag = new Array();
context.fillStyle = "black";
context.fillRect(0, 0, context.canvas.width, context.canvas.height);
context.fill();
}
document.write(`
<!-- Bootstrap CSS -->
<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" integrity="sha384-ggOyR0iXCbMQv3Xipma34MD+dH/1fQ784/j6cY/iJTQUOhcWr7x9JvoRxT2MZw1T" crossorigin="anonymous">
<style>
table, td, th {
border: 1px solid black;
text-align: center;
}
td:hover {
background-color: lightblue;
}
.table {
border-collapse: collapse;
width: 100%;
}
.slider {
-webkit-appearance: none;
height: 25px;
background: #d3d3d3;
outline: none;
opacity: 0.7;
-webkit-transition: .2s;
transition: opacity .2s;
}
.slider:hover {
opacity: 1;
}
.slider::-webkit-slider-thumb {
-webkit-appearance: none;
appearance: none;
width: 25px;
height: 25px;
background: #4CAF50;
cursor: pointer;
}
.slider::-moz-range-thumb {
width: 25px;
height: 25px;
background: #4CAF50;
cursor: pointer;
}
</style>
<div style="text-align: center;">
<canvas id="doodleCanvas" width="280" height="280" style="border:1px solid black;"></canvas>
<div>
<button type="button" style="margin:5px" class="btn btn-secondary" onclick="clearCanvas()">Clear or click spacebar</button>
</div>
<div>
<canvas id="smallCanvas" width="28" height="28" style="height:100;width:100; border:1px solid black;" >
</div>
<p>Brush Size:</p>
<input type="range" min="1" max="100" value="25" class="slider" id="brushSize" onchange="updateBrushSize(this.value)">
<p>Brush Type:</p>
<input type="range" min="0" max="2" value="0" class="slider" id="brushType" onchange="updateBrushType(this.value)">
<div>
<h2>Select Correct Result To Retrain</h2>
<table class="table" id="predictionTable">
</table>
</div>
</div>
`);
var lineJoin = ['round', 'bevel', 'miter'];
var BRUSH_STYLE = 0;
var BRUSH_SIZE = 25;
const PIXELS = 28; // images in data set are tiny
const PIXELSSQUARED = PIXELS * PIXELS;
var models = [];
var inputs;
var clickX = [];
var clickY = [];
var clickDrag = []
var paint;
// Let the user change brush size
function updateBrushSize(newBrushSize)
{
BRUSH_SIZE = newBrushSize;
redraw();
makePredictions();
}
// Let the user change brush type
function updateBrushType(value)
{
BRUSH_STYLE = value;
redraw();
makePredictions();
}
// Load our external libs asyncronously
function loadTensorFlowLibraries()
{
return $.when(
$.getScript("https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.min.js"),
$.getScript("https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis@1.0.2/dist/tfjs-vis.umd.min.js"),
$.getScript("uploads/michaelryan/data.js")
);
}
// Return a colour based on a prediction value
function getColourForPredictionValue(value)
{
let colour;
if (value >= .80)
{
colour = colourMap["GREEN"];
}
else if (value >= .60 && value < .80)
{
colour = colourMap["YELLOW"];
}
else if (value >= .40 && value < .60)
{
colour = colourMap["ORANGE"];
}
else if (value >= .20 && value < .40)
{
colour = colourMap["RED"];
}
return colour;
}
// Fill in our prediction values into the table
function fillPredictionTable(modelIndex, predictions)
{
// The Ids of our TD elements are modelIndex + 1
modelIndex += 1;
let highestIndex = 0;
for (let i = 0; i < predictions.length; i++)
{
if (predictions[i] > predictions[highestIndex])
{
highestIndex = i;
}
// Reset colour
$(`#${modelIndex}_${i}`).css('background-color', colourMap["WHITE"]);
$(`#${modelIndex}_${i}`).html(predictions[i].toFixed(2));
}
let colour = getColourForPredictionValue(predictions[highestIndex]);
$(`#${modelIndex}_${highestIndex}`).css('background-color', colour);
}
/* Retrain the models based on the users choice */
async function selectCorrectValue(correctLabel)
{
console.log(correctLabel);
let labelArray = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
labelArray[correctLabel] = 1;
const imageTensor = tf.tensor2d(inputs, [1, 784]);
const labelTensor = tf.tensor2d(labelArray, [1, 10]);
// Need to train with at least 2 images so take our input
// and one of the training images from the data set
const [trainXs, trainYs] = tf.tidy(() =>
{
const d = data.nextTrainBatch(1);
return [
d.xs,
d.labels
];
});
// Get a training sample from the training data
let nextTrainArray = data.nextTrainArray();
// Set our input as the 2nd sample in the training array
nextTrainArray.batchImagesArray.set(inputs, IMAGE_SIZE);
nextTrainArray.batchLabelsArray.set(labelArray, NUM_CLASSES);
const inputImages = tf.tensor2d(nextTrainArray.batchImagesArray, [2, IMAGE_SIZE]);
const inputLabels = tf.tensor2d(nextTrainArray.batchLabelsArray, [2, NUM_CLASSES]);
console.log("Retraining models....");
// Retrain all our models
await retrainModels(inputImages, inputLabels)
console.log("Complete....");
// Output new predictions
makePredictions();
}
// Retrain the 3 models in parallel
async function retrainModels(inputImages, inputLabels)
{
const promises = models.map(async model =>
{
await model.fit(inputImages.reshape([2, 28, 28, 1]), inputLabels);
});
await Promise.all(promises)
}
// Draw our table of predictions for the models
function createPredictionTable()
{
var table = $('#predictionTable');
var modelSamples = "Model 500";
for (let i = 0; i <= models.length; i++)
{
var row = '<tr>'
for (let y = -1; y < 10; y++)
{
if (i === 0)
{
if (y === -1)
{
row += `<td></td>`
}
else
{
row += `<td onclick="selectCorrectValue(${y})">${y}</td>`
}
}
else
{
if (y === -1)
{
row += `<td>${modelSamples}</td>`
modelSamples += "0";
}
else
{
row += `<td id=${i}_${y}>0.00</td>`
}
}
}
row += '</tr>'
table.append(row);
}
}
// Run a prediction against our model
function makePredictions()
{
// Scale down our image
let smallContext = document.getElementById("smallCanvas")
.getContext("2d");
let img = smallContext.getImageData(0, 0, smallContext.canvas.width, smallContext.canvas.height);
// set up inputs
inputs = new Float32Array(PIXELSSQUARED);
for (let i = 0; i < PIXELSSQUARED; i++)
{
inputs[i] = img.data[i * 4] / 255;
}
currentTensor = tf.tensor2d(inputs, [1, PIXELSSQUARED]);
for (let i = 0; i < models.length; i++)
{
let output = models[i].predict(currentTensor.reshape([1, 28, 28, 1]));
fillPredictionTable(i, output.dataSync());
}
}
/*Canvas code from http://www.williammalone.com/articles/create-html5-canvas-javascript-drawing-app/ */
async function setupCanvas()
{
var context = document.getElementById('doodleCanvas').getContext("2d");
const smallContext = document.getElementById("smallCanvas")
.getContext("2d");
smallContext.scale(0.1, 0.1);
context.fillStyle = "black";
context.fillRect(0, 0, context.canvas.width, context.canvas.height);
context.fill();
// Draw prediction table
createPredictionTable();
$('#doodleCanvas').mousedown(function(e)
{
var mouseX = e.pageX - this.offsetLeft;
var mouseY = e.pageY - this.offsetTop;
paint = true;
addClick(e.pageX - this.offsetLeft, e.pageY - this.offsetTop);
redraw(context);
});
$('#doodleCanvas').mousemove(function(e)
{
if (paint)
{
addClick(e.pageX - this.offsetLeft, e.pageY - this.offsetTop, true);
redraw(context);
}
});
// When user lets go of mouse button in canvas, execute predictions
$('#doodleCanvas').mouseup(function(e)
{
paint = false;
makePredictions();
});
$('#doodleCanvas').mouseleave(function(e)
{
paint = false;
});
function addClick(x, y, dragging)
{
clickX.push(x);
clickY.push(y);
clickDrag.push(dragging);
}
$('body').keyup(function(e)
{
if (e.keyCode == 32)
{
clearCanvas();
}
});
}
// Update the Canvas when the user draws or changes the brush
function redraw(context)
{
var context = document.getElementById('doodleCanvas').getContext("2d");
context.clearRect(0, 0, context.canvas.width, context.canvas.height); // Clears the canvas
context.fillStyle = "black";
context.fillRect(0, 0, context.canvas.width, context.canvas.height);
context.fill();
context.strokeStyle = "white";
context.lineJoin = lineJoin[BRUSH_STYLE];
context.lineWidth = BRUSH_SIZE;
for (var i = 0; i < clickX.length; i++)
{
context.beginPath();
if (clickDrag[i] && i)
{
context.moveTo(clickX[i - 1], clickY[i - 1]);
}
else
{
context.moveTo(clickX[i] - 1, clickY[i]);
}
context.lineTo(clickX[i], clickY[i]);
context.closePath();
context.stroke();
}
context.filter = 'blur(2px)';
drawSmallCanvas(doodleCanvas);
}
// Draw the mini Canvas
function drawSmallCanvas(canvas)
{
const smallContext = document.getElementById("smallCanvas")
.getContext("2d");
smallContext.clearRect(0, 0, smallContext.canvas.width, smallContext.canvas.height); // Clears the canvas
smallContext.fillStyle = "black";
smallContext.fillRect(0, 0, smallContext.canvas.width, smallContext.canvas.height);
smallContext.fill();
smallContext.drawImage(document.getElementById('doodleCanvas'), 0, 0);
}
// Compile the neural network on loading, needed to re-train the model
function compileModels(models)
{
console.log("compiling...");
const optimizer = tf.train.adam();
for (let i = 0; i < models.length; i++)
{
models[i].compile(
{
optimizer: optimizer,
loss: 'categoricalCrossentropy',
metrics: ['accuracy'],
});
}
console.log("Neural Networks loaded...");
}
let data;
// Load our scripts and models
async function setup()
{
// Load the TF libs, then on complete - create the model.
console.log("Load external scripts...");
await loadTensorFlowLibraries();
console.log("Scripts loaded...");
//Load models
models.push(await tf.loadLayersModel('uploads/michaelryan/500-model.json'));
models.push(await tf.loadLayersModel('uploads/michaelryan/5000-model.json'));
models.push(await tf.loadLayersModel('uploads/michaelryan/50000-model.json'));
// Compile models to allow retraining
compileModels(models);
//load training data, needed to retrain models.
data = new MnistData();
await data.load();
setupCanvas();
console.log("Ready to go");
}
setup();