import {
  GetConfusionMatrixResultCombinationsParams,
  NumberOrString,
} from '@tensorleap/api-client';
import { XYChartVizProps } from './interfaces';
import { MultiCharts } from '../../../../ui/charts/visualizers/MultiCharts';
import {
  chartSplitToSplitAgg,
  getSplitLabels,
  selectedSessionRunToSessionRunsToEpochs,
  toIntervalOrLimit,
} from './utils';
import api from '../../../../core/api-client';
import { useGetChart } from '../../../../core/data-fetching/getChart';
import { useMergedObject } from '../../../../core/useMergedObject';
import { useCallback, useMemo } from 'react';
import { ConfusionMatrixTypeEnum } from '../form/ConfusionMatrix';
import { ChartRequestData } from '../../../../ui/charts/common/interfaces';
import { NoDataChart } from '../../../../ui/charts/common/NoDataChart';

const X_AXIS_NAME = 'prediction';
const Y_AXIS_NAME = 'ground truth';
const POWER_NAME = 'samples';
const PERCENT_POWER_NAME = 'percent';

export function ConfusionMatrixByLabelVis({
  graphParams: {
    xAxis: xField,
    metricName,
    dataDistribution,
    orderBy,
    order,
    xAxisSizeInterval,
    modelIdPosition,
    firstSplit,
    secondSplit,
    autoScaleY,
    showPercentages,
    threshold,
    flipColorRange,
    confusionMatrixLabelsFilter,
  },
  filters,
  localFilters,
  sessionRuns,
  onFiltersChange,
  chartType,
  projectId,
}: XYChartVizProps) {
  const { verticalSplit, horizontalSplit } = useMemo(
    () => getSplitLabels(modelIdPosition, firstSplit, secondSplit),
    [modelIdPosition, firstSplit, secondSplit],
  );

  const sessionRunsToEpochs = useMemo(
    () => selectedSessionRunToSessionRunsToEpochs(sessionRuns),
    [sessionRuns],
  );

  const xAxis = useMemo(
    () =>
      chartSplitToSplitAgg(
        {
          field: xField,
          distribution: dataDistribution,
          orderField: orderBy,
          order,
          ...toIntervalOrLimit(dataDistribution, Number(xAxisSizeInterval)),
        },
        null,
      ),
    [xField, dataDistribution, orderBy, order, xAxisSizeInterval],
  );

  const params = useMergedObject<GetConfusionMatrixResultCombinationsParams>({
    projectId,
    x: xAxis,
    sessionRunsToEpochs,
    customMetricName: metricName,
    filters,
    verticalSplit,
    horizontalSplit,
    threshold: threshold ? Number(threshold) : undefined,
    filterLabels: confusionMatrixLabelsFilter,
  });

  const chartRequestData = useMergedObject<ChartRequestData>({
    xField: X_AXIS_NAME,
    yField: Y_AXIS_NAME,
    colorField: showPercentages ? PERCENT_POWER_NAME : POWER_NAME,
    flipColorRange,
    allowFiltering: false,
    colorRange: ['#1D4ED8', '#1E3A8A', '#0F1C42'],
  });

  const { multiChartsResponse, isLoading, error } = useGetChart({
    params,
    func: async (x) => await api.getConfusionMatrixResultCombinations(x),
  });

  const sortLabels = useCallback(
    (labels: NumberOrString[]): NumberOrString[] => {
      return labels.sort((a, b) => {
        if (typeof a === 'number' && typeof b === 'number') {
          return a - b;
        }
        if (typeof a === 'number') {
          return -1;
        }
        if (typeof b === 'number') {
          return 1;
        }
        return String(a).localeCompare(String(b));
      });
    },
    [],
  );

  const parsedResponse = useMemo(() => {
    if (!multiChartsResponse) return multiChartsResponse;

    const allLabels = new Set<NumberOrString>();
    const totalCounts: Record<NumberOrString, number> = {};

    multiChartsResponse.charts.forEach((chart) => {
      chart.data.data.forEach((el) => {
        const yLabel = el.data[Y_AXIS_NAME];
        const xLabel = el.data[X_AXIS_NAME];
        allLabels.add(yLabel);
        allLabels.add(xLabel);
        const count = Number(el.data[POWER_NAME]);
        totalCounts[yLabel] = (totalCounts[yLabel] || 0) + count;
      });
    });

    const sortedLabels = sortLabels(Array.from(allLabels));

    const parsed = multiChartsResponse.charts.map((chart) => {
      if (chart.data.data.length === 0) return chart;

      const matrix = sortedLabels.map(() =>
        new Array(sortedLabels.length).fill(0),
      );

      chart.data.data.forEach((el) => {
        const yIndex = sortedLabels.indexOf(el.data[Y_AXIS_NAME]);
        const xIndex = sortedLabels.indexOf(el.data[X_AXIS_NAME]);
        const count = Number(el.data[POWER_NAME]);
        matrix[yIndex][xIndex] = count;
      });

      const data = matrix.flatMap((row, yIndex) =>
        row.map((count, xIndex) => ({
          data: {
            [Y_AXIS_NAME]: sortedLabels[yIndex],
            [X_AXIS_NAME]: sortedLabels[xIndex],
            [POWER_NAME]: count,
            count: count,
            percent: totalCounts[sortedLabels[yIndex]]
              ? count / totalCounts[sortedLabels[yIndex]]
              : 0,
          },
        })),
      );

      return {
        ...chart,
        data: {
          ...chart.data,
          data,
        },
      };
    });

    return {
      ...multiChartsResponse,
      charts: parsed,
    };
  }, [multiChartsResponse, sortLabels]);

  if (
    !parsedResponse ||
    !parsedResponse?.charts.length ||
    parsedResponse?.charts.every((c) => !c.data.data.length)
  )
    return <NoDataChart />;

  return (
    <MultiCharts
      xyChartsResponse={parsedResponse}
      chartRequestData={chartRequestData}
      onFiltersChange={onFiltersChange}
      localFilters={localFilters}
      chartType={chartType}
      autoScaleY={autoScaleY}
      isLoading={isLoading}
      error={error}
      horizontalSplit={horizontalSplit ?? null}
      verticalSplit={verticalSplit ?? null}
      chartSubType={ConfusionMatrixTypeEnum.ConfusionMatrixByLabelVis}
      showLegend={false}
    />
  );
}
