import React, {
  useCallback,
  useEffect,
  useMemo,
  useRef,
  useState,
} from 'react';
import { OrientationType } from '../../../ui/charts/common/types';
import clsx from 'clsx';
import { NumberOrString } from '@tensorleap/api-client';
import { XYChartProps } from '../common/interfaces';
import { NoDataChart } from '../common/NoDataChart';

const TITLE_TO_CELL_RATIO = 1 / 3;

export interface HeatmpTableLocation {
  rowIndex?: number[];
  colIndex?: number[];
}
export interface HeatmapTableIndexRange {
  start: number;
  end: number;
}
export interface HeatmapTableCell {
  name?: string;
  value?: NumberOrString;
  color?: string;
}

export interface HeatmapTableTitle {
  name: string;
  color?: string;
  indexRange: HeatmapTableIndexRange;
  position: OrientationType;
}

export interface HeatmapTableProps {
  data: HeatmapTableCell[][];
  titles: HeatmapTableTitle[];
  className?: string;
  parentRef?: React.RefObject<HTMLDivElement>;
}

export const HeatmapTable: React.FC<HeatmapTableProps> = ({
  data,
  titles,
  className,
  parentRef,
}) => {
  const containerRef = useRef<HTMLDivElement>(null);
  const [cellSize, setCellSize] = useState(96);
  const [highlightedCells, setHighlightedCells] = useState<HeatmpTableLocation>(
    {},
  );

  const horizontalTitles = useMemo(
    () => titles.filter((t) => t.position === 'horizontal'),
    [titles],
  );
  const verticalTitles = useMemo(
    () => titles.filter((t) => t.position === 'vertical'),
    [titles],
  );

  useEffect(() => {
    const updateCellSize = () => {
      if (containerRef.current) {
        const { clientWidth, clientHeight } = containerRef.current;

        const minClientWidth = parentRef?.current?.clientWidth
          ? Math.min(
              clientWidth || 0,
              (parentRef?.current?.clientWidth || 0) - 75 || 0,
            )
          : clientWidth;
        const minClientHeight = parentRef?.current?.clientHeight
          ? Math.min(
              clientHeight || 0,
              (parentRef?.current?.clientHeight || 0) - 75 || 0,
            )
          : clientHeight;

        const horizontalTitleSpace = horizontalTitles.length
          ? TITLE_TO_CELL_RATIO
          : 0;
        const verticalTitleSpace = verticalTitles.length
          ? TITLE_TO_CELL_RATIO
          : 0;

        const numRows = data.length + verticalTitleSpace;
        const numCols =
          Math.max(...data.map((row) => row.length)) + horizontalTitleSpace;

        const cellSizeByRow = Math.floor(minClientHeight / numRows);
        const cellSizeByCol = Math.floor(minClientWidth / numCols);
        const minCellSize = Math.min(cellSizeByRow, cellSizeByCol);
        const cellSize = Math.max(minCellSize, 64);
        setCellSize(cellSize);
      }
    };

    updateCellSize();
    window.addEventListener('resize', updateCellSize);
    return () => window.removeEventListener('resize', updateCellSize);
  }, [data, horizontalTitles.length, parentRef, titles, verticalTitles.length]);

  if (!data || data.length === 0)
    return (
      <div className="flex flex-col flex-1 h-full w-full bg-gray-900">
        <NoDataChart />
      </div>
    );

  return (
    <div ref={containerRef} className="flex w-full h-full">
      <div className="w-fit h-fit mx-auto">
        <div className={clsx('flex flex-col w-full h-full ', className)}>
          <div className="flex">
            <RenderTitles
              allPlaneTitles={horizontalTitles}
              cellSize={cellSize}
              highlightedCells={highlightedCells}
              setHighlightedCells={setHighlightedCells}
            />
          </div>
          <div className="flex flex-row justify-start items-start">
            <div className="flex flex-col">
              <RenderTitles
                allPlaneTitles={verticalTitles}
                cellSize={cellSize}
                highlightedCells={highlightedCells}
                setHighlightedCells={setHighlightedCells}
              />
            </div>
            <div className="flex flex-col justify-start items-start">
              {data.map((row, rowIndex) => (
                <div className="flex" key={rowIndex}>
                  {row.map((cell, colIndex) => (
                    <RenderCell
                      data={data}
                      titles={titles}
                      cell={cell}
                      cellSize={cellSize}
                      key={colIndex}
                      rowIndex={rowIndex}
                      colIndex={colIndex}
                      highlightedCells={highlightedCells}
                      setHighlightedCells={setHighlightedCells}
                    />
                  ))}
                </div>
              ))}
            </div>
          </div>
        </div>
      </div>
    </div>
  );
};

