import { groupBy } from 'lodash';
import { hasNodeType } from '../../descriptor/utils';
import {
  isLossNode,
  isOptimizerNode,
  isInputNode,
  isInputsNode,
  isUnselectedInputNode,
  isRequiredUniqeNameNode,
  isLayerNode,
  isVisualizerNode,
  isCustomMetricNode,
  isKindOfLossNode,
  isUnselectedCustomLayerNode,
} from '../../graph-calculation/utils';
import { useMemo } from 'react';
import { useNetworkMapContext } from '../../../core/NetworkMapContext';
import { numberToEnglishWord } from '../../helper';
import { COMPONENT_DESCRIPTORS_MAP } from '../../interfaces';
import { ValidateAssetsStatus } from '../../interfaces/ValidateGraphStatus';
import { getInputsData } from '../../utils';
import { useGraphErrorData } from '../cards/GraphErrorData';
import {
  NodeMessageData,
  useGraphNodeDictErrorData,
} from '../cards/GraphNodeDictErrorData';
import { useGraphNodeErrorData } from '../cards/GraphNodeErrorData';
import {
  NodesMessageData,
  useGraphNodesErrorData,
} from '../cards/GraphNodesErrorData';
import { GetValidateAssetsErrorData } from '../cards/ValidateAssetsErrorData';
import {
  NetworkWizardData,
  GraphErrorKind,
  NetworkWizardCategory,
} from '../types';
import {
  NodeErrorMsg,
  GraphErrorMsg,
  ValidateAssetsErrorType,
} from '../errors';
import { Node } from '@tensorleap/api-client';
import { useCurrentProject } from '../../../core/CurrentProjectContext';

export function useGraphErrors(): NetworkWizardData[] {
  const noSelectedDatasetAlerts = useGenerateNoSelectedDatasetAlerts();
  const generalErrorAlert = useGenerateGeneralErrorAlert();
  const noOutputNodeAlerts = useGenerateNoOutputNodeAlerts();
  const noInputNodeAlert = useGenerateNoInputNodeAlert();
  const unselectedInputNodeAlerts = useGenerateUnselectedInputNodeAlerts();
  const inputNotConnectedToVisualizerAlerts = useGenerateInputNotConnectedToVisualizerAlerts();
  const manyOptimizersAlerts = useGenerateManyOptimizersAlerts();
  const invalidDatasetAlerts = useGenerateInvalidDatasetAlerts();
  const unselectedCustomLayesAlerts = useGenerateUnselectedCustomLayesAlerts();
  const disconnectedInputsAlerts = useGenerateDisconnectedInputsAlerts();
  const nodesWithSameNameAlerts = useGenerateNodesWithSameNameAlerts();
  const nodesWithoutNameAlerts = useGenerateNodesWithoutNameAlerts();
  const validateAssetAlerts = useGenerateValidateAssetAlerts();
  const errorsFromShapes = useGenerateErrorsFromShapes();

  return useMemo(
    () =>
      [
        noSelectedDatasetAlerts,
        generalErrorAlert,
        noOutputNodeAlerts,
        noInputNodeAlert,
        unselectedInputNodeAlerts,
        inputNotConnectedToVisualizerAlerts,
        manyOptimizersAlerts,
        invalidDatasetAlerts,
        unselectedCustomLayesAlerts,
        disconnectedInputsAlerts,
        nodesWithSameNameAlerts,
        nodesWithoutNameAlerts,
        validateAssetAlerts,
        errorsFromShapes,
      ].flatMap((x) => x),
    [
      disconnectedInputsAlerts,
      errorsFromShapes,
      generalErrorAlert,
      inputNotConnectedToVisualizerAlerts,
      invalidDatasetAlerts,
      unselectedCustomLayesAlerts,
      manyOptimizersAlerts,
      noInputNodeAlert,
      noOutputNodeAlerts,
      noSelectedDatasetAlerts,
      nodesWithSameNameAlerts,
      nodesWithoutNameAlerts,
      unselectedInputNodeAlerts,
      validateAssetAlerts,
    ]
  );
}

