import { last } from 'lodash';
import { DynamicShapeDetails } from '../../layer-details/DynamicShapeDetails';
import { LayerSetting } from '../../layer-details/LayerSetting';
import {
  ConditinalLabelDescriptor,
  ConditinalLabelDescriptorWithoutType,
  LabelDescriptor,
  LabelDescriptorWithoutType,
  NodeLabels,
} from './types';
import {
  setOrRemoveErrorFromNodeStates,
  updateStateIfPropNotExisted,
} from './utils';
import { layersMetadata } from '../graph-calculation/contract';
import { GraphErrorKind } from '../wizard/types';

export const PredictionLabel: LabelDescriptorWithoutType<NodeLabels.Prediction> =
  {
    name: NodeLabels.Prediction,
    colorTheme: 'prediction',
    CustomDetails: LayerSetting,
    updateState: (node, { datasetSetup, setNodeStates }) => {
      setNodeStates((currentNodeStates) => {
        const updatedNodeStates = updateStateIfPropNotExisted(
          node,
          'prediction_type',
          'PredictionType',
          currentNodeStates,
        );

        const prediction_name = node.data.prediction_type;
        const selectedPrediction = datasetSetup?.prediction_types.find(
          ({ name }) => name === prediction_name,
        );

        const nodeState = updatedNodeStates.get(node.id);
        if (!selectedPrediction || !nodeState?.shape?.length) {
          return currentNodeStates;
        }

        const lastTensorSize = last(nodeState.shape);
        const labelsLength = selectedPrediction.labels.length;

        const newNodeStates = setOrRemoveErrorFromNodeStates(
          lastTensorSize === labelsLength,
          {
            type: GraphErrorKind.nodeAttr,
            msg: `“Prediction“ labels count of ${labelsLength} does not match the predicted values count of ${lastTensorSize}`,
            nodeId: node.id,
            attrName: 'prediction_type',
          },
          node.id,
          updatedNodeStates,
        );
        return newNodeStates;
      });
    },
    add: ({ node, changeNodeProperty, codeIntegrationVersion }) => {
      const predictionTypes =
        codeIntegrationVersion?.metadata.setup?.prediction_types || [];
      const defaultPredictionType =
        predictionTypes.length === 1 ? predictionTypes[0].name : undefined;
      const prediction_type =
        node.data.prediction_type ?? defaultPredictionType;

      const labels = node.labels?.slice() || [];
      if (!labels.includes(NodeLabels.Prediction)) {
        labels.push(NodeLabels.Prediction);
      }

      changeNodeProperty({
        nodeId: node.id,
        nodeDataPropsToUpdate: { prediction_type },
        nodePropsToUpdate: { labels },
      });
    },
    remove: ({ node: { id, labels = [] }, changeNodeProperty }) => {
      changeNodeProperty({
        nodeId: id,
        nodeDataPropsToUpdate: { prediction_type: undefined },
        nodePropsToUpdate: {
          labels: labels.filter((l) => l !== NodeLabels.Prediction),
        },
      });
    },
  };

export const DynamicShapeLabel: ConditinalLabelDescriptorWithoutType = {
  name: NodeLabels.DynamicShape,
  CustomDetails: DynamicShapeDetails,
  updateState: (node, { setNodeStates }) => {
    const attrName = 'expected_output_shpae';
    const nodeId = node.id;
    setNodeStates((current) =>
      setOrRemoveErrorFromNodeStates(
        node.data[attrName],
        {
          type: GraphErrorKind.nodeAttr,
          msg: 'Please specify the expected output shape',
          nodeId,
          attrName,
        },
        nodeId,
        current,
      ),
    );
  },
  hasLabel: (nodeName) => {
    const hasDynamicShape = layersMetadata
      .get(nodeName)
      ?.shape_calc?.some((shpapeCalcFunc) => shpapeCalcFunc === 'Dynamic');
    return !!hasDynamicShape;
  },
};

export const LABEL_DESCRIPTORS: Partial<{
  [N in NodeLabels]: LabelDescriptor<N>;
}> = {
  [PredictionLabel.name]: { ...PredictionLabel, type: 'label' },
};

export const CONDITIONAL_LABEL_DESCRIPTORS: ConditinalLabelDescriptor[] = [
  { ...DynamicShapeLabel, type: 'label' },
];
