// External libraries
import * as tf from '@tensorflow/tfjs';
import '@tensorflow/tfjs-backend-webgpu';

// Internal modules
import { 
  keypointNames, 
  PipelineType, 
  MODEL_CONFIGS,
  PoseModelOutputType
} from '../utils/constants';
import { KeypointTracker } from './smoothing';


///////////////////////////////////////////////////////////////////////////////
/// UTILS
///////////////////////////////////////////////////////////////////////////////

function collateLatencyInfo(totalLatency, detectionLatency, poseEstimationLatency) {
  return {
    "detectionLatency": detectionLatency,
    "transformLatency": 0,
    "poseEstimationLatency": poseEstimationLatency,
    "postProcessLatency": 0,
    "drawLatency": 0,
    "totalLatency": totalLatency,
    "fps": Math.round(1000 / totalLatency)
  }
}

function bboxToROI(bbox, frameHeight, frameWidth) {
  /// Bbox is in the form [x, y, width, height], returns the ROI in the form [y1, x1, y2, x2]
  return {
    "position": [[
      bbox[1] / frameHeight,
      bbox[0] / frameWidth,
      (bbox[1] + bbox[3]) / frameHeight,
      (bbox[0] + bbox[2]) / frameWidth,
    ]],
    "width": bbox[2],
    "height": bbox[3]
  }
}


///////////////////////////////////////////////////////////////////////////////
/// PREPROCESSING
///////////////////////////////////////////////////////////////////////////////

function toInputTensor(input) {
  return input instanceof tf.Tensor ? input : tf.browser.fromPixels(input);
}

function preparePoseEstimatorInputs(video, roi, modelSettings) {
  /// Prepares the input for the pose estimator model.

  let tensor = toInputTensor(video);
  tensor = tf.expandDims(tensor, 0);
  tensor = tf.image.cropAndResize(
    tensor, 
    roi.position, 
    [0], 
    [MODEL_CONFIGS[modelSettings.modelName].inputHeight, MODEL_CONFIGS[modelSettings.modelName].inputWidth]
  );

  // Transposes from [1,img_height,img_width,3] to [1,3,img_height,img_width]. Channels-first is the 
  // PyTorch convention, and the model was originally built in PyTorch.
  tensor = tf.transpose(tensor, [0, 3, 1, 2]);

  // Scale the input values to 0 and 1
  tensor = tensor.div(255.0);

  // Normalize the values to expected means and stds
  let means = tf.tensor1d([0.485, 0.456, 0.406]).reshape([1, 3, 1, 1]);
  let stds = tf.tensor1d([0.229, 0.224, 0.225]).reshape([1, 3, 1, 1]);
  tensor = tensor.sub(means).div(stds);
   
  return tensor;
}


///////////////////////////////////////////////////////////////////////////////
/// POSTPROCESSING
///////////////////////////////////////////////////////////////////////////////


function getHeatmapCoordsAndScores(heatmaps) {
  /// Gets the coordinates and scores of the highest activated pixels 
  /// in the heatmaps predictions.

  const [N, K, H, W] = heatmaps.shape;
  const heatmapsReshaped = heatmaps.reshape([N, K, -1]);
  const idx = heatmapsReshaped.argMax(2).reshape([N, K, 1]);
  const scores = heatmapsReshaped.max(2).reshape([N, K, 1]);

  let heatmapCoords = tf.tile(idx, [1, 1, 2]).toFloat();
  
  const slice0 = heatmapCoords.slice([0, 0, 0], [heatmapCoords.shape[0], heatmapCoords.shape[1], 1]).mod(W);
  const slice1 = heatmapCoords.slice([0, 0, 1], [heatmapCoords.shape[0], heatmapCoords.shape[1], 1]).div(W).floor().toInt();

  heatmapCoords = tf.concat([slice0, slice1], 2)

  const condition = tf.tile(scores, [1, 1, 2]).greater(0)
  heatmapCoords = tf.where(condition, heatmapCoords, -1)

  return {
    heatmapCoords: heatmapCoords,
    scores
  };
}

