import {
  createContext,
  FC,
  PropsWithChildren,
  useContext,
  useEffect,
  useMemo,
  useState,
} from 'react';
import {
  DataStateType,
  ScatterViz,
  ScatterVizDataState,
  SampleIdentity,
  VisualizationResponse,
  MutualInformationElement,
  NumberOrString,
} from '@tensorleap/api-client';

import api from '../core/api-client';

import { useMergedObject } from '../core/useMergedObject';
import {
  extractScatterSettingOptions,
  ScatterViewSettingValues,
  settingValuesWithDefault,
  useScattersAssets,
} from './dashlet/PopulationExploration/hooks';
import { useDashletScatterContext } from './dashlet/PopulationExploration/DashletScatterContext';
import { UseMultiSelect, useMultiSelect } from '../core/useMultiSelect';
import { Setter } from '../core/types';
import { SortTypeEnum } from '../ui/charts/legend/LabelsLegendMenu';
import { useToggle } from '../core/useToggle';
import { DEFAULT_TRUNCATE_LONG_TAIL } from '../ui/charts/legend/LabelsLegend';
import { uniq } from 'lodash';
import { FindValue } from './dashlet/PopulationExploration/FindSamples';
import { ScaleType } from '../ui/charts/visualizers/ChartBlocks/scale';

export const VISUALIZATION_PAYLOAD_FILE_NAME = 'payload.json';

const NUM_OF_COLOR_RANGE_VALUES_TO_COUNT_AS_CATEGORY = 10;
const NUM_OF_SIZE_RANGE_VALUES_TO_COUNT_AS_CATEGORY = 4;

export enum Plane {
  XY = 'XY',
  XZ = 'XZ',
  YZ = 'YZ',
}

export type HoveredLegendFilter = {
  key: string;
  value: NumberOrString | NumberOrString[];
};

export interface LegendData {
  isLoading: boolean;
  colorField?: string;
  sizeField?: string;
  isColorFieldRangeable: boolean;
  isSizeFieldRangeable: boolean;
  scaleMethod?: ScaleType;
}

export interface ScatterDataContextProps {
  scatterData: ScatterVizDataState;
  epoch: number;
  sessionRunId: string;
  visualizationUUID: string;
  title: string;
  sample?: SampleIdentity;
  samplesIdsWithAssets: Set<string>;
  settings: ScatterViewSettingValues;
  settingsOptions: ReturnType<typeof extractScatterSettingOptions>;
  scatterSampleVisualizationsPrefix: string;
  visualizationDisplays: VisualizationDisplay[];
  selection: UseMultiSelect<number>;
  scatterMode: ScatterMode;
  miByCluster?: Record<string, Record<string, MutualInformationElement[]>>;
  domainGapMetadata?: Record<string, Record<string, number>>;
  clusterBlobPaths?: Record<string, Record<string, string>>;
  setScatterMode: Setter<ScatterMode>;
  setPressedScatterMode: Setter<ScatterMode | undefined>;
  showLegendNames: boolean;
  toggleShowLegendNames: () => void;
  legendTruncatedLongtail: number;
  setLegendTruncatedLongtail: Setter<number>;
  sizeOrShapeOrderMethod: SortTypeEnum;
  setSizeOrShapeOrderMethod: Setter<SortTypeEnum>;
  legendHovered?: HoveredLegendFilter;
  fieldsUniqueCount: Map<string, number>;
  setLegendHovered: Setter<HoveredLegendFilter | undefined>;
  legendData?: LegendData;
  markedIndexes: Set<number>;
}

export const contextDefaults: ScatterDataContextProps = {
  scatterData: {
    data_state: DataStateType.Test,
    scatter_data: [],
    samples: [],
    metadata: {},
  },
  title: 'Sample Selection',
  epoch: 0,
  sessionRunId: '',
  visualizationUUID: '',
  samplesIdsWithAssets: new Set(),
  scatterSampleVisualizationsPrefix: '',
  visualizationDisplays: [],
  settings: {
    sizeOrShape: '',
    dotColor: '',
    previewBy: undefined,
  },
  settingsOptions: {
    sizeOrShape: [],
    dotColor: [],
    previewBy: [],
    domainGapMetadataOptions: [],
  },
  selection: null as unknown as UseMultiSelect<number>,
  scatterMode: 'grab',
  setScatterMode: () => {},
  setPressedScatterMode: () => {},
  showLegendNames: false,
  toggleShowLegendNames: () => {},
  legendTruncatedLongtail: DEFAULT_TRUNCATE_LONG_TAIL,
  setLegendTruncatedLongtail: () => {},
  sizeOrShapeOrderMethod: SortTypeEnum.ASC_ALPHABETICALLY,
  setSizeOrShapeOrderMethod: () => {},
  fieldsUniqueCount: new Map(),
  setLegendHovered: () => {},
  markedIndexes: new Set(),
};
const ScatterDataContext =
  createContext<ScatterDataContextProps>(contextDefaults);

