import { useCallback, useMemo } from 'react';
import { DatasetSetup } from '@tensorleap/api-client';
import { Position } from '../../../core/position';
import { Plus } from '../../../ui/icons';
import {
  COMPONENT_DESCRIPTORS_MAP,
  isValidNodeName,
} from '../../interfaces/NodeDescriptor';
import { UI_COMPONENTS } from '../../../core/types/ui-components';
import { IMAGE_VISUALIZER_PROPERTIES } from '../../consts';
import { USER_UNIQUE_NAME } from '../../../layer-details/UserUniqueName';
import { isInputNode, isInputsNode } from '../../graph-calculation/utils';
import { generateValidateAssetsQuickFix } from './ValidateAssetsErrorData';
import { useNetworkMapContext } from '../../../core/NetworkMapContext';
import { ValidateAssetsStatus } from '../../interfaces/ValidateGraphStatus';
import { Unreachable } from '../../../core/Errors';
import {
  NetworkWizardData,
  QuickFixProps,
  GraphErrorKind,
  NetworkWizardCategory,
  NetworkWizardErrorSeverity,
} from '../types';
import { GraphNodeErrorType, NodeErrorMsg } from '../errors';
import { NetworkTabsEnum } from '../../NetworkDrawer';

const NODES_OFFSET = 300;

export interface AddLossParams {
  newNodeName: string;
  toGT: boolean;
  isCustomLoss: boolean;
}

function prepareAllLosses(
  datasetSetup: DatasetSetup | undefined
): Record<string, 'Loss' | 'CustomLoss'> {
  const uiCompTypes = UI_COMPONENTS.filter(
    ({ type }) => type === 'Loss'
  ).reduce<Record<string, 'Loss' | 'CustomLoss'>>((acc, { name }) => {
    acc[name] = 'Loss';
    return acc;
  }, {});

  const allLossTypes = (datasetSetup?.custom_losses || []).reduce<
    Record<string, 'Loss' | 'CustomLoss'>
  >((acc, { name }) => {
    acc[name] = 'CustomLoss';
    return acc;
  }, uiCompTypes);

  return allLossTypes;
}

