import React, {
  ReactElement,
  useCallback,
  useLayoutEffect,
  useMemo,
  useRef,
  useState,
} from 'react';
import useResizeObserver from 'use-resize-observer';
import clsx from 'clsx';
import { Tooltip } from '../ui/mui';
import * as d3 from 'd3';

import {
  Plane,
  useScatterData,
  VISUALIZATION_PAYLOAD_FILE_NAME,
} from './ScatterDataContext';
import { DataPoint, useScatterMapData } from './useScatterMapData';
import useAsyncEffect from '../core/useAsyncEffect';
import {
  CompositeVizData,
  GradsAnalysis,
  NumberOrString,
  VisData,
  VisualizedItem,
} from '@tensorleap/api-client';
import { useEnvironmentInfo } from '../core/EnvironmentInfoContext';
import { useScattersAssets } from './dashlet/PopulationExploration/hooks';
import { MousePosition } from '../core/useSelectionGroup';
import { useSmoothTransition } from '../core/useSmoothTransition';
import { calcNextSelection } from '../core/useMultiSelect';
import {
  DisplayVisualizedItemDataTooltip,
  MetadataTooltipProps,
} from './dashlet/VisualizationDisplay/VisData';
import {
  getMetaDataPreviewValue,
  getVisualizationPreviewValue,
  isMetaDataPreview,
  isVisualizationPreview,
  VisPayloadType,
} from './dashlet/VisualizationDisplay/visDataHelpers';
import { TOUR_SELECTORS_ENUM } from '../tour/ToursConfig';
import { useDashboardContext } from './DashboardContext';
import { ScatterShape, ShapeType } from './ScatterAnalyzerView/ScatterShape';

const RADIUS_SIZE_SCALE = 1.5;
const BASE_CHART_SIZE = { h: 220, w: 230, d: 220 };

function calcMinMax(data: DataPoint[]) {
  const { x = 0, y = 0, z = 0 } = data[0] || {};

  return data.reduce(
    (accu, current) => ({
      minX: Math.min(current.x, accu.minX),
      minY: Math.min(current.y, accu.minY),
      minZ: Math.min(current.z, accu.minZ),
      maxX: Math.max(current.x, accu.maxX),
      maxY: Math.max(current.y, accu.maxY),
      maxZ: Math.max(current.z, accu.maxZ),
    }),
    { maxX: x, maxY: y, maxZ: z, minX: x, minY: y, minZ: z }
  );
}

type ScaleLinear = d3.ScaleLinear<number, number, never>;

interface TitleProps {
  index: number;
}

function TooltipTitle({ index }: TitleProps): ReactElement {
  const {
    settings: { previewBy },
    samplesIdsWithAssets,
    scatterData,
    visualizationDisplays,
    scatterSampleVisualizationsPrefix,
  } = useScatterData();
  const [visalizationData, setVisalizationData] = useState<
    VisData | GradsAnalysis | CompositeVizData | MetadataTooltipProps
  >();

  const {
    environmentInfo: { clientStoragePrefixUrl },
  } = useEnvironmentInfo();

  useAsyncEffect(async () => {
    if (!previewBy) return;

    if (isMetaDataPreview(previewBy)) {
      const metaDataKey = getMetaDataPreviewValue(previewBy);
      if (!metaDataKey) {
        console.error('Should not happen');
        return;
      }

      const content = scatterData.metadata[metaDataKey].body[index].toString();

      setVisalizationData({
        content,
        type: VisPayloadType.Metadata,
      });
      return;
    }

    if (isVisualizationPreview(previewBy)) {
      const sampleIdentity = scatterData.samples[index];
      const sampleId = `${sampleIdentity?.state}_${sampleIdentity?.index}`;
      if (!samplesIdsWithAssets.has(sampleId)) {
        return;
      }

      const visualizationDisplay = visualizationDisplays.find(
        ({ visName }) => visName === getVisualizationPreviewValue(previewBy)
      );

      if (visualizationDisplay === undefined) {
        console.error('Should not happen');
        return;
      }

      const blobUrl = `${clientStoragePrefixUrl}/${scatterSampleVisualizationsPrefix}${sampleId}/${visualizationDisplay.visType}/${visualizationDisplay.visName}/${VISUALIZATION_PAYLOAD_FILE_NAME}`;
      const blobContent = await fetch(blobUrl, {
        method: 'GET',
        cache: 'force-cache',
        credentials: 'include',
      });
      const visualizedItem: VisualizedItem = await blobContent.json();
      setVisalizationData(visualizedItem.data);
    }
  }, [
    scatterData,
    samplesIdsWithAssets,
    previewBy,
    index,
    scatterSampleVisualizationsPrefix,
    visualizationDisplays,
  ]);

  return <DisplayVisualizedItemDataTooltip data={visalizationData} />;
}

