Code viewer for World: New World
let mmap;
let rand;
let transp;
let mat;
let e;
let sub;
let sqr;
let sum;
 
// make run header bigger
AB.headerCSS ( { "max-height": "95vh" } );




//--- start of AB.msgs structure: ---------------------------------------------------------
// We output a serious of AB.msgs to put data at various places in the run header 
var thehtml;

  // 1 Doodle header 
  thehtml = "<hr> <h1> 1. Doodle </h1> Top row: Doodle (left) and shrunk (right). <br> " +
        " Draw your doodle in top LHS. <button onclick='wipeDoodle();' class='normbutton' >Clear doodle</button> <br> ";
   AB.msg ( thehtml, 1 );

  // 2 Doodle variable data (guess)
  
  // 3 Training header
  thehtml = "<hr> <h1> 2. Training </h1> Middle row: Training image magnified (left) and original (right). <br>  " +
        " <button onclick='do_training = false;' class='normbutton' >Stop training</button> <br> ";
  AB.msg ( thehtml, 3 );
     
  // 4 variable training data 
  
  // 5 Testing header
  thehtml = "<h3> Hidden tests </h3> " ;
  AB.msg ( thehtml, 5 );
           
  // 6 variable testing data 
  
  // 7 Demo header 
  thehtml = "<hr> <h1> 3. Demo </h1> Bottom row: Test image magnified (left) and  original (right). <br>" +
        " The network is <i>not</i> trained on any of these images. <br> " +
        " <button onclick='makeDemo();' class='normbutton' >Demo test image</button> <br> ";
   AB.msg ( thehtml, 7 );
   
  // 8 Demo variable data (random demo ID)
  // 9 Demo variable data (changing guess)
  
const greenspan = "<span style='font-weight:bold; font-size:x-large; color:darkgreen'> "  ;

//--- end of AB.msgs structure: ---------------------------------------------------------

//---- normal P5 code -------------------------------------------------------

