import { GenericDataResponse, NumberOrString } from '@tensorleap/api-client';
import clsx from 'clsx';
import { memo, useMemo } from 'react';
import { chartFormatIfNumber } from '../chart.utils';
import { GRAPH_STYLE } from '../common/constants';
import { ChartRequestData, XYChartProps } from '../common/interfaces';
import chroma from 'chroma-js';
import { Portal } from 'react-portal';
import {
  HeatmapCellData,
  useHeatmapMode,
  UseParseHeatmapData,
  useParseHeatmapData,
} from '../hooks/useHeatmap';
import { getLabelRangeByFromTo } from '../hooks/utils';
import { ScaleType, useScaleFunction } from './ChartBlocks/scale';

const COLUMN_OFFSET = 2;

export function Heatmap({
  graphData,
  chartRequestData,
  onFiltersChange,
  localFilters,
  className,
  scaleY = 'linear',
}: XYChartProps): JSX.Element {
  const graphMetadata = useParseHeatmapData(
    graphData,
    chartRequestData,
    scaleY,
  );
  const {
    columnByLabel,
    rowByLabel,
    xLabels,
    yLabels,
    maxValue,
    scaleType,
    minValue,
  } = graphMetadata;
  const colorRange = useMemo((): [string, string, string] => {
    if (!chartRequestData.colorRange) {
      return chartRequestData.flipColorRange
        ? GRAPH_STYLE.heatmap.colorSchemeGreenYellowRed
        : GRAPH_STYLE.heatmap.colorSchemeRedYellowGreen;
    }
    return chartRequestData.flipColorRange
      ? chartRequestData.colorRange
      : ([...chartRequestData.colorRange].reverse() as [
          string,
          string,
          string,
        ]);
  }, [chartRequestData.colorRange, chartRequestData.flipColorRange]);

  const columns = columnByLabel.size;
  const rows = rowByLabel.size;
  const gridStyle = {
    gridTemplateRows: `repeat(${rows}, 1fr)  20px 20px`,
    gridTemplateColumns: `20px 80px repeat(${columns}, 1fr)`,
  };

  const {
    handleMouseDown,
    handleMouseLeave,
    handleMouseMove,
    handleMouseUp,
    mode,
  } = useHeatmapMode({
    xLabels,
    yLabels,
    chartRequestData,
    onFiltersChange,
    localFilters,
    allowFiltering: chartRequestData.allowFiltering ?? true,
  });
  const showTooltip = mode.mode === 'hover' && mode.data.index !== null;
  let hoverXValue, hoverYValue, selectFromX, selectToX, selectFromY, selectToY;
  if (mode.mode === 'hover') {
    hoverXValue = mode.data.xValue;
    hoverYValue = mode.data.yValue;
  } else if (mode.mode === 'select') {
    selectFromX = mode.data.fromX;
    selectToX = mode.data.toX;
    selectFromY = mode.data.fromY;
    selectToY = mode.data.toY;
  }

  return (
    <div className={clsx(className, 'flex overflow-auto h-full pb-4')}>
      <div
        onMouseDown={handleMouseDown}
        onMouseUp={handleMouseUp}
        onMouseLeave={handleMouseLeave}
        onMouseMove={handleMouseMove}
        className={clsx(
          'grid gap-[2px] flex-1 overflow-auto',
          mode.mode === 'select' && 'cursor-crosshair select-none',
        )}
        style={gridStyle}
      >
        <HeatmapContent
          {...graphMetadata}
          colorRange={colorRange}
          hoverXValue={hoverXValue}
          hoverYValue={hoverYValue}
          selectFromX={selectFromX}
          selectToX={selectToX}
          selectFromY={selectFromY}
          scaleType={scaleType}
          selectToY={selectToY}
          graphData={graphData}
          chartRequestData={chartRequestData}
        />
        <HeatmapAxis
          xLabel={chartRequestData.xField}
          yLabel={chartRequestData.yField}
          xLabels={xLabels}
          yLabels={yLabels}
        />
      </div>
      <HeatmapLegend
        className="h-full max-h-36"
        maxValue={maxValue}
        minValue={minValue}
        colorRange={colorRange}
      />
      {showTooltip && (
        <HeatmapTooltip
          positions={[mode.data.mouseY, mode.data.mouseX]}
          record={graphData.data[mode.data.index as number].data}
          colorField={chartRequestData.colorField}
          xField={chartRequestData.xField}
          yField={chartRequestData.yField}
        />
      )}
    </div>
  );
}