export function useGraphNodeErrorData({
  title = 'INVALID NODE',
  msg,
  isValidateAssetsError,
  category,
  nodeId,
}: GraphNodeErrorType): NetworkWizardData[] {
  const {
    nodes,
    addNewNode,
    addNewConnection,
    changeNodeProperty,
    updateConnection,
    addPredictionLabel,
    getNewNodeId,
    selectNode,
    datasetSetup,
    validateAssetsStatus,
    setOpenNetworkTab,
    onFitNodeToScreen,
  } = useNetworkMapContext();

  const addVisualizer = useCallback(
    (nodeId: string) => {
      const dsNode = nodes.get(nodeId);
      if (!dsNode) {
        console.warn(`Node #${nodeId} does not exist`);
        return;
      }

      const [dsNodePosX, dsNodePosY] = dsNode.position;
      const visualizrPosition: Position = [
        dsNodePosX + NODES_OFFSET,
        dsNodePosY - NODES_OFFSET,
      ];

      const visualizerNodeId = getNewNodeId();
      addNewNode({ name: 'Visualizer', position: visualizrPosition });
      changeNodeProperty({
        nodeId: visualizerNodeId,
        nodeDataPropsToUpdate: {
          ...IMAGE_VISUALIZER_PROPERTIES,
          [USER_UNIQUE_NAME]: IMAGE_VISUALIZER_PROPERTIES.name,
        },
      });

      const datasetInputs = datasetSetup?.inputs;
      if (!datasetInputs || !datasetInputs.length) {
        console.error('Dataset node was not found or has no outputs');
        return;
      }

      const datasetNodeFirstInputName =
        dsNode.data['output_name'] || datasetInputs[0].name;
      addNewConnection({
        outputNodeId: nodeId,
        outputName: datasetNodeFirstInputName,
        inputNodeId: visualizerNodeId,
        inputName: 'data',
        isDynamicInput: false,
      });

      onFitNodeToScreen(nodeId);
      selectNode(visualizerNodeId);
    },
    [
      addNewConnection,
      addNewNode,
      changeNodeProperty,
      datasetSetup?.inputs,
      getNewNodeId,
      nodes,
      onFitNodeToScreen,
      selectNode,
    ]
  );

  const selectInput = useCallback(
    (nodeId: string, inputName?: string) => {
      if (!inputName) throw new Unreachable();
      changeNodeProperty({
        nodeId,
        nodeDataPropsToUpdate: { output_name: inputName },
      });
      updateConnection(nodeId, undefined, [inputName]);
    },
    [changeNodeProperty, updateConnection]
  );

  const addLoss = useCallback(
    (
      { newNodeName, toGT, isCustomLoss }: AddLossParams,
      nodeId: string
    ): void => {
      const layerNode = nodes.get(nodeId);
      if (!layerNode) {
        console.error(`Node #${nodeId} does not exist`);
        return;
      }

      const [layerNodePosX, layerNodePosY] = layerNode.position;
      const lossPosition: Position = [
        layerNodePosX + NODES_OFFSET,
        layerNodePosY - NODES_OFFSET,
      ];

      const lossId = getNewNodeId();
      if (isCustomLoss) {
        addNewNode({
          name: 'CustomLoss',
          position: lossPosition,
          subType: newNodeName,
        });
      } else {
        addNewNode({ name: newNodeName, position: lossPosition });
      }

      if (!isCustomLoss) {
        if (toGT) {
          const datasetNode = Array.from(nodes.values()).find(
            (n) => isInputNode(n) || isInputsNode(n)
          );
          const { outputs = [] } =
            datasetNode?.data.datasetVersion?.metadata?.setup || {};
          const outputName = outputs[0]?.name || '';

          addNewConnection({
            inputNodeId: lossId,
            inputName: 'ground_truth',
            outputNodeId: nodeId,
            outputName,
            isDynamicInput: false,
          });
        } else {
          const componentDescriptor = COMPONENT_DESCRIPTORS_MAP.get(
            layerNode.name
          );
          const outputName = componentDescriptor?.outputs_data.outputs[0]?.name;
          if (!outputName) {
            console.error(`Node #${layerNode.name} does not have output name`);
            return;
          }

          addNewConnection({
            inputNodeId: lossId,
            inputName: 'prediction',
            outputNodeId: nodeId,
            outputName,
            isDynamicInput: false,
          });
          addPredictionLabel(layerNode);
        }
      }

      onFitNodeToScreen(nodeId);
      selectNode(lossId);
    },
    [
      addNewConnection,
      addNewNode,
      addPredictionLabel,
      getNewNodeId,
      nodes,
      onFitNodeToScreen,
      selectNode,
    ]
  );
  const {
    validateAssetsButtonState,
    handleValidateAssetsClicked,
  } = useNetworkMapContext();

  const quickFix = useMemo((): QuickFixProps | undefined => {
    if (!nodeId) {
      return undefined;
    }
    if (
      !!datasetSetup?.inputs?.length &&
      msg === NodeErrorMsg.InputIsNotSelected
    ) {
      return {
        title: 'Inputs',
        selectOptions: datasetSetup.inputs.map(({ name }) => name) || [],
        onSelect: (value?: string) => {
          selectInput(nodeId, value);
        },
      };
    } else if (
      !!datasetSetup?.inputs?.length &&
      msg === NodeErrorMsg.InputHasntVisualizer
    ) {
      return {
        onSelect: () => addVisualizer(nodeId),
        title: 'Add',
        tooltipMsg: 'Add Visualizer',
        icon: <Plus className="h-5 w-5" />,
      };
    } else if (isValidateAssetsError) {
      return generateValidateAssetsQuickFix(
        validateAssetsButtonState,
        handleValidateAssetsClicked
      );
    } else if (msg === NodeErrorMsg.GTHasntLoss) {
      const allLossNodes = prepareAllLosses(datasetSetup);
      return {
        title: 'Losses',
        selectOptions: Object.keys(allLossNodes),
        onSelect: (value?: string) => {
          if (!value) {
            console.warn('The Loss type is missing');
            return;
          }

          const isCustomLoss = allLossNodes[value] === 'CustomLoss';
          if (isCustomLoss || isValidNodeName(value)) {
            addLoss(
              {
                newNodeName: value,
                toGT: true,
                isCustomLoss,
              },
              nodeId
            );
          } else {
            console.warn(`${value} is not a valid node name`);
          }
        },
      };
    } else if (msg === NodeErrorMsg.NoOutputNode) {
      const allLossNodes = prepareAllLosses(datasetSetup);
      return {
        title: 'Losses',
        selectOptions: Object.keys(allLossNodes),
        onSelect: (newNodeName?: string) => {
          if (!newNodeName) {
            console.warn('The Loss type is missing');
            return;
          }
          const isCustomLoss = allLossNodes[newNodeName] === 'CustomLoss';
          if (isCustomLoss || isValidNodeName(newNodeName)) {
            addLoss({ newNodeName, toGT: false, isCustomLoss }, nodeId);
          } else {
            console.warn(`"${newNodeName}" is not a valid node name`);
          }
        },
      };
    } else if (msg === NodeErrorMsg.NoIntegrationScript) {
      return {
        title: 'Select',
        tooltipMsg: 'Select script',
        icon: <Plus className="h-5 w-5" />,
        onSelect: () => {
          setOpenNetworkTab(NetworkTabsEnum.CodeIntegration);
        },
      };
    }
  }, [
    addLoss,
    addVisualizer,
    datasetSetup,
    handleValidateAssetsClicked,
    isValidateAssetsError,
    msg,
    nodeId,
    selectInput,
    setOpenNetworkTab,
    validateAssetsButtonState,
  ]);

  const isValidateAssetsErrorAndCalculating =
    !!isValidateAssetsError &&
    validateAssetsStatus === ValidateAssetsStatus.Calculating;

  const calculateKey = useCallback(() => msg + nodeId, [msg, nodeId]);

  return useMemo(() => {
    if (!nodeId) {
      return [];
    }
    return [
      {
        errorType: GraphErrorKind.node,
        category: category || NetworkWizardCategory.MODEL,
        showNodeFooter: true,
        title: title,
        message: msg,
        calculateKey,
        showNode: () => {
          if (nodeId) {
            onFitNodeToScreen(nodeId);
            selectNode(nodeId);
          }
        },
        quickFixes: quickFix ? [quickFix] : [],
        errorSeverity: NetworkWizardErrorSeverity.ERROR,
        isLoading: isValidateAssetsErrorAndCalculating,
        key: calculateKey(),
      },
    ];
  }, [
    nodeId,
    category,
    title,
    msg,
    calculateKey,
    quickFix,
    isValidateAssetsErrorAndCalculating,
    onFitNodeToScreen,
    selectNode,
  ]);
}
