// template is blank
// define your own World here
const html = '<div class="digit-demo-container">' +
'<h3 id="digit-recognizer-live-demo">Digit Recognizer using TensorFlow.js Demo</h3>' +
'<div class="flex-two" style="margin-top: 20px;">' +
'<button id="clear_canvas" class="material-button-pink" onclick="clearCanvas(this.id)">Clear</button>' +
'<select id="select_model">' +
'<option>MLP</option>' +
'<option>CNN</option>' +
'</select>' +
'<button id="predict_canvas" class="material-button-pink" onclick="predict(this.id)">Predict</button>' +
'</div>' +
'<div class="flex-two">' +
'<div id="canvas_box_wrapper" class="canvas-box-wrapper">' +
'<div id="canvas_box" class="canvas-box">' +
'<button id="clear_canvas" class="material-button-pink" onclick="clearCanvas(this.id)">Clear</button>' +
'</div>' +
'</div>' +
'<div id="result_box">' +
'<canvas id="chart_box" width="100" height="100"></canvas>' +
'</div>' +
'</div>' +
'</div>'
document.body.innerHTML += html;
// function clearCanvas(event) {
// console.log(event);
// }
const tfurl = "https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"
const chartjsurl = "https://cdnjs.cloudflare.com/ajax/libs/Chart.js/2.4.0/Chart.min.js";
$.getScript ( chartjsurl, function() {
console.log('chartjs loaded');
});
$.getScript ( tfurl, function() {
console.log('tfjs loaded');
});
// $.getScript ( tfurl, function()
// {
// $.getScript ( "https://code.jquery.com/jquery-2.1.1.min.js", function()
// {
// $.getScript ( "/uploads/codingtrain/mnist.js", function()
// {
// console.log ("All JS loaded");
// });
// });
// });
// }
// GLOBAL variables
var modelName = "digitrecognizermlp";
let model;
// canvas related variables
// you can change these variables
var canvasWidth = 150;
var canvasHeight = 150;
var canvasStrokeStyle = "white";
var canvasLineJoin = "round";
var canvasLineWidth = 12;
var canvasBackgroundColor = "black";
var canvasId = "canvas";
// variables to hold coordinates and dragging boolean
var clickX = new Array();
var clickY = new Array();
var clickD = new Array();
var drawing;
document.getElementById('chart_box').innerHTML = "";
document.getElementById('chart_box').style.display = "none";
//---------------------
// Create canvas
//---------------------
var canvasBox = document.getElementById('canvas_box');
var canvas = document.createElement("canvas");
canvas.setAttribute("width", canvasWidth);
canvas.setAttribute("height", canvasHeight);
canvas.setAttribute("id", canvasId);
canvas.style.backgroundColor = canvasBackgroundColor;
canvasBox.appendChild(canvas);
if(typeof G_vmlCanvasManager != 'undefined') {
canvas = G_vmlCanvasManager.initElement(canvas);
}
ctx = canvas.getContext("2d");
//---------------------
// MOUSE DOWN function
//---------------------
$("#canvas").mousedown(function(e) {
var mouseX = e.pageX - this.offsetLeft;
var mouseY = e.pageY - this.offsetTop;
drawing = true;
addUserGesture(mouseX, mouseY);
drawOnCanvas();
});
//---------------------
// TOUCH START function
//---------------------
canvas.addEventListener("touchstart", function (e) {
if (e.target == canvas) {
e.preventDefault();
}
var rect = canvas.getBoundingClientRect();
var touch = e.touches[0];
var mouseX = touch.clientX - rect.left;
var mouseY = touch.clientY - rect.top;
drawing = true;
addUserGesture(mouseX, mouseY);
drawOnCanvas();
}, false);
// MOUSE MOVE function
//---------------------
$("#canvas").mousemove(function(e) {
if(drawing) {
var mouseX = e.pageX - this.offsetLeft;
var mouseY = e.pageY - this.offsetTop;
addUserGesture(mouseX, mouseY, true);
drawOnCanvas();
}
});
//---------------------
// TOUCH MOVE function
//---------------------
canvas.addEventListener("touchmove", function (e) {
if (e.target == canvas) {
e.preventDefault();
}
if(drawing) {
var rect = canvas.getBoundingClientRect();
var touch = e.touches[0];
var mouseX = touch.clientX - rect.left;
var mouseY = touch.clientY - rect.top;
addUserGesture(mouseX, mouseY, true);
drawOnCanvas();
}
}, false);
//---------------------
// MOUSE UP function
//---------------------
$("#canvas").mouseup(function(e) {
drawing = false;
});
//---------------------
// TOUCH END function
//---------------------
canvas.addEventListener("touchend", function (e) {
if (e.target == canvas) {
e.preventDefault();
}
drawing = false;
}, false);
//----------------------
// MOUSE LEAVE function
//----------------------
$("#canvas").mouseleave(function(e) {
drawing = false;
});
//---------------------
// TOUCH LEAVE function
//---------------------
canvas.addEventListener("touchleave", function (e) {
if (e.target == canvas) {
e.preventDefault();
}
drawing = false;
}, false);
//----------------------
// ADD CLICK function
//----------------------
function addUserGesture(x, y, dragging) {
clickX.push(x);
clickY.push(y);
clickD.push(dragging);
}
//----------------------
// RE DRAW function
//----------------------
function drawOnCanvas() {
ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height);
ctx.strokeStyle = canvasStrokeStyle;
ctx.lineJoin = canvasLineJoin;
ctx.lineWidth = canvasLineWidth;
for (var i = 0; i < clickX.length; i++) {
ctx.beginPath();
if(clickD[i] && i) {
ctx.moveTo(clickX[i-1], clickY[i-1]);
} else {
ctx.moveTo(clickX[i]-1, clickY[i]);
}
ctx.lineTo(clickX[i], clickY[i]);
ctx.closePath();
ctx.stroke();
}
}
//----------------------
// CLEAR CANVAS function
//----------------------
function clearCanvas(id) {
ctx.clearRect(0, 0, canvasWidth, canvasHeight);
clickX = new Array();
clickY = new Array();
clickD = new Array();
}
//-----------------------
// select model handler
//-----------------------
$("#select_model").change(function() {
var select_model = document.getElementById("select_model");
var select_option = select_model.options[select_model.selectedIndex].value;
if (select_option == "MLP") {
modelName = "digitrecognizermlp";
} else if (select_option == "CNN") {
modelName = "digitrecognizercnn";
} else {
modelName = "digitrecognizermlp";
}
loadModel(modelName);
});
//-------------------------------------
// loader for digitrecognizermlp model
//-------------------------------------
async function loadModel() {
console.log("model loading..");
// clear the model variable
model = undefined;
// load the model using a HTTPS request (where you have stored your model files)
model = await tf.loadLayersModel("uploads/lukescales/processed_model_v2.json");
console.log("model loaded..");
}
loadModel();
//-----------------------------------------------
// preprocess the canvas to be DNN friendly
//-----------------------------------------------
function preprocessCanvas(image, modelName) {
// if model is not available, send the tensor with expanded dimensions
if (modelName === undefined) {
alert("No model defined..")
}
// if model is digitrecognizermlp, perform all the preprocessing
else if (modelName === "digitrecognizermlp") {
// resize the input image to digitrecognizermlp's target size of (784, )
let tensor = tf.browser.fromPixels(image)
.resizeNearestNeighbor([28, 28])
.mean(2)
.toFloat()
.reshape([1 , 784]);
return tensor.div(255.0);
}
// if model is digitrecognizercnn, perform all the preprocessing
else if (modelName === "digitrecognizercnn") {
// resize the input image to digitrecognizermlp's target size of (1, 28, 28, 1)
let tensor = tf.browser.fromPixels(image)
.resizeNearestNeighbor([28, 28])
.mean(2)
.expandDims(2)
.expandDims()
.toFloat();
console.log(tensor.shape);
return tensor.div(255.0);
}
// else throw an error
else {
alert("Unknown model name..")
}
}
//----------------------------
// Bounding box for centering
//----------------------------
function boundingBox() {
var minX = Math.min.apply(Math, clickX) - 20;
var maxX = Math.max.apply(Math, clickX) + 20;
var minY = Math.min.apply(Math, clickY) - 20;
var maxY = Math.max.apply(Math, clickY) + 20;
var tempCanvas = document.createElement("canvas"),
tCtx = tempCanvas.getContext("2d");
tempCanvas.width = maxX - minX;
tempCanvas.height = maxY - minY;
tCtx.drawImage(canvas, minX, minY, maxX - minX, maxY - minY, 0, 0, maxX - minX, maxY - minY);
var imgBox = document.getElementById("canvas_image");
imgBox.src = tempCanvas.toDataURL();
return tempCanvas;
}
//--------------------------------------------
// predict function for digit recognizer mlp
//--------------------------------------------
async function predict() {
// get the user drawn region alone cropped
croppedCanvas = boundingBox();
// show the cropped image
document.getElementById("canvas_output").style.display = "block";
// preprocess canvas
let tensor = preprocessCanvas(croppedCanvas, modelName);
// make predictions on the preprocessed image tensor
let predictions = await model.predict(tensor).data();
// get the model's prediction results
let results = Array.from(predictions)
// display the predictions in chart
displayChart(results)
console.log(results);
}
//------------------------------
// Chart to display predictions
//------------------------------
var chart = "";
var firstTime = 0;
function loadChart(label, data, modelSelected) {
var context = document.getElementById('chart_box').getContext('2d');
chart = new Chart(context, {
// we are in need of a bar chart
type: 'bar',
// we feed in data dynamically using data variable
// that is passed as an argument to this function
data: {
labels: label,
datasets: [{
label: modelSelected + " prediction",
backgroundColor: '#f50057',
borderColor: 'rgb(255, 99, 132)',
data: data,
}]
},
// you can also play around with options for the
// chart if you find time!
options: {}
});
}
//----------------------------
// display chart with updated
// drawing from canvas
//----------------------------
function displayChart(data) {
var select_model = document.getElementById("select_model");
var select_option = select_model.options[select_model.selectedIndex].value;
label = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"];
if (firstTime == 0) {
loadChart(label, data, select_option);
firstTime = 1;
} else {
chart.destroy();
loadChart(label, data, select_option);
}
document.getElementById('chart_box').style.display = "block";
}