import { ReactNode } from 'react';
import {
  InputIcon,
  GroundTruthIcon,
  TagIcon as MetaDataIcon,
  HeatMapIcon,
  HorizontalBarsIcon,
  PictureType,
  TextPlaceholderIcon,
  VerticalGraphIcon,
  Loss,
  LayersIcon,
  MetricIcon,
} from '../ui/icons';
import {
  DatasetSetup,
  GraphValidatorData,
  LeapDataType,
  ModelSetup,
  ValidatedNode,
} from '@tensorleap/api-client';
import { Input } from '../ui/atoms/Input';
import { ValidateAssetsResultEnum } from '../network-editor/interfaces/ValidateGraphStatus';

export function formatDisplayTitle({
  category,
  name,
}: {
  category: string;
  name: string | undefined;
}): string {
  return name ? `${category} - ${name}` : ` ${category}`;
}

export interface DatasetItem {
  title: string;
  value: string;
}

export interface DatasetListItemInput {
  fnName: string;
  icon: ReactNode;
  title: string;
  subItems: DatasetItem[];
  error?: string;
  selected?: boolean;
  openByDefault?: boolean;
  validateAssetsResult: ValidateAssetsResultEnum;
  validateAssetsError?: string;
}

export enum DatasetMask {
  Preprocess = 1 << 0,
  Inputs = 1 << 1,
  GroundTruths = 1 << 2,
  Metadata = 1 << 3,
  Visualizers = 1 << 4,
  Predictions = 1 << 5,
  CustomLoss = 1 << 6,
  CustomLayer = 1 << 7,
  Metric = 1 << 8,
  All = Preprocess |
    Inputs |
    GroundTruths |
    Metadata |
    Visualizers |
    Predictions |
    CustomLoss |
    CustomLayer |
    Metric,
}

export interface GeneratePreprocessProps {
  datasetSetup?: DatasetSetup;
}
export function GeneratePreprocess({
  datasetSetup,
}: GeneratePreprocessProps): JSX.Element | null {
  if (!datasetSetup) {
    return null;
  }

  const preprocessData = [
    {
      title: 'TRAINING LENGTH',
      value: datasetSetup?.preprocess?.training_length?.toString() || '0',
    },
    {
      title: 'VALIDATION LENGTH',
      value: datasetSetup?.preprocess?.validation_length?.toString() || '0',
    },
    {
      title: 'TEST LENGTH',
      value: datasetSetup?.preprocess?.test_length?.toString() || '0',
    },
    {
      title: 'UNLABELED LENGTH',
      value: datasetSetup?.preprocess?.unlabeled_length?.toString() || '0',
    },
  ];

  return (
    <div className="flex flex-col gap-8 p-4">
      {preprocessData.map(({ title, value }) => (
        <Input key={title} label={title} value={value} readOnly />
      ))}
    </div>
  );
}

