let mmap;
let rand;
let transp;
let mat;
let e;
let sub;
let sqr;
let sum;
// make run header bigger
AB.headerCSS ( { "max-height": "95vh" } );
//--- start of AB.msgs structure: ---------------------------------------------------------
// We output a serious of AB.msgs to put data at various places in the run header
var thehtml;
// 1 Doodle header
thehtml = "<hr> <h1> 1. Doodle </h1> Top row: Doodle (left) and shrunk (right). <br> " +
" Draw your doodle in top LHS. <button onclick='wipeDoodle();' class='normbutton' >Clear doodle</button> <br> ";
AB.msg ( thehtml, 1 );
// 2 Doodle variable data (guess)
// 3 Training header
thehtml = "<hr> <h1> 2. Training </h1> Middle row: Training image magnified (left) and original (right). <br> " +
" <button onclick='do_training = false;' class='normbutton' >Stop training</button> <br> ";
AB.msg ( thehtml, 3 );
// 4 variable training data
// 5 Testing header
thehtml = "<h3> Hidden tests </h3> " ;
AB.msg ( thehtml, 5 );
// 6 variable testing data
// 7 Demo header
thehtml = "<hr> <h1> 3. Demo </h1> Bottom row: Test image magnified (left) and original (right). <br>" +
" The network is <i>not</i> trained on any of these images. <br> " +
" <button onclick='makeDemo();' class='normbutton' >Demo test image</button> <br> ";
AB.msg ( thehtml, 7 );
// 8 Demo variable data (random demo ID)
// 9 Demo variable data (changing guess)
const greenspan = "<span style='font-weight:bold; font-size:x-large; color:darkgreen'> " ;
//--- end of AB.msgs structure: ---------------------------------------------------------
//---- normal P5 code -------------------------------------------------------
function setup()
{
$.getScript ( "/uploads/kcarolan96/math.js", function()
{
$.getScript ( "/uploads/kcarolan96/mnist.js", function()
{
console.log ("All JS loaded");
/** these short cuts of mathJS functions make our life easier */
const mmap = math.map; // to be used to pass each element of a matrix to a function
const rand = math.random;
const transp = math.transpose;
const mat = math.matrix;
const e = math.evaluate;
const sub = math.subtract;
const sqr = math.square;
const sum = math.sum;
class NeuralNetwork {
constructor(inputnodes, hiddennodes, outputnodes, learningrate, wih, who) {
this.inputnodes = inputnodes;
this.hiddennodes = hiddennodes;
this.outputnodes = outputnodes;
this.learningrate = learningrate;
/* initialise the weights either randomly or, if passed in as arguments, with pretrained values */
/* wih = weights of input-to-hidden layer */
/* who = weights of hidden-to-output layer */
this.wih = wih || sub(mat(rand([hiddennodes, inputnodes])), 0.5);
this.who = who || sub(mat(rand([outputnodes, hiddennodes])), 0.5);
/* the sigmoid activation function */
this.act = (matrix) => mmap(matrix, (x) => 1 / (1 + Math.exp(-x)));
}
static normalizeData = (data) => {
return data.map((e) => (e / 255) * 0.99 + 0.01);
};
cache = { loss: [] };
forward = (input) => {
const wih = this.wih;
const who = this.who;
const act = this.act;
input = transp(mat([input]));
/* hidden layer */
const h_in = e("wih * input", { wih, input });
const h_out = act(h_in);
/* output layer */
const o_in = e("who * h_out", { who, h_out });
const actual = act(o_in);
/* these values are needed later in "backward" */
this.cache.input = input;
this.cache.h_out = h_out;
this.cache.actual = actual;
return actual;
};
backward = (target) => {
const who = this.who;
const input = this.cache.input;
const h_out = this.cache.h_out;
const actual = this.cache.actual;
target = transp(mat([target]));
// calculate the gradient of the error function (E) w.r.t the activation function (A)
const dEdA = sub(target, actual);
// calculate the gradient of the activation function (A) w.r.t the weighted sums (Z) of the output layer
const o_dAdZ = e("actual .* (1 - actual)", {
actual,
});
// calculate the error gradient of the loss function w.r.t the weights of the hidden-to-output layer
const dwho = e("(dEdA .* o_dAdZ) * h_out'", {
dEdA,
o_dAdZ,
h_out,
});
// calculate the weighted error for the hidden layer
const h_err = e("who' * (dEdA .* o_dAdZ)", {
who,
dEdA,
o_dAdZ,
});
// calculate the gradient of the activation function (A) w.r.t the weighted sums (Z) of the hidden layer
const h_dAdZ = e("h_out .* (1 - h_out)", {
h_out,
});
// calculate the error gradient of the loss function w.r.t the weights of the input-to-hidden layer
const dwih = e("(h_err .* h_dAdZ) * input'", {
h_err,
h_dAdZ,
input,
});
this.cache.dwih = dwih;
this.cache.dwho = dwho;
this.cache.loss.push(sum(sqr(dEdA)));
};
update = () => {
const wih = this.wih;
const who = this.who;
const dwih = this.cache.dwih;
const dwho = this.cache.dwho;
const r = this.learningrate;
/* update the current weights of each layer with their corresponding error gradients */
/* error gradients are negated by using the positve sign */
this.wih = e("wih + (r .* dwih)", { wih, r, dwih });
this.who = e("who + (r .* dwho)", { who, r, dwho });
};
predict = (input) => {
return this.forward(input);
};
train = (input, target) => {
this.forward(input);
this.backward(target);
this.update();
};
}
/* neural network's hyper parameters */
const inputnodes = 784;
const hiddennodes = 100;
const outputnodes = 10;
const learningrate = 0.2;
const threshold = 0.5;
let iter = 0;
const iterations = 5;
/* path to the data sets */
const trainingDataPath = "./mnist/mnist_train.csv";
const testDataPath = "./mnist/mnist_test.csv";
const weightsFilename = "weights.json";
const savedWeightsPath = `./dist/${weightsFilename}`;
/* these constants will be filled during data loading and preparation */
const trainingData = [];
const trainingLabels = [];
const testData = [];
const testLabels = [];
const savedWeights = {};
/* states after how many trained samples a log message should appear */
const printSteps = 1000;
let myNN;
window.onload = async () => {
/* Instantiate an entity from the NeuralNetwork class */
myNN = new NeuralNetwork(inputnodes, hiddennodes, outputnodes, learningrate);
trainButton.disabled = true;
testButton.disabled = true;
loadWeightsButton.disabled = true;
status.innerHTML = "Loading the data sets. Please wait ...<br>";
/* get all the data set files and do the preparations */
const trainCSV = await loadData(trainingDataPath, "CSV");
if (trainCSV) {
prepareData(trainCSV, trainingData, trainingLabels);
status.innerHTML += "Training data successfully loaded...<br>";
}
const testCSV = await loadData(testDataPath, "CSV");
if (testCSV) {
prepareData(testCSV, testData, testLabels);
status.innerHTML += "Test data successfully loaded...<br>";
}
if (!trainCSV || !testCSV) {
status.innerHTML +=
"Error loading train/test data set. Please check your file path! If you run this project locally, it needs to be on a local server.";
return;
}
trainButton.disabled = false;
testButton.disabled = false;
const weightsJSON = await loadData(savedWeightsPath, "JSON");
/* if there is a saved JSON file with pretrained weights existing, save the content in the weightsJSON constant */
if (weightsJSON) {
savedWeights.wih = weightsJSON.wih;
savedWeights.who = weightsJSON.who;
loadWeightsButton.disabled = false;
}
status.innerHTML += "Ready.<br><br>";
};
async function loadData(path, type) {
try {
const result = await fetch(path, {
mode: "no-cors",
});
switch (type) {
case "CSV":
return await result.text();
break;
case "JSON":
return await result.json();
break;
default:
return false;
}
} catch {
return false;
}
}
function prepareData(rawData, target, labels) {
rawData = rawData.split("\n"); // create an array where each element correspondents to one line in the CSV file
rawData.pop(); // remove the last element which is empty because it refers to a last blank line in the CSV file
rawData.forEach((current) => {
let sample = current.split(",").map((x) => +x); // create an array where each element has a grey color value
labels.push(sample[0]); // extract the first element of the sample which is (mis)used as the label
sample.shift(); // remove the first element
sample = NeuralNetwork.normalizeData(sample);
target.push(sample);
});
}
function train() {
trainButton.disabled = true;
testButton.disabled = true;
loadWeightsButton.disabled = true;
download.innerHTML = "";
if (iter < iterations) {
iter++;
status.innerHTML += "Starting training ...<br>";
status.innerHTML += "Iteration " + iter + " of " + iterations + "<br>";
trainingData.forEach((current, index) => {
setTimeout(() => {
/* create one-hot encoding of the label */
const label = trainingLabels[index];
const oneHotLabel = Array(10).fill(0);
oneHotLabel[label] = 0.99;
myNN.train(current, oneHotLabel);
/* check if the defined interval for showing a message on the training progress is reached */
if (index > 0 && !((index + 1) % printSteps)) {
status.innerHTML += `finished ${index + 1} samples ... <br>`;
}
/* check if the end of the training iteration is reached */
if (index === trainingData.length - 1) {
status.innerHTML += `Loss: ${
sum(myNN.cache.loss) / trainingData.length
}<br><br>`;
myNN.cache.loss = [];
test("", true); // true to signal "test" that it is called from within training
}
}, 0);
});
}
}
function test(_, inTraining = false) {
// skip the first parameter as it includes data from the event listener which we don't need here
trainButton.disabled = true;
testButton.disabled = true;
loadWeightsButton.disabled = true;
status.innerHTML += "Starting testing ...<br>";
let correctPredicts = 0;
testData.forEach((current, index) => {
setTimeout(() => {
const actual = testLabels[index];
const predict = formatPrediction(myNN.predict(current));
predict === actual ? correctPredicts++ : null;
/* check if the defined interval for showing a message on the testing progress is reached */
if (index > 0 && !((index + 1) % printSteps)) {
status.innerHTML += " finished " + (index + 1) + " samples ...<br>";
}
/* check if testing is complete */
if (index >= testData.length - 1) {
status.innerHTML +=
"Accuracy: " +
Math.round((correctPredicts / testData.length) * 100) +
" %<br><br>";
/* check if training is complete */
if (iter + 1 > iterations) {
createDownloadLink();
enableAllButtons();
status.innerHTML += "Finished training.<br><br>";
iter = 0;
} else if (inTraining) {
// if test is called from within training and the training is not complete yet, continue training
train();
} else {
enableAllButtons();
}
}
}, 0);
});
}
function predict() {
/* resize the canvas to the training data image size */
const tempCanvas = document.createElement("canvas");
const tempCtx = tempCanvas.getContext("2d");
tempCtx.drawImage(canvas, 0, 0, 150, 150, 0, 0, 28, 28);
/* convert the canvas image */
const img = tempCtx.getImageData(0, 0, 28, 28);
/* remove the alpha channel and convert to grayscale */
let sample = [];
for (let i = 0, j = 0; i < img.data.length; i += 4, j++) {
sample[j] = (img.data[i + 0] + img.data[i + 1] + img.data[i + 2]) / 3;
}
img.data = NeuralNetwork.normalizeData(img.data);
const predict = formatPrediction(myNN.predict(sample));
prediction.innerHTML = predict;
}
function formatPrediction(prediction) {
const flattened = prediction.toArray().map((x) => x[0]);
/* get the index of the highest number in the flattened array */
return flattened.indexOf(Math.max(...flattened));
}
function loadWeights() {
myNN.wih = savedWeights.wih;
myNN.who = savedWeights.who;
status.innerHTML += "Weights successfully loaded.";
}
function createDownloadLink() {
const wih = myNN.wih.toArray();
const who = myNN.who.toArray();
const weights = { wih, who };
download.innerHTML = `<a download="${weightsFilename}" id="downloadLink" href="data:text/json;charset=utf-8,${encodeURIComponent(
JSON.stringify(weights)
)}">Download model weights</a>`;
}
/* UI helper functions */
function enableAllButtons() {
trainButton.disabled = false;
testButton.disabled = false;
loadWeightsButton.disabled = false;
}
});
});
}
function loadData()
{
loadMNIST ( function(data)
{
mnist = data;
console.log ("All data loaded");
console.log(mnist);
AB.removeLoading(); // if no loading screen exists, this does nothing
});
}
function testing(){
var test = math.chain(3)
.add(4)
.multiply(2)
.done() ; // 5 inch
console.log(test);
}
function draw()
{
}