import { useCallback, useMemo } from 'react';
import { Node } from '@tensorleap/api-client';
import { isPredictionNode } from '../../utils';
import {
  isCustomMetricNode,
  isGroundTruthNode,
  isInputsNode,
  isVisualizerNode,
} from '../../graph-calculation/utils';
import { useNetworkMapContext } from '../../../core/NetworkMapContext';
import {
  GraphWarningDataProps,
  useGraphWarningData,
} from '../cards/GraphWarningData';
import {
  NetworkWizardData,
  GraphWarningKind,
  NetworkWizardCategory,
} from '../types';

type hasNodeConnectedFunc = (
  { id }: Node,
  predicate: (node: Node) => boolean,
) => boolean;

export function useGenerateGraphWarnings(): NetworkWizardData[] {
  const { nodes, connections } = useNetworkMapContext();

  const hasNodeConnected = useCallback(
    ({ id }: Node, predicate: (node: Node) => boolean) => {
      return connections?.some(({ outputNodeId, inputNodeId }) => {
        if (outputNodeId !== id) return false;

        const currentNode = nodes.get(inputNodeId);
        return !!currentNode && predicate(currentNode);
      });
    },
    [connections, nodes],
  );

  const warningsData = useMemo(
    (): GraphWarningDataProps[] =>
      Array.from(nodes.values()).reduce((acc, node) => {
        const predictionNeedsMetric = predictionNeedsMetricWarning(
          hasNodeConnected,
          node,
        );

        if (predictionNeedsMetric) {
          acc.push(predictionNeedsMetric);
        }
        const vizualizationRequired = vizualizationRequiredWarning(
          hasNodeConnected,
          node,
        );

        if (vizualizationRequired) {
          acc.push(vizualizationRequired);
        }
        return acc;
      }, [] as GraphWarningDataProps[]),
    [hasNodeConnected, nodes],
  );

  return useGraphWarningData(warningsData);
}

function predictionNeedsMetricWarning(
  hasNodeConnected: hasNodeConnectedFunc,
  node: Node,
): GraphWarningDataProps | undefined {
  if (isPredictionNode(node) && !hasNodeConnected(node, isCustomMetricNode)) {
    return {
      type: GraphWarningKind.predictionNeedsMetric,
      nodeId: node.id,
      category: NetworkWizardCategory.METRICS,
    };
  }
}

function vizualizationRequiredWarning(
  hasNodeConnected: hasNodeConnectedFunc,
  node: Node,
): GraphWarningDataProps | undefined {
  if (!hasNodeConnected(node, isVisualizerNode)) {
    if (isPredictionNode(node)) {
      return {
        type: GraphWarningKind.predictionNeedsVisualization,
        nodeId: node.id,
        category: NetworkWizardCategory.VIS,
      };
    }
    if (isGroundTruthNode(node)) {
      return {
        type: GraphWarningKind.gtNeedsVisualization,
        category: NetworkWizardCategory.VIS,
        nodeId: node.id,
      };
    }
    if (isInputsNode(node)) {
      return {
        type: GraphWarningKind.inputNeedsVisualization,
        category: NetworkWizardCategory.VIS,
        nodeId: node.id,
      };
    }
  }
}