function useGenerateNoSelectedDatasetAlerts(): NetworkWizardData[] {
  const { nodes } = useNetworkMapContext();
  const { selectedCodeIntegrationVersion } = useCurrentProject();

  const inputsNodes = Array.from(nodes.values()).filter(
    (n) => isInputNode(n) || isInputsNode(n)
  );
  const graphNodeErrorData = useGraphNodeErrorData({
    type: GraphErrorKind.node,
    msg: NodeErrorMsg.NoIntegrationScript,
    category: NetworkWizardCategory.CODE,
    nodeId: inputsNodes?.[0]?.id,
  });
  const graphErrorData = useGraphErrorData({
    type: GraphErrorKind.graph,
    msg: GraphErrorMsg.NoInputNode,
    category: NetworkWizardCategory.INPUTS,
  });

  return useMemo((): NetworkWizardData[] => {
    if (!nodes.size || selectedCodeIntegrationVersion) {
      return [];
    }

    return inputsNodes.length === 0 ? graphErrorData : graphNodeErrorData;
  }, [
    graphErrorData,
    graphNodeErrorData,
    inputsNodes.length,
    nodes.size,
    selectedCodeIntegrationVersion,
  ]);
}

function useGenerateGeneralErrorAlert(): NetworkWizardData[] {
  const { nodes, networkContextGeneralError } = useNetworkMapContext();
  const { selectedCodeIntegrationVersion } = useCurrentProject();
  const graphErrorData = useGraphErrorData({
    type: GraphErrorKind.graph,
    category: NetworkWizardCategory.MODEL,
    msg: networkContextGeneralError || '',
  });

  return useMemo((): NetworkWizardData[] => {
    return networkContextGeneralError &&
      nodes.size &&
      selectedCodeIntegrationVersion
      ? graphErrorData
      : [];
  }, [
    graphErrorData,
    networkContextGeneralError,
    nodes.size,
    selectedCodeIntegrationVersion,
  ]);
}

function useGenerateNoOutputNodeAlerts(): NetworkWizardData[] {
  const { nodes, connections } = useNetworkMapContext();
  const { selectedCodeIntegrationVersion } = useCurrentProject();

  const layersWithoutLayerOrLoss = Array.from(nodes.values()).filter(
    (node) =>
      isLayerNode(node) &&
      !node.data?.output_blocks &&
      !hasNodeType({
        node,
        nodes,
        connectionsByOutputId: groupBy(connections, 'outputNodeId'),
        predicate: (n) => isLayerNode(n) || isLossNode(n),
      })
  );

  const nodesMessageData = layersWithoutLayerOrLoss.map(
    ({ id }): NodeMessageData => ({
      nodeId: id,
      msg: NodeErrorMsg.NoOutputNode,
    })
  );
  const graphNodeErrorData = useGraphNodeDictErrorData({
    type: GraphErrorKind.node,
    msg: NodeErrorMsg.NoOutputNode,
    category: NetworkWizardCategory.LOSS,
    nodesMessageData,
  });

  return useMemo((): NetworkWizardData[] => {
    return !nodes.size || !selectedCodeIntegrationVersion
      ? []
      : graphNodeErrorData;
  }, [graphNodeErrorData, nodes, selectedCodeIntegrationVersion]);
}

function useGenerateNoInputNodeAlert(): NetworkWizardData[] {
  const { nodes } = useNetworkMapContext();
  const { selectedCodeIntegrationVersion } = useCurrentProject();

  const nodeList = Array.from(nodes.values());
  const inputsNodes = nodeList.filter((n) => isInputNode(n) || isInputsNode(n));
  const graphErrorData = useGraphErrorData({
    type: GraphErrorKind.graph,
    msg: GraphErrorMsg.NoInputNode,
    category: NetworkWizardCategory.INPUTS,
  });

  return useMemo((): NetworkWizardData[] => {
    return !nodes.size || !selectedCodeIntegrationVersion || inputsNodes.length
      ? []
      : graphErrorData;
  }, [
    graphErrorData,
    inputsNodes.length,
    nodes.size,
    selectedCodeIntegrationVersion,
  ]);
}

function useGenerateUnselectedInputNodeAlerts(): NetworkWizardData[] {
  const { nodes } = useNetworkMapContext();
  const { selectedCodeIntegrationVersion } = useCurrentProject();

  const nodeList = Array.from(nodes.values());
  const unselectedInputsNodes = nodeList.filter((node) =>
    isUnselectedInputNode(node, selectedCodeIntegrationVersion)
  );

  const nodesMessageData = unselectedInputsNodes.map(
    ({ id }): NodeMessageData => ({
      nodeId: id,
      msg: NodeErrorMsg.InputIsNotSelected,
    })
  );
  const graphNodeErrorData = useGraphNodeDictErrorData({
    type: GraphErrorKind.node,
    msg: NodeErrorMsg.InputIsNotSelected,
    category: NetworkWizardCategory.INPUTS,
    nodesMessageData,
  });

  return useMemo((): NetworkWizardData[] => {
    return !nodes.size ||
      !selectedCodeIntegrationVersion ||
      !unselectedInputsNodes.length
      ? []
      : graphNodeErrorData;
  }, [
    graphNodeErrorData,
    nodes.size,
    selectedCodeIntegrationVersion,
    unselectedInputsNodes.length,
  ]);
}