interface LabelTooltipProps {
  index: number;
  children: ReactElement;
}

function LabelTooltip({ index, children }: LabelTooltipProps): ReactElement {
  const {
    settings: { previewBy },
    scatterData,
    samplesIdsWithAssets,
  } = useScatterData();

  if (!previewBy) {
    return children;
  }

  const sampleIdentity = scatterData.samples[index];
  const sampleId = `${sampleIdentity?.state}_${sampleIdentity?.index}`;
  if (!samplesIdsWithAssets.has(sampleId) && !isMetaDataPreview(previewBy)) {
    return children;
  }

  return (
    <Tooltip placement="top" title={<TooltipTitle index={index} />}>
      {children}
    </Tooltip>
  );
}

export interface ScatterMapProps {
  className?: string;
  projectId: string;
  sessionRunId: string;
  epoch: number;
  startMousePosition: MousePosition;
  endMousePosition: MousePosition;
  disableGrab?: boolean;
  isDrawing: boolean;
}

export const ScatterMap = React.memo(
  ({
    className,
    projectId,
    sessionRunId,
    epoch,
    startMousePosition,
    endMousePosition,
    disableGrab,
    isDrawing,
  }: ScatterMapProps) => {
    const { selection, scatterData, legendHovered } = useScatterData();
    const data = useScatterMapData();
    const {
      insightScatterSelectionFilter,
      setSelectedScatterInsightFilter,
      setHoveredScatterInsightFilter,
    } = useDashboardContext();

    const { ref: containerRef, width = 0, height = 0 } = useResizeObserver();
    const svgRef = useRef<SVGSVGElement>(null);
    const prevFilterRef = useRef(insightScatterSelectionFilter);

    const [zoomTransform, setZoomTransform] = useState<d3.ZoomTransform>();
    const [hoveredIndices, setHoveredIndices] = useState<Set<number>>(
      new Set()
    );
    const [isSelecting, setIsSelecting] = useState(false);

    useLayoutEffect(() => {
      const svgElement = svgRef.current as Element;
      if (!svgElement) return;

      const zoomBehavior = d3
        .zoom()
        .scaleExtent([0.5, 25])
        .on('zoom', ({ transform }: d3.D3ZoomEvent<SVGSVGElement, unknown>) => {
          setZoomTransform(transform);
        });

      const selected = d3.select(svgElement).call(zoomBehavior);

      d3.select(svgElement).style('cursor', 'grab');

      zoomBehavior
        .on('start', () => d3.select(svgElement).style('cursor', 'grabbing'))
        .on('end', () => d3.select(svgElement).style('cursor', 'grab'));

      if (disableGrab) {
        selected
          .on('mousedown.zoom', null)
          .on('mousemove.zoom', null)
          .style('cursor', 'unset');
      }

      return () => {
        selected.on('.zoom', null);
        d3.select(svgElement).style('cursor', null);
      };
    }, [disableGrab]);

    const centerOffset = useMemo(() => {
      const offset =
        -((BASE_CHART_SIZE.w * width) / height - BASE_CHART_SIZE.w) / 2;
      return Number.isNaN(offset) ? 0 : offset;
    }, [width, height]);
    const centerOffsetView = useSmoothTransition(centerOffset);

    const clearSelection = useCallback(() => {
      if (disableGrab || selection.mode !== 'select') return;
      selection.clear();
    }, [selection, disableGrab]);

    const { samplesIdsWithAssets } = useScattersAssets({
      projectId,
      sessionRunId,
      epoch,
    });

    const handleMarked = useCallback(() => {
      if (!svgRef.current || !isDrawing) return;
      const selectedPoints: number[] = calcMarkedPoints(
        svgRef.current,
        startMousePosition,
        endMousePosition
      );
      selection.setMarked(new Set(selectedPoints));
    }, [isDrawing, startMousePosition, endMousePosition, selection]);

    useLayoutEffect(() => {
      if (isDrawing && !isSelecting) {
        setIsSelecting(true);
        handleMarked();
      } else if (!isDrawing && isSelecting) {
        setIsSelecting(false);
        const selectedPoints: number[] = calcMarkedPoints(
          svgRef.current!,
          startMousePosition,
          endMousePosition
        );
        const selected = calcNextSelection(
          selection.mode,
          selection.selected,
          new Set(selectedPoints)
        );
        selection.setSelected(selected);
        selection.setMarked(undefined);

        setSelectedScatterInsightFilter(undefined);
        setHoveredScatterInsightFilter(undefined);
        setHoveredIndices(new Set());
      }
    }, [
      isDrawing,
      isSelecting,
      startMousePosition,
      endMousePosition,
      handleMarked,
      selection,
      setSelectedScatterInsightFilter,
      setHoveredScatterInsightFilter,
    ]);

    const findMatchingIndices = useCallback(
      (key: string, value: NumberOrString | NumberOrString[]) => {
        const metadataArray = scatterData.metadata[key]?.body;
        if (!metadataArray) return new Set<number>();

        const compereFunc = Array.isArray(value)
          ? (a: NumberOrString) => value.includes(a)
          : (a: NumberOrString) => a === value;

        return new Set(
          metadataArray.reduce((acc: number[], metaValue, index) => {
            if (compereFunc(metaValue)) {
              const matchingDataPoints = data.filter(
                (point) => point.originalIndex === index
              );
              acc.push(
                ...matchingDataPoints.map((point) => point.originalIndex)
              );
            }
            return acc;
          }, [])
        );
      },
      [scatterData.metadata, data]
    );

    useLayoutEffect(() => {
      if (legendHovered) {
        const { key, value } = legendHovered;
        const newHoveredIndices = findMatchingIndices(key, value);
        setHoveredIndices(newHoveredIndices);
        return;
      } else if (!legendHovered && !insightScatterSelectionFilter) {
        setHoveredIndices(new Set());
      } else if (
        JSON.stringify(prevFilterRef.current) ===
        JSON.stringify(insightScatterSelectionFilter)
      ) {
        return;
      }

      prevFilterRef.current = insightScatterSelectionFilter;

      // Extract the digest from the first available cluster blob path
      const scatterDigest = scatterData.clusters_blob_path
        ? extractDigest(Object.values(scatterData.clusters_blob_path)[0][0])
        : null;

      // Extract the digest from the cluster URL

      if (insightScatterSelectionFilter && scatterDigest) {
        const { hovered, selected } = insightScatterSelectionFilter;

        if (
          hovered &&
          hovered.sessionRunId === sessionRunId &&
          scatterDigest === hovered.digest
        ) {
          const newHoveredIndices = findMatchingIndices(
            hovered.filter.key,
            hovered.filter.value
          );
          setHoveredIndices(newHoveredIndices);
        } else {
          setHoveredIndices(new Set());
        }

        if (
          selected &&
          selected.sessionRunId === sessionRunId &&
          scatterDigest === selected.digest
        ) {
          const selectedIndices = findMatchingIndices(
            selected.filter.key,
            selected.filter.value
          );
          selection.setSelected(selectedIndices);
          setSelectedScatterInsightFilter(undefined);
        }
      } else {
        setHoveredIndices(new Set());
      }
    }, [
      legendHovered,
      insightScatterSelectionFilter,
      scatterData.metadata,
      selection,
      findMatchingIndices,
      sessionRunId,
      scatterData.clusters_blob_path,
      setSelectedScatterInsightFilter,
    ]);

    const radio = BASE_CHART_SIZE.w / BASE_CHART_SIZE.h;
    const containerRadio = width / height;
    const minWidth = containerRadio < radio ? height * radio : width;

    return (
      <div
        className={className}
        ref={containerRef}
        id={TOUR_SELECTORS_ENUM.POPULATION_EXPLORATION_CIRCLES_ID}
      >
        <div className="absolute top-1 left-2 text-xs text-gray-400">
          <Tooltip title="Number of samples" placement="top">
            <span>{data.length}</span>
          </Tooltip>
        </div>
        <svg
          viewBox={`${centerOffsetView} 0 ${BASE_CHART_SIZE.w} ${BASE_CHART_SIZE.h}`}
          className="h-full w-full"
          style={{
            minWidth,
          }}
          onClick={clearSelection}
          preserveAspectRatio="xMinYMin meet"
          ref={svgRef}
        >
          <g transform={zoomTransform?.toString()}>
            <Circles
              data={data}
              samplesIdsWithAssets={samplesIdsWithAssets}
              transformK={zoomTransform?.k}
              plane={Plane.XY}
              hoveredIndices={hoveredIndices}
            />
          </g>
        </svg>
      </div>
    );
  }
);