interface RenderTitlesProps {
  allPlaneTitles: HeatmapTableTitle[];
  cellSize: number;
  highlightedCells: HeatmpTableLocation;
  setHighlightedCells: React.Dispatch<
    React.SetStateAction<HeatmpTableLocation>
  >;
}
function RenderTitles({
  allPlaneTitles,
  cellSize,
  highlightedCells,
  setHighlightedCells,
}: RenderTitlesProps): JSX.Element {
  return (
    <>
      {allPlaneTitles.map((title, index) => (
        <RenderTitle
          title={title}
          key={index}
          cellSize={cellSize}
          previousTitle={index > 0 ? allPlaneTitles[index - 1] : undefined}
          highlightedCells={highlightedCells}
          setHighlightedCells={setHighlightedCells}
        />
      ))}
    </>
  );
}

interface RenderCellProps {
  data: HeatmapTableCell[][];
  titles: HeatmapTableTitle[];
  cell: HeatmapTableCell;
  rowIndex: number;
  colIndex: number;
  cellSize: number;
  highlightedCells: HeatmpTableLocation;
  setHighlightedCells: React.Dispatch<
    React.SetStateAction<HeatmpTableLocation>
  >;
}
function RenderCell({
  data,
  titles,
  cell,
  rowIndex,
  colIndex,
  cellSize,
  highlightedCells,
  setHighlightedCells,
}: RenderCellProps): JSX.Element {
  const titleOverlaps = useCallback(
    (position: OrientationType, range: HeatmapTableIndexRange) => {
      if (position === 'horizontal') {
        return (
          colIndex >= range.start && colIndex <= range.end && rowIndex === 0
        );
      }
      return rowIndex >= range.start && rowIndex <= range.end && colIndex === 0;
    },
    [colIndex, rowIndex],
  );

  const borderRadiusClass = useMemo(() => {
    const hasVerticalTitle = titles.some(
      (t) =>
        t.position === 'vertical' &&
        titleOverlaps(OrientationType.Vertical, t.indexRange),
    );
    const hasHorizontalTitle = titles.some(
      (t) =>
        t.position === 'horizontal' &&
        titleOverlaps(OrientationType.Horizontal, t.indexRange),
    );

    const isTopLeftCorner = rowIndex === 0 && colIndex === 0;
    const isTopRightCorner =
      rowIndex === 0 && colIndex === data[rowIndex].length - 1;
    const isBottomLeftCorner = rowIndex === data.length - 1 && colIndex === 0;
    const isBottomRightCorner =
      rowIndex === data.length - 1 && colIndex === data[rowIndex].length - 1;

    return clsx({
      'rounded-tl-xl':
        isTopLeftCorner && !hasVerticalTitle && !hasHorizontalTitle,
      'rounded-tr-xl': isTopRightCorner && !hasHorizontalTitle,
      'rounded-bl-xl': isBottomLeftCorner && !hasVerticalTitle,
      'rounded-br-xl': isBottomRightCorner,
    });
  }, [titles, rowIndex, colIndex, data, titleOverlaps]);

  const isDimmed = useMemo(() => {
    const isHovered =
      highlightedCells &&
      (!highlightedCells.rowIndex?.length ||
        highlightedCells.rowIndex.includes(rowIndex)) &&
      (!highlightedCells.colIndex?.length ||
        highlightedCells.colIndex.includes(colIndex));

    return highlightedCells && !isHovered;
  }, [colIndex, highlightedCells, rowIndex]);
  const fontSize = useCellFontSize(cellSize);

  const onMouseEnter = useCallback(() => {
    setHighlightedCells({
      rowIndex: [rowIndex],
      colIndex: [colIndex],
    });
  }, [colIndex, rowIndex, setHighlightedCells]);
  const onMouseLeave = useCallback(() => {
    setHighlightedCells({});
  }, [setHighlightedCells]);

  return (
    <div
      className={clsx(
        'flex justify-center items-center overflow-hidden cursor-default',
        {
          [`w-${cellSize} h-${cellSize}`]: true,
          [cell.color || 'bg-transparent']: true,
          'opacity-40': isDimmed,
        },
        borderRadiusClass,
      )}
      style={{
        minWidth: `${cellSize}px`,
        minHeight: `${cellSize}px`,
        maxWidth: `${cellSize}px`,
        maxHeight: `${cellSize}px`,
        border: 'solid',
        borderColor: 'gray',
        borderLeftWidth: colIndex > 0 ? '1px' : '2px',
        borderRightWidth: colIndex < data[rowIndex].length - 1 ? '1px' : '2px',
        borderTopWidth: rowIndex > 0 ? '1px' : '2px',
        borderBottomWidth: rowIndex < data.length - 1 ? '1px' : '2px',
      }}
      onMouseEnter={onMouseEnter}
      onMouseLeave={onMouseLeave}
    >
      {cell.name && (
        <div className="flex flex-col p-1">
          <span
            className={clsx(
              'font-bold justify-center items-center break-word text-center',
              fontSize,
            )}
          >
            {cell.name}
          </span>
          <span
            className={clsx(
              'font-bold justify-center items-center break-word text-center',
              fontSize,
            )}
          >
            {cell.value}
          </span>
        </div>
      )}
    </div>
  );
}