function useGenerateInputNotConnectedToVisualizerAlerts(): NetworkWizardData[] {
  const { nodes, connections } = useNetworkMapContext();
  const { selectedCodeIntegrationVersion } = useCurrentProject();
  const nodeList = Array.from(nodes.values());
  const inputsNodes = nodeList.filter((n) => isInputNode(n) || isInputsNode(n));
  const unselectedInputsNodes = inputsNodes.filter((node) =>
    isUnselectedInputNode(node, selectedCodeIntegrationVersion)
  );
  const connectionsByOutputId = groupBy(connections, 'outputNodeId');
  const nodesMessageData = nodeList
    .filter(
      (node) =>
        (isInputNode(node) || isInputsNode(node)) &&
        !hasNodeType({
          node,
          nodes,
          connectionsByOutputId,
          predicate: isVisualizerNode,
        })
    )
    .map(
      ({ id }): NodeMessageData => ({
        nodeId: id,
        msg: NodeErrorMsg.InputHasntVisualizer,
      })
    );

  const graphNodeErrorData = useGraphNodeDictErrorData({
    type: GraphErrorKind.node,
    msg: NodeErrorMsg.InputHasntVisualizer,
    category: NetworkWizardCategory.VIS,
    nodesMessageData,
  });

  return useMemo((): NetworkWizardData[] => {
    return !nodes.size ||
      !selectedCodeIntegrationVersion ||
      !inputsNodes.length ||
      unselectedInputsNodes.length
      ? []
      : graphNodeErrorData;
  }, [
    graphNodeErrorData,
    inputsNodes.length,
    nodes.size,
    selectedCodeIntegrationVersion,
    unselectedInputsNodes.length,
  ]);
}

function useGenerateManyOptimizersAlerts(): NetworkWizardData[] {
  const { nodes } = useNetworkMapContext();
  const { selectedCodeIntegrationVersion } = useCurrentProject();

  const nodeList = Array.from(nodes.values());
  const optimizersNodesList = nodeList.filter(isOptimizerNode);
  const graphErrorData = useGraphErrorData({
    type: GraphErrorKind.graph,
    msg: GraphErrorMsg.ManyOptimizers,
    category: NetworkWizardCategory.LOSS,
  });

  return useMemo((): NetworkWizardData[] => {
    if (
      !nodes.size ||
      !selectedCodeIntegrationVersion ||
      optimizersNodesList.length < 2
    ) {
      return [];
    }
    return graphErrorData;
  }, [
    graphErrorData,
    nodes.size,
    optimizersNodesList.length,
    selectedCodeIntegrationVersion,
  ]);
}

function useGenerateUnselectedCustomLayesAlerts(): NetworkWizardData[] {
  const { nodes } = useNetworkMapContext();

  const nodeList = Array.from(nodes.values());
  const unselectedCustomLayerNodes = nodeList.filter((node) =>
    isUnselectedCustomLayerNode(node)
  );

  const nodesMessageData = unselectedCustomLayerNodes.map(
    ({ id }): NodeMessageData => ({
      nodeId: id,
      msg: NodeErrorMsg.CustomLayerIsNotSelected,
    })
  );
  const graphNodeErrorData = useGraphNodeDictErrorData({
    type: GraphErrorKind.node,
    msg: NodeErrorMsg.CustomLayerIsNotSelected,
    category: NetworkWizardCategory.MODEL,
    nodesMessageData,
  });

  return useMemo((): NetworkWizardData[] => {
    return !nodes.size || !unselectedCustomLayerNodes.length
      ? []
      : graphNodeErrorData;
  }, [graphNodeErrorData, nodes.size, unselectedCustomLayerNodes.length]);
}

function useGenerateInvalidDatasetAlerts(): NetworkWizardData[] {
  const { nodes } = useNetworkMapContext();
  const { selectedCodeIntegrationVersion } = useCurrentProject();

  const graphErrorData = useGraphErrorData({
    type: GraphErrorKind.graph,
    msg: GraphErrorMsg.InvalidDataset,
    category: NetworkWizardCategory.CODE,
  });

  return useMemo((): NetworkWizardData[] => {
    if (!nodes.size || !selectedCodeIntegrationVersion) {
      return [];
    }

    if (!selectedCodeIntegrationVersion?.metadata.setup) {
      return graphErrorData;
    }
    return [];
  }, [graphErrorData, nodes.size, selectedCodeIntegrationVersion]);
}

