import {
  createContext,
  PropsWithChildren,
  useCallback,
  useContext,
  useMemo,
  useState,
} from 'react';
import { uniq } from 'lodash';
import { Setter } from '../../../core/types';
import { useMergedObject } from '../../../core/useMergedObject';
import {
  ScatterViewSettingValues,
  ScatterViewSettingOptions,
  ScatterViewSetting,
  settingValuesWithDefault,
} from './hooks';
import {
  DashboardFiltersParams,
  DashboardFiltersResponse,
  useDashletFilters,
} from '../useDashletFields';
import {
  PopulationExplorationDashletData,
  PopulationExplorationDisplayParams,
  ReductionAlgorithm,
  ScatterVizDataState,
} from '@tensorleap/api-client';
import { useLocalStorage } from '../../../core/useLocalStorage';
import { PopExpSettingsFields } from './PopExpSettings';
import { useToggle } from '../../../core/useToggle';
import { FindValue } from './FindSamples';
import { FilterEnum, FilterFieldMeta } from '../../../filters/helpers';
import { ScaleType } from '../../../ui/charts/visualizers/ChartBlocks/scale';

export const DEFAULT_NUM_OF_SAMPLES = 2000;
export const POP_EXP_DEFAULT_NAME = 'Population Exploration';
export const METADATA_PREFIX = 'metadata.';
export const METRICS_PREFIX = 'metrics.';
export function useScatterOptions(
  scatterSettingOptions: ScatterViewSettingOptions,
  scatterValues: ScatterViewSettingValues,
  setScatterValues: Setter<ScatterViewSettingValues>,
): ScatterViewSetting {
  return useMemo(() => {
    const sizeOrShape = {
      options: scatterSettingOptions.sizeOrShape,
      value: scatterValues.sizeOrShape,
      setOption: (option?: string) =>
        setScatterValues({ ...scatterValues, sizeOrShape: option }),
    };
    const dotColor = {
      options: scatterSettingOptions.dotColor,
      value: scatterValues.dotColor,
      setOption: (option?: string) =>
        setScatterValues({ ...scatterValues, dotColor: option }),
    };
    const previewBy = {
      options: scatterSettingOptions.previewBy,
      value: scatterValues.previewBy,
      setOption: (option?: string | null) =>
        setScatterValues({ ...scatterValues, previewBy: option }),
    };

    return {
      sizeOrShape,
      dotColor,
      previewBy,
    };
  }, [scatterSettingOptions, scatterValues, setScatterValues]);
}

export type ColorRangeViewSettings = {
  flipColor: boolean;
  scaleType: ScaleType;
};

const DEFAULT_COLOR_RANGE_VIEW_SETTINGS: ColorRangeViewSettings = {
  flipColor: false,
  scaleType: 'linear',
};

export interface DashletScatterContextValue {
  projectId: string;

  register: (
    key: string,
    options: ScatterViewSettingOptions,
    payload: ScatterVizDataState,
  ) => void;
  unregister: (key: string) => void;
  viewSettingsValues: ScatterViewSettingValues;
  viewSettings: ScatterViewSetting;
  settings: PopExpSettingsFields;
  filters: DashboardFiltersResponse;
  findActive: boolean;
  toggleFindActive: () => void;
  findQuery: FindValue;
  setFindQuery: Setter<FindValue>;
  findFields: FilterFieldMeta[];
  hoverSampleId?: string;
  setHoverSampleId: Setter<string | undefined>;
  dotColorViewSettings: ColorRangeViewSettings;
  setDotColorViewSettings: (
    viewSettings: Partial<ColorRangeViewSettings>,
  ) => void;
}

export const DEFAULT_DISPLAY_PARAMS: PopulationExplorationDisplayParams = {
  population_exploration_n_samples: DEFAULT_NUM_OF_SAMPLES,
  balance_by: [],
  should_fill_remaining_with_unbalance: true,
  reduction_algorithm: ReductionAlgorithm.Tsne,
};