interface RenderTitleProps {
  previousTitle?: HeatmapTableTitle;
  title: HeatmapTableTitle;
  cellSize: number;
  highlightedCells: HeatmpTableLocation;
  setHighlightedCells: React.Dispatch<
    React.SetStateAction<HeatmpTableLocation>
  >;
}
function RenderTitle({
  title,
  previousTitle,
  cellSize,
  highlightedCells,
  setHighlightedCells,
}: RenderTitleProps): JSX.Element {
  const { indexRange, position } = title;

  const offset = useMemo(() => {
    const previousTitleEndIndex =
      (previousTitle?.indexRange.end || 0) + (previousTitle ? 1 : 0);
    const currentTitleOffsetIndex = indexRange.start - previousTitleEndIndex;
    const currentTitleOffsetPixelsSize = currentTitleOffsetIndex * cellSize;
    const horizontalOffset =
      !previousTitle && position === 'horizontal'
        ? cellSize * TITLE_TO_CELL_RATIO
        : 0;
    const totalOffset = currentTitleOffsetPixelsSize + horizontalOffset;
    return totalOffset;
  }, [cellSize, indexRange, position, previousTitle]);

  if (position === 'vertical') {
    return (
      <div
        style={{
          marginTop: `${offset}px`,
        }}
      >
        <VerticalTitle
          title={title}
          cellSize={cellSize}
          highlightedCells={highlightedCells}
          setHighlightedCells={setHighlightedCells}
        />
      </div>
    );
  }
  return (
    <div
      style={{
        marginLeft: `${offset}px`,
      }}
    >
      <HorizontalTitle
        title={title}
        cellSize={cellSize}
        highlightedCells={highlightedCells}
        setHighlightedCells={setHighlightedCells}
      />
    </div>
  );
}