async function heatmapToKeypoints(heatmaps, roi, frameWidth, frameHeight) {
  /// Converts the heatmap output of the model to a list of keypoints.
  
  const prediction = getHeatmapCoordsAndScores(heatmaps);
  
  // The slice operation extracts the X coordinates only
  // The multiply term shifts from heatmap indexing to detection indexing
  // The addition term shifts from detection indexing to full image indexing
  // The round op ensures we have discrete value since we are referring to pixel space
  const adjustedKeypointsX = prediction.heatmapCoords
    .slice([0, 0, 0], [1, prediction.heatmapCoords.shape[1], 1])
    .mul(roi.width / heatmaps.shape[3])
    .add(roi.position[0][1] * frameWidth)
    .round();

  // The slice operation extracts the Y coordinates only
  const adjustedKeypointsY = prediction.heatmapCoords
    .slice([0, 0, 1], [1, prediction.heatmapCoords.shape[1], 1])
    .mul(roi.height / heatmaps.shape[2])
    .add(roi.position[0][0] * frameHeight)
    .round();

  // Contats the adjusted X and Y values along with the keypoint score
  const keypointsTensor = tf.concat([adjustedKeypointsX, adjustedKeypointsY, prediction.scores], 2)
  const keypointsData = await keypointsTensor.squeeze().data();
  const keypoints = [];

  for (let kpt = 0; kpt < prediction.heatmapCoords.shape[1]; kpt++) {
    const keypoint = {
      "position": {
        "x": keypointsData[kpt*3], 
        "y": keypointsData[kpt*3 + 1],
      },
      "part": keypointNames[kpt],
      "score": keypointsData[kpt*3 + 2]
    };

    keypoints.push(keypoint);

  }

  return keypoints;
}

async function simccToKeypoints(simccOutputs, roi, frameWidth, frameHeight) {
  // X Coordinates
  // Argmax finds the x coordinates that are most probable
  // The multiply term shifts from heatmap indexing to detection indexing
  // The addition term shifts from detection indexing to full image indexing
  // The round op ensures we have discrete value since we are referring to pixel space
  const simccX = simccOutputs[0]
    .argMax(2)
    .mul(roi.width / 384.0)
    .add(roi.position[0][1] * frameWidth)
    .round()

  // Y Coordinates - Same as above
  const simccY = simccOutputs[1]
    .argMax(2)
    .mul(roi.height / 512.0)
    .add(roi.position[0][0] * frameHeight)
    .round()

  // Confidence Scores
  // Each keypoint has a confidence associated with x coord and y coord. computes the mean
  const scores = simccOutputs[0].max(2).add(simccOutputs[1].max(2)).div(2)

  // Fetch from GPU
  const keypointData = await tf.concat([simccX, simccY, scores], 0).data()
  
  // Reformat and return
  const keypoints = [];
  const numKeypoints = keypointData.length / 3
  for (let kpt = 0; kpt < numKeypoints; kpt++) {
    const keypoint = {
      "position": {
        "x": keypointData[kpt],
        "y": keypointData[kpt + numKeypoints]
      },
      "part": keypointNames[kpt],
      "score": keypointData[kpt + (numKeypoints * 2)]
    };
    keypoints.push(keypoint);
  }

  return keypoints;
}

