// Cloned by sheeba on 2 Dec 2022 from World "Character recognition neural network" by "Coding Train" project
const numEpochs = 3, pixels = 28, pixelsSquared = pixels * pixels, zoomFactor = 10, zoomPixels = 10 * pixels, canvasHeight = 2 * zoomPixels + 50, canvasWidth = zoomPixels, doodleBlur = 3, doodleThickness = 18, greenSpan = "<span style='font-weight:bold; font-size:x-large; color:darkgreen'>", splashMessage = "<h1>Training in Progress...</h1></br>";
let demo, doodle, theHTML, networkModel, encodedTestLabels, encodedTrainingLabels, testImages, testLabels, trainingImages, trainingLabels, currentEpoch = 0, demoExists = false, demoGuessed = false, doodleExists = false, doodleGuessed = false, modelTrained = false, mouseDrag = false;
function setup() {
createCanvas(canvasWidth, canvasHeight), (doodle = createGraphics(zoomPixels, zoomPixels)).pixelDensity(1), theHTML = "<hr><h1>1. Doodle</h1>Please draw your doodle in the top-left.</br></br>To clear the doodle, click here: <button onclick='wipeDoodle();' class='normbutton'>Clear</button></br>", AB.msg(theHTML, 1), theHTML = "<hr><h1>2. Demo</h1>Demonstrations of the trained network can be viewed here.</br></br>The network is <i>not</i> trained on any of these images.</br></br>To test out an image, click here: <button onclick='makeDemo();' class='normbutton'>Test Image</button></br>", AB.msg(theHTML, 3), theHTML = "<hr><h1>3. Model Settings</h1>To delete the currently trained model, click here: <button onclick='localStorage.clear();' class='normbutton'>Delete Trained Model</button></br>", AB.msg(theHTML, 6), AB.newSplash(), AB.splashHtml("<h1>Training in Progress...</h1>"), $.getScript("https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js", () => {
if (existingModel = localStorage.getItem("tensorflowjs_models/drummk2-tensorflow-network-model/info"), existingModel) tf.loadLayersModel("localstorage://drummk2-tensorflow-network-model").then(e => {
console.log("Loading test data"), loadTestData(), console.log("Pre-existing model loaded from local storage."), networkModel = e, modelTrained = true;
}); else {
console.log("No model was found in local storage; commencing with training.");
let e = {}, t = {trainingImages: "/uploads/vrushali/emnist-letters-train-images-idx3-ubyte.bin", trainingLabels: "/uploads/vrushali/emnist-letters-train-labels-idx1-ubyte.bin", testImages: "/uploads/vrushali/emnist-letters-test-images-idx3-ubyte.bin", testLabels: "/uploads/vrushali/emnist-letters-test-labels-idx1-ubyte.bin"};
Promise.all(Object.keys(t).map(async o => {
e[o] = await loadBinaryFile(t[o]);
})).then(() => {
trainingImages = tf.reshape(tf.tensor(e.trainingImages, [124800, 784], "float32"), [124800, 28, 28, 1]), trainingLabels = tf.tensor1d(e.trainingLabels, "int32"), testImages = e.testImages, testLabels = e.testLabels, encodedTestLabels = tf.oneHot(testLabels, 26), encodedTrainingLabels = tf.oneHot(trainingLabels, 26), (networkModel = tf.sequential()).add(tf.layers.conv2d({activation: "tanh", filters: 8, inputShape: trainingImages.shape.slice(1), kernelInitializer: "glorotUniform", kernelSize: 5, strides: 1})), networkModel.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]})), networkModel.add(tf.layers.conv2d({activation: "tanh", filters: 16, inputShape: trainingImages.shape.slice(1), kernelInitializer: "glorotUniform", kernelSize: 5, strides: 1})), networkModel.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]})), networkModel.add(tf.layers.flatten()), networkModel.add(tf.layers.dense({activation: "softmax", kernelInitializer: "glorotUniform", units: 26})), networkModel.compile({loss: "categoricalCrossentropy", optimizer: tf.train.sgd(0.0001)});
console.log("size is: " + trainingImages.shape.slice(1));
let t = {callbacks: {onBatchEnd: (e, o) => {
AB.splashHtml(splashMessage + "<p><b>Epoch:</b> " + currentEpoch + "/" + t.epochs + "</br><b>Batch:</b> " + e + "</br><b>Loss:</b> " + o.loss);
}, onEpochBegin: (e, t) => {
currentEpoch = e + 1;
}}, batchSize: 512, epochs: numEpochs, shuffle: true, validationSplit: 0.1};
networkModel.fit(trainingImages, encodedTrainingLabels, t).then(() => {
networkModel.save("localstorage://drummk2-tensorflow-network-model").then(() => {
console.log("Model saved to local storage."), modelTrained = true;
});
});
});
}
});
}
function draw() {
if (modelTrained) if (AB.removeSplash(), background("black"), demoExists && (drawDemo(), demoGuessed || guessDemo()), doodleExists && (drawDoodle(), doodleGuessed || (guessDoodle(), doodleGuessed = true)), mouseIsPressed) {
var e = zoomPixels + 50;
mouseX < e && mouseY < e && pmouseX < e && pmouseY < e && (mouseDrag = true, doodleExists = true, doodleGuessed = false, doodle.stroke("white"), doodle.strokeWeight(doodleThickness), doodle.line(mouseX, mouseY, pmouseX, pmouseY));
} else mouseDrag && (mouseDrag = false, doodle.filter(BLUR, doodleBlur));
}
function drawDemo() {
let e = getImage(demo);
image(e, 0, canvasHeight - zoomPixels, zoomPixels, zoomPixels), image(e, zoomPixels + 50, canvasHeight - zoomPixels, pixels, pixels);
}
function drawDoodle() {
let e = doodle.get();
image(e, 0, 0, zoomPixels, zoomPixels), image(e, zoomPixels + 50, 0, pixels, pixels);
}
function getImage(e) {
let t = createImage(pixels, pixels);
t.loadPixels();
for (let o = 0; o < pixelsSquared; o++) {
let s = e[o], a = 4 * o;
t.pixels[a + 0] = s, t.pixels[a + 1] = s, t.pixels[a + 2] = s, t.pixels[a + 3] = 255;
}
return t.updatePixels(), t;
}
function getInputs(e) {
let t = [];
for (let o = 0; o < pixelsSquared; o++) {
let s = e[o];
t[o] = s / 255;
}
return t;
}
function guessDemo() {
let e = getInputs(demo), t = tf.reshape(tf.tensor(e, [1, 784], "int32"), [1, 28, 28, 1]), o = networkModel.predict(t).arraySync()[0], s = o.indexOf(Math.max(...o)), a = String.fromCharCode(97 + s);
theHTML = "The network classified it as: " + greenSpan + a + "</span>", AB.msg(theHTML, 5), demoGuessed = true;
}
function guessDoodle() {
let e = doodle.get();
e.resize(pixels, pixels), e.loadPixels();
let t = [];
for (let o = 0; o < pixelsSquared; o++) t[o] = e.pixels[4 * o] / 255;
let o = tf.reshape(tf.tensor(t, [1, 784], "int32"), [1, 28, 28, 1]), s = networkModel.predict(o).arraySync()[0], a = s.indexOf(Math.max(...s)), l = String.fromCharCode(97 + a);
theHTML = "Classification:</span>", AB.msg(theHTML, 2), theHTML = "Classification: " + greenSpan + l + "</span>", AB.msg(theHTML, 2);
}
async function loadBinaryFile(e) {
let t, o, s = await fetch(e).then(e => e.arrayBuffer()), a = 4, l = new DataView(s, 0, 4 * a), i = new Array(a).fill().map((e, t) => l.getUint32(4 * t, false));
2049 === i[0] ? (t = "label", o = 1, a = 2) : 2051 === i[0] && (t = "image", o = i[2] * i[3]);
let n = Int32Array.from(new Uint8Array(s, 4 * a));
if ("image" === t) {
let e = [];
for (let t = 0; t < i[1]; t++) e.push(n.subarray(o * t, o * (t + 1)));
return e;
}
return n;
}
function loadTestData() {
let e = {}, t = {testImages: "/uploads/vrushali/emnist-letters-test-images-idx3-ubyte.bin", testLabels: "/uploads/vrushali/emnist-letters-test-labels-idx1-ubyte.bin"};
Promise.all(Object.keys(t).map(async o => {
e[o] = await loadBinaryFile(t[o]);
})).then(() => {
testImages = e.testImages, testLabels = e.testLabels, encodedTestLabels = tf.oneHot(testLabels, 26);
});
}
function makeDemo() {
demoExists = true;
let e = AB.randomIntAtoB(0, testImages.length - 1);
demo = testImages[e];
let t = testLabels[e];
theHTML = "Test image no: " + e + "<br>Classification: " + String.fromCharCode(97 + t) + "<br>", AB.msg(theHTML, 4), demoGuessed = false;
}
function wipeDoodle() {
doodleExists = false, doodle.background("black"), theHTML = "Classification:</span>", AB.msg(theHTML, 2);
}