export const contextDefaults: DashletScatterContextValue = {
  projectId: '',
  register: () => undefined,
  unregister: () => undefined,
  viewSettingsValues: {} as ScatterViewSettingValues,
  viewSettings: {} as ScatterViewSetting,
  settings: {
    dashletName: POP_EXP_DEFAULT_NAME,
    projectionMetric: undefined,
    projectionMetricsFieldsNames: [],
    domainGapMetadata: undefined,
    domainGapMetadataFieldsNames: [],
    balanceFieldsNames: [],
    displayParams: DEFAULT_DISPLAY_PARAMS,
  },
  filters: {} as DashboardFiltersResponse,
  findActive: false,
  toggleFindActive: () => undefined,
  findQuery: { value: '' },
  setFindQuery: () => undefined,
  findFields: [],
  hoverSampleId: '',
  setHoverSampleId: () => undefined,
  setDotColorViewSettings: () => undefined,
  dotColorViewSettings: DEFAULT_COLOR_RANGE_VIEW_SETTINGS,
};
const DashletScatterContext =
  createContext<DashletScatterContextValue>(contextDefaults);

type DashletScatterContextProviderProps = PropsWithChildren<{
  cid: string;
  projectId: string;
  data?: PopulationExplorationDashletData;
  filterProps: Omit<
    DashboardFiltersParams,
    'projectId' | 'useRegisteredFilters'
  >;
}>;

const DEFAULT_FIND_KEY = 'sample_id';