type HeatmapAxisProps = {
  xLabels: NumberOrString[];
  yLabels: NumberOrString[];
  xLabel?: string;
  yLabel?: string;
};
const HeatmapAxis = memo<HeatmapAxisProps>(
  ({ xLabel, xLabels, yLabel, yLabels }) => {
    return (
      <>
        {xLabel && (
          <div
            className="text-xs text-gray-500 text-center"
            style={{
              gridRow: yLabels.length + 2,
              gridColumn: `${COLUMN_OFFSET + 1} / span ${xLabels.length}`,
            }}
          >
            {xLabel}
          </div>
        )}
        {xLabels.map((label, index) => (
          <div
            className="text-center text-xs"
            style={{ gridColumn: index + 3, gridRow: yLabels.length + 1 }}
            key={`x-${label}`}
          >
            {chartFormatIfNumber(label)}
          </div>
        ))}
        {yLabel && (
          <div
            className="row-span-full text-xs text-gray-500 flex items-center justify-start"
            style={{ gridColumn: 1 }}
          >
            <div>
              <div className="-rotate-90 text-center pt-2 -ml-[100%]">
                {yLabel}
              </div>
            </div>
          </div>
        )}
        {yLabels.map((label, index) => (
          <div
            className="flex items-center text-xs leading-[12px] justify-end pr-2"
            style={{ gridRow: index + 1, gridColumn: 2 }}
            key={`y-${label}`}
          >
            {chartFormatIfNumber(label)}
          </div>
        ))}
      </>
    );
  },
);

HeatmapAxis.displayName = 'HeatmapAxis';

type HeatmapContentProps = UseParseHeatmapData & {
  colorRange: string[];
  graphData: GenericDataResponse;
  chartRequestData: ChartRequestData;
  hoverYValue?: NumberOrString;
  hoverXValue?: NumberOrString;
  selectFromX?: NumberOrString;
  selectToX?: NumberOrString;
  selectFromY?: NumberOrString;
  selectToY?: NumberOrString;
};
const HeatmapContent = memo<HeatmapContentProps>(
  ({
    colorRange,
    scaleType,
    minValue,
    maxValue,
    graphData,
    labelsToIndexData,
    columnByLabel,
    rowByLabel,
    xLabels,
    yLabels,
    chartRequestData,
    hoverYValue,
    hoverXValue,
    selectFromX,
    selectToX,
    selectFromY,
    selectToY,
  }) => {
    const valueToColor = useValueToColor(
      minValue,
      maxValue,
      colorRange,
      scaleType,
    );

    const selectedSetX = useMemo(() => {
      if (selectFromX === undefined || selectToX === undefined)
        return undefined;
      return new Set(getLabelRangeByFromTo(selectFromX, selectToX, xLabels));
    }, [selectFromX, selectToX, xLabels]);

    const selectedSetY = useMemo(() => {
      if (selectFromY === undefined || selectToY === undefined)
        return undefined;
      return new Set(getLabelRangeByFromTo(selectFromY, selectToY, yLabels));
    }, [selectFromY, selectToY, yLabels]);

    return (
      <>
        {yLabels.map((yValue) =>
          xLabels.map((xValue) => {
            const column =
              (columnByLabel.get(xValue) as number) + COLUMN_OFFSET;
            const row = rowByLabel.get(yValue) as number;
            const key = `${xValue}-${yValue}`;
            const index = labelsToIndexData[yValue]?.[xValue];
            const isEmpty = index === undefined;

            if (isEmpty) {
              return (
                <HeatmapEmptyCell
                  key={key}
                  xValue={xValue}
                  yValue={yValue}
                  row={row}
                  column={column}
                  allowFiltering={chartRequestData.allowFiltering ?? true}
                />
              );
            }

            const value = graphData.data[index]?.data[
              chartRequestData.colorField as string
            ] as number;

            const isSelectMode =
              selectedSetX !== undefined &&
              selectedSetX.size > 0 &&
              selectedSetY !== undefined &&
              selectedSetY.size > 0;

            const isHoverMode =
              hoverYValue !== undefined &&
              hoverXValue !== undefined &&
              !isSelectMode;

            const isNotSelected =
              isSelectMode &&
              !(selectedSetX.has(xValue) && selectedSetY.has(yValue));

            const isNotHover =
              isHoverMode && yValue !== hoverYValue && xValue !== hoverXValue;

            const feedOut = isNotHover || isNotSelected;

            return (
              <HeatmapCell
                key={key}
                xValue={xValue}
                yValue={yValue}
                row={row}
                column={column}
                dataIndex={index}
                value={value}
                bgColor={valueToColor(value).hex()}
                feedOut={feedOut}
                allowFiltering={chartRequestData.allowFiltering ?? true}
              />
            );
          }),
        )}
      </>
    );
  },
);