function useGenerateDisconnectedInputsAlerts(): NetworkWizardData[] {
  const { nodes, connections } = useNetworkMapContext();
  const { selectedCodeIntegrationVersion } = useCurrentProject();

  const nodeList = Array.from(nodes.values());

  const connectionsByInputId = groupBy(connections, 'inputNodeId');

  const nodesMessageData = nodeList.reduce((acc, node) => {
    const nodeDescriptor = COMPONENT_DESCRIPTORS_MAP.get(node.name);
    if (!nodeDescriptor) {
      return acc;
    }

    const { inputsData, isDynamicInput } = getInputsData(node, nodeDescriptor);

    if (isDynamicInput) {
      return acc;
    }

    const nodeConnections = connectionsByInputId[node.id];
    const connectionsNames = new Set(
      nodeConnections?.map(({ inputName }) => inputName)
    );

    const disconnectedInput = inputsData
      .filter(({ name }) => !connectionsNames?.has(name))
      .map(({ name }) => name);

    if (disconnectedInput.length) {
      const msg = `${node.name} has the following missing input${
        disconnectedInput.length > 1 ? 's' : ''
      }: ${disconnectedInput.join(', ')}`;
      const category = getNodeCategory(node);
      acc.push({
        nodeId: node.id,
        msg,
        category,
      });
    }
    return acc;
  }, [] as NodeMessageData[]);
  const graphNodeErrorData = useGraphNodeDictErrorData({
    type: GraphErrorKind.node,
    msg: '',
    category: NetworkWizardCategory.MODEL,
    nodesMessageData,
  });

  return useMemo((): NetworkWizardData[] => {
    return !nodes.size || !selectedCodeIntegrationVersion
      ? []
      : graphNodeErrorData;
  }, [graphNodeErrorData, nodes.size, selectedCodeIntegrationVersion]);
}

function useGenerateNodesWithoutNameAlerts(): NetworkWizardData[] {
  const { nodes } = useNetworkMapContext();

  const nodeList = Array.from(nodes.values());

  const nodesWithdoutName = nodeList
    .filter(isRequiredUniqeNameNode)
    .filter(
      (node) =>
        node.data.user_unique_name === undefined ||
        node.data.user_unique_name === ''
    );

  const nodesWithdoutNameGroupedByType = groupBy(
    nodesWithdoutName,
    (node) => node.data.type
  );

  const nodeMessageData: NodeMessageData[] = Object.entries(
    nodesWithdoutNameGroupedByType
  )
    .filter(([_, sameTypeNodes]) => sameTypeNodes.length === 1)
    .map(([type, sameTypeNodes]) => {
      return {
        msg: `${type} node has no name. Assign a name to it to avoid conflicts`,
        nodeId: sameTypeNodes[0].id,
        category: getNodeCategory(sameTypeNodes[0]),
      };
    });

  const graphNodeErrorData = useGraphNodeDictErrorData({
    title: 'Node name is missing',
    type: GraphErrorKind.node,
    msg: NodeErrorMsg.InputIsNotSelected,
    category: NetworkWizardCategory.INPUTS,
    nodesMessageData: nodeMessageData,
  });

  const nodesMessageData: NodesMessageData[] = Object.entries(
    nodesWithdoutNameGroupedByType
  )
    .filter(([_, sameTypeNodes]) => sameTypeNodes.length > 1)
    .map(([type, sameTypeNodes]) => {
      return {
        msg: `${numberToEnglishWord(
          sameTypeNodes.length
        )} ${type}s nodes have no name. Assign a name to them to avoid conflicts`,
        nodeIds: sameTypeNodes.map(({ id }) => id),
        category: getNodeCategory(sameTypeNodes[0]),
      };
    });

  const graphNodesErrorData = useGraphNodesErrorData({
    title: 'Nodes names are missing',
    type: GraphErrorKind.nodes,
    msg: '',
    nodesMessageData: nodesMessageData,
  });

  return useMemo((): NetworkWizardData[] => {
    if (!nodes.size) return [];

    return graphNodeErrorData.concat(graphNodesErrorData);
  }, [graphNodeErrorData, graphNodesErrorData, nodes.size]);
}