export function DashletScatterContextProvider({
  cid,
  children,
  projectId,
  data,
  filterProps,
}: DashletScatterContextProviderProps): JSX.Element {
  const [scatterOptionsByKey, setScatterOptionByKey] = useState<
    Record<string, ScatterViewSettingOptions>
  >({});
  const [scatterDataByKey, setScatterDataByKey] = useState<
    Record<string, ScatterVizDataState>
  >({});

  const scatterOptions = useMemo(() => {
    const allOptions = Object.values(scatterOptionsByKey);
    const options = {
      sizeOrShape: uniq(allOptions.flatMap((o) => o.sizeOrShape)),
      dotColor: uniq(allOptions.flatMap((o) => o.dotColor)),
      previewBy: uniq(allOptions.flatMap((o) => o.previewBy)),
      domainGapMetadataOptions: uniq(
        allOptions.flatMap((o) => o.domainGapMetadataOptions),
      ),
    };

    return options;
  }, [scatterOptionsByKey]);

  const findFields = useMemo<FilterFieldMeta[]>(() => {
    const allPayloads = Object.values(scatterDataByKey);
    const keys = uniq(allPayloads.flatMap((p) => Object.keys(p.metadata)));
    return keys.map((key) => {
      const values = uniq(
        allPayloads.flatMap((p) => p.metadata[key]?.body || []),
      );
      return {
        field: key,
        enum: values as FilterEnum<string>,
        type: 'string',
      };
    });
  }, [scatterDataByKey]);

  const [scatterValues, setScatterValues] =
    useLocalStorage<ScatterViewSettingValues>(`pe-settings-${cid}`, {
      sizeOrShape: 'metrics.loss',
      dotColor: 'metrics.loss',
      previewBy: null,
    });

  const dashletName =
    (data?.data?.name as string | undefined) || POP_EXP_DEFAULT_NAME;

  const valuesWithDefault = useMemo(
    () => settingValuesWithDefault(scatterValues, scatterOptions),
    [scatterValues, scatterOptions],
  );

  const viewSettings = useScatterOptions(
    scatterOptions,
    valuesWithDefault,
    setScatterValues,
  );

  const filters = useDashletFilters({
    projectId,
    ...filterProps,
  });

  const projectionMetric = data?.data?.projectionMetric as string | undefined;

  const projectionMetricsFieldsNames = useMemo(
    () =>
      filters.filterFieldsMeta
        .filter(
          ({ field }) =>
            field.startsWith(METRICS_PREFIX) ||
            field.startsWith(METADATA_PREFIX),
        )
        .map(({ field }) => field),
    [filters.filterFieldsMeta],
  );

  const domainGapMetadata = data?.data?.domainGapMetadata as string | undefined;

  const domainGapMetadataFieldsNames = useMemo(
    () => scatterOptions.domainGapMetadataOptions,
    [scatterOptions.domainGapMetadataOptions],
  );
  const displayParams = useMemo(() => {
    return {
      ...DEFAULT_DISPLAY_PARAMS,
      ...((data?.data?.displayParams as PopulationExplorationDisplayParams) ||
        undefined),
    };
  }, [data]);

  const balanceFieldsNames = useMemo(
    () =>
      filters.filterFieldsMeta
        .filter(({ field }) => field.startsWith('metadata.'))
        .map(({ field }) => field),
    [filters.filterFieldsMeta],
  );

  const register = useCallback(
    (
      key: string,
      options: ScatterViewSettingOptions,
      scatterData: ScatterVizDataState,
    ) => {
      setScatterOptionByKey((prev) => ({ ...prev, [key]: options }));
      setScatterDataByKey((prev) => ({ ...prev, [key]: scatterData }));
    },
    [],
  );

  const unregister = useCallback((key: string) => {
    setScatterOptionByKey((prev) => {
      const { [key]: _, ...rest } = prev;
      return rest;
    });
    setScatterDataByKey((prev) => {
      const { [key]: _, ...rest } = prev;
      return rest;
    });
  }, []);

  const [findActive, toggleFindActive] = useToggle(false);
  const [findQuery, setFindQuery] = useState<FindValue>({
    key: DEFAULT_FIND_KEY,
    value: '',
    exact: true,
  });
  const [hoverSampleId, setHoverSampleId] = useState<string | undefined>();

  const [dotColorViewSettingsMap, setDotColorViewSettingsMap] = useLocalStorage<
    Record<string, ColorRangeViewSettings>
  >(`${projectId}-color-by-view-settings`, {});

  const dotColorViewSettings = useMemo(() => {
    return viewSettings.dotColor?.value
      ? dotColorViewSettingsMap[viewSettings.dotColor.value] ||
          DEFAULT_COLOR_RANGE_VIEW_SETTINGS
      : DEFAULT_COLOR_RANGE_VIEW_SETTINGS;
  }, [dotColorViewSettingsMap, viewSettings.dotColor?.value]);

  const setDotColorViewSettings = useCallback(
    (newSettings: Partial<ColorRangeViewSettings>) => {
      const dotColor = viewSettings?.dotColor.value;
      if (!dotColor) return;
      setDotColorViewSettingsMap({
        ...dotColorViewSettingsMap,
        [dotColor]: {
          ...dotColorViewSettings,
          ...newSettings,
        },
      });
    },
    [
      dotColorViewSettingsMap,
      dotColorViewSettings,
      viewSettings?.dotColor.value,
      setDotColorViewSettingsMap,
    ],
  );

  const value: DashletScatterContextValue = useMergedObject({
    projectId,
    register,
    unregister,
    viewSettings,
    viewSettingsValues: valuesWithDefault,
    settings: {
      dashletName,
      projectionMetric,
      projectionMetricsFieldsNames,
      domainGapMetadata,
      domainGapMetadataFieldsNames,
      displayParams,
      balanceFieldsNames,
    },
    filters,
    findFields,
    findActive,
    toggleFindActive,
    findQuery,
    setFindQuery,
    hoverSampleId,
    setHoverSampleId,
    dotColorViewSettings,
    setDotColorViewSettings,
  });
  return (
    <DashletScatterContext.Provider value={value}>
      {children}
    </DashletScatterContext.Provider>
  );
}

export function useDashletScatterContext(): DashletScatterContextValue {
  return useContext(DashletScatterContext);
}
