//Code made by Vincent Tran, written from scratch with the help of the coding train's video: https://www.youtube.com/watch?v=KhogNPC24eI&t=10069s&ab_channel=TheCodingTrain
//and from the original digit recognizer lab.
//Front end part
var thehtml;
// 1 Doodle header
thehtml = "<hr> <h1> 1. Doodle </h1> Doodle on the left <br> " +
" Draw your doodle in there. <br>" +"<button onclick='wipeDoodle();' class='normbutton' >Clear doodle</button> <br> ";
AB.msg ( thehtml, 1 );
// 2 Doodle variable data (guess)
// 3 Training header
thehtml = "<hr> <h1> 2. Training </h1> training images on the right <br> " +
" <button onclick='do_training = false;' class='normbutton' >Stop training</button> <br> ";
AB.msg ( thehtml, 3 );
// 4 variable training data
// 5 Testing header
thehtml = "<h3> Hidden tests </h3> " ;
AB.msg ( thehtml, 6 );
const greenspan = "<span style='font-weight:bold; font-size:x-large; color:darkgreen'> " ;
const redspan = "<span style='font-weight:bold; font-size:x-large; color:darkred'> " ;
const span="<span style='font-weight:bold; font-size:x-large'> " ;
//--- End of front end ---------------------------------------------------------
//---- normal P5 code -------------------------------------------------------
//Modifiable parameters
let nb_hidden_nodes=32;
let learningrate=0.1;
let do_training=true; // variable to stop the trainng of the neural network
let mnist={};
//a simple list to refer to the label int list
let alphabet= ['a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z'];
//Since I directly implemented the mnist.js here, it's a Boolean to prevent
//the program to continue while the emnist data hasn't been fully loaded yet
let loaded=false;
//Variables for the doodle recognizer
let user_letter;
let doodle_exist=false;
//variables to count the number of training and testing
let train_index=0;
let test_index=0;
let train_set=1;
let test_set=1;
//variable to calculate the success percentage
let total_tests=0;
let total_correct=0;
let train_image;
let training_images;
let training_labels;
let testing_images;
let testing_labels;
// asynchronous function to load the MNIST-------------------------------------------------------
function loadMNIST(callback){
loadFile('/uploads/archyyu321/emnist-letters-test-images-idx3-ubyte.bin',16)
.then(data=>{
mnist.test_images=data;
return loadFile('/uploads/archyyu321/emnist-letters-test-labels-idx1-ubyte.bin',8)
})
.then(data=>{
mnist.test_labels=data;
return loadFile('/uploads/archyyu321/emnist-letters-train-images-idx3-ubyte.bin',16)
})
.then(data=>{
mnist.train_images=data;
return loadFile('/uploads/archyyu321/emnist-letters-train-labels-idx1-ubyte.bin',8)
})
.then(data=>{
mnist.train_labels=data;
callback(mnist);
loaded=true;
});
}
async function loadFile(file,offset){
let r= await fetch(file);
let data= await r.arrayBuffer(file);
return new Uint8Array(data).slice(offset);
}
// asynchronous function to load the MNIST-------------------------------------------------------
function setup()
{
createCanvas(400,200);
AB.loadingScreen();
$.getScript ( "/uploads/vilonart/matrix.js", function()//matrix.js taken from the coding train and repurposed for the practical
{
$.getScript ( "/uploads/vilonart/nn.js", function()//neural network from the coding train and repurposed for the practical
{
console.log ("All JS loaded");
nn = new NeuralNetwork( 784, nb_hidden_nodes, 26 );
nn.setLearningRate ( learningrate );
});
});
user_letter=createGraphics(200,200);
user_letter.pixelDensity(1);
train_image= createImage(28,28);
loadMNIST(function(data){
mnist = data;
console.log(mnist);
AB.removeLoading();
});
}
//Functions to get the prediction-----------------------------
function findMax(arr){
//console.log(arr);
let record=0;
let index=1;
for (let i =0; i<arr.length; i++){
if(arr[i]>record){
record=arr[i];
index=i;
}
}
return index;
}
function find12 (a) // this function has been copy/pasted from the Character recognition neural network lab
{
let no1 = 0;
let no2 = 0;
let no1value = 0;
let no2value = 0;
for (let i = 0; i < a.length; i++)
{
if (a[i] > no1value) // new no1
{
// old no1 becomes no2
no2 = no1;
no2value = no1value;
// now put in the new no1
no1 = i;
no1value = a[i];
}
else if (a[i] > no2value) // new no2
{
no2 = i;
no2value = a[i];
}
}
var b = [ no1, no2 ];
return b;
}
//Functions to get the prediction-----------------------------
// train, testing and doodle test functions--------------------------------------
function train(show){
let inputs=[];
if (show){
train_image.loadPixels();
}
for (let i=0; i < 784; i++){
let bright = mnist.train_images[i+train_index*784];
inputs[i]=bright/255;
if (show){
let index=i*4;
train_image.pixels[index + 0]=bright;
train_image.pixels[index + 1]=bright;
train_image.pixels[index + 2]=bright;
train_image.pixels[index + 3]=255;
}
}
if (show) {
train_image.updatePixels();
//console.log(train_image)
image(train_image, 200, 0, 200, 200);
}
let label=mnist.train_labels[train_index];
let targets=Array(26).fill(0);
targets[label-1]=1;
//console.log(inputs);
//console.log(targets);
let guess=nn.predict(inputs);
let letterGuess=findMax(guess);
//console.log(letterGuess);
nn.train(inputs, targets);
thehtml = " train set: " + train_set + "<br> no: " + train_index ;
AB.msg(thehtml,4);
if (letterGuess==label){
thehtml = greenspan+"<br> label: " + alphabet[label-1] + "<br> guess: " + alphabet[letterGuess-1] ;
}else thehtml = redspan+"<br> label: " + alphabet[label-1] + "<br> guess: " + alphabet[letterGuess-1] ;
AB.msg ( thehtml, 5 );
train_index++;
if (train_index==mnist.train_labels.length){
train_index=0;
console.log('finished train set');
train_set++
}
//train_index=(train_index + 1)%mnist.train_labels.length;
//noLoop();
}
function testing(){
let inputs=[];
for (let i=0; i < 784; i++){
let bright = mnist.test_images[i+test_index*784];
inputs[i]=bright/255;
}
let label=mnist.test_labels[test_index];
let guess=nn.predict(inputs);
let letterGuess=findMax(guess);
total_tests+=1;
if (letterGuess==label){
total_correct+=1;
}
let percent=100*(total_correct/total_tests);
thehtml ="test set: "+ test_set +"<br> no: "+ test_index +"<br> Success rate: " + nf(percent,2,2)+"%";
AB.msg(thehtml,7);
test_index++;
if (test_index==mnist.test_labels.length){
test_index=0;
console.log('finished test set');
console.log(percent);
total_tests=0;
total_correct=0;
test_set++;
}
//test_index=(test_index + 1)%mnist.test_labels.length;
}
function guessUserLetter(){
let img=user_letter.get();
let inputs=[];
img.resize(28,28);
img.loadPixels();
for (let i=0; i < 784; i++){
inputs[i]=img.pixels[i*4]/255;
}
let guess=nn.predict(inputs);
let letterGuess=find12(guess);
thehtml = span+"First guess: " + alphabet[letterGuess[0]-1] + "<br> Second guess: " + alphabet[letterGuess[1]-1] ;
AB.msg ( thehtml, 2);
return img;
}
// train, testing and doodle test functions--------------------------------------------
function wipeDoodle(){
user_letter.background(0);
doodle_exist=false;
}
function draw()
{
background(0);
if (loaded===true){
if(do_training){
let total1=20;
for (let i=0; i<total1 ;i++){
if (i==total1-1){
train(true);
}else train(false);
}
}
let total2=100;
for (let i = 0; i < total2 ;i++){
testing();
}
if (doodle_exist){
let user=guessUserLetter();
image(user,0,0);
image(user_letter,0,0);
}
if (mouseIsPressed){
doodle_exist=true;
user_letter.stroke(255);
user_letter.strokeWeight(16);
user_letter.line(mouseX, mouseY,pmouseX,pmouseY);
}
}
}