import {useEffect, useMemo, useRef, useState, MutableRefObject} from "react";
import {Pose, VERSION} from '@mediapipe/pose';
import {Camera} from '@mediapipe/camera_utils';
import {drawConnectors, drawLandmarks} from '@mediapipe/drawing_utils';
import * as tf from "@tensorflow/tfjs";
import {
  Chart as ChartJS,
  CategoryScale,
  LinearScale,
  BarElement,
  Title,
  Tooltip,
  Legend,
} from 'chart.js';
import { Bar } from 'react-chartjs-2';

import useStateRef from '../hooks/useStateRef';
import useCountdown from '../hooks/useCountdown';
import {CUSTOM_POSE_CONNECTIONS, CUSTOM_POSE_LANDMARKS} from "../constants/POSE";

ChartJS.register(
  CategoryScale,
  LinearScale,
  BarElement,
  Title,
  Tooltip,
  Legend,
);

type Video = HTMLVideoElement;
type Canvas = HTMLCanvasElement;
type Context = CanvasRenderingContext2D;
type RefType<T> = MutableRefObject<T | null>;

interface LandmarkType {
  x: number;
  y: number;
  z: number;
}

export default function Mediapipe() {
  // console.log('init.');
  const camera: RefType<Camera> = useRef<Camera | null>(null);
  const video: RefType<Video> = useRef<Video | null>(null);
  const canvas: RefType<Canvas> = useRef<Canvas | null>(null);
  const context: RefType<Context> = useRef<Context | null>(null);

  const {counter, start} = useCountdown(5);

  /*
   * setup
   */

  function setupCamera(): void {
    video.current = document.getElementById('input-video') as HTMLVideoElement
    camera.current = new Camera(video.current, {
      onFrame: handleCameraByFrame,
      width: 1920 / 2,
      height: 1080 / 2,
    });
    camera.current.start();
  }

  function setupCanvas() {
    canvas.current = document.getElementById('output-canvas') as HTMLCanvasElement;
    context.current = canvas.current.getContext('2d') as CanvasRenderingContext2D;
  }

  useEffect(() => {
    setupCamera();
    setupCanvas();
  }, []); // eslint-disable-line react-hooks/exhaustive-deps

  /*
   * camera
   */

  const [burstCount, burstCountRef, setBurstCount] = useStateRef<number>(0);
  const lastBurstTime = useRef<number>(0);

  function startCountdown() {
    const name = prompt('你即將要擺出的動作名稱是：');
    if (name === null || name.length === 0) return;

    setActionNames([...actionNames, name]);
    start(() => {
      setBurstCount(10);
    });
  }

  async function handleCameraByFrame() {
    if (!video.current) return;

    await pose.send({image: video.current});
  }

  /*
   * mediapipe:pose
   */

  const pose: Pose = useMemo<Pose>(() => {
    return new Pose({
      locateFile: file => {
        // console.log('locateFile', `https://cdn.jsdelivr.net/npm/@mediapipe/pose@${VERSION}/${file}`);
        return `https://cdn.jsdelivr.net/npm/@mediapipe/pose@${VERSION}/${file}`;
      },
    });
  }, []);

  function handlePoseResult(result: any) {
    if (!video.current || !canvas.current || !context.current) return;

    context.current.save();
    context.current.clearRect(0, 0, canvas.current.width, canvas.current.height);
    context.current.drawImage(video.current, 0, 0, canvas.current.width, canvas.current.height);

    if (!result.poseLandmarks) {
      return;
    }

    // predict

    if (isTrainedRef.current && model.current) {
      const predictTensor = tf.tensor2d([convertLandmarks(result.poseLandmarks)]);
      // console.log(predictTensor.print());
      const resultTensor: tf.Tensor = model.current.predict(predictTensor) as tf.Tensor;
      const resultArray = resultTensor.dataSync();
      setPredictResult(getResult(resultArray));
    }

    // burst photos

    if (burstCountRef.current && burstCountRef.current > 0) {
      if (Date.now() - lastBurstTime.current > 100) {
        // console.log('click!');
        lastBurstTime.current = Date.now();
        addDataToInput(result.poseLandmarks);
        setBurstCount(burstCountRef.current - 1);
      }
    }

    drawConnectors(
      context.current,
      result.poseLandmarks,
      CUSTOM_POSE_CONNECTIONS,
      {visibilityMin: 0.65, color: 'white'},
    );
    drawLandmarks(
      context.current,
      CUSTOM_POSE_LANDMARKS.map(index => result.poseLandmarks[index]),
      {visibilityMin: 0.65, color: 'white', fillColor: 'rgb(255,138,0)'},
    );
  }

  useEffect(() => {
    pose.setOptions({
      modelComplexity: 0,
    });
    pose.onResults(handlePoseResult);
  }, []); // eslint-disable-line react-hooks/exhaustive-deps

  /*
   * tensorflow
   */

  const [actionNames, actionNamesRef, setActionNames] = useStateRef<string[]>([]);
  const [isTrained, isTrainedRef, setIsTrained] = useStateRef<boolean>(false);
  const [predictResult, setPredictResult] = useState<string | null>(null);
  const trainingInput = useRef<number[][]>([]);
  const model = useRef<tf.Sequential | null>(null);

  function addDataToInput(landmarks: LandmarkType[]) {
    trainingInput.current.push(convertLandmarks(landmarks));
  }

  function convertLandmarks(landmarks: LandmarkType[]) {
    return landmarks
      .filter((_, index) => CUSTOM_POSE_LANDMARKS.includes(index))
      .map(landmark => [landmark.x, landmark.y, landmark.z]).flat();
  }

  function getUnique(input: string[]) {
    return input.filter((value, index, array) => array.indexOf(value) === index);
  }

  function getTrainingOutput(): number[][] {
    const unique = getUnique(actionNames);
    return actionNames.map(item => Array(10).fill(item))
      .flat()
      .map(item => unique.map(format => item === format ? 1 : 0));
  }

  async function trainModel() {
    const trainingOutput = getTrainingOutput();
    const trainingOutputCount = getUnique(actionNames).length;
    const trainingInputTensor = tf.tensor2d(trainingInput.current);
    const trainingOutputTensor = tf.tensor2d(trainingOutput);

    // console.log('input');
    // console.log(trainingInputTensor.shape);
    // console.log(trainingInput.current);
    // console.log('output');
    // console.log(trainingOutputTensor.shape);
    // console.log(trainingOutput);

    model.current = tf.sequential();

    model.current.add(tf.layers.dense({
      inputShape: [CUSTOM_POSE_LANDMARKS.length * 3],
      activation: 'sigmoid',
      units: 100,
    }));
    model.current.add(tf.layers.dense({
      inputShape: [CUSTOM_POSE_LANDMARKS.length * 3 + 10],
      activation: 'sigmoid',
      units: trainingOutputCount,
    }));
    model.current.add(tf.layers.dense({
      activation: 'sigmoid',
      units: trainingOutputCount,
    }));
    model.current.compile({
      loss: 'meanSquaredError',
      optimizer: tf.train.adam(.06),
    });

    // console.log('starting to train model...');
    await model.current.fit(trainingInputTensor, trainingOutputTensor, {epochs: 50});
    setIsTrained(true);
    // console.log('model trained.');
  }

  function getResult(array: Float32Array | Int32Array | Uint8Array): string | null {
    const nums: number[] = Array.from(array);
    const max: number = Math.max(...nums);
    const index: number = nums.indexOf(max);
    const unique = getUnique(actionNamesRef.current);
    // console.log(unique);
    // console.log(JSON.stringify(nums));

    setChartData(nums);

    return unique[index];
    // if (max >= 1 / unique.length * 1.5) {
    // }

    // return null;
  }

  /*
   * Chart
   */

  const [chartData, setChartData] = useState<number[]>([]);

  return (
    <div className="bg-black absolute inset-0 flex justify-center">
      <video id="input-video" className="h-screen hidden" />
      <canvas id="output-canvas" className="h-screen" width="1280px" height="720px" />

      <div className="absolute inset-0 flex flex-col justify-center items-center">
        {counter > 0 && (
          <div className="text-white text-9xl drop-shadow">{counter}</div>
        )}
        <div className="text-white text-5xl drop-shadow">{predictResult}</div>
      </div>

      {isTrained || (
        <div className="absolute bottom-10 inset-x-10 flex flex-row justify-center items-center">
          <div className="p-2 border border-4 border-white rounded-full" onClick={startCountdown}>
            <div className="bg-white h-20 w-20 rounded-full flex justify-center items-center">
              {burstCount > 0 && (
                <div className="text-5xl text-black font-bold">{burstCount}</div>
              )}
            </div>
          </div>
        </div>
      )}

      {(!isTrained && actionNames.length >= 2) && (
        <div className="absolute bottom-10 right-10">
          <div
            onClick={trainModel}
            className="bg-white text-xl px-7 py-3 rounded-md cursor-pointer select-none"
          >
            Train
          </div>
        </div>
      )}

      {chartData.length > 0 && (
        <div className="absolute bottom-10 left-10 bg-white/60 backdrop-blur px-5 py-2 rounded-md">
          <Bar
            options={{
              responsive: true,
              plugins: {
                legend: {display: false},
                title: {display: false},
              },
              scales: {y: {min: 0,max: 1}}
            }}
            data={{
              labels: getUnique(actionNames),
              datasets: [
                {data: chartData, backgroundColor: '#ea5545'},
              ],
            }}
          />
        </div>
      )}
    </div>
  );
}