export type ScatterMode = 'grab' | 'box-selection' | 'magic-selection';

type VisualizationDisplay = { visType: string; visName: string };

type ScatterDataProviderProps = PropsWithChildren<{
  projectId: string;
  epoch: number;
  sessionRunId: string;
  scatterVisualization: VisualizationResponse;
}>;

export const ScatterDataProvider: FC<ScatterDataProviderProps> = ({
  children,
  projectId,
  epoch,
  sessionRunId,
  scatterVisualization,
}): JSX.Element => {
  const {
    scatterSampleVisualizationsPrefix,
    visualizationDisplays,
    samplesIdsWithAssets,
  } = useScattersAssets({ projectId, sessionRunId, epoch });
  const payload = useMemo(
    () => scatterVisualization.data.payload[0] as ScatterViz,
    [scatterVisualization.data.payload],
  );

  const {
    viewSettingsValues: settingsValues,
    register,
    unregister,
    findActive,
    findQuery,
  } = useDashletScatterContext();
  const scatterData = payload.scatter_data;
  const [_scatterMode, setScatterMode] = useState<ScatterMode>('grab');
  const [pressedScatterMode, setPressedScatterMode] = useState<ScatterMode>();
  const scatterMode = pressedScatterMode || _scatterMode;
  const selection = useMultiSelect<number>();

  const fieldsUniqueCount = useMemo(
    () =>
      scatterData.metadata
        ? new Map(
            Object.keys(scatterData.metadata).map((key) => [
              key,
              uniq(scatterData.metadata[key].body).length,
            ]),
          )
        : new Map(),
    [scatterData.metadata],
  );

  const settingsOptions = useMemo(
    () => extractScatterSettingOptions(payload, visualizationDisplays || []),
    [payload, visualizationDisplays],
  );

  const settings = useMemo(
    () => settingValuesWithDefault(settingsValues, settingsOptions),
    [settingsValues, settingsOptions],
  );

  const { dotColor, sizeOrShape } = settings;

  const [legendData, setLegendData] = useState<LegendData>();

  useEffect(() => {
    setLegendData((curr) =>
      curr === undefined ? undefined : { ...curr, isLoading: true },
    );

    const fieldsToCheck = [dotColor, sizeOrShape].filter(
      (f) => f !== undefined && scatterData.metadata[f]?.type === 'range',
    ) as string[];

    let isColorFieldRangeable = false;
    let isSizeFieldRangeable = false;

    if (!fieldsToCheck.length) {
      setLegendData({
        isLoading: false,
        colorField: dotColor,
        sizeField: sizeOrShape,
        isColorFieldRangeable,
        isSizeFieldRangeable,
      });
      return;
    }

    let isLatestEffectCall = true;

    const fetchOptions = async (fields: string[]) =>
      api.getFieldsValues({
        projectId,
        fields: fields.map((f) => ({
          field: f as string,
          type: 'number',
          size: 12,
        })),
        sessionRunIds: [sessionRunId],
        filters: [],
      });

    fetchOptions(fieldsToCheck)
      .then((res) => {
        if (!isLatestEffectCall) return;

        res.results.forEach(({ field, values }) => {
          const valuesLength = values.length;
          if (field === dotColor) {
            isColorFieldRangeable =
              valuesLength === 0 ||
              valuesLength >= NUM_OF_COLOR_RANGE_VALUES_TO_COUNT_AS_CATEGORY;
          }
          if (field === sizeOrShape) {
            isSizeFieldRangeable =
              valuesLength === 0 ||
              valuesLength >= NUM_OF_SIZE_RANGE_VALUES_TO_COUNT_AS_CATEGORY;
          }
        });

        setLegendData({
          isLoading: false,
          colorField: dotColor,
          sizeField: sizeOrShape,
          isColorFieldRangeable,
          isSizeFieldRangeable,
        });
      })
      .catch((e) => {
        if (!isLatestEffectCall) return;
        console.error('Failed to fetchOptions ', e);
        setLegendData({
          isLoading: false,
          colorField: dotColor,
          sizeField: sizeOrShape,
          isColorFieldRangeable: false,
          isSizeFieldRangeable: false,
        });
      });

    return () => {
      isLatestEffectCall = false;
    };
  }, [projectId, sessionRunId, dotColor, sizeOrShape, scatterData.metadata]);

  const [showLegendNames, toggleShowLegendNames] = useToggle(true);
  const [legendHovered, setLegendHovered] = useState<HoveredLegendFilter>();
  const [legendTruncatedLongtail, setLegendTruncatedLongtail] =
    useState<number>(DEFAULT_TRUNCATE_LONG_TAIL);
  const [sizeOrShapeOrderMethod, setSizeOrShapeOrderMethod] = useState(
    SortTypeEnum.ASC_ALPHABETICALLY,
  );

  useEffect(() => {
    register(payload.guid, settingsOptions, scatterData);
    return () => unregister(payload.guid);
  }, [settingsOptions, payload.guid, scatterData, register, unregister]);

  const markedIndexes = useMarkedIndexes(scatterData, findActive, findQuery);

  const value = useMergedObject({
    visualizationUUID: payload.guid,
    clusterBlobPaths: payload.scatter_data.clusters_blob_path,
    miByCluster: payload.scatter_data?.mi_by_cluster,
    domainGapMetadata: payload.scatter_data?.domain_gap_dist,
    scatterData,
    epoch,
    sessionRunId,
    settings,
    settingsOptions,
    title:
      /**
       * NOTE:
       * This is only done because of the storybook mocks being incomplete,
       * once they'll be updated to real `node-server` responses then this'll be removed.
       */
      scatterVisualization.info?.analyze_type?.replace(/_/g, ' ') ||
      contextDefaults.title,
    samplesIdsWithAssets,
    scatterSampleVisualizationsPrefix,
    visualizationDisplays,
    selection,
    scatterMode,
    setScatterMode,
    setPressedScatterMode,
    showLegendNames,
    toggleShowLegendNames,
    legendTruncatedLongtail,
    setLegendTruncatedLongtail,
    sizeOrShapeOrderMethod,
    setSizeOrShapeOrderMethod,
    fieldsUniqueCount,
    legendHovered,
    setLegendHovered,
    legendData,
    markedIndexes,
  });
  return (
    <ScatterDataContext.Provider value={value}>
      {children}
    </ScatterDataContext.Provider>
  );
};
ScatterDataContext.displayName = 'ScatterDataContext';

