/*TensorFlow.js Demo*/
/*This demo is a mod / port of a tutorial by codingthesmartway.com which can be found at the following link:
https://codingthesmartway.com/tensorflow-js-crash-course-machine-learning-for-the-web-handwriting-recognition/
Which itself is a condensed version of a TensorFlow tutorial found here:
https://www.tensorflow.org/js/tutorials/training/handwritten_digit_cnn
When run, the MNIST data is loaded, a TensorFlow convolutional neural network is created and trained.
This port is a simpler version of the TensorFlow tutorial and is much easier to digest.
*/
document.write ( `
<style>
.prediction-canvas{
width: 100px;
margin: 20px;
}
.prediction-div{
display: inline-block;
margin: 10px;
}
</style>
<div class='container' style='padding-top: 20px'>
<div class='card'>
<div class='card-header'>
<strong>TensorFlow.js Demo - Handwriting Recognition</strong>
</div>
<div class='card-body'>
<div class='card'>
<div class='card-body'>
<h5 class='card-title'>Log Output:</h5>
<div id='log'></div>
</div>
</div>
<div class='card'>
<div class='card-body'>
<h5 class='card-title'>Predict</h5>
<button type='button' class='btn btn-primary' id='selectTestDataButton' disabled>Please wait until model is ready ...</button>
<div id='predictionResult'></div>
</div>
</div>
</div>
</div>
</div>
` );
// A function to output progress to the UI
function createLogEntry(entry) {
document.getElementById('log').innerHTML += '<br>' + entry;
}
// Load our external libs and call back this function when finished.
function loadTensorFlowLibraries() {
return $.when(
$.getScript( "https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.min.js" ),
$.getScript( "https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis@1.0.2/dist/tfjs-vis.umd.min.js" ),
$.getScript( "uploads/michaelryan/data.js")
);
}
// Create our NN model
function createModel() {
console.log('Create model ...');
var model = tf.sequential();
console.log('Model created');
console.log('Add layers ...');
model.add(tf.layers.conv2d({
inputShape: [28, 28, 1],
kernelSize: 5,
filters: 8,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
}));
model.add(tf.layers.maxPooling2d({
poolSize: [2,2],
strides: [2,2]
}));
model.add(tf.layers.conv2d({
kernelSize: 5,
filters: 16,
strides: 1,
activation: 'relu',
kernelInitializer: 'VarianceScaling'
}));
model.add(tf.layers.maxPooling2d({
poolSize: [2,2],
strides: [2,2]
}));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({
units: 10,
kernelInitializer: 'VarianceScaling',
activation: 'softmax'
}));
console.log('Layers created');
console.log('Start compiling ...');
model.compile({
optimizer: tf.train.sgd(0.15),
loss: 'categoricalCrossentropy'
});
console.log('Compiled');
return model;
}
// Load the MNIST Data
async function load() {
console.log('Loading MNIST data ...');
var data = new MnistData();
await data.load();
console.log('Data loaded successfully');
return data;
}
// Train the NN
async function train(data, model) {
const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
const container = {
name: 'Model Training', styles: { height: '1000px' }
};
const BATCH_SIZE = 512;
const TRAIN_DATA_SIZE = 5500;
const TEST_DATA_SIZE = 1000;
const [trainXs, trainYs] = tf.tidy(() => {
const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
return [
d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),
d.labels
];
});
const [testXs, testYs] = tf.tidy(() => {
const d = data.nextTestBatch(TEST_DATA_SIZE);
return [
d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]),
d.labels
];
});
return model.fit(trainXs, trainYs, {
batchSize: BATCH_SIZE,
validationData: [testXs, testYs],
epochs: 10,
shuffle: true
});
}
// Predict function
async function predict(batch, model) {
console.log("Predict...");
tf.tidy(() => {
const input_value = Array.from(batch.labels.argMax(1).dataSync());
const div = document.createElement('div');
div.className = 'prediction-div';
// Input is an object of the following format
// const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
const output = model.predict(batch.xs.reshape([-1, 28, 28, 1]));
const prediction_value = Array.from(output.argMax(1).dataSync());
const image = batch.xs.slice([0, 0], [1, batch.xs.shape[1]]);
const canvas = document.createElement('canvas');
canvas.className = 'prediction-canvas';
draw(image.flatten(), canvas);
const label = document.createElement('div');
label.innerHTML = 'Original Value: ' + input_value;
label.innerHTML += '<br>Prediction Value: ' + prediction_value;
if (prediction_value - input_value == 0) {
label.innerHTML += '<br>Value recognized successfully';
} else {
label.innerHTML += '<br>Recognition failed!'
}
div.appendChild(canvas);
div.appendChild(label);
document.getElementById('predictionResult').innerHTML = '';
document.getElementById('predictionResult').appendChild(div);
});
}
// Draw onto canvas
function draw(image, canvas) {
const [width, height] = [28, 28];
canvas.width = width;
canvas.height = height;
const ctx = canvas.getContext('2d');
const imageData = new ImageData(width, height);
const data = image.dataSync();
for (let i = 0; i < height * width; ++i) {
const j = i * 4;
imageData.data[j + 0] = data[i] * 255;
imageData.data[j + 1] = data[i] * 255;
imageData.data[j + 2] = data[i] * 255;
imageData.data[j + 3] = 255;
}
ctx.putImageData(imageData, 0, 0);
}
//---- normal P5 code -------------------------------------------------------
async function setup()
{
// Load the TF libs, then on complete - create the model.
createLogEntry("Load external scripts...");
await loadTensorFlowLibraries();
createLogEntry("Scripts loaded...");
createLogEntry("Create convolutional neural network...");
var model = createModel();
//Load an existing model
//const model = await tf.loadLayersModel('uploads/michaelryan/my-model.json');
createLogEntry("Neural network created...");
createLogEntry("Load MNIST Data...");
var data = await load();
createLogEntry("MNIST data loaded...");
createLogEntry("Train neural network...");
// Comment out if loading an existing, trained model
await train(data, model);
createLogEntry("Neural network trained...");
document.getElementById('selectTestDataButton').disabled = false;
document.getElementById('selectTestDataButton').innerText = "Randomly Select Test Data And Predict";
// Attach button listener
document.getElementById('selectTestDataButton').addEventListener('click', async (el,ev) => {
const batch = data.nextTestBatch(1);
await predict(batch, model);
});
// Save a model
await model.save('downloads://my-model');
}
setup();