import api from '../../../../core/api-client';
import { useMemo } from 'react';
import { XYChartVizProps } from './interfaces';
import { MultiCharts } from '../../../../ui/charts/visualizers/MultiCharts';
import { AggregationMethod, MultiChartsParams } from '@tensorleap/api-client';
import {
  chartSplitToSplitAgg,
  getSplitLabels,
  selectedSessionRunToSessionRunsToEpochs,
  toIntervalOrLimit,
} from './utils';
import { useGetChart } from '../../../../core/data-fetching/getChart';
import { ChartRequestData } from '../../../../ui/charts/common/interfaces';

type XYVizProps = XYChartVizProps & {
  preview?: boolean;
};

export function XYViz({
  graphParams,
  filters,
  sessionRuns,
  onFiltersChange,
  chartType,
  projectId,
  preview,
  className,
}: XYVizProps) {
  const params = useMemo<MultiChartsParams>(() => {
    const {
      xAxis: xField,
      yAxis: yField,
      dataDistribution: dataDistributionType,
      orderBy,
      order,
      xAxisSizeInterval,
      modelIdPosition,
      firstSplit,
      secondSplit,
      showAllEpochs,
      aggregation: aggregationMethod,
    } = graphParams;

    const { verticalSplit, horizontalSplit, innerSplit } = getSplitLabels(
      modelIdPosition,
      firstSplit,
      secondSplit
    );

    const orderField = orderBy === '1' ? yField : orderBy;

    const sessionRunsToEpochs = selectedSessionRunToSessionRunsToEpochs(
      sessionRuns
    );

    const p: MultiChartsParams = {
      projectId,
      x: chartSplitToSplitAgg(
        {
          field: xField,
          distribution: dataDistributionType,
          orderField,
          order,
          ...toIntervalOrLimit(dataDistributionType, Number(xAxisSizeInterval)),
        },
        null
      ),
      y: {
        field: yField,
        aggregation: aggregationMethod || AggregationMethod.Average,
      },
      filters,
      verticalSplit,
      horizontalSplit,
      innerSplit,
      sessionRunsToEpochs,
      showAllEpochs,
    };
    return p;
  }, [filters, graphParams, projectId, sessionRuns]);

  const chartRequestData: ChartRequestData = {
    xField: graphParams.xAxis,
    yField: graphParams.yAxis,
    innerSplit: params.innerSplit,
    dataDistribution: graphParams.dataDistribution,
    orderByParam: graphParams.orderBy,
    orderParams: params.x.order,
    xSizeInterval: graphParams.xAxisSizeInterval,
  };

  const { multiChartsResponse, isLoading, error } = useGetChart({
    params,
    func: async (x) => await api.getXYChart(x),
  });

  return (
    <MultiCharts
      xyChartsResponse={multiChartsResponse}
      chartRequestData={chartRequestData}
      onFiltersChange={onFiltersChange}
      filters={filters}
      chartType={chartType}
      autoScaleY={graphParams.autoScaleY}
      isLoading={isLoading}
      horizontalSplit={params.horizontalSplit ?? null}
      verticalSplit={params.verticalSplit ?? null}
      error={error}
      preview={preview}
      className={className}
    />
  );
}
