import {
  useState,
  createContext,
  FC,
  useEffect,
  useCallback,
  useContext,
  useMemo,
} from 'react';
import {
  Session,
  SessionRunData,
  SessionWeightData,
  SlimVersion,
  StatusEnum,
} from '@tensorleap/api-client';
import api from '../core/api-client';
import { useCurrentProject } from './CurrentProjectContext';
import { useMergedObject } from './useMergedObject';
import { useFetchVersionControl } from './data-fetching/version-control';
import { useMapIdToKey } from './useUniqueKeyManager';
import { groupBy } from 'lodash';
import { KeyedMutator } from 'swr';
import { useDebounce } from './useDebounce';
import {
  SELECTED_SESSION_RUN_KEY,
  changeQueryParams,
  getArrayValueFromQueryParams,
} from '../url/url-builder';
import { useHistory, useLocation } from 'react-router-dom';
import { useLocalStorage } from './useLocalStorage';
import { DEFAULT_SESSION_RUN } from '../ui/ProjectPageLoader';
import { Setter } from './types';

export type SessionToTrainCtx = { session: Session; version: SlimVersion };

export interface epochsState {
  epochs: number[];
  selectEpoch: (epoch: number) => void;
  selectedEpoch: number;
}

export type SessionRunDataAndEpoch = SessionRunData & {
  epochsState: epochsState;
};

export interface VersionControlContextInterface {
  versions: SlimVersion[];
  fetchVersions: () => Promise<void>;
  deleteSessionRun: (sessionRunId: string) => Promise<void>;
  deleteSession: (sessionId: string) => Promise<void>;
  deleteVersion: (versionId: string) => Promise<void>;
  sessionToVersionsMap: Map<string, SlimVersion>;
  sessionRunsToVersionsMap: Map<string, SlimVersion>;
  sessionRunsToSessionMap: Map<string, Session>;
  selectedSessionRunMap: Map<string, SessionRunDataAndEpoch>;
  getSessionRunEpochs: (
    sessionRunId: SessionRunData['cid'],
    finishedOnly?: boolean
  ) => number[];
  getSessionRunWeightData: (
    sessionRunId: SessionRunData['cid'],
    finishedOnly?: boolean
  ) => SessionWeightData[];
  toggleSelectSessionRun: (sessionRunId: string) => void;
  isSessionRunSelected: (sessionRunId?: string) => boolean;
  removeSelectedSessionRuns: (sessionRunId: string[]) => void;
  toggleSelectedSessionRunVisibility: (sessionRunId: string) => void;
  isSessionRunVisible: (sessionRunId: string) => boolean;
  getSelectedSessionRunUniqueKey: (sessionRunId: string) => number;
  isLoading: boolean;
  error?: Error;
  refetch: KeyedMutator<SlimVersion[]>;
  sessionToShowItsEpochs?: Session;
  setSessionToShowItsEpochs: Setter<Session | undefined>;
  handleConfirmDeleteModel: () => void;
  sessionToExport: Session | null;
  setSessionToExport: Setter<Session | null>;
  sessionToTrain: SessionToTrainCtx | null;
  setSessionToTrain: Setter<SessionToTrainCtx | null>;
  sessionToDelete: Session | null;
  setSessionToDelete: Setter<Session | null>;
}

export const VersionControlContext = createContext<VersionControlContextInterface>(
  {
    versions: [],
    fetchVersions: () => Promise.reject(Error('Unimplemented')),
    deleteSessionRun: () => Promise.reject(),
    deleteSession: () => Promise.reject(),
    deleteVersion: () => Promise.reject(),
    selectedSessionRunMap: new Map(),
    sessionToVersionsMap: new Map(),
    sessionRunsToVersionsMap: new Map(),
    sessionRunsToSessionMap: new Map(),
    getSessionRunEpochs: () => [0, 0],
    getSessionRunWeightData: () => [],
    toggleSelectSessionRun: () => Promise.reject(),
    isSessionRunSelected: () => false,
    removeSelectedSessionRuns: () => undefined,
    toggleSelectedSessionRunVisibility: () => undefined,
    isSessionRunVisible: () => false,
    getSelectedSessionRunUniqueKey: () => 0,
    isLoading: false,
    refetch: () => Promise.reject(),
    sessionToShowItsEpochs: undefined,
    setSessionToShowItsEpochs: () => undefined,
    handleConfirmDeleteModel: () => undefined,
    sessionToExport: null,
    setSessionToExport: () => undefined,
    sessionToTrain: null,
    setSessionToTrain: () => undefined,
    sessionToDelete: null,
    setSessionToDelete: () => undefined,
  }
);

