import { last } from 'lodash';
import yaml from 'js-yaml';
import { ModelGraph, Node } from '@tensorleap/api-client';
import { isCustomLayerNode } from './utils';

export interface YamlMapping {
  decorators: DatasetMapping[];
  layers: Partial<DatasetMapping>[];
}

export interface DatasetMapping {
  operation: string;
  data: Record<string, unknown>;
  name: string;
  id: string;
  inputs: Record<string, unknown>;
  outputs: Record<string, unknown>;
}

function extractSocketKey(socketName: string): string {
  return last(socketName.split('-') || '') || '';
}

function isLayerTypeNode(node: Node): boolean {
  return isLayerType(node.data?.type, node.name);
}

function isPredictionLayer(node: Node): boolean {
  return isLayerTypeNode(node) && node.data?.prediction_type;
}

function isLayerType(nodeType?: string, nodeName?: string): boolean {
  return (
    nodeType === 'Layer' ||
    nodeType === 'CustomLayer' ||
    nodeName === 'Representation Block'
  );
}

function generateInputsMapping(
  node: Node,
  modelGraph: ModelGraph
): Record<string, unknown> {
  return Object.keys(node.inputs).reduce((acc, nodeKey) => {
    const inputValue = node.inputs[nodeKey];
    const inputKey = extractSocketKey(nodeKey);
    acc[inputKey] = (inputValue?.connections || []).map((connection) => {
      const outputKey = extractSocketKey(connection.output);
      const outputNode = modelGraph.nodes[connection.node];
      return {
        outputKey,
        operation: outputNode?.name,
        name: outputNode?.data.origin_name,
        id: outputNode?.id,
      };
    });
    return acc;
  }, {} as Record<string, unknown>);
}

function generateOutputsMapping(
  node: Node,
  modelGraph: ModelGraph
): Record<string, unknown> {
  return Object.keys(node.outputs).reduce((acc, nodeKey) => {
    const outputValue = node.outputs[nodeKey];
    const outputKey = extractSocketKey(nodeKey);
    acc[outputKey] = (outputValue?.connections || []).map((connection) => {
      const inputKey = extractSocketKey(connection.input);
      const inputNode = modelGraph.nodes[connection.node];
      return {
        inputKey,
        operation: inputNode?.name,
        name: inputNode?.data.origin_name,
        id: inputNode?.id,
      };
    });
    return acc;
  }, {} as Record<string, unknown>);
}

function getDecoratorMap(modelGraph: ModelGraph): DatasetMapping[] {
  return Object.values(modelGraph.nodes).reduce((acc, node) => {
    if (node && !isLayerTypeNode(node)) {
      acc.push({
        operation: node.name,
        data: node.data,
        name: node?.data['origin_name'],
        id: node.id,
        inputs: generateInputsMapping(node, modelGraph),
        outputs: generateOutputsMapping(node, modelGraph),
      });
    }

    return acc;
  }, [] as DatasetMapping[]);
}

function getHiddenLayerMap(modelGraph: ModelGraph): Partial<DatasetMapping>[] {
  return Object.values(modelGraph.nodes).reduce((acc, node) => {
    if (node && isCustomLayerNode(node)) {
      acc.push({
        operation: node.name,
        data: node.data,
        name: node?.data['origin_name'],
        id: node.id,
        outputs: generateOutputsMapping(node, modelGraph),
        inputs: generateInputsMapping(node, modelGraph),
      });
    } else if (node && isPredictionLayer(node)) {
      acc.push({
        operation: node.name,
        data: { prediction_type: node.data.prediction_type },
        name: node?.data['origin_name'],
        id: node.id,
      });
    }
    return acc;
  }, [] as Partial<DatasetMapping>[]);
}

export function generateConnectionsMapYaml(modelGraph: ModelGraph): string {
  const decoratorMapping = getDecoratorMap(modelGraph);
  const hiddenLayerMapping = getHiddenLayerMap(modelGraph);

  const yamlData: YamlMapping = {
    decorators: decoratorMapping,
    layers: hiddenLayerMapping,
  };

  return yaml.dump(yamlData);
}