function updateKeypointTracks(keypointTracker, keypoints, prevKeypoints, keypointConfidenceThreshold) {
  /// Updates the KeypointTracker with new keypoints for existing tracks, and resets tracks for lost keypoints.

  // If this is first tracked frame, initialize the tracker.
  if (keypointTracker === null || prevKeypoints === null) {
    const kpts = keypoints.map(keypoint => [keypoint.position.x, keypoint.position.y])
    return [new KeypointTracker(kpts), keypoints];
  } 

  // Update the tracks for all keypoints.
  prevKeypoints.forEach((keypoint, idx) => {
    if (keypoint.score >= keypointConfidenceThreshold) {
      // Update the track of the keypoint if this is not first detection. Update the keypoints array in-place with the smoothed signal.
      const smoothedTrajectory = keypointTracker.updateKeypointTrack(idx, [keypoints[idx].position.x, keypoints[idx].position.y])
      keypoints[idx].position.x = smoothedTrajectory[0]
      keypoints[idx].position.y = smoothedTrajectory[1]
    } else {
      // Reset the track if the keypoint was not detected in last frame. No updates to keypoints array needed.
      keypointTracker.resetKeypointTrack(idx, [keypoints[idx].position.x, keypoints[idx].position.y])
    }
  });

  return [keypointTracker, keypoints]
}


///////////////////////////////////////////////////////////////////////////////
/// MODEL EXECUTION
///////////////////////////////////////////////////////////////////////////////


async function runDetector(video, detector) {
  //
  // Run detection step

  let detections = await detector.detect(video);

  // Removes all non-person object detections
  detections = detections.filter(function(detection) {
    return detection.class === "person";
  });

  // Sorts by detections by confidence
  detections = detections.sort(function(a,b) {
    return b.score - a.score;
  });

  // If no person detected, return nothing
  if (detections.length === 0) {
    return null
  }

  //
  // If > 0 persons detected, return the bounding box to the highest confidence detection

  // bbox is of form [x, y, width, height] where x,y represents the top left.
  // Reminder: Image origin here is top left. (0, 0)
  const bbox = detections[0].bbox

  return bbox
}

async function predictOrInferBbox(video, detector, modelSettings, prevKeypoints) {
  /// This method predicts the bounding box for the next frame. If the previous frame 
  /// had a detection, it uses the keypoints to predict the next bounding box. Otherwise, 
  /// it runs the detector to get the bounding box for the next frame.

  // Filter out low confidence keypoints from the previous frame.
  const filteredKeypoints = prevKeypoints?.filter(keypoint => keypoint.score >= modelSettings.keypointConfidenceThreshold) || [];
  var bboxCenter = null;

  if (filteredKeypoints.length === 0) {
    /// No previous detection, run detector
    const detectionBbox = await runDetector(video, detector);

    // End the pipeline if no person is detected.
    if (detectionBbox == null) { return [null, null]; }

    // Compute the center of the bbox for tracking
    bboxCenter = [
      detectionBbox[0] + (detectionBbox[2] / 2),
      detectionBbox[1] + (detectionBbox[3] / 2)
    ]
  } else {
    /// We have a previous detection. Utilize the previous keypoints to infer new bounding box.

    // Compute average position of all keypoints
    let x = filteredKeypoints.reduce((sum, keypoint) => sum + keypoint.position.x, 0) / filteredKeypoints.length;
    let y = filteredKeypoints.reduce((sum, keypoint) => sum + keypoint.position.y, 0) / filteredKeypoints.length;
    
    // Compute the center of the inffered bbox for tracking
    bboxCenter = [x, y];
  }

  /// Compute the Bbox from the BBOX center. In this setup, we stretch the detection bbox to fit the bounds of the frame with a fixed aspect ratio.

  const inferredBboxHeight = video.videoHeight;
  const inferredBboxWidth = (video.videoHeight / MODEL_CONFIGS[modelSettings.modelName].inputHeight) * MODEL_CONFIGS[modelSettings.modelName].inputWidth

  // bbox is of form [x, y, width, height] where x,y represents the top left.
  // Reminder: Image origin here is top left. (0, 0)
  const bbox = [
    Math.floor(bboxCenter[0] - inferredBboxWidth / 2),
    0,
    inferredBboxWidth,
    inferredBboxHeight
  ]

  return bbox;
}