export function parseDatasetVersion(
  datasetMask: DatasetMask,
  datasetSetup?: DatasetSetup,
  modelSetup?: ModelSetup,
  validateAssetsData?: GraphValidatorData,
): DatasetListItemInput[] {
  if (!datasetSetup) return [];

  const {
    inputs,
    outputs,
    metadata,
    visualizers,
    metrics,
    prediction_types: predictions,
    custom_losses,
  } = datasetSetup;
  const { custom_layers } = modelSetup ?? {};

  const inputsProps: DatasetListItemInput[] =
    datasetMask & DatasetMask.Inputs
      ? inputs.map(({ name, shape }) => {
          const { validateAssetsResult, validateAssetsError } =
            calcValidateAssetsResult(validateAssetsData?.inputs, name);

          return {
            fnName: name,
            icon: <InputIcon />,
            title: formatDisplayTitle({ name, category: 'Input' }),
            subItems: [
              { title: 'NAME', value: name },
              { title: 'SHAPE', value: shape.join(', ') },
            ],
            validateAssetsResult,
            validateAssetsError,
          };
        })
      : [];

  const groundTruthsProps: DatasetListItemInput[] =
    datasetMask & DatasetMask.GroundTruths
      ? outputs.map(({ name, shape }) => {
          const { validateAssetsResult, validateAssetsError } =
            calcValidateAssetsResult(validateAssetsData?.ground_truths, name);

          return {
            fnName: name,
            icon: <GroundTruthIcon />,
            title: formatDisplayTitle({ name, category: 'Ground Truth' }),
            subItems: [
              { title: 'NAME', value: name },
              { title: 'SHAPE', value: shape.join(', ') },
            ],
            validateAssetsResult,
            validateAssetsError,
          };
        })
      : [];

  const metadataProps: DatasetListItemInput[] =
    datasetMask & DatasetMask.Metadata
      ? metadata.map(({ name, type }) => {
          const { validateAssetsResult, validateAssetsError } =
            calcValidateAssetsResult(validateAssetsData?.metadata, name);

          return {
            fnName: name,
            icon: <MetaDataIcon />,
            title: formatDisplayTitle({ name, category: 'Metadata' }),
            subItems: [
              { title: 'NAME', value: name },
              { title: 'TYPE', value: type },
            ],
            validateAssetsResult,
            validateAssetsError,
          };
        })
      : [];

  const visualizersProps: DatasetListItemInput[] =
    datasetMask & DatasetMask.Visualizers
      ? visualizers.map(({ name, type, arg_names }) => {
          const { validateAssetsResult, validateAssetsError } =
            calcValidateAssetsResult(validateAssetsData?.visualizers, name);
          return {
            fnName: name,
            icon: getVisualizerIcon(type),
            title: formatDisplayTitle({ name, category: 'Visualizer' }),
            subItems: [
              { title: 'NAME', value: name },
              { title: 'TYPE', value: type },
              { title: 'LABELS', value: arg_names.join(', ') },
            ],
            validateAssetsResult,
            validateAssetsError,
          };
        })
      : [];

  const customMetricsProps: DatasetListItemInput[] =
    datasetMask & DatasetMask.Metric && metrics?.length
      ? metrics.map(({ name, arg_names }) => {
          const { validateAssetsResult, validateAssetsError } =
            calcValidateAssetsResult(validateAssetsData?.metrics, name);

          return {
            fnName: name,
            icon: <MetricIcon />,
            title: formatDisplayTitle({ name, category: 'Metrics' }),
            subItems: [
              { title: 'NAME', value: name },
              { title: 'LABELS', value: arg_names.join(', ') },
            ],
            validateAssetsResult,
            validateAssetsError,
          };
        })
      : [];

  const customLossProps: DatasetListItemInput[] =
    datasetMask & DatasetMask.CustomLoss && custom_losses
      ? custom_losses?.map(({ name, arg_names }) => {
          const { validateAssetsResult, validateAssetsError } =
            calcValidateAssetsResult(validateAssetsData?.metrics, name);

          return {
            fnName: name,
            icon: <Loss />,
            title: formatDisplayTitle({ name, category: 'Custom Loss' }),
            subItems: [
              { title: 'NAME', value: name },
              { title: 'ARGUMENTS', value: arg_names.join(', ') },
            ],
            validateAssetsResult,
            validateAssetsError,
          };
        }) || []
      : [];

  const predictionsProps: DatasetListItemInput[] =
    datasetMask & DatasetMask.Predictions
      ? predictions.map(({ name, labels }) => {
          const { validateAssetsResult, validateAssetsError } =
            calcValidateAssetsResult(validateAssetsData?.metrics, name);

          return {
            fnName: name,
            icon: <HeatMapIcon />,
            title: formatDisplayTitle({ name, category: 'Prediction' }),
            subItems: [
              { title: 'NAME', value: name },
              { title: 'LABELS', value: labels.join(',') },
            ],
            validateAssetsResult,
            validateAssetsError,
          };
        })
      : [];

  const customLayersProps: DatasetListItemInput[] =
    datasetMask & DatasetMask.CustomLayer && custom_layers
      ? custom_layers.map(({ name, init_arg_names, call_arg_names }) => {
          const { validateAssetsResult, validateAssetsError } =
            calcValidateAssetsResult(validateAssetsData?.custom_layers, name);

          return {
            fnName: name,
            icon: <LayersIcon />,
            title: formatDisplayTitle({ name, category: 'Custom Layer' }),
            subItems: [
              { title: 'NAME', value: name },
              { title: 'INIT ARG NAMES', value: init_arg_names.join(',') },
              { title: 'CALL ARG NAMES', value: call_arg_names.join(',') },
            ],
            validateAssetsResult,
            validateAssetsError,
          };
        })
      : [];

  return [
    ...inputsProps,
    ...groundTruthsProps,
    ...metadataProps,
    ...visualizersProps,
    ...customMetricsProps,
    ...predictionsProps,
    ...customLossProps,
    ...customLayersProps,
  ];
}

function getVisualizerIcon(type: LeapDataType): ReactNode {
  switch (type) {
    case 'Image': {
      return <PictureType />;
    }
    case 'Text': {
      return <TextPlaceholderIcon />;
    }
    case 'Graph': {
      return <VerticalGraphIcon />;
    }
    case 'HorizontalBar': {
      return <HorizontalBarsIcon />;
    }
    case 'ImageMask': {
      return <HeatMapIcon />;
    }
    case 'TextMask': {
      return <HeatMapIcon />;
    }
    default: {
      return null;
    }
  }
}

function calcValidateAssetsResult(
  validatedNodes: ValidatedNode[] | undefined,
  name: string,
): {
  validateAssetsResult: ValidateAssetsResultEnum;
  validateAssetsError?: string;
} {
  const validateAssetsItemData = validatedNodes?.find(
    (validatedNode) => validatedNode.name === name,
  );
  const validateAssetsResult =
    validateAssetsItemData === undefined
      ? ValidateAssetsResultEnum.NotTested
      : validateAssetsItemData.error
        ? ValidateAssetsResultEnum.Failed
        : ValidateAssetsResultEnum.Passed;

  const validateAssetsError = validateAssetsItemData?.error;

  return { validateAssetsResult, validateAssetsError };
}
