import { Vector2, Vector3, Matrix4, Matrix3, Raycaster, MathUtils } from 'three';
import { throttle } from 'lodash';
import { cacheManager } from '../../cache-manager';
import { default as logger } from '../../logger';
import { get_cam_to_abs_tx, lin_mult, getCamToUiTx } from './utils/threeUtils';
import { commonConstants } from './const_params';
import { toolsEvents } from '../../event-bus/supportedKeys';
import { utils } from '../../utils';
import Rect from './classes/rect';
import { getImageCenter } from '../niri-manager.logic';

const {
  max_distance_to_selected_pt_mm,
  max_image_rotation_threshold,
  degreValues,
  distortionRadialThreshold,
  distance_between_view_rotation_center_and_its_projection,
} = commonConstants;

const cache = {
  rotation: 0,
  pointInsideSurface: new Vector3(),
};

const getRectOfImage = (roi) => {
  const { row_min, row_max, col_min, col_max, x_min, x_max, y_min, y_max } = roi;

  return new Rect({
    x_min: row_min !== undefined ? row_min : x_min,
    x_max: row_max !== undefined ? row_max : x_max,
    y_min: col_min !== undefined ? col_min : y_min,
    y_max: col_max !== undefined ? col_max : y_max,
  });
};

export const niriHasCariesParam = (jawsObject, jawName) =>
  jawsObject[jawName].images.some(({ niri }) => {
    const { caries_detection_score } = niri || {};
    return caries_detection_score >= 0;
  });

export const getImageInfoSource = (isEnableNGNiriImageSelectionAlgo, jawsObject, jawName, modelType) => {
  const isNiriHasCariesParam = niriHasCariesParam(jawsObject, jawName);
  const imageInfoSource = isEnableNGNiriImageSelectionAlgo && isNiriHasCariesParam ? 'niri' : 'color';

  if (isNiriHasCariesParam && isEnableNGNiriImageSelectionAlgo) {
    logger
      .info(`Caries detection algo is on for: ${modelType}`)
      .to(['analytics', 'host'])
      .data({ module: 'lumina algorithm', data: `Caries detection algo is on for: ${modelType}` })
      .end();
  }

  return imageInfoSource;
};

export const getCameraTransforms = (dataJson, imageInfoSource, camera_id) => {
  const { scan_to_cam_tx: averageScanToCamTx, camera_to_pixel: averageCameraToPixelTx, calibration_data } = dataJson;

  const cameraTransforms = {
    scan_to_cam_tx: averageScanToCamTx,
    camera_to_pixel: averageCameraToPixelTx,
    roi: {
      col_max: 960,
      col_min: 0,
      row_max: 540,
      row_min: 0,
    },
  };

  if (calibration_data) {
    const camera_calibration_data = Object.values(dataJson.calibration_data)
      .filter((cData) => cData.type === imageInfoSource)
      .find((cD) => cD.camera_id === camera_id);
    /**
     * this parameer will be used in future version of NIRI improvement.
     */
    //cameraTransforms.scan_to_cam_tx = camera_calibration_data.scan_to_cam_tx;
    cameraTransforms.camera_to_pixel = camera_calibration_data.camera_to_pixel;
    if (imageInfoSource === 'niri') {
      cameraTransforms.roi = {
        col_max: 480,
        col_min: 0,
        row_max: 270,
        row_min: 0,
      };
    }
  }

  return cameraTransforms;
};

export const getCameraDirection = (
  isEnableNGNiriImageSelectionAlgo,
  cam_to_abs_tx,
  cameraToPixel,
  rectOfImage,
  camera_dir
) => {
  if (isEnableNGNiriImageSelectionAlgo) {
    const cameraToPixelMatrix4 = new Matrix4().set(
      ...cameraToPixel.slice(0, 3),
      0,
      ...cameraToPixel.slice(3, 6),
      0,
      0,
      0,
      1,
      0,
      0,
      0,
      0,
      1
    );
    const pixelToCameraMatrix4 = cameraToPixelMatrix4.clone().invert();
    const roiCenter = rectOfImage.center();
    const centerOfRoiVector3 = new Vector3(roiCenter.y, roiCenter.x, 1);
    const multDir = lin_mult(centerOfRoiVector3, pixelToCameraMatrix4);
    return lin_mult(multDir, cam_to_abs_tx);
  }
  return camera_dir;
};