export function useScatterData(): ScatterDataContextProps {
  return useContext(ScatterDataContext);
}

function useMarkedIndexes(
  scatterData: ScatterVizDataState,
  findActive: boolean,
  findQuery: FindValue,
) {
  return useMemo(() => {
    if (!findActive || !findQuery.value) return new Set<number>();

    const markedIndexes = new Set<number>();
    const query = findQuery.value.toLowerCase();
    const metadata = scatterData.metadata || {};
    const keys = Object.keys(metadata);

    const matchesQuery = findQuery.exact
      ? (value: NumberOrString) => String(value).toLowerCase() === query
      : (value: NumberOrString) => String(value).toLowerCase().includes(query);

    scatterData.samples.forEach((sample, index) => {
      if (findQuery.key) {
        const value = metadata[findQuery.key]?.body[index];
        if (value !== undefined && matchesQuery(value)) {
          markedIndexes.add(index);
        }
        return;
      }

      if (matchesQuery(sample.index)) {
        markedIndexes.add(index);
        return;
      }

      for (const key of keys) {
        const value = metadata[key]?.body[index];
        if (value !== undefined && value !== null) {
          markedIndexes.add(index);
          break;
        }
      }
    });

    return markedIndexes;
  }, [scatterData.metadata, scatterData.samples, findActive, findQuery]);
}
