import React, { useEffect, useRef, useState } from 'react';
import './App.css';

const home_endpoint = 'https://backend-5v3d3gapha-ue.a.run.app/home'
// const home_endpoint = 'http://127.0.0.1:5001/home'

function NeuralNetwork({layers, linewidths, neuroncolors, losses, maxlosses, datalabels, accuracy, test_samples}) {
    const canvasRef = useRef(null);
    const [frameIndex, setFrameIndex] = useState(0);
    var [canvasSize, setCanvasSize] = useState([0,0]);
    const [windowWidth, setWindowWidth] = useState(0);
    // const [firstRender, setFirstRender] = useState(0);

    // const [width, setWidth] = useState(0);
    // const [height, setHeight] = useState(0);   
  
    useEffect(() => {

        const canvas = canvasRef.current;
        const ctx = canvas.getContext('2d');

        const scale = window.devicePixelRatio;

        // const scale = window.devicePixelRatio || 1;
        const parent = canvas.parentElement;
        const rect = parent.getBoundingClientRect();

        const width = rect.width * scale;
        const height = rect.height * scale;

        // if (canvasSize[0] == 0) {
        //     setCanvasSize([width, height])
        //     canvas.width = width;
        //     canvas.height = height;
        //     setWindowWidth(width/scale)
        // } 
        // else {
        //     canvas.width = canvasSize[0];
        //     canvas.height = canvasSize[1];
        // }

        canvas.width = 2700
        canvas.height = height

        setWindowWidth(width/scale)

        ctx.scale(scale, scale);

        // Define the positions of neurons in each layer
        const neuronPositions = [];

        // Calculate the maximum number of neurons in any layer
        const maxNeurons = Math.max(...layers);

        // Calculate the vertical spacing between layers
        //   const verticalSpacing = canvas.height / (maxNeurons + 1)*1.5;
        const verticalSpacing = 90
        // const horizontalSpacing = canvas.width / (layers.length + 1);
        const horizontalSpacing = 1800/(layers.length - 1)

        // Calculate the required height for the canvas based on the number of neurons
        const requiredHeight = (maxNeurons - 1) * verticalSpacing + 300; // Add some padding
        const requiredWidth = 2000

        // Set the canvas height
        canvasRef.current.height = requiredHeight + 400;
        // canvasRef.current.Width = requiredWidth;

        const abs = (num) => {
          if (num < 0) {
            return num*-1;
          }
          else {
            return num
          };
        };

        const progressText = (frame) => {
            let iteration = frame % 4
            let epoch = (frame - iteration)/4
            let text = `Epoch ${epoch + 1} / 10 | Iteration ${iteration + 1} / 4`;

            ctx.beginPath();
            ctx.moveTo(0, 50); // Adjusted startY for neuron centering
            ctx.lineTo(((frame+1)/40)*canvas.width, 50);
            ctx.strokeStyle = 'green';
            ctx.lineWidth = 100;
            ctx.stroke();
            
            // String.format("Epoch %d / 10 | Iteration %d / 4", epoch + 1, iteration + 1)

            ctx.font = '40px Arial'; // Set font size and family

            // Set text properties
            ctx.fillStyle = 'white'; // Set text color

            // Display text
            ctx.fillText(text, 50, 50);

        }

        // Draw neurons
        const drawNeurons = (neuroncolors, layer) => {

            ctx.fillStyle = 'black'; // Set background color
            ctx.fillRect(0, 0, canvas.width, canvas.height); // Fill the canvas with black

            for (let i = 0; i < layers.length; i++) {

                // Calculate the starting position for the first neuron in the layer
                const startY = (requiredHeight - ((layers[i]) * verticalSpacing)) / 2;
                const startX = (canvas.width - ((layers.length - 1) * horizontalSpacing)) / 2 - 50;

                for (let j = 0; j < layers[i]; j++) {
                const x = startX + i * horizontalSpacing;

                // const x = (i + 1) * horizontalSpacing; // Position in X direction
                // const x = (i + 1) * (canvas.width / (layers.length + 1)); // Position in X direction
                const y = startY + (j + 1) * verticalSpacing; // Position in Y direction
                
                const neuron_radius = 30

                ctx.lineWidth = 1

                // Draw neuron
                ctx.beginPath();
                ctx.arc(x, y, neuron_radius, 0, Math.PI * 2);
                
                if (layer.includes(i)) {
                    ctx.fillStyle = neuroncolors[i][j];
                } else {
                    ctx.fillStyle = 'black';
                }

                
                ctx.fill();
                ctx.strokeStyle = 'white';
                ctx.stroke();

                if (i == 0) {
                  ctx.font = '30px Arial'; // Set font size and family

                  // Set text properties
                  ctx.fillStyle = 'white'; // Set text color

                  // Display text
                  ctx.fillText(datalabels[j], startX-300, y);
                }

                if (i == layers.length-1) {
                  ctx.font = '30px Arial'; // Set font size and family

                  // Set text properties
                  ctx.fillStyle = 'white'; // Set text color

                  // Display text
                  ctx.fillText(datalabels[datalabels.length - 1], x+60, y);
                }
            }
            }
      }
  
        // Draw weights/biases
        const drawWeightsBiases = (linewidths) => {


            for (let i = 1; i < layers.length; i++) {
            // const horizontalSpacing = 300

            const startY = (requiredHeight - ((layers[i]) * verticalSpacing)) / 2;
        
            for (let j = 0; j < layers[i]; j++) {

                const startX = (canvas.width - ((layers.length - 1) * horizontalSpacing)) / 2 + i*horizontalSpacing - 50;

                // const startX = 50 + i * horizontalSpacing;
        
                for (let k = 0; k < layers[i - 1]; k++) {
                    const endX = startX - horizontalSpacing;
                    const endY = (requiredHeight - (layers[i - 1] * verticalSpacing)) / 2 + (k + 1) * verticalSpacing; // Adjusted endY for neuron centering
                    
                    const neuron_radius = 30

                    var line_color = '';

                    if (linewidths[i-1][j][k] < 0) {
                      line_color = '#A1C7FF';
                    } else {
                      line_color = '#FFA1A1';
                    };

                    // Draw line between neurons
                    ctx.beginPath();
                    ctx.moveTo(startX - neuron_radius, startY + (j+1) * verticalSpacing); // Adjusted startY for neuron centering
                    ctx.lineTo(endX + neuron_radius, endY);
                    ctx.strokeStyle = line_color;
                    ctx.lineWidth = abs(linewidths[i-1][j][k]/50);
                    ctx.stroke();
                }
            }
            }
        }

        const drawlosses = (losses, maxlosses, iteration) => {

            const startY = canvas.height - 2.5
            const startX = 110

            let roundedloss = +losses[iteration].toFixed(4)
            let diff = (losses[iteration] - losses[0])/losses[0]*100
            let roundeddiff = +diff.toFixed(2)

            let text = `Loss: ${roundedloss} (${roundeddiff}%)`
            
            ctx.font = '40px Arial'; // Set font size and family
            ctx.fillStyle = 'red'; // Set text color
            ctx.fillText(text, startX, startY - 310);

            for (let i = 0; i < iteration + 1; i++) {
            // const horizontalSpacing = 300

            const endY = startY - losses[i]/maxlosses*280

            const barwidth = 25

            // ctx.beginPath();
            // ctx.moveTo(startX-25-7, startY+7); // Adjusted startY for neuron centering
            // ctx.lineTo(startX-25-7, startY+7 - 250);
            // ctx.strokeStyle = 'white';
            // ctx.lineWidth = 5;
            // ctx.stroke();

            // ctx.beginPath();
            // ctx.moveTo(startX-25-7, startY +7); // Adjusted startY for neuron centering
            // ctx.lineTo(startX-25-7 + 2150, startY +7);
            // ctx.strokeStyle = 'white';
            // ctx.lineWidth = 5;
            // ctx.stroke();

            ctx.beginPath();
            ctx.moveTo(startX + i*(barwidth + 2.5), startY); // Adjusted startY for neuron centering
            ctx.lineTo(startX + i*(barwidth + 2.5), endY);
            ctx.strokeStyle = 'red';
            ctx.lineWidth = barwidth;
            ctx.stroke();

            }
        }

        const draw_test_samples = (accuracy) => {

          let text = `Test Case Accuracy: ${accuracy}%`;
          
          ctx.font = '40px Arial'; // Set font size and family
          ctx.fillStyle = 'white'; // Set text color
          ctx.fillText(text, 1350, canvas.height - 310);

          if (accuracy == 'running...') {
              // Calculate the end angle based on the percentage
              const endAngle = 2 * Math.PI;

              const centerX = 1600;
              const centerY = canvas.height - 140;
              const outerRadius = 100;
              const innerRadius = 50;

              // Draw the outer arc
              ctx.beginPath();
              ctx.arc(centerX, centerY, outerRadius, 0, endAngle);
              ctx.lineWidth = outerRadius - innerRadius;
              ctx.strokeStyle = '#7d7d7d';
              ctx.stroke();

          } else 
          {
            // Calculate the end angle based on the percentage
          const endAngle = (accuracy / 100) * 2 * Math.PI;

          const centerX = 1600;
          const centerY = canvas.height - 140;
          const outerRadius = 100;
          const innerRadius = 50;

          // Draw the outer arc
          ctx.beginPath();
          ctx.arc(centerX, centerY, outerRadius, 0, endAngle);
          ctx.lineWidth = outerRadius - innerRadius;
          ctx.strokeStyle = 'green';
          ctx.stroke();

          ctx.beginPath();
          ctx.arc(centerX, centerY, outerRadius, endAngle, 0);
          ctx.lineWidth = outerRadius - innerRadius;
          ctx.strokeStyle = 'red';
          ctx.stroke();
          }

          
          
          // var x = 900;
          // var y = canvas.height - 310;

          // for (let i = 0; i < test_samples.length; i++) {

          // let text = `Test Case ${i + 1} : ${test_samples[i]}`;

          // if (i % 8 == 0) {
          //   x = x + 300;
          //   y = canvas.height - 310;
          // }
          // else {
          //   y = y + 50
          // }

          // ctx.font = '30px Arial'; // Set font size and family
          // ctx.fillStyle = 'red'; // Set text color
          // ctx.fillText(text, x, y);
          };

        const draw_legend = () => {
          // let text = `Test Case Accuracy: ${accuracy}%`;
          
          ctx.font = '40px Arial'; // Set font size and family
          ctx.fillStyle = 'white'; // Set text color
          ctx.fillText('Legend:', 2200, canvas.height - 310);
          
          ctx.font = '30px Arial';
          ctx.fillStyle = '#FFA1A1'; // Set text color
          ctx.fillText('red - positive weight', 2200, canvas.height - 230);
          
          ctx.font = '30px Arial';
          ctx.fillStyle = '#A1C7FF'; // Set text color
          ctx.fillText('blue - negative weight', 2200, canvas.height - 180);

          ctx.font = '30px Arial';
          ctx.fillStyle = 'white'; // Set text color
          ctx.fillText('white circle - neuron = 1', 2200, canvas.height - 130);

          ctx.font = '30px Arial';
          ctx.fillStyle = '#7d7d7d'; // Set text color
          ctx.fillText('black circle - neuron = 0', 2200, canvas.height - 80);
        }

        const classifyloss = (loss) => {
          var text = '';
      
          if (loss == '') {
            text = '';
          }
          else if (loss <= 0.0025) {
            text = 'excellent';
          }
          else if (loss <= 0.01) {
            text = 'good';
          }
          else if (loss <= 0.04) {
            text = 'okay';
          }
          else {
            text = 'poor';
          }
      
          return text;
      
        };

        const drawaccuracy = (accuracy, classification) => {
          ctx.font = '40px Arial'; // Set font size and family
          ctx.fillStyle = 'white'; // Set text color
          ctx.fillText(`Accuracy (MSE): ${accuracy} - ${classification}`, canvas.width/2-780, canvas.height - 310);
        }

        // Draw neurons
        const drawstaticNeurons = () => {
            ctx.fillStyle = 'black'; // Set background color
            ctx.fillRect(0, 0, canvas.width, canvas.height); // Fill the canvas with black


            for (let i = 0; i < layers.length; i++) {

                // Calculate the starting position for the first neuron in the layer
                const startY = (requiredHeight - ((layers[i]) * verticalSpacing)) / 2;
                const startX = (canvas.width - ((layers.length - 1) * horizontalSpacing)) / 2 -50;

                for (let j = 0; j < layers[i]; j++) {
                  const x = startX + i * horizontalSpacing;

                  // const x = (i + 1) * horizontalSpacing; // Position in X direction
                  // const x = (i + 1) * (canvas.width / (layers.length + 1)); // Position in X direction
                  const y = startY + (j + 1) * verticalSpacing; // Position in Y direction
                  
                  const neuron_radius = 30

                  ctx.lineWidth = 1

                  // Draw neuron
                  ctx.beginPath();
                  ctx.arc(x, y, neuron_radius, 0, Math.PI * 2);
                  ctx.fillStyle = 'black';
                  ctx.fill();
                  ctx.strokeStyle = 'white';
                  ctx.stroke();

                  if (i == 0) {
                    ctx.font = '30px Arial'; // Set font size and family

                    // Set text properties
                    ctx.fillStyle = 'white'; // Set text color

                    // Display text
                    ctx.fillText(datalabels[j], startX-300, y);
                  }

                  if (i == layers.length-1) {
                    ctx.font = '30px Arial'; // Set font size and family

                    // Set text properties
                    ctx.fillStyle = 'white'; // Set text color

                    // Display text
                    ctx.fillText(datalabels[datalabels.length - 1], x-90, y-70);
                  }
            }
            }
          }
  
        // Draw weights/biases
        const drawstaticWeightsBiases = () => {
            for (let i = 0; i < layers.length - 1; i++) {
            // const horizontalSpacing = 300

            const startY = (requiredHeight - ((layers[i]) * verticalSpacing)) / 2;
        
            for (let j = 0; j < layers[i]; j++) {

                const startX = (canvas.width - ((layers.length - 1) * horizontalSpacing)) / 2 + i*horizontalSpacing -50;

                // const startX = 50 + i * horizontalSpacing;
        
                for (let k = 0; k < layers[i + 1]; k++) {
                    const endX = startX + horizontalSpacing;
                    const endY = (requiredHeight - (layers[i + 1] * verticalSpacing)) / 2 + (k + 1) * verticalSpacing; // Adjusted endY for neuron centering
                    
                    const neuron_radius = 30

                    // Draw line between neurons
                    ctx.beginPath();
                    ctx.moveTo(startX + neuron_radius, startY + (j + 1) * verticalSpacing); // Adjusted startY for neuron centering
                    ctx.lineTo(endX - neuron_radius, endY);
                    ctx.strokeStyle = 'white';
                    ctx.lineWidth = 1;
                    ctx.stroke();
                }
            }
            }
        };
        
        
        if (neuroncolors.length == 0) {
            drawstaticNeurons();
            drawstaticWeightsBiases();
        } else {
            let frameCount = 0;
            let layerCount = 0;
            let lastRenderTime = 0;
            let iterationCount = 0;
            let animationFrameId
            const fps = 4; // Desired frames per second

            const render = (timestamp) => {

                // Calculate the time elapsed since the last frame
                const deltaTime = timestamp - lastRenderTime;

                // Check if enough time has elapsed to render the next frame based on desired FPS
                if (deltaTime > 1000 / fps) {
                    layerCount = frameCount % 5
                    iterationCount = Math.min(39, (frameCount - layerCount)/5)

                    drawNeurons(neuroncolors[iterationCount], [...Array(layerCount).keys()]);
                    drawWeightsBiases(linewidths[iterationCount]);
                    progressText(iterationCount)
                    drawlosses(losses, maxlosses, iterationCount)

                    if (frameCount >= 199) {
                      drawaccuracy(accuracy, classifyloss(accuracy));
                      draw_test_samples(test_samples);
                    }
                    else {
                      drawaccuracy('running...', '');
                      draw_test_samples('running...');
                    }

                    draw_legend();

                    // draw_test_samples(frameCount);

                    // if (layerCount == 3) {
                    //     drawNeurons(neuroncolors[frameCount]);
                    //     drawWeightsBiases(linewidths[frameCount]);
                    // }
                        
                    lastRenderTime = timestamp;
                    frameCount++;
                }

                // Continue the animation loop
                if (frameCount < 200) {
                    animationFrameId = window.requestAnimationFrame(render);
                }
            };

            render();
            
        }
  
    }, [layers, linewidths, neuroncolors, losses, maxlosses, datalabels, accuracy, test_samples]);
  
    return <canvas ref={canvasRef} style={{ width: "100%"}}></canvas> ;
  }