ScatterMap.displayName = 'ScatterMap';
ScatterMap.displayName = 'ScatterMap';

interface CirclesProps {
  data: (DataPoint & { originalIndex: number })[];
  samplesIdsWithAssets: Set<string>;
  transformK?: number;
  plane: Plane;
  hoveredIndices: Set<number>;
}

const Circles = React.memo<CirclesProps>(
  ({ data, samplesIdsWithAssets, transformK = 1, plane, hoveredIndices }) => {
    const {
      selection: { nextSelected, selected },
    } = useScatterData();
    const { minX, maxX, minY, maxY, minZ, maxZ } = useMemo(
      () => calcMinMax(data),
      [data]
    );

    const memoizedScales = useMemo(() => {
      return {
        scaleX: d3
          .scaleLinear()
          .domain([minX, maxX])
          .range([0, BASE_CHART_SIZE.w]),
        scaleY: d3
          .scaleLinear()
          .domain([minY, maxY])
          .range([BASE_CHART_SIZE.h, 0]),
        scaleZ: d3
          .scaleLinear()
          .domain([minZ, maxZ])
          .range([BASE_CHART_SIZE.d, 0]),
      };
    }, [minX, maxX, minY, maxY, minZ, maxZ]);

    const { scaleX, scaleY, scaleZ } = memoizedScales;

    const { selection } = useScatterData();
    const handleSelect = useCallback(
      (e: React.SyntheticEvent<SVGElement, Event>, sampleIdx: number) => {
        e.preventDefault();
        e.stopPropagation();
        const nextSelection = calcNextSelection(
          selection.mode,
          selection.selected,
          new Set([sampleIdx])
        );
        selection.setSelected(nextSelection);
        selection.setMarked(undefined);
      },
      [selection]
    );

    const cursor = useMemo(
      () =>
        selection.mode === 'subtract'
          ? 'pointer'
          : selection.mode === 'add'
          ? 'copy'
          : 'pointer',
      [selection.mode]
    );

    const calcCircleStyle = useCallback(
      (
        isNextSelected: boolean,
        isSelected: boolean,
        isHovered: boolean,
        hasAssets: boolean,
        originalColor: string
      ) => {
        const baseSize = 1 / transformK;
        const borderSize =
          isNextSelected || isSelected || isHovered
            ? 1.1 * baseSize
            : hasAssets
            ? 0.4 * baseSize
            : 0;

        const className = isNextSelected
          ? 'opacity-100 fill-primary-300 z-50'
          : isSelected
          ? 'opacity-50 fill-primary-300'
          : isHovered
          ? 'opacity-75'
          : hasAssets
          ? 'opacity-50'
          : 'opacity-40';

        const stroke =
          !isNextSelected && !isSelected && isHovered ? '#00FFFF' : '#FFFFFF';

        const color = isNextSelected || isSelected ? undefined : originalColor;

        return { borderSize, className, color, stroke };
      },
      [transformK]
    );

    return (
      <>
        {data.map(
          ({
            x,
            y,
            z,
            radius,
            sample,
            shape,
            color: originalColor,
            originalIndex,
          }) => {
            const isNextSelected = nextSelected?.has(originalIndex) || false;
            const isSelected = selected?.has(originalIndex) || false;
            const isHovered = hoveredIndices.has(originalIndex);

            const hasAssets = samplesIdsWithAssets.has(
              `${sample?.state}_${sample?.index}`
            );

            const { className, borderSize, color, stroke } = calcCircleStyle(
              isNextSelected,
              isSelected,
              isHovered,
              hasAssets,
              originalColor
            );
            return (
              <Point
                key={originalIndex}
                cursor={cursor}
                className={clsx(
                  'hover:opacity-100 transition-opacity duration-200',
                  className
                )}
                sampleIdx={originalIndex}
                x={x}
                y={y}
                z={z}
                scaleX={scaleX}
                scaleY={scaleY}
                scaleZ={scaleZ}
                shape={shape}
                plane={plane}
                color={color}
                borderSize={borderSize}
                radius={radius}
                transformK={transformK}
                onSelect={handleSelect}
                stroke={stroke}
                isSelected={isSelected}
                isHovered={isHovered}
              />
            );
          }
        )}
      </>
    );
  }
);

