import { AnalyticsDashletType } from '@tensorleap/api-client';
import { useMemo } from 'react';
import { VisualizationFilter } from '../../../../core/types/filters';
import { mapToEsFilters } from '../../../../model-tests/modelTestHelpers';
import { UnifiedXYChartParams } from '../form/utils';
import { ConfusionMatrixVis } from './ConfusionMatrixVis';
import { HeatmapViz } from './HeatmapVis';
import { TableViz } from './TableVis';
import { XYChartVizProps } from './interfaces';
import { XYViz } from './XYViz';
import { useCurrentProject } from '../../../../core/CurrentProjectContext';
import clsx from 'clsx';
import {
  ConfusionMatrixType,
  ConfusionMatrixTypeEnum,
} from '../form/ConfusionMatrix';
import { ConfusionMatrixTableVis } from './ConfusionMatrixTableVis';
import { SelectedSessionRun } from '../../../../ui/molecules/useModelFilter';

function getVisualizationComponent(
  graphType: AnalyticsDashletType,
  subType?: ConfusionMatrixType
): (xYChartVizProps: XYChartVizProps) => JSX.Element {
  switch (graphType) {
    case AnalyticsDashletType.Bar:
    case AnalyticsDashletType.Line:
    case AnalyticsDashletType.Area:
    case AnalyticsDashletType.Donut:
      return XYViz;
    case AnalyticsDashletType.Table:
      return TableViz;
    case AnalyticsDashletType.Heatmap:
      return HeatmapViz;

    case AnalyticsDashletType.ConfusionMatrix:
      switch (subType) {
        case ConfusionMatrixTypeEnum.F1:
        case ConfusionMatrixTypeEnum.BalancedAccuracy:
        case ConfusionMatrixTypeEnum.PrCurve:
        case ConfusionMatrixTypeEnum.Roc:
          return ConfusionMatrixVis;
        case ConfusionMatrixTypeEnum.ConfusionMatrixTable:
          return ConfusionMatrixTableVis;
      }
      return ConfusionMatrixVis;
    default:
      throw new Error(`Unsupported graph type: ${graphType}`);
  }
}

export type ElasticVisProps = {
  sessionRuns: SelectedSessionRun[];
  filters: VisualizationFilter[];
  className?: string;
  graphType: AnalyticsDashletType;
  graphParams: UnifiedXYChartParams;
  onFiltersChange: (_: VisualizationFilter[]) => void;
};
export function ElasticVis({
  sessionRuns,
  graphType,
  className,
  graphParams,
  filters,
  onFiltersChange,
}: ElasticVisProps): JSX.Element {
  const { fetchValidProjectCid } = useCurrentProject();
  const projectId = fetchValidProjectCid();

  const esFilters = mapToEsFilters(filters);
  const VizComponent = useMemo(
    () => getVisualizationComponent(graphType, graphParams?.type),
    [graphParams.type, graphType]
  );

  return (
    <div
      className={clsx('flex h-full w-full bg-gray-900', className)}
      onMouseDown={(event) => event.stopPropagation()}
    >
      <div className="flex w-full h-full">
        <VizComponent
          projectId={projectId}
          filters={esFilters}
          graphParams={graphParams}
          sessionRuns={sessionRuns}
          onFiltersChange={onFiltersChange}
          chartType={graphType}
        />
      </div>
    </div>
  );
}