export const initializeImagesMetaData = ({
  currentActiveJaw,
  jawName,
  mesh,
  images_meta_data_array,
  dataJsonCacheKey,
  isEnableNGNiriImageSelectionAlgo,
  modelType,
}) => {
  const dataJson = cacheManager.get(dataJsonCacheKey);
  const { K_vector, P_vector } = dataJson;
  const jawsObject = {
    upper_jaw: dataJson.jaws.upper_jaw,
    lower_jaw: dataJson.jaws.lower_jaw,
  };

  const imageInfoSource = getImageInfoSource(isEnableNGNiriImageSelectionAlgo, jawsObject, jawName, modelType);
  const kVector = JSON.parse(K_vector);
  const pVector = JSON.parse(P_vector);

  for (let img_idx = 0; img_idx < currentActiveJaw.length; img_idx++) {
    const image_info = jawsObject[jawName].images[img_idx];
    if (!image_info) continue;

    const {
      timestamp,
      camera_id,
      scan_id,
      local_to_world_tx,
      roi,
      caries_detection_score = 0,
      caries_reference_points = [],
    } = image_info[imageInfoSource];

    // 1. Camera points and directions
    const { scan_to_cam_tx, camera_to_pixel, roi: defaultRoi } = getCameraTransforms(
      dataJson,
      imageInfoSource,
      camera_id
    );
    const cameraToPixel = JSON.parse(camera_to_pixel);

    const scanToCamTransform = new Matrix4().fromArray(JSON.parse(scan_to_cam_tx));
    const cameraToPixelTransform = new Matrix3()
      .fromArray(new Matrix3().fromArray(cameraToPixel).elements.map((x) => x || 0))
      .transpose();
    const localToWorldTransform = new Matrix4().fromArray(JSON.parse(local_to_world_tx));
    const cam_to_abs_tx = get_cam_to_abs_tx(scanToCamTransform, localToWorldTransform);
    const camera_pt = new Vector3(0, 0, 0).applyMatrix4(cam_to_abs_tx);
    const camera_dir = lin_mult(new Vector3(0, 0, 1), cam_to_abs_tx);

    const ROI = roi || defaultRoi;
    const rectOfImage = getRectOfImage(ROI);
    const roiCorrectionDirection = getCameraDirection(
      isEnableNGNiriImageSelectionAlgo,
      cam_to_abs_tx,
      cameraToPixel,
      rectOfImage,
      camera_dir
    );

    const imagesCenter = getImageCenter(camera_to_pixel);

    // 2. Projection of camera to surface
    const intersect = getRayIntersectCameraToModelSurface(camera_pt, roiCorrectionDirection, mesh);

    let was_cam_projected = false;
    let img_cen_on_surf_pt = camera_pt;
    let dist_from_cam_to_surf = Number.MAX_VALUE;
    let imageProjectedCenterDirection = new Vector2(0, 0);

    if (intersect) {
      was_cam_projected = true;
      img_cen_on_surf_pt = new Vector3().addVectors(
        camera_pt,
        roiCorrectionDirection.clone().multiplyScalar(intersect.distance)
      );

      // imageProjectedCenterDirection in raw image coordinates :: testing purpose
      imageProjectedCenterDirection = projectPointToImage(
        { point: img_cen_on_surf_pt },
        isEnableNGNiriImageSelectionAlgo,
        currentActiveJaw[img_idx],
        cameraToPixelTransform,
        kVector,
        pVector
      );

      // used in rail algo only
      const min_dist_from_surface = mesh.geometry.boundsTree.closestPointToPoint(camera_pt).distance;
      dist_from_cam_to_surf = min_dist_from_surface;
    }

    // reference points for caries detection
    const cariesReferencePoints = caries_reference_points.map(({ point, score }) => ({
      point: new Vector3(point.x, point.y, point.z),
      score,
    }));

    const images_meta_data = {
      // 1. Camera points and directions
      camera_pt,
      camera_dir,

      // 2. Projection of camera to surface
      was_cam_projected,
      img_cen_on_surf_pt,
      dist_from_cam_to_surf,
      cam_to_abs_tx,
      is_visible: false,
      scan_role: jawName,
      timestamp: timestamp,
      camera_id: camera_id,
      scan_id: scan_id,
      rect_of_image: rectOfImage,
      caries_detection_score,
      scan_to_cam_tx: scanToCamTransform,
      camera_to_pixel: cameraToPixelTransform,
      imagesCenter,
      K_vector: kVector,
      P_vector: pVector,
      image_id: scan_id * 6 + camera_id,
      imageProjectedCenterDirection,
      caries_reference_points: cariesReferencePoints,
    };

    images_meta_data_array.push(images_meta_data);
  }

  return images_meta_data_array;
};

