import {
  SampleIdentity,
  ScatterViz,
  VisualizationResponse,
} from '@tensorleap/api-client';
import { PreviewContainer } from '../common/PreviewsContainer';
import { SelectedSessionRun } from '../../../ui/molecules/useModelFilter';
import { useCallback, useMemo, useRef, useState } from 'react';
import { GraphIcon, Refresh2Icon, VisualizeIcon } from '../../../ui/icons';
import clsx from 'clsx';
import { compactNumberFormatter } from '../../../core/formatters/number-formatting';
import { Popover, Tooltip } from '@material-ui/core';
import { VisualizationRunningPlaceholder } from '../common/Placeholders';
import api from '../../../core/api-client';
import { usePopulationExploration } from './usePopulationExploration';
import { useFetchScatterVisualizationsMap } from '../../../core/data-fetching/scatter-visualizations-map';
import { NoDataChart } from '../../../ui/charts/common/NoDataChart';
import { useDashletScatterContext } from './DashletScatterContext';
import { MouseHandler } from '../../../core/useSelectionGroup';
import { TOUR_SELECTORS_ENUM } from '../../../tour/ToursConfig';
import { Button } from '../../../ui/atoms/Button';
import {
  usePopupState,
  bindTrigger,
  bindPopover,
} from 'material-ui-popup-state/hooks';
import { ConfirmDialog } from '../../../ui/atoms/DeleteContentDialog';
import { ScatterDataProvider } from '../../ScatterDataContext';
import { ScatterAnalyzerItem } from '../../ScatterAnalyzerView';

interface ScatterAnalyzerViewProps {
  sessionRun: SelectedSessionRun;
  dashletId: string;
  mouseHandler: MouseHandler;
  className?: string;
}

export function ScatterAnalyzerView({
  sessionRun,
  className,
  dashletId,
  mouseHandler,
}: ScatterAnalyzerViewProps): JSX.Element {
  const {
    projectId,
    settings: { projectionMetric, displayParams, domainGapMetadata },
    filters: { dashletAndGlobalFilters: localAndGlobalFilters },
  } = useDashletScatterContext();

  const { fullVisualization, loadingStatus, lastReadyDigest, retry } =
    usePopulationExploration({
      projectId,
      sessionRunId: sessionRun.id,
      filters: localAndGlobalFilters,
      dashletId,
      projectionMetric,
      domainGapMetadata,
      displayParams,
      epoch: sessionRun.epochsState.selectedEpoch,
    });

  const [selectedSampleIdentities, setSelectedSampleIdentities] = useState<
    Array<SampleIdentity>
  >([]);

  const handleSelectedSamleIdentitiesChange = useCallback(
    (sampleIdentities: Array<SampleIdentity>) =>
      setSelectedSampleIdentities(sampleIdentities),
    [],
  );

  const headerLoadingType =
    loadingStatus === 'updating' || loadingStatus === 'refreshing'
      ? loadingStatus
      : undefined;

  return (
    <PreviewContainer
      flex
      className={className}
      sessionRun={sessionRun}
      loadingType={headerLoadingType}
      header={
        <div className="flex flex-row h-full gap-2 items-center relative justify-center w-fit">
          {fullVisualization && (
            <CreateVisualizationsButton
              sessionRun={sessionRun}
              epoch={sessionRun.epochsState.selectedEpoch}
              projectId={projectId}
              scatterVisualization={fullVisualization}
              digest={lastReadyDigest}
              selectedSampleIdentities={selectedSampleIdentities}
            />
          )}
        </div>
      }
    >
      {fullVisualization ? (
        <ScatterDataProvider
          scatterVisualization={fullVisualization}
          projectId={projectId}
          epoch={sessionRun.epochsState.selectedEpoch}
          sessionRunId={sessionRun.id}
        >
          <ScatterAnalyzerItem
            projectId={projectId}
            sessionRunId={sessionRun.id}
            epoch={sessionRun.epochsState.selectedEpoch}
            mouseHandler={mouseHandler}
            onSelectedSampleIdentitiesChange={
              handleSelectedSamleIdentitiesChange
            }
            className="flex-1"
          />
        </ScatterDataProvider>
      ) : loadingStatus === 'loading' ? (
        <VisualizationRunningPlaceholder
          processName="processing..."
          tourId={TOUR_SELECTORS_ENUM.POPULATION_EXPLORATION_PROCESSING_ID}
        />
      ) : loadingStatus === 'error' ? (
        <div
          className={clsx(
            'flex flex-col flex-1 min-h-[200px] justify-center items-center',
            className,
          )}
        >
          <GraphIcon className="text-gray-700" />
          <span className="text-sm text-gray-400">
            Population Exploration creation failed
          </span>
          <Button onClick={retry} variant="text">
            Retry
          </Button>
        </div>
      ) : (
        <NoDataChart />
      )}
    </PreviewContainer>
  );
}