Circles.displayName = 'Circles';

const preventDefault = (e: React.MouseEvent) => {
  e.stopPropagation();
  e.preventDefault();
};

interface PointProps {
  className?: string;
  sampleIdx: number;
  x: number;
  y: number;
  z: number;
  scaleX: ScaleLinear;
  scaleY: ScaleLinear;
  scaleZ: ScaleLinear;
  plane?: Plane;
  color?: string;
  borderSize?: number;
  radius: number;
  transformK: number;
  absPosition?: { x: number; y: number };
  cursor: string;
  shape: ShapeType;
  stroke: string;
  onSelect: (
    e: React.SyntheticEvent<SVGElement, Event>,
    sampleIdx: number
  ) => void;
  isSelected: boolean;
  isHovered: boolean;
}

const Point = React.memo<PointProps>(
  ({
    className,
    sampleIdx,
    x,
    y,
    z,
    scaleX,
    scaleY,
    scaleZ,
    plane = Plane.XY,
    color,
    borderSize = 0,
    radius,
    shape,
    transformK,
    onSelect,
    cursor,
    stroke,
    isSelected,
    isHovered,
  }) => {
    const { cx, cy } = useMemo(() => {
      let cx: ReturnType<ScaleLinear>;
      let cy: ReturnType<ScaleLinear>;
      switch (plane) {
        case Plane.XY: {
          cx = scaleX(x);
          cy = scaleY(y);
          break;
        }
        case Plane.XZ: {
          cx = scaleX(x);
          cy = scaleZ(z);
          break;
        }
        case Plane.YZ: {
          cx = scaleY(y);
          cy = scaleZ(z);
          break;
        }
        default: {
          throw Error('The selected plane is not supported');
        }
      }
      return { cx, cy };
    }, [x, y, z, scaleX, scaleY, scaleZ, plane]);

    const circleRadius = useMemo(
      () => (radius * RADIUS_SIZE_SCALE) / transformK,
      [radius, transformK]
    );

    const handleClick = useCallback(
      (e: React.SyntheticEvent<SVGElement, Event>) => {
        onSelect(e, sampleIdx);
      },
      [onSelect, sampleIdx]
    );

    const circleStyle = useMemo(
      () => ({
        fill: color,
        stroke: isHovered ? '#00FFFF' : stroke,
        strokeWidth: isHovered ? borderSize * 1.5 : borderSize,
        strokeOpacity: isHovered ? 0.8 : 1,
        cursor,
      }),
      [color, isHovered, stroke, borderSize, cursor]
    );

    return (
      <LabelTooltip index={sampleIdx}>
        <ScatterShape
          key={sampleIdx}
          shapeType={shape}
          x={cx}
          y={cy}
          id={sampleIdx.toString()}
          size={circleRadius * 2}
          isSelected={isSelected}
          className={className}
          isHighlighted={isHovered}
          style={circleStyle}
          onClick={handleClick}
          onMouseDown={preventDefault}
        />
      </LabelTooltip>
    );
  },
  (prevProps, nextProps) => {
    return (
      prevProps.x === nextProps.x &&
      prevProps.y === nextProps.y &&
      prevProps.z === nextProps.z &&
      prevProps.scaleX === nextProps.scaleX &&
      prevProps.scaleY === nextProps.scaleY &&
      prevProps.scaleZ === nextProps.scaleZ &&
      prevProps.plane === nextProps.plane &&
      prevProps.isSelected === nextProps.isSelected &&
      prevProps.isHovered === nextProps.isHovered &&
      prevProps.color === nextProps.color &&
      prevProps.radius === nextProps.radius &&
      prevProps.transformK === nextProps.transformK &&
      prevProps.borderSize === nextProps.borderSize &&
      prevProps.stroke === nextProps.stroke &&
      prevProps.cursor === nextProps.cursor
    );
  }
);
Point.displayName = 'Point';