async function runPoseEstimator(video, poseModel, bbox, modelSettings) {
  /// This method extracts the bounding box from the video frame and runs the pose 
  /// estimation model on the ROI of the bounding box.

  // Represents the bounding box in the form [y1, x1, y2, x2]
  const roi = bboxToROI(bbox, video.videoHeight, video.videoWidth)

  // Prep Pose Model Input
  const poseInputTensor = preparePoseEstimatorInputs(video, roi, modelSettings);

  // Run Pose Estimation
  const modelOutputs = poseModel.execute(poseInputTensor);

  // Postprocess Pose outputs
  let keypoints;
  if (MODEL_CONFIGS[modelSettings.modelName].outputType == PoseModelOutputType.HEATMAP) {
    keypoints = await heatmapToKeypoints(modelOutputs, roi, video.videoWidth, video.videoHeight);
  } else if (MODEL_CONFIGS[modelSettings.modelName].outputType == PoseModelOutputType.SIMCC) {
    keypoints = await simccToKeypoints(modelOutputs, roi, video.videoWidth, video.videoHeight)
  }

  return keypoints;
}


///////////////////////////////////////////////////////////////////////////////
/// PIPELINES
///////////////////////////////////////////////////////////////////////////////


  export async function infer(video, detector, poseModel, modelSettings, prevPredictions) {
  //
  // Initialization
  tf.engine().startScope();
  
  let prevKeypoints = prevPredictions ? prevPredictions.keypoints : null;
  let keypointTracker = prevPredictions ? prevPredictions.keypointTracker : null;
  
  //
  // Detection
  const start = performance.now()

  var bbox = null;
  if (modelSettings.pipelineType == PipelineType.DetectAndPose) {
  
    // This pipeline type runs detection on every frame, allowing it to have a tighter bounding box 
    // around the person, giving up some latency to do so.
    bbox = await runDetector(video, detector);
  
  } else if (modelSettings.pipelineType == PipelineType.PoseAndTrack) {

    // This pipeline type uses the keypoint predictions from the previous image to infer a bounding 
    // box. If it is unable to do so, it will run the detector on this frame.
    bbox = await predictOrInferBbox(video, detector, modelSettings, prevKeypoints);

  } else if (modelSettings.pipelineType == PipelineType.PoseOnly) {

    // This pipeline type uses the entire image as input into the pose model. Skipping the detection stage.
    bbox = [0, 0, video.videoWidth, video.videoHeight];
  }
  
  const endDetection = performance.now();

  // If no person is detected, the pipeline has complete.
  if (bbox === null) {
    // cleanup and exit
    tf.engine().endScope();
    return {
      "bbox": null, 
      "keypoints": null,
      "keypointTracker": null,
      "latencyInfo": null
    }
  }

  //
  // Pose Estimation
  var rawKeypoints = await runPoseEstimator(video, poseModel, bbox, modelSettings);

  // 
  // Keypoint Smoothing (optional)
  let [tracker, keypoints] = modelSettings.keypointSmoothingMode ? updateKeypointTracks(keypointTracker, rawKeypoints, prevKeypoints, modelSettings.keypointConfidenceThreshold) : [null, rawKeypoints];

  //
  // Cleanup
  tf.engine().endScope();
  const end = performance.now()

  const response = {
    "bbox": bbox, 
    "keypoints": keypoints,
    "keypointTracker": tracker,
    "latencyInfo": collateLatencyInfo(
      Math.round(end-start), 
      Math.round(endDetection-start),
      Math.round(end-endDetection)
    )
  }

  return response;
}

///////////////////////////////////////////////////////////////////////////////
/// DEBUG - TODO: Remove this section
///////////////////////////////////////////////////////////////////////////////