export const getRayIntersectCameraToModelSurface = (origin, direction, mesh) => {
  const rayCaster = new Raycaster(origin, direction.normalize(), 0, Infinity);
  rayCaster.firstHitOnly = true;
  const intersects = rayCaster.intersectObjects([mesh]);

  if (intersects && intersects.length > 0) {
    const intersect = intersects[0];
    return intersect;
  }

  return null;
};

export const getIsProjectedPointOnSurface = throttle((point, pointInsideSurface, mesh) => {
  let distance = Number.MAX_VALUE;
  const pointOnSurface = point.clone();
  const rayCaster = new Raycaster(pointInsideSurface, pointOnSurface.normalize(), 0, Infinity);
  rayCaster.firstHitOnly = true;

  const intersects = rayCaster.intersectObjects([mesh]);
  if (intersects && intersects.length > 0) {
    distance = intersects[0].distance;

    if (distance > max_distance_to_selected_pt_mm) {
      return pointInsideSurface;
    } else {
      const { point } = intersects[0];
      return point;
    }
  }
  return point;
}, utils.getValueByBrowser());

export const calculatePointInsideSurface = (intersect, ray) => {
  const distance =
    intersect.length > 1
      ? Math.min(
          (intersect[1].distance - intersect[0].distance) / 2,
          distance_between_view_rotation_center_and_its_projection
        )
      : distance_between_view_rotation_center_and_its_projection;

  const viewRotationCenter = new Vector3().addVectors(intersect[0].point, ray.direction.multiplyScalar(distance));
  return viewRotationCenter;
};

export const calculateIntersectPointForImageSelection = (eventOrigin, intersect, ray, mesh) => {
  const { EVENT_ORIGINS } = toolsEvents;
  const intersectPoint = intersect[0].point;
  if (eventOrigin === EVENT_ORIGINS.LOUPE_DRAG) {
    const viewRotationCenter = calculatePointInsideSurface(intersect, ray);
    cache.pointInsideSurface = viewRotationCenter;
    const { point } = intersect[0];
    return point;
  } else if (eventOrigin === EVENT_ORIGINS.MODEL_ROTATION) {
    return getIsProjectedPointOnSurface(intersectPoint, cache.pointInsideSurface, mesh);
  } else {
    logger
      .error('error')
      .data({ module: 'lumina-scanner-type.logic', errorMessage: 'Invalid eventOrigin' })
      .to(['analytics', 'host'])
      .end();
    return null;
  }
};

export const onImageSelectedUpdateRotation = (
  selected_image,
  camMatrixWorldInverse,
  luminaImageMetaData,
  viewer360Align2DImages,
  isuminaBestScoreAlgorithmAvaliable
) => {
  // 1. Handle invalid_image_index
  const { img_idx, image, selected2DPointOnImage } = selected_image;
  const imageMetadata = luminaImageMetaData[img_idx];
  const { cam_to_abs_tx, scan_to_cam_tx, rect_of_image, imageProjectedCenterDirection } = imageMetadata;

  const closestPhotoObjectParams = {
    rotation: 0,
    selected_pt_on_image: null,
    originalImageSize: isuminaBestScoreAlgorithmAvaliable ? { width: 480, height: 270 } : { width: 960, height: 540 },
    imageProjectedCenterDirection: null,
    shouldTransform: false,
    roi: rect_of_image,
  };

  // 2. Calculate rotation angle and set image cache
  if (img_idx !== -1) {
    const cam_to_ui_tx = getCamToUiTx(camMatrixWorldInverse, cam_to_abs_tx);
    const rotation = calculateRotationAngle(cam_to_ui_tx, scan_to_cam_tx, viewer360Align2DImages);

    closestPhotoObjectParams.rotation = rotation || 0;
    closestPhotoObjectParams.selected_pt_on_image = selected2DPointOnImage;
    closestPhotoObjectParams.shouldTransform =
      Math.abs(Math.abs(cache.prevRotation) - Math.abs(rotation)) < max_image_rotation_threshold;
    closestPhotoObjectParams.imageProjectedCenterDirection = imageProjectedCenterDirection;
    cache.prevRotation = rotation;
  }

  return {
    ...image,
    ...closestPhotoObjectParams,
  };
};