function NNsimulator() {
    const [layers, setLayers] = useState([1, 2, 2, 1]); // Number of neurons in each layer
    const [dataset, setDataset] = useState('')
    const [iolayers, setIOLayers] = useState([1,1])
    const [hiddenlayers, setHiddenLayers] = useState([2, 2])
    const [accuracy, setAccuracy] = useState(1)
    const [colorswidths, setColorsWidths] = useState([[],[]])
    const [dataSent, setDataSent] = useState(false);
    const [simulatorReady, setSimulatorReady] = useState(true);
    const [staticmodel, setStaticModel] = useState(true);
    const [dataLabels, setDataLabels] = useState(['(input type)', '(output type)']);

    const handleLayerChange = (event, index) => {
      if (!dataSent){
        let newValue = event.target.value.trim() === '' ? 2 : parseInt(event.target.value);
        
        if (newValue > 20) {
          newValue -= 20
        }

        newValue = Math.min(Math.max(newValue, 2), 16);
  
        const newHiddenLayers = [...hiddenlayers];
        newHiddenLayers[index] = newValue;
        setHiddenLayers(newHiddenLayers);
  
        const newLayers = [...layers];
        newLayers[1] = newHiddenLayers[0];
        newLayers[2] = newHiddenLayers[1];
        setLayers(newLayers);
      }
    };

    const changeNeuron = (difference, index) => {
        if (!dataSent){
        const newHiddenLayers = [...hiddenlayers];
        newHiddenLayers[index] = newHiddenLayers[index] + difference;
        if (newHiddenLayers[index] < 2) {
            newHiddenLayers[index] = 2
        }
        if (newHiddenLayers[index] > 16) {
            newHiddenLayers[index] = 16
        }
        setHiddenLayers(newHiddenLayers);
  
        const newLayers = [...layers];
        newLayers[1] = newHiddenLayers[0];
        newLayers[2] = newHiddenLayers[1];
        setLayers(newLayers);
        };
    };
  
    const sendDatasetToBackend = (dataset) => {
  
      fetch(home_endpoint, {
        method: 'POST',
        headers: {
          'Content-Type': 'application/json',
        },
        body: JSON.stringify({ dataset }),
      })
        .then(response => response.json())
        .then(data => setIOLayers(data.iolayers))
        .catch(error => console.error('Error:', error));

        if (dataset === 'iris_dataset') {
          setDataLabels(['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)', 'class of iris plant'])
        }

      // setInputLayer(input);
      // const new_layers = [...layers];
      // new_layers[0] = input;
      // new_layers[3] = output
      // setLayers(new_layers);
    };
  
    const handleDatasetChange = (event) => {
  
      if (!dataSent){
      const new_dataset = event.target.value;
      setDataset(event.target.value);
  
      if (new_dataset != 'none') {
        sendDatasetToBackend(new_dataset);
  
        // const newLayers = [...layers];
        // newLayers[0] = inputlayer
        // newLayers[3] = outputlayer
        // setLayers(newLayers);
  
        // sendDatasetToBackend(new_dataset);
  
        setSimulatorReady(false);
      } else {
        setSimulatorReady(true);
        setIOLayers([1, 1]);
      }
  
      }
    };
  
    const sendLayersDataToBackend = () => {
      layers[0] = iolayers[0];
      layers[3] = iolayers[1];
  
      fetch(home_endpoint, {
        method: 'POST',
        headers: {
          'Content-Type': 'application/json',
        },
        body: JSON.stringify({ layers, dataset }),
      })
        .then(response => response.json())
        .then(data => setColorsWidths(data.colorswidths))
        .catch(error => console.error('Error:', error));
  
      setDataSent(true);
      setStaticModel(false);
  
    };

    return(
      <div className="App">
      <h2>Neural Network Simulator:</h2>
      <br />
      <br />
      <div>
      <label>Select dataset:</label>
        <select value={dataset} onChange={(event) => {handleDatasetChange(event)}} disabled={dataSent}>
          <option value="none">None</option>
          <option value="iris_dataset">Iris Dataset</option>
        </select>
      </div>
      <br />
      <br />
      <div className = "inline-container">
        <label>Input Layer: {iolayers[0]}</label>
        <div className = "component" index = {hiddenlayers[0]}>
        <label> Hidden Layer 1:</label>
        <br />
        <button className = "circularbutton" onClick={(event) => changeNeuron(-1, 0)}>-</button>
        <label> {hiddenlayers[0]} </label>
        <button className = "circularbutton" onClick={(event) => changeNeuron(1, 0)}>+</button>
        </div>
        <div className = "component" index = {hiddenlayers[1]}>
        <label> Hidden Layer 2:</label>
        <br />
        <button className = "circularbutton" onClick={(event) => changeNeuron(-1, 1)}>-</button>
        <label> {hiddenlayers[1]} </label>
        <button className = "circularbutton" onClick={(event) => changeNeuron(1, 1)}>+</button>
        </div>
        <div className = "component">
        <label>Output Layer: {iolayers[1]} </label>
        </div>
      </div>
      <br />
      <br />
      <div>
      <button onClick={sendLayersDataToBackend} disabled={dataSent || simulatorReady}>Run Neural Network</button>
      </div>
      <br />
      <div style={{ width: '100%', overflowX: 'scroll' }}>
      <NeuralNetwork layers={[iolayers[0], hiddenlayers[0], hiddenlayers[1], iolayers[1]]} linewidths = {colorswidths[1]} neuroncolors = {colorswidths[0]} losses = {colorswidths[3]} maxlosses = {colorswidths[4]} datalabels = {dataLabels} accuracy = {colorswidths[2]/10000} test_samples = {colorswidths[5]}/>
      </div>
      <br />
      <br />
      <br />
      <h2>About this simulator</h2>
      <p>This simulator visualizes the training process of a basic Neural Network. It is composed of four layers: an input layer, two hidden layers, and an output layer. At the top, select a dataset to train your Neural Network on, and then specify the number of neurons in each of the hidden layers (between 2 and 16). Then, click Run Neural Network, and watch as your Neural Network trains itself. The green progress bar at the top of the visual shows the progression of the training process. The Neural Network itself is visualized, where the circles represents neurons and the lines represent the weights connecting each neuron. The intensity that each of the neuron lights up to shows the activation of the neuron, while the widths of each connecting line shows the magnitude of the weight. At the bottom, the red bars show the loss yielded by the Neural Network at each iteration. Below, you can find the final accuracy yielded by the Neural Network. Try to see if you can make this value as low as you can.</p>
    </div>
    );
  };
  
  export default NNsimulator;