HeatmapContent.displayName = 'HeatmapContent';

type HeatmapCellProps = HeatmapCellBaseProps & {
  value: NumberOrString;
  bgColor?: string;
  dataIndex: number;
  feedOut: boolean;
  allowFiltering?: boolean;
};
const HeatmapCell = memo<HeatmapCellProps>(
  ({
    feedOut,
    row,
    column,
    value,
    xValue,
    yValue,
    bgColor,
    dataIndex,
    allowFiltering,
  }) => {
    const data: HeatmapCellData = { xValue, yValue, index: dataIndex };
    return (
      <div
        data-cell={JSON.stringify(data)}
        className={clsx(
          'text-white flex-1 overflow-hidden min-w-[20px] justify-center flex items-center text-xs leading-[12px] select-none',
          feedOut && 'opacity-50',
          allowFiltering ? 'cursor-crosshair' : 'cursor-default',
        )}
        style={{
          gridColumn: column,
          gridRow: row,
          background: bgColor,
        }}
      >
        {chartFormatIfNumber(value)}
      </div>
    );
  },
);
HeatmapCell.displayName = 'HeatmapCell';

type HeatmapCellBaseProps = {
  row: number;
  column: number;
  xValue: NumberOrString;
  yValue: NumberOrString;
  allowFiltering?: boolean;
};
const HeatmapEmptyCell = memo<HeatmapCellBaseProps>(
  ({ row, column, xValue, yValue, allowFiltering }) => {
    const data: HeatmapCellData = { xValue, yValue, index: null };
    return (
      <div
        className={clsx(allowFiltering ? 'cursor-crosshair' : 'cursor-default')}
        data-cell={JSON.stringify(data)}
        style={{
          gridColumn: column,
          gridRow: row,
        }}
      />
    );
  },
);

HeatmapEmptyCell.displayName = 'HeatmapEmptyCell';

function useValueToColor(
  minValue: number,
  maxValue: number,
  colors: string[],
  scaleType: ScaleType,
) {
  const scale = useScaleFunction(scaleType, minValue, maxValue, 0, 1);

  return useMemo(() => {
    const colorScale = chroma.scale(colors).domain([0, 1], undefined, 'rgb');
    colorScale.cache(true);
    return (value: number) => {
      const scaledValue = scale(value);
      return colorScale(scaledValue);
    };
  }, [scale, colors]);
}

export interface HeatmapTooltipProps {
  positions: [number, number];
  record: Record<string, NumberOrString>;
  xField?: string;
  yField?: string;
  colorField?: string;
}

export function HeatmapTooltip({
  positions,
  record,
  xField,
  yField,
  colorField,
}: HeatmapTooltipProps): JSX.Element | null {
  const [top, left] = positions;

  return (
    <Portal>
      <div
        style={{
          top,
          left,
          position: 'fixed',
        }}
        className="custom-tooltip z-10 fixed flex flex-col bg-gray-800 border-2 border-gray-500 rounded-xl justify-center items-start w-fit"
      >
        <div className="flex flex-col gap-2 p-2 ">
          {xField && (
            <span className="capitalize">
              {xField} : {record[xField]}
            </span>
          )}
          {yField && (
            <span className="capitalize">
              {yField} : {record[yField]}
            </span>
          )}
          {colorField && (
            <span className="capitalize">
              {colorField} :
              <span className="font-bold"> {record[colorField]}</span>
            </span>
          )}
        </div>
      </div>
    </Portal>
  );
}

type HeatmapLegendProps = {
  colorRange: string[];
  className?: string;
  minValue: number;
  maxValue: number;
};
function HeatmapLegend({
  minValue,
  maxValue,
  colorRange,
  className,
}: HeatmapLegendProps) {
  const gradient = useMemo(
    () => `linear-gradient(${colorRange})`,
    [colorRange],
  );
  return (
    <div className={clsx('flex p-2 ', className)}>
      <div
        style={{ backgroundImage: gradient }}
        className="rounded-md w-2 m-1"
      />
      <div className="flex text-xs flex-col justify-between">
        <span>{chartFormatIfNumber(minValue)}</span>
        <span>{chartFormatIfNumber(maxValue)}</span>
      </div>
    </div>
  );
}