export const calculateRotationAngle = (cam_to_ui_tx, scan_to_cam_tx, viewer360Align2DImages) => {
  let rotation_angle = 0;
  let rotation_PI_axis = null;

  // TODO remove the condition when we will get the correct scan_to_cam_tx from NG scanner (should be identity matrix)
  if (scan_to_cam_tx.elements[0] === new Matrix4().elements[0]) {
    rotation_PI_axis = new Matrix4().makeRotationX(Math.PI);
  } else {
    rotation_PI_axis = new Matrix4().makeRotationY(Math.PI);
  }

  const cam_to_ui_tx_x_rotated = new Matrix4().multiplyMatrices(rotation_PI_axis, cam_to_ui_tx);

  const { elements } = new Matrix4().copy(cam_to_ui_tx_x_rotated).transpose();

  let x_axis_projection = new Vector3(elements[0], elements[4], 0);

  x_axis_projection = x_axis_projection.divideScalar(x_axis_projection.length());

  const projection_x = x_axis_projection.x;
  const projection_y = x_axis_projection.y;

  if (Math.abs(projection_x) > Math.abs(projection_y)) {
    rotation_angle = projection_x < 0 ? degreValues.e_180_deg : degreValues.e_0_deg;
  } else {
    rotation_angle = projection_y < 0 ? degreValues.e_270_deg : degreValues.e_90_deg;
  }

  if (viewer360Align2DImages) {
    let rotationDegrees = MathUtils.radToDeg(Math.acos(projection_x));
    return projection_y > 0 ? rotationDegrees : -rotationDegrees;
  }

  return rotation_angle;
};

const getLensPoints = (point, worldToImageMatrix) => {
  const [px, py, pz] = point;

  const worldToCamTx = worldToImageMatrix.elements;

  const z = px * worldToCamTx[8] + py * worldToCamTx[9] + pz * worldToCamTx[10] + worldToCamTx[11] || 1;

  const xLens = (px * worldToCamTx[0] + py * worldToCamTx[1] + pz * worldToCamTx[2] + worldToCamTx[3]) / z;
  const yLens = (px * worldToCamTx[4] + py * worldToCamTx[5] + pz * worldToCamTx[6] + worldToCamTx[7]) / z;

  return [xLens, yLens];
};

const applyDistortion = (xLens, yLens, K_vector, P_vector) => {
  const xpnts2 = xLens * xLens;
  const ypnts2 = yLens * yLens;
  const xypnts = xLens * yLens;
  const sqr = xpnts2 + ypnts2;

  const kVec = K_vector;
  const pVec = P_vector;
  const pVecNew = [2 * pVec[1], 2 * pVec[0]];

  const pMatrix = [[2 * pVec[0], 2 * pVec[1]], [pVecNew[0] / 2, pVecNew[1] / 2]];

  const maxDistortionRadialEffect = distortionRadialThreshold / 100;

  const distAdd = kVec[2] * Math.pow(sqr, 3) + kVec[1] * Math.pow(sqr, 2) + kVec[0] * sqr;

  let xDistorted = xpnts2 * pVecNew[0] + sqr * pMatrix[1][0] + xypnts * pMatrix[0][0] + xLens * (1 + distAdd);
  let yDistorted = ypnts2 * pVecNew[1] + sqr * pMatrix[1][1] + xypnts * pMatrix[0][1] + yLens * (1 + distAdd);

  if (Math.abs(distAdd) > maxDistortionRadialEffect) {
    xDistorted = xLens;
    yDistorted = yLens;
  }

  return [xDistorted, yDistorted];
};

export const projectPointToImage = (
  intersect,
  isEnableNGNiriImageSelectionAlgo,
  photoObj,
  cameraToPixel,
  K_vector,
  P_vector
) => {
  const point = intersect.point;
  const item = isEnableNGNiriImageSelectionAlgo
    ? photoObj.niri || photoObj.color
    : photoObj.color || photoObj.niri || {};
  const worldToCam = new Matrix4()
    .fromArray(item.rawImageMatrix)
    .transpose()
    .invert();

  const [xLens, yLens] = getLensPoints(point, worldToCam);
  const [xDistorted, yDistorted] = applyDistortion(xLens, yLens, K_vector, P_vector);

  const m = cameraToPixel.elements;

  const yPixel = xDistorted * m[0] + yDistorted * m[3] + m[6];
  const xPixel = xDistorted * m[1] + yDistorted * m[4] + m[7];

  return {
    x: xPixel,
    y: yPixel,
  };
};