export const VersionControlProvider: FC = ({ children }) => {
  const [notVisibleSessionRunsMap, setNotVisibleSessionRunsMap] = useState(
    new Set<string>()
  );
  const { currentProjectId } = useCurrentProject();

  if (!currentProjectId) {
    throw new Error('currentProjectId is not defined');
  }

  const {
    versions,
    refetch: _refetch,
    isLoading: isLoadingVersions,
    error,
  } = useFetchVersionControl(currentProjectId.toString());
  const refetch = useDebounce(_refetch, 1000);
  const fetchVersions = useCallback(async () => {
    await refetch();
  }, [refetch]);

  const { search, pathname } = useLocation();
  const history = useHistory();

  const [_, setSelectedSessionRunIdsFromLocalStorage] = useLocalStorage<
    string[]
  >(calcDefaultSessionRunsFromLocalStorageKey(currentProjectId), []);

  const [selectedSessionRunIds, setSelectedSessionRunIds] = useState(
    new Set<string>([])
  );

  useEffect(() => {
    const selectedSessionRunIdsFromUrl = getArrayValueFromQueryParams(
      search,
      SELECTED_SESSION_RUN_KEY
    );

    setSelectedSessionRunIdsFromLocalStorage(selectedSessionRunIdsFromUrl);
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [
    search,
    // setSelectedSessionRunIdsFromLocalStorage, //Do not remove this comment, the setter is triggering it self and causing infinite loop. we need to fix this bug on useLocalStorage hook first
  ]);

  const {
    get: getColorKeyById,
    updateCurrentIds: updateColorKeys,
  } = useMapIdToKey();

  useEffect(() => {
    if (isLoadingVersions) return;

    const selectedSessionRunIdsFromUrl = getArrayValueFromQueryParams(
      search,
      SELECTED_SESSION_RUN_KEY
    );

    if (selectedSessionRunIdsFromUrl.length === selectedSessionRunIds.size) {
      return;
    }

    const allValidSessionRunIds = new Set(
      versions.flatMap(({ sessions }) =>
        sessions.flatMap(
          ({ sessionRuns }) =>
            sessionRuns
              ?.filter((sr) => {
                const isDisabled = !sr?.weightAssets.length;
                return !isDisabled;
              })
              ?.map(({ cid }) => cid) ?? []
        )
      )
    );

    if (
      selectedSessionRunIdsFromUrl.length === 1 &&
      selectedSessionRunIdsFromUrl[0] === DEFAULT_SESSION_RUN
    ) {
      history.replace({
        pathname,
        search: changeQueryParams(
          search,
          SELECTED_SESSION_RUN_KEY,
          allValidSessionRunIds.size > 0
            ? [allValidSessionRunIds.values().next().value!]
            : []
        ),
      });
      return;
    }

    const missingSessionRunIds = selectedSessionRunIdsFromUrl.filter(
      (id) => !allValidSessionRunIds.has(id)
    );

    if (missingSessionRunIds.length > 0) {
      history.push({
        search: changeQueryParams(
          search,
          SELECTED_SESSION_RUN_KEY,
          selectedSessionRunIdsFromUrl.filter(
            (id) => !missingSessionRunIds.includes(id)
          )
        ),
      });
      return;
    }

    updateColorKeys(selectedSessionRunIdsFromUrl);
    setSelectedSessionRunIds(new Set(selectedSessionRunIdsFromUrl));
  }, [
    history,
    isLoadingVersions,
    pathname,
    search,
    selectedSessionRunIds,
    updateColorKeys,
    versions,
  ]);

  const sessionRunById = useMemo(
    () =>
      new Map<string, SessionRunData>(
        versions.flatMap(({ sessions }) =>
          sessions.flatMap(
            ({ sessionRuns }) =>
              sessionRuns?.map((sessionRun) => [sessionRun?.cid, sessionRun]) ??
              []
          )
        )
      ),
    [versions]
  );

  const isSessionRunSelected = useCallback(
    (sessionRunId?: string) =>
      sessionRunId !== undefined && selectedSessionRunIds.has(sessionRunId),
    [selectedSessionRunIds]
  );

  const isSessionRunVisible = useCallback(
    (sessionRunId: string) => !notVisibleSessionRunsMap.has(sessionRunId),
    [notVisibleSessionRunsMap]
  );

  const getSelectedSessionRunUniqueKey = useCallback(
    (sessionRunId: string) => getColorKeyById(sessionRunId) || 0,
    [getColorKeyById]
  );

  const getSessionRunWeightData = useCallback(
    (
      sessionRunId: SessionRunData['cid'],
      finishedOnly = false
    ): SessionWeightData[] => {
      const [sessionRun] = versions
        .flatMap(({ sessions }) =>
          sessions.flatMap(({ sessionRuns }) =>
            sessionRuns?.find(({ cid }) => cid === sessionRunId)
          )
        )
        .filter(
          (sessionRunData): sessionRunData is SessionRunData =>
            sessionRunData !== undefined
        );
      const sessionRunWeightIds = new Set(
        sessionRun?.weightAssets
          ?.filter(({ esMetricIndex }) => !!esMetricIndex)
          .map(({ sessionWeightId }) => sessionWeightId) || []
      );

      if (!sessionRun || sessionRunWeightIds?.size === 0) {
        return [];
      }
      const sessionWeights = versions
        .flatMap(({ sessions }) =>
          sessions.flatMap(({ sessionWeights }) =>
            sessionWeights?.filter(
              ({ cid, status }) =>
                sessionRunWeightIds.has(cid) &&
                (!finishedOnly || status === StatusEnum.Finished)
            )
          )
        )
        .filter((weight): weight is SessionWeightData => !!weight);

      // currently sessionWeights are sorted as a log
      const unifyWeights = Object.values(
        groupBy(sessionWeights, ({ epoch }) => epoch)
      ).map(
        (sessionWeights) =>
          sessionWeights.find(({ status }) => status !== StatusEnum.Started) ??
          sessionWeights[0]
      );
      return unifyWeights;
    },
    [versions]
  );

  const getSessionRunEpochs = useCallback(
    (sessionRunId: SessionRunData['cid'], finishedOnly = false): number[] => {
      const sessionWeights = getSessionRunWeightData(
        sessionRunId,
        finishedOnly
      );

      const epochs = Array.from(
        new Set(sessionWeights.flatMap(({ epoch }) => epoch))
      ).sort((a, b) => b - a);

      return epochs;
    },
    [getSessionRunWeightData]
  );

  const [sessionRunToSelectedEpoch, setSessionRunToSelectedEpoch] = useState<
    Map<string, number>
  >(new Map());

  const selectEpoch = useCallback(
    (sessionRunId: string, epoch: number) => {
      setSessionRunToSelectedEpoch((current) => {
        const updated = new Map(current);
        updated.set(sessionRunId, epoch);
        return updated;
      });
    },
    [setSessionRunToSelectedEpoch]
  );

  const selectedSessionRunMap = useMemo(
    () =>
      Array.from(selectedSessionRunIds).reduce((ret, sessionRunId) => {
        const sessionRun = sessionRunById.get(sessionRunId);
        if (sessionRun) {
          const epochs = getSessionRunEpochs(sessionRunId);
          const savedEpoch = sessionRunToSelectedEpoch.get(sessionRunId);
          const selectedEpoch =
            savedEpoch !== undefined && epochs.includes(savedEpoch)
              ? savedEpoch
              : epochs[0];

          const selectSessionRunEpoch = (epoch: number) => {
            selectEpoch(sessionRunId, epoch);
          };

          ret.set(sessionRunId, {
            ...sessionRun,
            epochsState: {
              selectedEpoch,
              epochs: epochs,
              selectEpoch: selectSessionRunEpoch,
            },
          });
        }
        return ret;
      }, new Map<string, SessionRunDataAndEpoch>()),
    [
      getSessionRunEpochs,
      selectEpoch,
      selectedSessionRunIds,
      sessionRunById,
      sessionRunToSelectedEpoch,
    ]
  );

  const removeSelectedSessionRuns = useCallback<
    VersionControlContextInterface['removeSelectedSessionRuns']
  >(
    (sessionRunIds) => {
      history.push({
        search: changeQueryParams(
          search,
          SELECTED_SESSION_RUN_KEY,
          Array.from(selectedSessionRunIds).filter(
            (id) => !sessionRunIds.includes(id)
          )
        ),
      });
    },
    [history, search, selectedSessionRunIds]
  );

  const toggleSelectedSessionRunVisibility = useCallback(
    (sessionRunId: string) =>
      setNotVisibleSessionRunsMap((currentNotVisible) => {
        const updated = new Set(Array.from(currentNotVisible));
        updated.has(sessionRunId)
          ? updated.delete(sessionRunId)
          : updated.add(sessionRunId);
        return updated;
      }),
    []
  );

  const deleteSessionRun = useCallback(
    async (sessionRunId: string) => {
      if (!currentProjectId) {
        return;
      }

      await api.deleteSessionRun({
        sessionRunId,
        projectId: currentProjectId,
      });

      removeSelectedSessionRuns([sessionRunId]);
      refetch();
    },
    [currentProjectId, refetch, removeSelectedSessionRuns]
  );

  const deleteSession = useCallback(
    async (sessionId: string) => {
      try {
        if (!currentProjectId) {
          console.error(
            "Somehow tried to delete session while project wasn't loaded"
          );
          return;
        }
        await api.deleteSession({ sessionId, projectId: currentProjectId });
        const filteredSessionRuns = (versions
          .flatMap(({ sessions }) => sessions)
          .find(({ cid }) => cid === sessionId)
          ?.sessionRuns?.filter((sessionRun) => sessionRun !== undefined)
          .map(({ cid }) => cid) || []) as string[];
        removeSelectedSessionRuns(filteredSessionRuns);
        refetch();
      } catch (error) {
        console.error(error);
      }
    },
    [currentProjectId, refetch, removeSelectedSessionRuns, versions]
  );

  const deleteVersion = useCallback(
    async (versionId: string) => {
      try {
        const version = versions.find(({ cid }) => cid === versionId);
        if (!version || !currentProjectId) {
          console.error("shouldn't happen");
          return;
        }
        await api.deleteVersion({ versionId, projectId: currentProjectId });
        const filteredSessionRuns = (versions
          .find(({ cid }) => cid === versionId)
          ?.sessions.flatMap(({ sessionRuns }) => sessionRuns)
          ?.flatMap((sessionRun) => sessionRun?.cid)
          .filter((cid) => cid !== undefined) || []) as string[];

        removeSelectedSessionRuns(filteredSessionRuns);
        refetch();
      } catch (error) {
        console.error(error);
      }
    },
    [currentProjectId, refetch, removeSelectedSessionRuns, versions]
  );

  const toggleSelectSessionRun = useCallback(
    (sessionRunId: string) => {
      const newSearch = changeQueryParams(
        search,
        SELECTED_SESSION_RUN_KEY,
        selectedSessionRunIds.has(sessionRunId)
          ? Array.from(selectedSessionRunIds).filter(
              (id) => id !== sessionRunId
            )
          : [...Array.from(selectedSessionRunIds), sessionRunId]
      );

      history.push({
        search: newSearch,
      });
    },
    [search, selectedSessionRunIds, history]
  );

  const sessionToVersionsMap = useMemo(
    () => getSessionToVersionsMap(versions),
    [versions]
  );

  const sessionRunsToVersionsMap = useMemo(
    () => getSessionRunToVersionsMap(versions),
    [versions]
  );

  const sessionRunsToSessionMap = useMemo(
    () => getSessionRunToSessionMap(versions),
    [versions]
  );

  const [
    sessionToShowItsEpochs,
    setSessionToShowItsEpochs,
  ] = useState<Session>();

  const [sessionToExport, setSessionToExport] = useState<Session | null>(null);

  const [
    sessionToTrain,
    setSessionToTrain,
  ] = useState<SessionToTrainCtx | null>(null);

  const [sessionToDelete, setSessionToDelete] = useState<Session | null>(null);

  const handleConfirmDeleteModel = useCallback(async () => {
    try {
      if (sessionToDelete) {
        await deleteSession(sessionToDelete.cid);
      }
    } catch (err) {
      console.error(err);
    }
    setSessionToDelete(null);
  }, [deleteSession, sessionToDelete]);

  const value = useMergedObject({
    versions,
    selectedSessionRunMap,
    toggleSelectSessionRun,
    deleteSessionRun,
    deleteSession,
    deleteVersion,
    isSessionRunSelected,
    isSessionRunVisible,
    removeSelectedSessionRuns,
    toggleSelectedSessionRunVisibility,
    getSelectedSessionRunUniqueKey,
    sessionToVersionsMap,
    sessionRunsToVersionsMap,
    sessionRunsToSessionMap,
    getSessionRunEpochs,
    getSessionRunWeightData,
    fetchVersions,
    isLoading: isLoadingVersions,
    error,
    refetch,
    sessionToShowItsEpochs,
    setSessionToShowItsEpochs,
    handleConfirmDeleteModel,
    sessionToExport,
    setSessionToExport,
    sessionToTrain,
    setSessionToTrain,
    sessionToDelete,
    setSessionToDelete,
  });

  return (
    <VersionControlContext.Provider value={value}>
      {children}
    </VersionControlContext.Provider>
  );
};

export const useVersionControl = (): VersionControlContextInterface =>
  useContext(VersionControlContext);

export function getSessionToVersionsMap(
  versions: SlimVersion[]
): Map<string, SlimVersion> {
  return versions.reduce((ret, version) => {
    version.sessions.forEach((session) => {
      if (session?.cid) {
        ret.set(session.cid, version);
      }
    });
    return ret;
  }, new Map<string, SlimVersion>());
}
export function getSessionRunToVersionsMap(
  versions: SlimVersion[]
): Map<string, SlimVersion> {
  return versions.reduce((ret, version) => {
    version.sessions
      .flatMap(({ sessionRuns }) => sessionRuns)
      .forEach((sessionRun) => {
        if (sessionRun?.cid) {
          ret.set(sessionRun.cid, version);
        }
      });
    return ret;
  }, new Map<string, SlimVersion>());
}

export function getSessionRunToSessionMap(
  versions: SlimVersion[]
): Map<string, Session> {
  return versions
    .flatMap(({ sessions }) => sessions)
    .reduce((ret, session) => {
      session?.sessionRuns?.forEach((sessionRun) => {
        if (sessionRun?.cid) {
          ret.set(sessionRun.cid, session);
        }
      });
      return ret;
    }, new Map<string, Session>());
}

const DEFAULT_SESSION_RUNS_FROM_LOCAL_STORAGE = 'defaultSessionRuns';

export function calcDefaultSessionRunsFromLocalStorageKey(
  currentProjectId: string
): string {
  return `${DEFAULT_SESSION_RUNS_FROM_LOCAL_STORAGE}_${currentProjectId}`;
}