function VerticalTitle({
  title: { name, indexRange, color },
  cellSize,
  highlightedCells,
  setHighlightedCells,
}: RenderTitleProps): JSX.Element {
  const size = useMemo(() => {
    if (!indexRange) return 0;
    const numCells = indexRange.end - indexRange.start + 1;
    return numCells * cellSize;
  }, [cellSize, indexRange]);

  const isDimmed = useMemo(
    () =>
      highlightedCells &&
      highlightedCells.rowIndex?.every(
        (index) => index < indexRange.start || index > indexRange.end,
      ),
    [highlightedCells, indexRange],
  );

  const onMouseEnter = useCallback(() => {
    setHighlightedCells({
      rowIndex: Array.from(
        { length: indexRange.end - indexRange.start + 1 },
        (_, i) => i + indexRange.start,
      ),
    });
  }, [indexRange, setHighlightedCells]);

  const onMouseLeave = useCallback(() => {
    setHighlightedCells({});
  }, [setHighlightedCells]);

  return (
    <div
      className={clsx(
        'flex flex-col justify-center items-center w-full h-full rounded-l-xl overflow-hidden cursor-default border-l border-y border-gray-450',
        color || 'bg-transparent',
        isDimmed && 'opacity-40',
      )}
      style={{
        width: `${cellSize / 3}px`,
        height: `${size}px`,
      }}
      onMouseEnter={onMouseEnter}
      onMouseLeave={onMouseLeave}
    >
      <span
        className="text-sm font-bold fa-flip-vertical break-word text-center"
        style={{
          writingMode: 'vertical-rl',
          transform: 'rotate(180deg)',
        }}
      >
        {name}
      </span>
    </div>
  );
}

function HorizontalTitle({
  title: { name, indexRange, color },
  cellSize,
  highlightedCells,
  setHighlightedCells,
}: RenderTitleProps): JSX.Element {
  const size = useMemo(() => {
    if (!indexRange) return 0;
    const numCells = indexRange.end - indexRange.start + 1;
    return numCells * cellSize;
  }, [cellSize, indexRange]);

  const isDimmed = useMemo(
    () =>
      highlightedCells &&
      highlightedCells.colIndex?.every(
        (index) => index < indexRange.start || index > indexRange.end,
      ),
    [highlightedCells, indexRange],
  );

  const fontSize = useCellFontSize(cellSize);

  const onMouseEnter = useCallback(() => {
    setHighlightedCells({
      colIndex: Array.from(
        { length: indexRange.end - indexRange.start + 1 },
        (_, i) => i + indexRange.start,
      ),
    });
  }, [indexRange, setHighlightedCells]);
  const onMouseLeave = useCallback(() => {
    setHighlightedCells({});
  }, [setHighlightedCells]);

  return (
    <div
      className={clsx(
        'flex justify-center items-center w-full h-full rounded-t-xl overflow-hidden cursor-default border-t border-x border-gray-450',
        color || 'bg-transparent',
        isDimmed && 'opacity-40',
      )}
      style={{
        width: `${size}px`,
        height: `${cellSize / 3}px`,
      }}
      onMouseEnter={onMouseEnter}
      onMouseLeave={onMouseLeave}
    >
      <span
        className={clsx(
          'font-bold justify-center items-center break-word text-center',
          fontSize,
        )}
      >
        {name}
      </span>
    </div>
  );
}

function useCellFontSize(cellSize: number): string {
  return useMemo(() => {
    if (cellSize <= 60) return 'text-3xs';
    if (cellSize <= 64) return 'text-xs';
    if (cellSize <= 96) return 'text-sm';
    if (cellSize <= 128) return 'text-base';
    return 'text-lg';
  }, [cellSize]);
}

interface ConfusionMatrixTable {
  totalPopulation: number;
  tp: number;
  fp: number;
  tn: number;
  fn: number;
  informedness: number;
  sensitivity: number;
  fallout: number;
  prevelance: number;
  precision: number;
  falseOmissionRate: number;
  positiveLikelihoodRatio: number;
}

const CONFUSION_MATRIX_TABLE_TITLES: HeatmapTableTitle[] = [
  {
    name: 'Predicted Condition',
    color: 'bg-cyan-700',
    indexRange: { start: 1, end: 2 },
    position: OrientationType.Horizontal,
  },
  {
    name: 'Actual Condition',
    color: 'bg-insight-700',
    indexRange: { start: 1, end: 2 },
    position: OrientationType.Vertical,
  },
];

