import {
  ESFilter,
  GenericDataResponse,
  NumberOrString,
  OrderType,
} from '@tensorleap/api-client';
import {
  MouseEvent,
  MouseEventHandler,
  useCallback,
  useMemo,
  useState,
} from 'react';
import { VisualizationFilter } from '../../../core/types/filters';
import { mapToVisualizationFilters } from '../../../model-tests/modelTestHelpers';
import { ChartRequestData } from '../common/interfaces';
import { max, min, orderBy } from 'lodash';
import { getOnFirstClickFilter, getOnlastClickFilter } from './utils';
import { ScaleType } from '../visualizers/ChartBlocks/scale';

export type UseHeatmapProps = {
  chartRequestData: ChartRequestData;
  localFilters?: ESFilter[];
  xLabels: NumberOrString[];
  yLabels: NumberOrString[];
  onFiltersChange?: (_: VisualizationFilter[]) => void;
  allowFiltering?: boolean;
};

export type HeatmapHoverData = {
  index: number | null;
  xValue: NumberOrString;
  yValue: NumberOrString;
  mouseX: number;
  mouseY: number;
};

export type HeatmapSelectData = {
  fromX: NumberOrString;
  toX: NumberOrString;
  fromY: NumberOrString;
  toY: NumberOrString;
};

export type HeatmapCellData = {
  index: number | null;
  xValue: NumberOrString;
  yValue: NumberOrString;
};

export type HeatmapMode =
  | {
      mode: 'none';
    }
  | {
      mode: 'hover';
      data: HeatmapHoverData;
    }
  | {
      mode: 'select';
      data: HeatmapSelectData;
    };

const NONE_MODE = {
  mode: 'none',
} as const;

export type UseHeatmapMode = {
  handleMouseDown?: MouseEventHandler<HTMLDivElement>;
  handleMouseMove?: MouseEventHandler<HTMLElement>;
  handleMouseUp?: MouseEventHandler<HTMLDivElement>;
  handleMouseLeave?: MouseEventHandler<HTMLDivElement>;
  mode: HeatmapMode;
};

export function useHeatmapMode({
  chartRequestData,
  onFiltersChange,
  xLabels,
  yLabels,
  localFilters,
  allowFiltering = true,
}: UseHeatmapProps): UseHeatmapMode {
  const [mode, setMode] = useState<HeatmapMode>(NONE_MODE);

  const { xField, yField, dataDistribution, yDataDistribution } =
    chartRequestData;

  const resetSelection = useCallback(() => {
    setMode(NONE_MODE);
  }, []);

  const setLocalFilter = useCallback((): void => {
    if (
      !onFiltersChange ||
      mode.mode !== 'select' ||
      xField === undefined ||
      yField === undefined
    ) {
      return;
    }

    const { fromX, toX, fromY, toY } = mode.data;
    const updatedFilters = (localFilters || []).map(mapToVisualizationFilters);

    if (fromX === toX || toX === undefined) {
      const xFilter = getOnFirstClickFilter({
        from: fromX,
        field: xField,
        dataDistribution,
        sizeInterval: chartRequestData.xSizeInterval,
      });

      const yFilter = getOnFirstClickFilter({
        from: fromY,
        field: yField,
        dataDistribution: yDataDistribution,
        sizeInterval: chartRequestData.ySizeInterval,
      });

      if (!xFilter || !yFilter) return;

      updatedFilters.push(xFilter, yFilter);
      onFiltersChange(updatedFilters);
      resetSelection();
      return;
    }

    const xFilterLast = getOnlastClickFilter({
      dataDistribution,
      from: fromX,
      to: toX,
      field: xField,
      labels: xLabels,
    });

    updatedFilters.push(xFilterLast);

    const yFilterLast = getOnlastClickFilter({
      dataDistribution: yDataDistribution,
      from: fromY,
      to: toY,
      field: yField,
      labels: yLabels,
    });

    updatedFilters.push(yFilterLast);

    onFiltersChange(updatedFilters);
    resetSelection();
    return;
  }, [
    onFiltersChange,
    mode,
    xField,
    yField,
    localFilters,
    dataDistribution,
    yDataDistribution,
    resetSelection,
    chartRequestData.xSizeInterval,
    chartRequestData.ySizeInterval,
    xLabels,
    yLabels,
  ]);

  const handleMouseDown = useCallback<MouseEventHandler<HTMLElement>>(
    (e) => {
      if (!allowFiltering) return;
      const data = getDataFromEventElement(e);
      if (data === null) return;
      setMode({
        mode: 'select',
        data: {
          fromX: data.xValue,
          toX: data.xValue,
          fromY: data.yValue,
          toY: data.yValue,
        },
      });
    },
    [allowFiltering],
  );

  const handleMouseUp = useCallback(() => {
    if (!allowFiltering) return;
    setLocalFilter();
    resetSelection();
  }, [allowFiltering, setLocalFilter, resetSelection]);

  const handleMouseLeave = useCallback(() => {
    resetSelection();
  }, [resetSelection]);

  const handleMouseMove = useCallback<MouseEventHandler<HTMLElement>>((e) => {
    const data = getDataFromEventElement(e);
    if (!data) {
      return;
    }
    // need to calc it sync
    const mouseX = e.clientX;
    const mouseY = e.clientY;

    setMode((mode) => {
      if (mode.mode === 'select') {
        if (!data) {
          return mode;
        }
        const newMode: HeatmapMode = {
          mode: 'select',
          data: {
            fromX: mode.data.fromX,
            toX: data.xValue,
            fromY: mode.data.fromY,
            toY: data.yValue,
          },
        };
        return newMode;
      } else {
        const { xValue, yValue } = data || {};

        const newMode: HeatmapMode = {
          mode: 'hover',
          data: {
            index: data.index,
            xValue,
            yValue,
            mouseX,
            mouseY,
          },
        };
        return newMode;
      }
    });
  }, []);

  return {
    handleMouseDown,
    handleMouseUp,
    handleMouseLeave,
    handleMouseMove,
    mode,
  };
}

