import clsx from 'clsx';
import {
  ScatterShape,
  ShapeType,
} from '../../../dashboard/ScatterAnalyzerView/ScatterShape';
import {
  MutualInformationElement,
  NumberOrString,
} from '@tensorleap/api-client';
import { Tooltip } from '@material-ui/core';
import { CorrelatedMetadataOnCluster } from '../../../insights/InsightCardText';
import { ReactNode, useCallback, useMemo } from 'react';

type LegendItemProps<Value extends NumberOrString | NumberOrString[]> = {
  shape?: ShapeType;
  label: NumberOrString;
  value: Value;
  handleLegendClick?: (value: Value) => void;
  handleLegendMouseOver?: (value: Value) => void;
  handleLegendMouseLeave?: () => void;
  isHidden?: boolean;
  showNames: boolean;
  hiddenLabels?: string[];
  truncatedLongtail: number;
  appearances?: Map<NumberOrString, number>;
  clusterData?: Record<string, MutualInformationElement[]>;
  domainGapMetadata?: Record<string, Record<string, number>>;
  color?: string;
};

export function LegendItem<Value extends NumberOrString | NumberOrString[]>({
  shape = 'circle',
  label,
  value,
  handleLegendClick,
  handleLegendMouseOver,
  handleLegendMouseLeave,
  showNames,
  isHidden,
  truncatedLongtail,
  appearances,
  clusterData,
  domainGapMetadata,
  color,
}: LegendItemProps<Value>): JSX.Element {
  const icon = (
    <svg width="14" height="14" viewBox="0 0 14 14">
      <ScatterShape
        style={{
          stroke: !color ? '#bbb' : 'none',
          fill: color,
        }}
        className={clsx(!color && 'hover:fill-primary-200 fill-primary-50')}
        x={6}
        y={6}
        shapeType={shape}
        size={12}
      />
    </svg>
  );

  const addFilter = useCallback(() => {
    handleLegendClick?.(value);
  }, [handleLegendClick, value]);

  const displayLabel = useMemo(
    () => calcTruncateLabel(label, truncatedLongtail),
    [label, truncatedLongtail],
  );

  const domainGapData = domainGapMetadata?.[label];

  return (
    <div
      className="flex flex-row-reverse gap-2 h-6 items-center cursor-pointer pointer-events-auto"
      onMouseEnter={() => handleLegendMouseOver?.(value)}
      onMouseLeave={handleLegendMouseLeave}
    >
      {appearances || clusterData || domainGapData ? (
        <Tooltip
          placement="left"
          interactive
          leaveDelay={300}
          PopperProps={{ style: { marginRight: 30 } }}
          title={
            <LabelLegendTooltip
              appearances={appearances}
              clusterData={clusterData}
              domainGapData={domainGapData}
              label={label}
              value={Array.isArray(value) ? label : value}
              icon={icon}
            />
          }
        >
          <div onClick={addFilter}>
            {/* Added div for tooltip - tooltip does not display on pure SVG */}
            {icon}
          </div>
        </Tooltip>
      ) : (
        icon
      )}
      {showNames && (
        <p className={clsx(isHidden && 'opacity-20')}>{displayLabel}</p>
      )}
    </div>
  );
}

export type LabelLegendTooltipProps = {
  appearances?: Map<NumberOrString, number>;
  clusterData?: Record<string, MutualInformationElement[]>;
  domainGapData?: Record<string, number>;
  label: NumberOrString;
  icon: ReactNode;
  value: NumberOrString;
};

export function LabelLegendTooltip({
  appearances,
  clusterData,
  domainGapData,
  label,
  value,
  icon,
}: LabelLegendTooltipProps): JSX.Element {
  return (
    <div className="flex justify-end w-full h-full">
      <div className="flex flex-col w-fit h-fit border-[1px] border-gray-700 rounded-xl bg-gray-800 -m-4 text-sm text-gray-350">
        <div className="flex flex-row align-baseline p-2 items-center gap-1">
          {icon}
          <p>{label}</p>
        </div>
        {clusterData && clusterData[label]?.length && (
          <div className="p-2 border-t-[1px] border-gray-700">
            <CorrelatedMetadataOnCluster infoElements={clusterData[label]} />
          </div>
        )}
        {domainGapData && (
          <div className="p-2 border-t-[1px] border-gray-700">
            {Object.entries(domainGapData).map(([key, value]) => (
              <p key={key} className="text-sm">
                {`Distance to ${key}: ${value}`}
              </p>
            ))}
          </div>
        )}
        {appearances && (
          <p className="p-2 border-t-[1px] border-gray-700">
            Appearances: {appearances.get(value)}
          </p>
        )}
      </div>
    </div>
  );
}

export function calcTruncateLabel(
  label: NumberOrString,
  truncatedLongtail: number,
): NumberOrString {
  if (typeof label === 'string') {
    const slicedLabel = label.slice(0, truncatedLongtail);
    return slicedLabel + (slicedLabel.length < label.length ? '...' : '');
  }
  return label;
}