function useGenerateNodesWithSameNameAlerts(): NetworkWizardData[] {
  const { nodes } = useNetworkMapContext();
  const { selectedCodeIntegrationVersion } = useCurrentProject();

  const nodeList = Array.from(nodes.values());

  const nodesWhichRequiredUniqueName = nodeList.filter(isRequiredUniqeNameNode);
  const namesMap = nodesWhichRequiredUniqueName.reduce((acc, node) => {
    if (!node.data.user_unique_name) {
      return acc;
    } else if (acc[node.data.user_unique_name]?.nodeIds?.length) {
      acc[node.data.user_unique_name].nodeIds.push(node.id);
    } else {
      acc[node.data.user_unique_name] = {
        nodeIds: [node.id],
        category: getNodeCategory(node),
      };
    }
    return acc;
  }, {} as { [key: string]: { nodeIds: string[]; category: NetworkWizardCategory | undefined } });

  const nodesMessageData = Object.entries(namesMap)
    .filter(([_, { nodeIds }]) => nodeIds.length > 1)
    .map(
      ([user_unique_name, { nodeIds, category }]): NodesMessageData => ({
        msg: `${numberToEnglishWord(
          nodeIds.length
        )} nodes share the same name: "${user_unique_name}". Assign distinct names to each of them to avoid conflicts`,
        nodeIds,
        category,
      })
    );
  const graphNodesErrorData = useGraphNodesErrorData({
    type: GraphErrorKind.nodes,
    msg: '',
    nodesMessageData,
  });

  return useMemo((): NetworkWizardData[] => {
    return !nodes.size || !selectedCodeIntegrationVersion
      ? []
      : graphNodesErrorData;
  }, [graphNodesErrorData, nodes.size, selectedCodeIntegrationVersion]);
}

function useGenerateValidateAssetAlerts(): NetworkWizardData[] {
  const {
    nodes,
    validateAssetsStatus,
    validateAssetsErrors,
  } = useNetworkMapContext();
  const { selectedCodeIntegrationVersion } = useCurrentProject();

  const validationAlertData = validateAssetsErrors.reduce(
    (acc, { msg, nodeId }) => {
      if (!nodeId || msg === GraphErrorMsg.ValidateAssets) {
        return acc;
      }
      const node = nodes.get(nodeId);
      const category = node
        ? getNodeCategory(node)
        : NetworkWizardCategory.MODEL;
      const validateAssetsError: ValidateAssetsErrorType = {
        type: GraphErrorKind.validateAssets,
        category,
        msg,
      };
      acc.push(validateAssetsError);
      return acc;
    },
    [] as ValidateAssetsErrorType[]
  );
  const validateAssetsErrorCards = GetValidateAssetsErrorData(
    validationAlertData
  );

  return useMemo((): NetworkWizardData[] => {
    return !nodes.size ||
      !selectedCodeIntegrationVersion ||
      validateAssetsStatus === ValidateAssetsStatus.CalculatingDigest
      ? []
      : validateAssetsErrorCards;
  }, [
    nodes.size,
    selectedCodeIntegrationVersion,
    validateAssetsErrorCards,
    validateAssetsStatus,
  ]);
}

function useGenerateErrorsFromShapes(): NetworkWizardData[] {
  const { nodes, nodesShapesRef } = useNetworkMapContext();
  const { selectedCodeIntegrationVersion } = useCurrentProject();

  const nodesMessageData = Array.from(nodesShapesRef.current.values()).reduce(
    (acc, { error }) => {
      if (
        !error ||
        !error.nodeId ||
        (error.type === GraphErrorKind.nodeInput && error.inputHasError)
      ) {
        return acc;
      }
      acc.push({
        nodeId: error.nodeId,
        msg: error.msg,
      });
      return acc;
    },
    [] as NodeMessageData[]
  );

  const errorData = useGraphNodeDictErrorData({
    type: GraphErrorKind.node,
    msg: '',
    category: NetworkWizardCategory.MODEL,
    nodesMessageData,
  });

  return useMemo(() => {
    if (!nodes.size || !selectedCodeIntegrationVersion) {
      return [];
    }
    return errorData;
  }, [errorData, nodes.size, selectedCodeIntegrationVersion]);
}

function getNodeCategory(node: Node): NetworkWizardCategory {
  if (isKindOfLossNode(node)) {
    return NetworkWizardCategory.LOSS;
  }
  if (isVisualizerNode(node)) {
    return NetworkWizardCategory.VIS;
  }
  if (isInputNode(node) || isInputsNode(node)) {
    return NetworkWizardCategory.INPUTS;
  }
  if (isCustomMetricNode(node)) {
    return NetworkWizardCategory.METRICS;
  }
  return NetworkWizardCategory.MODEL;
}