const calcMarkedPoints = (
  svgElm: SVGSVGElement,
  startMousePosition: MousePosition,
  endMousePosition: MousePosition
) => {
  const { left: absoluteX, top: absoluteY } = svgElm.getBoundingClientRect();
  // point can be circle or path
  const pointSelector = 'circle, path, rect, ellipse, polygon';
  const circles = svgElm.querySelectorAll(pointSelector);
  const selectedPoints: number[] = [];

  const minX = Math.min(startMousePosition.x, endMousePosition.x);
  const maxX = Math.max(startMousePosition.x, endMousePosition.x);
  const minY = Math.min(startMousePosition.y, endMousePosition.y);
  const maxY = Math.max(startMousePosition.y, endMousePosition.y);

  circles.forEach((circle) => {
    const { x, y, width, height } = circle.getBoundingClientRect();
    const centerX = x - absoluteX + width / 2;
    const centerY = y - absoluteY + height / 2;

    if (
      centerX >= minX &&
      centerX <= maxX &&
      centerY >= minY &&
      centerY <= maxY
    ) {
      selectedPoints.push(Number(circle.id));
    }
  });

  return selectedPoints;
};

function extractDigest(url: string): string | null {
  const match = url.match(/digest_([a-f0-9]+)\//);
  return match ? match[1] : null;
}