async function runPoseEstimatorDebug(video, poseModel, bbox, modelSettings) {
  /// This method extracts the bounding box from the video frame and runs the pose 
  /// estimation model on the ROI of the bounding box.

  const startTick = performance.now()

  // Represents the bounding box in the form [y1, x1, y2, x2]
  const roi = bboxToROI(bbox, video.videoHeight, video.videoWidth)

  // Prep Pose Model Input
  const poseInputTensor = preparePoseEstimatorInputs(video, roi, modelSettings);
  const poseInputTensorSync = await poseInputTensor.data()
  const posePreprocessTick = performance.now()

  // Run Pose Estimation

  const modelOutputs = poseModel.execute(poseInputTensor);

  // Postprocess Pose outputs
  let keypoints;
  let poseInferenceTick;
  if (MODEL_CONFIGS[modelSettings.modelName].outputType == PoseModelOutputType.HEATMAP) {
    const modelOutputsSync = await modelOutputs.data()
    poseInferenceTick = performance.now()
    keypoints = await heatmapToKeypoints(modelOutputs, roi, video.videoWidth, video.videoHeight);
  } else if (MODEL_CONFIGS[modelSettings.modelName].outputType == PoseModelOutputType.SIMCC) {
    const modelOutputsSync = await modelOutputs[0].data()
    poseInferenceTick = performance.now()
    keypoints = await simccToKeypoints(modelOutputs, roi, video.videoWidth, video.videoHeight)
  }
  // Don't need to sync again here as the heatmapsToKeypoints already does
  const posePostProcessTick = performance.now()

  console.log("Pose Latencies:", {
    'pose-preprocess': Math.round(posePreprocessTick - startTick),
    'pose-inference': Math.round(poseInferenceTick - posePreprocessTick),
    'pose-postprocess': Math.round(posePostProcessTick - poseInferenceTick),
  })
  
  return keypoints
}

export async function inferDebug(video, detector, poseModel, modelSettings, prevPredictions) {
  //
  // Initialization
  tf.engine().startScope();
  
  let prevKeypoints = prevPredictions ? prevPredictions.keypoints : null;
  let keypointTracker = prevPredictions ? prevPredictions.keypointTracker : null;
  
  //
  // Detection
  const startTick = performance.now()

  var bbox = null;
  if (modelSettings.pipelineType == PipelineType.DetectAndPose) {
  
    // This pipeline type runs detection on every frame, allowing it to have a tighter bounding box 
    // around the person, giving up some latency to do so.
    bbox = await runDetector(video, detector);
  
  } else if (modelSettings.pipelineType == PipelineType.PoseAndTrack) {

    // This pipeline type uses the keypoint predictions from the previous image to infer a bounding 
    // box. If it is unable to do so, it will run the detector on this frame.
    bbox = await predictOrInferBbox(video, detector, modelSettings, prevKeypoints);

  } else if (modelSettings.pipelineType == PipelineType.PoseOnly) {

    // This pipeline type uses the entire image as input into the pose model. Skipping the detection stage.
    bbox = [0, 0, video.videoWidth, video.videoHeight];
  }

  const detectionTick = performance.now()
  console.log("Detection Latency:", Math.round(detectionTick - startTick))

  // If no person is detected, the pipeline has complete.
  if (bbox === null) {
    // cleanup and exit
    tf.engine().endScope();
    return {
      "bbox": null, 
      "keypoints": null,
      "keypointTracker": null,
      "latencyInfo": null
    }
  }

  //
  // Pose Estimation
  var rawKeypoints = await runPoseEstimatorDebug(video, poseModel, bbox, modelSettings);

  // 
  // Keypoint Smoothing (optional)
  const startSmoothTick = performance.now()
  let [tracker, keypoints] = modelSettings.keypointSmoothingMode ? updateKeypointTracks(keypointTracker, rawKeypoints, prevKeypoints) : [null, rawKeypoints];
  console.log("Keypoint smoothing:", Math.round(performance.now() - startSmoothTick))
  //
  // Cleanup
  tf.engine().endScope();
  const endTick = performance.now()

  const response = {
    "bbox": bbox, 
    "keypoints": keypoints,
    "keypointTracker": tracker,
    "latencyInfo": collateLatencyInfo(
      Math.round(endTick-startTick), 
      Math.round(detectionTick-startTick),
      Math.round(endTick-detectionTick)
    )
  }

  return response;
}