import {
  ConfusionMatrixParams,
  MultiChartsResponse,
} from '@tensorleap/api-client';
import { useMemo } from 'react';
import { XYChartVizProps } from './interfaces';
import { MultiCharts } from '../../../../ui/charts/visualizers/MultiCharts';
import {
  chartSplitToSplitAgg,
  getSplitLabels,
  selectedSessionRunToSessionRunsToEpochs,
  toIntervalOrLimit,
} from './utils';
import api from '../../../../core/api-client';
import { useGetChart } from '../../../../core/data-fetching/getChart';
import { useMergedObject } from '../../../../core/useMergedObject';

const INTERNAL_SPLIT = chartSplitToSplitAgg({ field: 'label' }, null);

export function MapVis({
  graphParams: {
    xAxis: xField,
    metricName,
    dataDistribution,
    orderBy,
    order,
    xAxisSizeInterval,
    modelIdPosition,
    firstSplit,
    autoScaleY,
    showAllAvgPrecisions,
  },
  filters,
  localFilters,
  sessionRuns,
  onFiltersChange,
  chartType,
  projectId,
}: XYChartVizProps) {
  const { verticalSplit, horizontalSplit } = getSplitLabels(
    modelIdPosition,
    firstSplit,
  );

  const sessionRunsToEpochs = useMemo(
    () => selectedSessionRunToSessionRunsToEpochs(sessionRuns),
    [sessionRuns],
  );

  const xAxis = useMemo(
    () =>
      chartSplitToSplitAgg(
        {
          field: xField,
          distribution: dataDistribution,
          orderField: orderBy,
          order,
          ...toIntervalOrLimit(dataDistribution, Number(xAxisSizeInterval)),
        },
        null,
      ),
    [xField, dataDistribution, orderBy, order, xAxisSizeInterval],
  );

  const params = useMergedObject<ConfusionMatrixParams>({
    projectId,
    filters,
    sessionRunsToEpochs,
    customMetricName: metricName,
    verticalSplit,
    horizontalSplit,
    x: xAxis,
  });

  const chartRequestData = {
    xField,
    yField: 'mAP',
    dataDistribution,
    orderByParam: orderBy,
    orderParams: order,
    innerSplit: INTERNAL_SPLIT,
  };

  const { multiChartsResponse, isLoading, error } = useGetChart({
    params,
    func: async (x) => await api.getMeanAveragePrecision(x),
  });

  const filteredMultiChartsResponse: MultiChartsResponse | undefined =
    useMemo(() => {
      if (showAllAvgPrecisions || showAllAvgPrecisions === undefined) {
        return multiChartsResponse;
      }
      if (!multiChartsResponse) {
        return undefined;
      }
      return {
        charts: multiChartsResponse.charts.map((chart) => {
          const filteredPoints = chart.data.data.filter(
            (d) => d.data.label === 'Mean Average Precision',
          );
          return {
            ...chart,
            data: {
              ...chart.data,
              data: filteredPoints,
            },
          };
        }),
      };
    }, [multiChartsResponse, showAllAvgPrecisions]);

  return (
    <MultiCharts
      xyChartsResponse={filteredMultiChartsResponse}
      chartRequestData={chartRequestData}
      onFiltersChange={onFiltersChange}
      localFilters={localFilters}
      chartType={chartType}
      autoScaleY={autoScaleY}
      isLoading={isLoading}
      error={error}
      horizontalSplit={horizontalSplit ?? null}
      verticalSplit={verticalSplit ?? null}
    />
  );
}