export type UseParseHeatmapData = {
  columnByLabel: Map<NumberOrString, number>;
  rowByLabel: Map<NumberOrString, number>;
  xLabels: NumberOrString[];
  yLabels: NumberOrString[];
  maxValue: number;
  minValue: number;
  scaleType: ScaleType;
  labelsToIndexData: Record<string, Record<string, number>>;
};

export function useParseHeatmapData(
  graphData: GenericDataResponse,
  chartRequestData: ChartRequestData,
  scaleType: ScaleType,
): UseParseHeatmapData {
  return useMemo(() => {
    const xLabels = new Set<NumberOrString>();
    const yLabels = new Set<NumberOrString>();
    const values = new Set<number>();
    const labelsToIndexData: Record<string, Record<string, number>> = {};
    const { orderParams, yOrderParams, xField, yField, colorField } =
      chartRequestData;

    for (const index in graphData.data) {
      const cell = graphData.data[index];
      const xValue = cell.data[xField];
      const yValue = cell.data[yField];
      const value = cell.data[colorField as string] as number;
      xLabels.add(xValue);
      yLabels.add(yValue);
      values.add(value);

      let row = labelsToIndexData[yValue];
      if (!row) {
        row = {};
        labelsToIndexData[yValue] = row;
      }
      row[xValue] = Number(index);
    }
    const valueAsArray = Array.from(values);
    let yLabelAsArray = Array.from(yLabels);
    let xLabelAsArray = Array.from(xLabels);
    const maxValue = max(valueAsArray) || 0;
    const minValue = min(valueAsArray) || 0;

    xLabelAsArray = orderStringAndNumber(xLabelAsArray, orderParams);
    yLabelAsArray = orderStringAndNumber(yLabelAsArray, yOrderParams);

    const columnByLabel = new Map(
      xLabelAsArray.map((label, index) => [label, index + 1]),
    );
    const rowByLabel = new Map(
      yLabelAsArray.map((label, index) => [label, index + 1]),
    );

    return {
      rowByLabel,
      columnByLabel,
      xLabels: xLabelAsArray,
      yLabels: yLabelAsArray,
      labelsToIndexData,
      maxValue,
      minValue,
      scaleType,
    };
  }, [graphData, chartRequestData, scaleType]);
}

function getDataFromEventElement(
  e: MouseEvent<HTMLElement>,
): null | HeatmapCellData {
  const elm = e.target as HTMLElement;
  const value = elm?.getAttribute?.('data-cell');

  if (value === null) return null;

  return JSON.parse(value) as HeatmapCellData;
}

function orderStringAndNumber(
  labels: NumberOrString[],
  orderParams?: OrderType,
): NumberOrString[] {
  return orderBy(
    labels,
    (v) => (isNaN(Number(v)) ? v : Number(v)),
    orderParams,
  );
}