interface CreateVisualizationsButtonProps {
  projectId: string;
  sessionRun: SelectedSessionRun;
  epoch: number;
  scatterVisualization?: VisualizationResponse;
  selectedSampleIdentities: Array<SampleIdentity>;
  digest?: string;
}
function CreateVisualizationsButton({
  projectId,
  epoch,
  sessionRun,
  scatterVisualization,
  selectedSampleIdentities,
  digest,
}: CreateVisualizationsButtonProps): JSX.Element {
  const ref = useRef<HTMLButtonElement>(null);
  const popoverState = usePopupState({
    variant: 'popover',
    popupId: 'createVisualizationPopover',
    disableAutoFocus: false,
  });

  const { scatterVisualizationsMapResponse } = useFetchScatterVisualizationsMap(
    projectId,
    sessionRun.id,
    epoch,
  );

  const allSampleIds = useMemo(
    () =>
      (scatterVisualization?.data.payload[0] as ScatterViz)?.scatter_data
        ?.samples || [],
    [scatterVisualization?.data.payload],
  );

  const samplesToVisualize = useMemo(() => {
    const samplesIds = new Set(scatterVisualizationsMapResponse?.samplesIds);

    return allSampleIds.filter(
      ({ index, state }) => !samplesIds.has(`${state}_${index}`),
    );
  }, [allSampleIds, scatterVisualizationsMapResponse?.samplesIds]);

  const visualizeHeader = useMemo(() => {
    return samplesToVisualize.length
      ? `Visualize ${compactNumberFormatter.format(
          samplesToVisualize.length,
        )} samples`
      : '';
  }, [samplesToVisualize.length]);

  const generateScatterImages = useCallback(async () => {
    if (!digest) return;

    const hasVisualizations =
      (scatterVisualizationsMapResponse?.samplesIds || []).length > 0;

    if (hasVisualizations || selectedSampleIdentities.length > 0) {
      popoverState.open();
      return;
    }
    await api.createSamplesVisualizations({
      projectId,
      sessionRunId: sessionRun.id,
      epoch,
      sampleIdentities: samplesToVisualize,
      digest,
    });
  }, [
    digest,
    epoch,
    popoverState,
    projectId,
    samplesToVisualize,
    scatterVisualizationsMapResponse?.samplesIds,
    selectedSampleIdentities.length,
    sessionRun.id,
  ]);

  return (
    <>
      <Tooltip title={visualizeHeader}>
        <Button
          className={clsx(
            'flex uppercase text-xs h-2',
            samplesToVisualize.length && 'text-warning-300',
          )}
          {...bindTrigger(popoverState)}
          ref={ref}
          variant="text"
          onClick={generateScatterImages}
          disabled={!digest}
          tourId={
            TOUR_SELECTORS_ENUM.POPULATION_EXPLORATION_VISUALIZE_BUTTON_ID
          }
        >
          <span className="mr-1 text-xs">Visualize</span>
          <VisualizeIcon className="w-5 h-5" />
        </Button>
      </Tooltip>
      <Popover
        {...bindPopover(popoverState)}
        anchorOrigin={{ vertical: 'bottom', horizontal: 'left' }}
        anchorEl={ref.current}
        classes={{
          paper: 'bg-gray-850 border border-gray-700',
        }}
      >
        <RefreshScatterVisualizationsDialog
          onClose={popoverState.close}
          digest={digest || ''}
          epoch={epoch}
          projectId={projectId}
          sessionRunId={sessionRun.id}
          sampleToVisualize={samplesToVisualize}
          allSampleIdentities={allSampleIds}
          selectedSampleIdentities={selectedSampleIdentities}
        />
      </Popover>
    </>
  );
}

