import { SampleIdentity, SlimVisualization } from '@tensorleap/api-client';
import { SelectedSessionRun } from '../../../ui/molecules/useModelFilter';
import { useCallback, useEffect, useMemo, useState } from 'react';
import { useMergedObject } from '../../../core/useMergedObject';
import { groupBy } from 'lodash';
import { useFetchSessionRunsVisualization } from '../../../core/data-fetching/sessionVisualizations';
import { usePushNotifications } from '../../../core/PushNotificationsContext';
import { isSampleAnalysisMsg } from '../../../core/websocket-message-types';

export type SampleVisualizations = {
  id: SampleIdentity;
  slimVisualizationsPerSessionId: Record<string, SlimVisualization[]>;
};

export type UseSampleListState = {
  selectedSessionRuns: SelectedSessionRun[];
  activeSamples: SampleVisualizations[];
  onSampleSelect: (
    sample: SampleIdentity,
    multiSelectType: MultiSelectType,
  ) => void;
  samples: SampleVisualizations[];
  isLoading: boolean;
  selectAllSamples: () => void;
  selectFirstSample: () => void;
  refetchSlimVisualizations: () => void;
};

type SampleListStateProps = {
  activeSampleIds?: SampleIdentity[];
  onActiveSamplesChange: (samples: SampleIdentity[]) => void;
  projectId: string;
  selectedSessionRuns: SelectedSessionRun[];
  samplesOverrides?: SampleVisualizations[];
};

export enum MultiSelectType {
  Single,
  Control,
  Shift,
}

export function useSampleListState({
  activeSampleIds,
  onActiveSamplesChange,
  projectId,
  selectedSessionRuns,
  samplesOverrides,
}: SampleListStateProps): UseSampleListState {
  const selectedSessionRunsIds = useMemo(
    () => selectedSessionRuns.map(({ id }) => id),
    [selectedSessionRuns],
  );
  const { slimVisualizations, refetch: refetchSlimVisualizations } =
    useFetchSessionRunsVisualization(projectId, selectedSessionRunsIds);
  const { lastServerMessage } = usePushNotifications();

  useEffect(() => {
    if (isSampleAnalysisMsg(lastServerMessage)) {
      refetchSlimVisualizations();
    }
  }, [lastServerMessage, refetchSlimVisualizations]);

  const samples = useMemo(() => {
    if (samplesOverrides) {
      return samplesOverrides;
    }
    if (!slimVisualizations) {
      return [];
    }

    const bySampleId = groupBy(
      slimVisualizations.filter((v) => v.type == 'sample_analysis'),
      (v) =>
        JSON.stringify(
          (v.jobParms as { sampleIdentity: SampleIdentity })?.sampleIdentity ||
            {},
        ),
    );

    return Object.entries(bySampleId).map(
      ([sampleId, slimVisualizations]): SampleVisualizations => {
        const slimVisualizationsPerSessionId = groupBy(
          slimVisualizations,
          (v) => v.sessionRunId,
        );
        return {
          id: JSON.parse(sampleId) as SampleIdentity,
          slimVisualizationsPerSessionId,
        };
      },
    );
  }, [samplesOverrides, slimVisualizations]);

  const activeSamples = useMemo(() => {
    if (!activeSampleIds) {
      return [];
    }
    return samples.filter((s) =>
      activeSampleIds.some((id) => isSampleIdEquals(s.id, id)),
    );
  }, [activeSampleIds, samples]);

  const [lastSelectedIndex, setLastSelectedIndex] = useState<number | null>(
    null,
  );

  useEffect(() => {
    if (
      (!activeSampleIds || activeSampleIds.length === 0) &&
      samples.length > 0
    ) {
      if (!activeSampleIds?.some((id) => isSampleIdEquals(samples[0].id, id))) {
        onActiveSamplesChange([samples[0].id]);
      }
    }
  }, [activeSampleIds, samples, onActiveSamplesChange]);

  const onSampleSelect = useCallback(
    (sample: SampleIdentity, multiSelectType: MultiSelectType) => {
      const sampleIndex = samples.findIndex((s) =>
        isSampleIdEquals(s.id, sample),
      );
      let newActiveSamples: SampleVisualizations[];

      switch (multiSelectType) {
        case MultiSelectType.Single:
          newActiveSamples = [samples[sampleIndex]];
          break;
        case MultiSelectType.Control:
          if (activeSamples.some((s) => isSampleIdEquals(s.id, sample))) {
            newActiveSamples = activeSamples.filter(
              (s) => !isSampleIdEquals(s.id, sample),
            );
          } else {
            newActiveSamples = [...activeSamples, samples[sampleIndex]];
          }
          break;
        case MultiSelectType.Shift:
          if (lastSelectedIndex !== null) {
            const start = Math.min(lastSelectedIndex, sampleIndex);
            const end = Math.max(lastSelectedIndex, sampleIndex);
            const rangeSelection = samples.slice(start, end + 1);
            newActiveSamples = Array.from(
              new Set([...activeSamples, ...rangeSelection]),
            );
          } else {
            newActiveSamples = [samples[sampleIndex]];
          }
          break;
        default:
          newActiveSamples = activeSamples;
      }

      onActiveSamplesChange(newActiveSamples.map((s) => s.id));
      setLastSelectedIndex(sampleIndex);
    },
    [onActiveSamplesChange, activeSamples, samples, lastSelectedIndex],
  );

  const selectAllSamples = useCallback(() => {
    onActiveSamplesChange(samples.map((s) => s.id));
  }, [onActiveSamplesChange, samples]);

  const selectFirstSample = useCallback(() => {
    if (samples.length > 0) {
      onActiveSamplesChange([samples[0].id]);
    }
  }, [onActiveSamplesChange, samples]);

  return useMergedObject({
    selectedSessionRuns,
    activeSamples,
    onSampleSelect,
    samples,
    isLoading: false,
    selectAllSamples,
    selectFirstSample,
    refetchSlimVisualizations,
  });
}

export function isSampleIdEquals(
  sample: SampleIdentity,
  activeSample: SampleIdentity | undefined,
) {
  return (
    (sample.index === activeSample?.index &&
      sample.state === activeSample.state) ||
    (!activeSample && sample.state === 'unlabeled')
  );
}
