Code viewer for World: TensorFlow.js Demo #2
/*TensorFlow.js Demo*/

/*This demo is a mod / port of a tutorial by codingthesmartway.com which can be found at the following link:

    https://codingthesmartway.com/tensorflow-js-crash-course-machine-learning-for-the-web-handwriting-recognition/
    
Which itself is a condensed version of a TensorFlow tutorial found here:

    https://www.tensorflow.org/js/tutorials/training/handwritten_digit_cnn
    
When run, the MNIST data is loaded, a TensorFlow convolutional neural network is created and trained. 

This port is a simpler version of the TensorFlow tutorial and is much easier to digest.

*/

document.write ( `

<style>
	.prediction-canvas{
		width: 100px;
		margin: 20px;
	}
	.prediction-div{
		display: inline-block;
		margin: 10px;
	}
</style>

<div class='container' style='padding-top: 20px'>
	<div class='card'>
		<div class='card-header'>
			<strong>TensorFlow.js Demo - Handwriting Recognition</strong>
		</div>
		<div class='card-body'>
			<div class='card'>
				<div class='card-body'>
					<h5 class='card-title'>Log Output:</h5>
					<div id='log'></div>
				</div>
			</div>

			<div class='card'>
				<div class='card-body'>
					<h5 class='card-title'>Predict</h5>
					<button type='button' class='btn btn-primary' id='selectTestDataButton' disabled>Please wait until model is ready ...</button>
					<div id='predictionResult'></div>
				</div>
			</div>
		</div>
	</div>
</div>
` );

// A function to output progress to the UI
function createLogEntry(entry) {
    document.getElementById('log').innerHTML += '<br>' + entry;
}

// Load our external libs and call back this function when finished.
function loadTensorFlowLibraries() {
    return $.when(
        $.getScript( "https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.min.js" ),
        $.getScript( "https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis@1.0.2/dist/tfjs-vis.umd.min.js" ),
        $.getScript( "uploads/michaelryan/data.js")
    );
}

// Create our NN model
function createModel() {
    console.log('Create model ...');
    
    var model = tf.sequential();
    
    console.log('Model created');
    console.log('Add layers ...');
    
    model.add(tf.layers.conv2d({
        inputShape: [28, 28, 1],
        kernelSize: 5,
        filters: 8,
        strides: 1,
        activation: 'relu',
        kernelInitializer: 'varianceScaling'
    }));

    model.add(tf.layers.maxPooling2d({
        poolSize: [2,2],
        strides: [2,2]
    }));

    model.add(tf.layers.conv2d({
        kernelSize: 5,
        filters: 16,
        strides: 1,
        activation: 'relu',
        kernelInitializer: 'VarianceScaling'
    }));

    model.add(tf.layers.maxPooling2d({
        poolSize: [2,2],
        strides: [2,2]
    }));

    model.add(tf.layers.flatten());

    model.add(tf.layers.dense({
        units: 10,
        kernelInitializer: 'VarianceScaling',
        activation: 'softmax'
    }));

    console.log('Layers created');

    console.log('Start compiling ...');
    
    model.compile({
        optimizer: tf.train.sgd(0.15),
        loss: 'categoricalCrossentropy'
    });
    
    console.log('Compiled');
    
    return model;
}

// Load the MNIST Data
async function load() {
    console.log('Loading MNIST data ...');
    
    var data = new MnistData();
    
    await data.load();
    console.log('Data loaded successfully');
    
    return data;
}

// Train the NN
async function train(data, model) {
  const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
  const container = {
    name: 'Model Training', styles: { height: '1000px' }
  };

  const BATCH_SIZE = 512;
  const TRAIN_DATA_SIZE = 5500;
  const TEST_DATA_SIZE = 1000;

  const [trainXs, trainYs] = tf.tidy(() => {
    const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
    return [
      d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),
      d.labels
    ];
  });

  const [testXs, testYs] = tf.tidy(() => {
    const d = data.nextTestBatch(TEST_DATA_SIZE);
    return [
      d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]),
      d.labels
    ];
  });
  

  return model.fit(trainXs, trainYs, {
    batchSize: BATCH_SIZE,
    validationData: [testXs, testYs],
    epochs: 10,
    shuffle: true
  });
}

// Predict function
async function predict(batch, model) {
    console.log("Predict...");
    
    tf.tidy(() => {
        const input_value = Array.from(batch.labels.argMax(1).dataSync());

        const div = document.createElement('div');
        div.className = 'prediction-div';
        
        // Input is an object of the following format
        // const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
        const output = model.predict(batch.xs.reshape([-1, 28, 28, 1]));

        const prediction_value = Array.from(output.argMax(1).dataSync());
        
        const image = batch.xs.slice([0, 0], [1, batch.xs.shape[1]]);


        const canvas = document.createElement('canvas');
        canvas.className = 'prediction-canvas';
        draw(image.flatten(), canvas);
        
        const label = document.createElement('div');
        label.innerHTML = 'Original Value: ' + input_value;
        label.innerHTML += '<br>Prediction Value: ' + prediction_value;

        if (prediction_value - input_value == 0) {
            label.innerHTML += '<br>Value recognized successfully';
        } else {
            label.innerHTML += '<br>Recognition failed!'
        }

        div.appendChild(canvas);
        div.appendChild(label);
        document.getElementById('predictionResult').innerHTML = '';
        document.getElementById('predictionResult').appendChild(div);
    });
}

// Draw onto canvas
function draw(image, canvas) {
    const [width, height] = [28, 28];
    canvas.width = width;
    canvas.height = height;
    const ctx = canvas.getContext('2d');
    const imageData = new ImageData(width, height);
    const data = image.dataSync();
    for (let i = 0; i < height * width; ++i) {
      const j = i * 4;
      imageData.data[j + 0] = data[i] * 255;
      imageData.data[j + 1] = data[i] * 255;
      imageData.data[j + 2] = data[i] * 255;
      imageData.data[j + 3] = 255;
    }
    ctx.putImageData(imageData, 0, 0);
}


//---- normal P5 code -------------------------------------------------------


async function setup()       
{
    // Load the TF libs, then on complete - create the model.
    createLogEntry("Load external scripts...");
    
    await loadTensorFlowLibraries();
    
    createLogEntry("Scripts loaded...");
    
    createLogEntry("Create convolutional neural network...");
    
    var model = createModel();
    
    //Load an existing model
    //const model = await tf.loadLayersModel('uploads/michaelryan/my-model.json');
    
    createLogEntry("Neural network created...");
    
    createLogEntry("Load MNIST Data...");
    
    var data = await load();
    
    createLogEntry("MNIST data loaded...");
    
    createLogEntry("Train neural network...");
    
    // Comment out if loading an existing, trained model
    await train(data, model);
    
    createLogEntry("Neural network trained...");
    
    document.getElementById('selectTestDataButton').disabled = false;
    document.getElementById('selectTestDataButton').innerText = "Randomly Select Test Data And Predict";
    
    // Attach button listener
    document.getElementById('selectTestDataButton').addEventListener('click', async (el,ev) => {
        const batch = data.nextTestBatch(1);
        await predict(batch, model);
    });
    
    // Save a model
    await model.save('downloads://my-model');
}

setup();