interface RefreshScatterVisualizationsDialogProps {
  onClose: () => void;
  projectId: string;
  sessionRunId: string;
  epoch: number;
  allSampleIdentities: Array<SampleIdentity>;
  sampleToVisualize: Array<SampleIdentity>;
  selectedSampleIdentities: Array<SampleIdentity>;
  digest: string;
}
function RefreshScatterVisualizationsDialog({
  onClose,
  projectId,
  sessionRunId,
  epoch,
  allSampleIdentities,
  sampleToVisualize,
  selectedSampleIdentities,
  digest,
}: RefreshScatterVisualizationsDialogProps): JSX.Element {
  const handleCreateSamplesVisualizations = useCallback(
    async (refresh: boolean, samplesToVisualize: Array<SampleIdentity>) => {
      await api.createSamplesVisualizations({
        projectId,
        sessionRunId,
        epoch,
        sampleIdentities: samplesToVisualize,
        digest,
        refresh,
      });
      onClose();
    },
    [projectId, sessionRunId, epoch, digest, onClose],
  );

  const [isRevisualizeDialogOpen, setIsRevisualizeDialogOpen] = useState(false);

  const openRevisualizeDialog = useCallback(
    () => setIsRevisualizeDialogOpen(true),
    [],
  );

  const handleCloseRevisualizeDialog = useCallback(
    () => setIsRevisualizeDialogOpen(false),
    [],
  );

  const handleConfirmRevisualizeDialog = useCallback(() => {
    handleCreateSamplesVisualizations(true, allSampleIdentities);
    handleCloseRevisualizeDialog();
  }, [
    handleCreateSamplesVisualizations,
    allSampleIdentities,
    handleCloseRevisualizeDialog,
  ]);

  const unvisualizedSelectedSampleIdentities = useMemo(() => {
    const sampleToVisualizeSet = new Set(
      sampleToVisualize.map((s) => `${s.state}_${s.index}`),
    );
    return selectedSampleIdentities.filter((ss) =>
      sampleToVisualizeSet.has(`${ss.state}_${ss.index}`),
    );
  }, [selectedSampleIdentities, sampleToVisualize]);

  return (
    <>
      <div className="flex flex-col gap-4 p-2">
        {sampleToVisualize.length > 0 && (
          <Tooltip
            title={`Visualize ${compactNumberFormatter.format(
              sampleToVisualize.length,
            )} samples`}
          >
            <Button
              variant="outline"
              onClick={() =>
                handleCreateSamplesVisualizations(false, sampleToVisualize)
              }
            >
              Visualize rest
            </Button>
          </Tooltip>
        )}
        {selectedSampleIdentities.length > 0 && (
          <Tooltip
            title={`Visualize ${unvisualizedSelectedSampleIdentities.length} samples`}
          >
            <Button
              variant="outline"
              disabled={unvisualizedSelectedSampleIdentities.length === 0}
              onClick={() =>
                handleCreateSamplesVisualizations(
                  false,
                  unvisualizedSelectedSampleIdentities,
                )
              }
            >
              Visualize selected
            </Button>
          </Tooltip>
        )}
        <Tooltip
          title={`Visualize ${compactNumberFormatter.format(
            allSampleIdentities.length,
          )} samples`}
        >
          <Button variant="outline" onClick={openRevisualizeDialog}>
            Revisualize all
          </Button>
        </Tooltip>
      </div>

      <ConfirmDialog
        title="This action will delete all existing visualizations. Proceeding will initiate a new visualization process for the selected samples. Are you sure you want to continue?"
        isOpen={isRevisualizeDialogOpen}
        onClose={handleCloseRevisualizeDialog}
        onConfirm={handleConfirmRevisualizeDialog}
        confirmButtonText="Continue"
        confirmButtonIcon={<Refresh2Icon />}
      />
    </>
  );
}