function setup()       
{
     $.getScript ( "/uploads/kcarolan96/math.js", function()
 {
     $.getScript ( "/uploads/kcarolan96/mnist.js", function()
        {
        console.log ("All JS loaded");
       /** these short cuts of mathJS functions make our life easier */
const mmap = math.map; // to be used to pass each element of a matrix to a function
const rand = math.random;
const transp = math.transpose;
const mat = math.matrix;
const e = math.evaluate;
const sub = math.subtract;
const sqr = math.square;
const sum = math.sum;

class NeuralNetwork {
  constructor(inputnodes, hiddennodes, outputnodes, learningrate, wih, who) {
    this.inputnodes = inputnodes;
    this.hiddennodes = hiddennodes;
    this.outputnodes = outputnodes;
    this.learningrate = learningrate;

    /* initialise the weights either randomly or, if passed in as arguments, with pretrained values */
    /* wih = weights of input-to-hidden layer */
    /* who = weights of hidden-to-output layer */
    this.wih = wih || sub(mat(rand([hiddennodes, inputnodes])), 0.5);
    this.who = who || sub(mat(rand([outputnodes, hiddennodes])), 0.5);

    /* the sigmoid activation function */
    this.act = (matrix) => mmap(matrix, (x) => 1 / (1 + Math.exp(-x)));
  }

  static normalizeData = (data) => {
    return data.map((e) => (e / 255) * 0.99 + 0.01);
  };

  cache = { loss: [] };

  forward = (input) => {
    const wih = this.wih;
    const who = this.who;
    const act = this.act;

    input = transp(mat([input]));

    /* hidden layer */
    const h_in = e("wih * input", { wih, input });
    const h_out = act(h_in);

    /* output layer */
    const o_in = e("who * h_out", { who, h_out });
    const actual = act(o_in);

    /* these values are needed later in "backward" */
    this.cache.input = input;
    this.cache.h_out = h_out;
    this.cache.actual = actual;

    return actual;
  };

  backward = (target) => {
    const who = this.who;
    const input = this.cache.input;
    const h_out = this.cache.h_out;
    const actual = this.cache.actual;

    target = transp(mat([target]));

    // calculate the gradient of the error function (E) w.r.t the activation function (A)
    const dEdA = sub(target, actual);

    // calculate the gradient of the activation function (A) w.r.t the weighted sums (Z) of the output layer
    const o_dAdZ = e("actual .* (1 - actual)", {
      actual,
    });

    // calculate the error gradient of the loss function w.r.t the weights of the hidden-to-output layer
    const dwho = e("(dEdA .* o_dAdZ) * h_out'", {
      dEdA,
      o_dAdZ,
      h_out,
    });

    // calculate the weighted error for the hidden layer
    const h_err = e("who' * (dEdA .* o_dAdZ)", {
      who,
      dEdA,
      o_dAdZ,
    });

    // calculate the gradient of the activation function (A) w.r.t the weighted sums (Z) of the hidden layer
    const h_dAdZ = e("h_out .* (1 - h_out)", {
      h_out,
    });

    // calculate the error gradient of the loss function w.r.t the weights of the input-to-hidden layer
    const dwih = e("(h_err .* h_dAdZ) * input'", {
      h_err,
      h_dAdZ,
      input,
    });

    this.cache.dwih = dwih;
    this.cache.dwho = dwho;
    this.cache.loss.push(sum(sqr(dEdA)));
  };

  update = () => {
    const wih = this.wih;
    const who = this.who;
    const dwih = this.cache.dwih;
    const dwho = this.cache.dwho;
    const r = this.learningrate;

    /* update the current weights of each layer with their corresponding error gradients */
    /* error gradients are negated by using the positve sign */
    this.wih = e("wih + (r .* dwih)", { wih, r, dwih });
    this.who = e("who + (r .* dwho)", { who, r, dwho });
  };

  predict = (input) => {
    return this.forward(input);
  };

  train = (input, target) => {
    this.forward(input);
    this.backward(target);
    this.update();
  };
}

/* neural network's hyper parameters */
const inputnodes = 784;
const hiddennodes = 100;
const outputnodes = 10;
const learningrate = 0.2;
const threshold = 0.5;
let iter = 0;
const iterations = 5;

/* path to the data sets */
const trainingDataPath = "./mnist/mnist_train.csv";
const testDataPath = "./mnist/mnist_test.csv";
const weightsFilename = "weights.json";
const savedWeightsPath = `./dist/${weightsFilename}`;

/* these constants will be filled during data loading and preparation */
const trainingData = [];
const trainingLabels = [];
const testData = [];
const testLabels = [];
const savedWeights = {};

/* states after how many trained samples a log message should appear */
const printSteps = 1000;

let myNN;

window.onload = async () => {
  /* Instantiate an entity from the NeuralNetwork class */
  myNN = new NeuralNetwork(inputnodes, hiddennodes, outputnodes, learningrate);

  trainButton.disabled = true;
  testButton.disabled = true;
  loadWeightsButton.disabled = true;

  status.innerHTML = "Loading the data sets. Please wait ...<br>";

  /* get all the data set files and do the preparations */

  const trainCSV = await loadData(trainingDataPath, "CSV");

  if (trainCSV) {
    prepareData(trainCSV, trainingData, trainingLabels);
    status.innerHTML += "Training data successfully loaded...<br>";
  }

  const testCSV = await loadData(testDataPath, "CSV");

  if (testCSV) {
    prepareData(testCSV, testData, testLabels);
    status.innerHTML += "Test data successfully loaded...<br>";
  }

  if (!trainCSV || !testCSV) {
    status.innerHTML +=
      "Error loading train/test data set. Please check your file path! If you run this project locally, it needs to be on a local server.";
    return;
  }

  trainButton.disabled = false;
  testButton.disabled = false;

  const weightsJSON = await loadData(savedWeightsPath, "JSON");

  /* if there is a saved JSON file with pretrained weights existing, save the content in the weightsJSON constant */
  if (weightsJSON) {
    savedWeights.wih = weightsJSON.wih;
    savedWeights.who = weightsJSON.who;
    loadWeightsButton.disabled = false;
  }

  status.innerHTML += "Ready.<br><br>";
};

async function loadData(path, type) {
  try {
    const result = await fetch(path, {
      mode: "no-cors",
    });

    switch (type) {
      case "CSV":
        return await result.text();
        break;
      case "JSON":
        return await result.json();
        break;
      default:
        return false;
    }
  } catch {
    return false;
  }
}

function prepareData(rawData, target, labels) {
  rawData = rawData.split("\n"); // create an array where each element correspondents to one line in the CSV file
  rawData.pop(); // remove the last element which is empty because it refers to a last blank line in the CSV file

  rawData.forEach((current) => {
    let sample = current.split(",").map((x) => +x); // create an array where each element has a grey color value

    labels.push(sample[0]); // extract the first element of the sample which is (mis)used as the label
    sample.shift(); // remove the first element

    sample = NeuralNetwork.normalizeData(sample);

    target.push(sample);
  });
}

function train() {
  trainButton.disabled = true;
  testButton.disabled = true;
  loadWeightsButton.disabled = true;
  download.innerHTML = "";

  if (iter < iterations) {
    iter++;
    status.innerHTML += "Starting training ...<br>";
    status.innerHTML += "Iteration " + iter + " of " + iterations + "<br>";

    trainingData.forEach((current, index) => {
      setTimeout(() => {
        /* create one-hot encoding of the label */
        const label = trainingLabels[index];
        const oneHotLabel = Array(10).fill(0);
        oneHotLabel[label] = 0.99;

        myNN.train(current, oneHotLabel);

        /* check if the defined interval for showing a message on the training progress is reached */
        if (index > 0 && !((index + 1) % printSteps)) {
          status.innerHTML += `finished  ${index + 1}  samples ... <br>`;
        }

        /* check if the end of the training iteration is reached */
        if (index === trainingData.length - 1) {
          status.innerHTML += `Loss:  ${
            sum(myNN.cache.loss) / trainingData.length
          }<br><br>`;
          myNN.cache.loss = [];

          test("", true); // true to signal "test" that it is called from within training
        }
      }, 0);
    });
  }
}

function test(_, inTraining = false) {
  // skip the first parameter as it includes data from the event listener which we don't need here
  trainButton.disabled = true;
  testButton.disabled = true;
  loadWeightsButton.disabled = true;

  status.innerHTML += "Starting testing ...<br>";

  let correctPredicts = 0;
  testData.forEach((current, index) => {
    setTimeout(() => {
      const actual = testLabels[index];

      const predict = formatPrediction(myNN.predict(current));
      predict === actual ? correctPredicts++ : null;

      /* check if the defined interval for showing a message on the testing progress is reached */
      if (index > 0 && !((index + 1) % printSteps)) {
        status.innerHTML += " finished " + (index + 1) + " samples ...<br>";
      }

      /* check if testing is complete */
      if (index >= testData.length - 1) {
        status.innerHTML +=
          "Accuracy: " +
          Math.round((correctPredicts / testData.length) * 100) +
          " %<br><br>";

        /* check if training is complete */
        if (iter + 1 > iterations) {
          createDownloadLink();
          enableAllButtons();
          status.innerHTML += "Finished training.<br><br>";
          iter = 0;
        } else if (inTraining) {
          // if test is called from within training and the training is not complete yet, continue training
          train();
        } else {
          enableAllButtons();
        }
      }
    }, 0);
  });
}

function predict() {
  /* resize the canvas to the training data image size  */
  const tempCanvas = document.createElement("canvas");
  const tempCtx = tempCanvas.getContext("2d");
  tempCtx.drawImage(canvas, 0, 0, 150, 150, 0, 0, 28, 28);

  /* convert the canvas image */
  const img = tempCtx.getImageData(0, 0, 28, 28);

  /* remove the alpha channel and convert to grayscale */
  let sample = [];
  for (let i = 0, j = 0; i < img.data.length; i += 4, j++) {
    sample[j] = (img.data[i + 0] + img.data[i + 1] + img.data[i + 2]) / 3;
  }

  img.data = NeuralNetwork.normalizeData(img.data);

  const predict = formatPrediction(myNN.predict(sample));
  prediction.innerHTML = predict;
}

function formatPrediction(prediction) {
  const flattened = prediction.toArray().map((x) => x[0]);

  /* get the index of the highest number in the flattened array */
  return flattened.indexOf(Math.max(...flattened));
}

function loadWeights() {
  myNN.wih = savedWeights.wih;
  myNN.who = savedWeights.who;
  status.innerHTML += "Weights successfully loaded.";
}

function createDownloadLink() {
  const wih = myNN.wih.toArray();
  const who = myNN.who.toArray();
  const weights = { wih, who };
  download.innerHTML = `<a download="${weightsFilename}" id="downloadLink" href="data:text/json;charset=utf-8,${encodeURIComponent(
    JSON.stringify(weights)
  )}">Download model weights</a>`;
}

/* UI helper functions */

function enableAllButtons() {
  trainButton.disabled = false;
  testButton.disabled = false;
  loadWeightsButton.disabled = false;
}
         
        });
     
 });
}

function loadData()    
{
  loadMNIST ( function(data)    
  {
    mnist = data;
    console.log ("All data loaded");
    console.log(mnist);
    AB.removeLoading();     // if no loading screen exists, this does nothing 
  });
}


function testing(){
var test = math.chain(3)
    .add(4)
    .multiply(2)
    .done()  ;  // 5 inch

    console.log(test);
}

function draw()             
{
}