export function ConfusionMatrixSingleHeatmapTable({
  graphData,
  multiChartRef,
  chartRequestData: { showPercentages },
}: XYChartProps): JSX.Element {
  const unparsedData = !graphData.data.length
    ? ({} as ConfusionMatrixTable)
    : (graphData?.data?.[0].data as unknown as ConfusionMatrixTable);
  const data = useConfusionMatrixTableDisplayValues(
    unparsedData,
    showPercentages,
  );
  return (
    <div className="w-full h-full py-2">
      <HeatmapTable
        data={data}
        titles={CONFUSION_MATRIX_TABLE_TITLES}
        parentRef={multiChartRef}
      />
    </div>
  );
}

function useConfusionMatrixTableDisplayValues(
  unparsedData: ConfusionMatrixTable,
  showPercentages = false,
): HeatmapTableCell[][] {
  const data: HeatmapTableCell[][] = useMemo(() => {
    if (!unparsedData.totalPopulation) return [];
    const [tp, fp, tn, fn] = [
      unparsedData.tp,
      unparsedData.fp,
      unparsedData.tn,
      unparsedData.fn,
    ].map((value) =>
      percentageIfToggle({
        showPercentages,
        value,
        total: unparsedData.totalPopulation,
      }),
    );

    return [
      [
        {
          name: 'Total Population',
          value: parseFloat((unparsedData.totalPopulation ?? 0).toFixed(4)),
          color: 'bg-gray-800',
        },
        { name: 'Predicted Positive', color: 'bg-cyan-900' },
        { name: 'Predicted Negative', color: 'bg-cyan-700' },
        {
          name: 'Informedness',
          value: parseFloat((unparsedData.informedness ?? 0).toFixed(4)),
          color: 'bg-gray-800',
        },
      ],
      [
        { name: 'Actual Positive', color: 'bg-insight-700' },
        {
          name: 'TP',
          value: tp,
          color: 'bg-success-700',
        },
        {
          name: 'FP',
          value: fp,
          color: 'bg-error-700',
        },
        {
          name: 'Sensitivity',
          value: parseFloat((unparsedData.sensitivity ?? 0).toFixed(4)),
          color: 'bg-success-900',
        },
      ],
      [
        {
          name: 'Actual Negative',
          color: 'bg-insight-600',
        },
        {
          name: 'FN',
          value: fn,
          color: 'bg-error-600',
        },
        {
          name: 'TN',
          value: tn,
          color: 'bg-success-600',
        },
        {
          name: 'Fallout',
          value: parseFloat((unparsedData.fallout ?? 0).toFixed(4)),
          color: 'bg-error-900',
        },
      ],
      [
        {
          name: 'Prevalence',
          value: parseFloat((unparsedData.prevelance ?? 0).toFixed(4)),
          color: 'bg-gray-800',
        },
        {
          name: 'Precision',
          value: parseFloat((unparsedData.precision ?? 0).toFixed(4)),
          color: 'bg-success-950',
        },
        {
          name: 'False Omission',
          value: parseFloat((unparsedData.falseOmissionRate ?? 0).toFixed(4)),
          color: 'bg-error-950',
        },
        {
          name: 'Positive Likelihood Ratio',
          value: parseFloat(
            (unparsedData.positiveLikelihoodRatio ?? 0).toFixed(4),
          ),
          color: 'bg-primary-950',
        },
      ],
    ];
  }, [showPercentages, unparsedData]);

  return data;
}

interface PercentageIfToggleProps {
  showPercentages: boolean;
  value: number;
  total: number;
}
function percentageIfToggle({
  showPercentages,
  value,
  total,
}: PercentageIfToggleProps): number {
  return parseFloat(
    (showPercentages && total > 0 ? value / total : (value ?? 0)).toFixed(4),
  );
}
