import { ChartLoading } from './charts/common/ChartLoading';
import clsx from 'clsx';
import {
  AggregationMethod,
  AnalyticsDashletType,
  MultiChartsParams,
  OrderType,
} from '@tensorleap/api-client';
import { MultiCharts } from './charts/visualizers/MultiCharts';
import { useMemo } from 'react';
import { useGetChart } from '../core/data-fetching/getChart';
import api from '../core/api-client';
import { sessionRunDataAndEpochToSessionRunsToEpochs } from '../dashboard/dashlet/Analytics/ElasticVis/utils';
import { useVersionControl } from '../core/VersionControlContext';

const XFIELD = 'epoch';
const YFIELD = 'metrics.loss';
const AGGREGATION_METHOD = AggregationMethod.Average;
const DATA_DISTRIBUTION_TYPE = 'continuous' as const;
const X_AXIS_SIZE_INTERVAL = 1;

export type SelectEpochLineChartProps = {
  className?: string;
  sessionRunId?: string;
  projectId: string;
};

export function SelectEpochLineChart({
  className,
  sessionRunId,
  projectId,
}: SelectEpochLineChartProps) {
  const { selectedSessionRunMap } = useVersionControl();
  const sessionRunsToEpochs = useMemo(() => {
    const sessionRun = sessionRunId
      ? selectedSessionRunMap.get(sessionRunId)
      : undefined;

    return sessionRun
      ? sessionRunDataAndEpochToSessionRunsToEpochs(sessionRun)
      : [];
  }, [selectedSessionRunMap, sessionRunId]);

  const params = useMemo<MultiChartsParams>(
    () => ({
      projectId,
      x: {
        field: XFIELD,
        distribution: DATA_DISTRIBUTION_TYPE,
        order: OrderType.Asc,
        interval: X_AXIS_SIZE_INTERVAL,
        limit: null,
      },
      y: { field: YFIELD, aggregation: AGGREGATION_METHOD },
      sessionRunsToEpochs,
      showAllEpochs: true,
    }),
    [projectId, sessionRunsToEpochs],
  );
  const chartRequestData = {
    xField: XFIELD,
    yField: YFIELD,
    dataDistribution: DATA_DISTRIBUTION_TYPE,
  };

  const { multiChartsResponse, isLoading } = useGetChart({
    params,
    func: async (x) => await api.getXYChart(x),
  });

  return (
    <div
      className={clsx(
        'border-gray-500 border border-solid p-1 pt-2 flex flex-col bg-gray-850 rounded-lg h-full w-full',
        className,
      )}
    >
      <span className="text-sm font-bold text-center uppercase">
        average loss over epoch
      </span>
      {isLoading || !multiChartsResponse ? (
        <ChartLoading />
      ) : (
        <MultiCharts
          xyChartsResponse={multiChartsResponse}
          chartRequestData={chartRequestData}
          chartType={AnalyticsDashletType.Line}
          autoScaleY={false}
          showLegend={false}
          isLoading={false}
          error={undefined}
          horizontalSplit={null}
          verticalSplit={null}
        />
      )}
    </div>
  );
}
