import {
  Badge,
  Collapse,
  Dialog,
  DialogTitle,
  IconButton,
  Tooltip,
} from '@material-ui/core';
import {
  DistributionType,
  EpochData,
  GenericDataItem,
  OrderType,
  Session,
} from '@tensorleap/api-client';
import { useCallback, useEffect, useMemo, useState } from 'react';
import { useCreateFieldHelper } from '../ui/model-list/utils';
import { Table } from '../ui/model-list/table/Table';
import api from '../core/api-client';
import {
  Check1,
  CloudDownload2,
  CsvIcon,
  ImportIcon,
  Plus,
  SmallFilterEmpty,
  SmallFilterFilled,
  XClose,
} from '../ui/icons';
import { Chip } from '../ui/atoms/Chip';
import { sort, Sorter, useSort } from '../ui/model-list/sorter';
import { ModelFields } from '../ui/model-list/types';
import { useEnvironmentInfo } from '../core/EnvironmentInfoContext';
import download from 'downloadjs';
import { unparse } from 'papaparse';
import { DisplayMultiMetadata } from '../dashboard/dashlet/common/DisplayMetadata';
import { useToggle } from '../core/useToggle';
import clsx from 'clsx';
import SvgDown from '../ui/icons/Down';
import { format } from 'date-fns';
import { useFetchSessionsEpochs } from '../core/data-fetching/session-epochs';
import { useVersionControl } from '../core/VersionControlContext';
import { IconMenu } from '../ui/atoms/IconMenu';
import { useMapIdToKey } from '../core/useUniqueKeyManager';
import {
  BG_900_COLOR_TEMPLATES,
  BG_COLOR_TEMPLATES,
  BG_COLOR_TEMPLATES_HEX,
  CHIP_COLOR_TEMPLATES,
  GENTLE_BG_COLOR_TEMPLATES,
} from '../ui/molecules/colorTemplates';
import { VisibleOrNonVisibleChipIcon } from '../ui/molecules/ModelChip';
import { Input } from '../ui/atoms/Input';
import { truncateLongtail } from '../core/formatters/string-formatting';
import { PlaceholderChip } from '../ui/atoms/PlaceholderChip';
import { LineChart } from '../ui/charts/visualizers/LineChart';
import { sortBy } from 'lodash';
import { NoDataChart } from '../ui/charts/common/NoDataChart';
import { MultiSelectIconMenu } from '../ui/atoms/MultiSelectIconMenu';

type SessionEpochsDialogProps = {
  session: Session;
  onClose: () => void;
};

type EpochRow = {
  sessionId: string;
  epoch: number;
  tags: string[];
  uploadedModelPath?: string;
  imported: boolean;
};

export function SessionInfoDialog({
  session,
  onClose,
}: SessionEpochsDialogProps): JSX.Element {
  const [selectedSessions, setSelectedSessions] = useState<Session[]>([
    session,
  ]);
  const [hiddenIds, setHiddenIds] = useState(new Set<string>());
  const filterSelectedSessions = useMemo(
    () => selectedSessions.filter((s) => !hiddenIds.has(s.cid)),
    [selectedSessions, hiddenIds]
  );

  const toggleHiddenId = useCallback(
    (id: string) => {
      const newHiddenIds = new Set(hiddenIds);
      if (hiddenIds.has(id)) {
        newHiddenIds.delete(id);
      } else {
        newHiddenIds.add(id);
      }
      setHiddenIds(newHiddenIds);
    },
    [hiddenIds]
  );
  const { get, updateCurrentIds } = useMapIdToKey();
  const getSessionColorIndex = useCallback(
    (sessionId: string) => get(sessionId) || 0,
    [get]
  );

  useEffect(() => {
    const ids = selectedSessions.map((s) => s.cid);
    updateCurrentIds(ids);
  }, [selectedSessions, updateCurrentIds]);

  const filteredSessionIds = useMemo(
    () => filterSelectedSessions.map((s) => s.cid),
    [filterSelectedSessions]
  );
  const { epochs = [], projectId } = useFetchSessionsEpochs(filteredSessionIds);

  return (
    <Dialog
      open={true}
      onClose={onClose}
      maxWidth="xl"
      fullWidth
      scroll="paper"
      classes={{ paper: 'bg-gray-850 h-full' }}
    >
      <DialogTitle>
        <SessionInfoHeader
          onClose={onClose}
          value={selectedSessions}
          onChange={setSelectedSessions}
          getColorIndex={getSessionColorIndex}
          hiddenIds={hiddenIds}
          toggleHiddenId={toggleHiddenId}
        />
      </DialogTitle>
      <SessionsContent
        epochs={epochs}
        getColorIndex={getSessionColorIndex}
        sessions={filterSelectedSessions}
        projectId={projectId}
      />
    </Dialog>
  );
}

type SessionsContentProps = {
  projectId: string;
  sessions: Session[];
  epochs: EpochData[];
  getColorIndex: (sessionId: string) => number;
};

export function SessionsContent({
  epochs,
  projectId,
  sessions,
  getColorIndex,
}: SessionsContentProps) {
  const { textField, numericField } = useCreateFieldHelper<EpochRow>();

  const sessionById = useMemo(
    () => new Map(sessions.map((session) => [session.cid, session])),
    [sessions]
  );
  const [hiddenMetrics, setHiddenMetrics] = useState<string[]>([]);
  const { rows, fields, metricsTypes } = useMemo(() => {
    const fields = [
      textField('sessionId', {
        label: '-',
        format: (sessionId) => (
          <Tooltip
            title={
              <SessionChip
                showFullName
                colorKey={getColorIndex(sessionId as string)}
                session={sessionById.get(sessionId as string) as Session}
              />
            }
          >
            <div
              className={clsx(
                BG_COLOR_TEMPLATES[getColorIndex(sessionId as string)],
                'h-2 w-2'
              )}
            />
          </Tooltip>
        ),
        table: { align: 'left', width: '20px' },
        sortable: { level: 'primary' },
      }),
      numericField('epoch', {
        label: '#',
        format: (epoch) => epoch,
        sortable: { level: 'secondary' },
        table: { align: 'left', width: '20px' },
      }),
      textField('imported', {
        label: 'Imported',
        format: (imported) => (imported ? <Check1 /> : '-'),
        sortable: true,
      }),
      textField('tags', {
        label: 'Tags',
        format: (tags) =>
          Array.isArray(tags) && tags.length ? (
            <>
              {tags.map((tag: string) => (
                <Chip className="h-6 text-center" key={tag}>
                  {tag}
                </Chip>
              ))}
            </>
          ) : (
            '-'
          ),
      }),
    ] as ModelFields<EpochRow>;

    const metricsTypes = new Map<string, 'number' | 'string' | 'image'>();

    epochs.forEach((epoch) => {
      Object.entries(epoch.externalData?.metrics || {}).forEach(
        ([key, { type }]) => {
          if (key in metricsTypes) {
            const existingType = metricsTypes.get(key);
            if (existingType !== type) {
              console.warn('Inconsistent metric type', key, existingType, type);
            }
          }

          metricsTypes.set(key, type);
        }
      );
    });

    const metricKeys = Array.from(metricsTypes.keys());

    metricKeys.forEach((metricKey) => {
      const key = `metrics_${metricKey}`;
      const type = metricsTypes.get(metricKey) || 'string';
      const buildFieldFunc = type === 'number' ? numericField : textField;
      fields.push(
        // eslint-disable-next-line @typescript-eslint/no-explicit-any
        (buildFieldFunc as (key: string, o: unknown) => any)(key, {
          label: metricKey,
          format: (value: unknown) => (value === undefined ? '- ' : value),
          sortable: true,
        })
      );
    });

    const rows = epochs.map((epoch) => {
      const row: EpochRow = {
        sessionId: epoch.sessionId,
        epoch: epoch.epoch,
        tags: epoch.tags,
        uploadedModelPath: epoch.uploadedModelFilePath,
        imported: !!epoch.weightsData,
      };

      const unknownRow = row as Record<string, unknown>;

      metricKeys.forEach((metricKey) => {
        unknownRow[`metrics_${metricKey}`] =
          epoch.externalData?.metrics[metricKey]?.value;
      });

      return unknownRow as EpochRow;
    });
    return { rows, fields, metricsTypes };
  }, [epochs, getColorIndex, numericField, sessionById, textField]);

  const sorter = useSort(fields);

  const sortedData = useMemo(() => sort(rows, sorter.sortBy), [
    sorter.sortBy,
    rows,
  ]);

  const downloadCsv = useCallback(() => {
    const data = sortedData.map((row) => {
      const session = sessionById.get(row.sessionId);
      const sessionName = session
        ? `${session.modelName} - ${formatDateAndTime(session.createdAt)}`
        : '';

      const newRow: Record<string, unknown> = {
        session: sessionName,
        ...row,
      };
      delete newRow.uploadedModelPath;
      delete newRow.imported;
      delete newRow.sessionId;

      return newRow;
    });
    const csv = unparse(data);
    download(csv, 'epochs.csv', 'text/csv');
  }, [sortedData, sessionById]);

  const sessionsProperties = sessions.flatMap((session) => {
    if (session.properties) {
      return {
        data: session.properties,
        bgClass: GENTLE_BG_COLOR_TEMPLATES[getColorIndex(session.cid)],
        title: (
          <SessionChip
            session={session}
            colorKey={getColorIndex(session.cid)}
          />
        ),
      };
    }
    return [];
  });

  return (
    <div
      className="grid max-h-full overflow-hidden"
      style={{
        gridTemplateRows: 'minmax(2fr,auto) minmax(1fr,auto) minmax(1fr,auto)',
      }}
    >
      <CollapseContent
        defaultOpen={true}
        label="Metrics"
        actions={
          <SelectMetrics
            metricsTypes={metricsTypes}
            hiddenMetrics={hiddenMetrics}
            onHiddenMetricsChange={setHiddenMetrics}
          />
        }
      >
        <EpochsCharts
          hiddenMetrics={hiddenMetrics}
          className="overflow-auto"
          metricsTypes={metricsTypes}
          epochs={epochs}
          getColorIndex={getColorIndex}
          sessions={sessions}
        />
      </CollapseContent>
      <CollapseContent
        defaultOpen={true}
        label="Epochs"
        actions={
          <Tooltip title="Download CSV" placement="top">
            <IconButton onClick={downloadCsv} className="p-2">
              <CsvIcon />
            </IconButton>
          </Tooltip>
        }
      >
        <EpochsTable
          projectId={projectId}
          rows={sortedData}
          getColorIndex={getColorIndex}
          sorter={sorter}
          fields={fields}
        />
      </CollapseContent>

      <CollapseContent defaultOpen={true} label="Global Properties">
        <DisplayMultiMetadata
          className="overflow-y-auto"
          smallKey
          content={sessionsProperties}
          keysLabel="Properties"
        />
      </CollapseContent>
    </div>
  );
}

function CollapseContent({
  children,
  defaultOpen,
  label,
  actions,
}: {
  defaultOpen: boolean;
  label: string;
  children: React.ReactNode;
  actions?: React.ReactNode;
}) {
  const [open, toggle] = useToggle(defaultOpen);

  return (
    <div className="flex flex-col overflow-hidden">
      <div
        className="font-semibold h-12 cursor-pointer hover:bg-primary-925 flex items-center justify-between px-2 py-2 bg-gray-800"
        onClick={toggle}
      >
        <div className="uppercase flex items-center gap-2">
          <SvgDown
            className={clsx('transition-transform', !open && '-rotate-180')}
          />
          {label}:
        </div>
        {actions && <div onClick={(e) => e.stopPropagation()}>{actions}</div>}
      </div>
      <Collapse
        disableStrictModeCompat
        classes={{
          wrapper: 'overflow-hidden max-h-full max-w-full',
          wrapperInner: 'overflow-auto',
        }}
        in={open}
      >
        {children}
      </Collapse>
    </div>
  );
}

function formatDateAndTime(date: Date): string {
  return format(date, 'HH:mm dd/MM');
}

function SessionInfoHeader({
  value,
  onChange,
  hiddenIds,
  toggleHiddenId,
  getColorIndex,
  onClose,
}: {
  value: Session[];
  onChange: (value: Session[]) => void;
  hiddenIds: Set<string>;
  toggleHiddenId: (value: string) => void;
  getColorIndex: (sessionId: string) => number;
  onClose: () => void;
}) {
  const { versions } = useVersionControl();

  const selectedSessionIds = useMemo(() => new Set(value.map((s) => s.cid)), [
    value,
  ]);

  const allSessions = useMemo(() => versions.flatMap((v) => v.sessions), [
    versions,
  ]);

  const remove = useCallback(
    (sessionId: string) => {
      onChange(value.filter((s) => s.cid !== sessionId));
    },
    [value, onChange]
  );

  const [query, setQueryParam] = useState('');

  const filteredSessions = useMemo(
    () =>
      allSessions.filter(
        (session) =>
          session.hasExternalEpoch &&
          (session.modelName.toLowerCase().includes(query.toLowerCase()) ||
            session.cid.toLowerCase().includes(query.toLowerCase()))
      ),
    [allSessions, query]
  );

  return (
    <div className="flex flex-col w-full">
      <div className="flex items-center gap-2">
        <span className=" text-gray-300 uppercase">experiment tracking </span>
        <Tooltip title="Add experiment">
          <div>
            <IconMenu
              transformOrigin={{ horizontal: 'center', vertical: 'bottom' }}
              iconWrapperClassName="h-8 w-8"
              icon={<Plus className="h-4 w-4" />}
            >
              <Input
                value={query}
                label="Search"
                className="!w-80"
                clean
                onChange={(e) => setQueryParam(e.target.value)}
              />
              {filteredSessions.map((session) => {
                const isAlreadySelected = selectedSessionIds.has(session.cid);
                return (
                  <div
                    key={session.cid}
                    onClick={() => {
                      const newSessions = isAlreadySelected
                        ? value.filter((s) => s.cid !== session.cid)
                        : [...value, session];
                      onChange(newSessions);
                    }}
                    className={clsx(
                      isAlreadySelected
                        ? BG_900_COLOR_TEMPLATES[getColorIndex(session.cid)]
                        : 'hover:bg-primary-925',
                      'flex items-center gap-2 p-2 cursor-pointer '
                    )}
                  >
                    <span className="font-bold">{session.modelName}</span>
                    <span className="text-sm text-gray-300">
                      {formatDateAndTime(session.createdAt)}
                    </span>
                  </div>
                );
              })}
              {!filteredSessions.length && (
                <span className="text-gray-300 p-2">No experiment found</span>
              )}
            </IconMenu>
          </div>
        </Tooltip>
        <span className="flex-1" />
        <IconButton className="w-10 h-10 -mr-2" onClick={onClose}>
          <XClose className="h-6 w-6" />
        </IconButton>
      </div>
      <div className="flex gap-2 flex-wrap">
        {value.length ? (
          value.map((session) => (
            <SessionChip
              key={session.cid}
              colorKey={getColorIndex(session.cid)}
              session={session}
              onRemove={remove}
              visible={!hiddenIds?.has(session.cid)}
              onToggleVisible={toggleHiddenId}
            />
          ))
        ) : (
          <PlaceholderChip className="h-7 text-sm uppercase">
            no experiment selected
          </PlaceholderChip>
        )}
      </div>
    </div>
  );
}

export function SessionChip({
  colorKey,
  session,
  onRemove,
  onToggleVisible,
  visible,
  showFullName,
}: {
  colorKey: number;
  onRemove?: (id: string) => void;
  session: Session;
  visible?: boolean;
  showFullName?: boolean;
  onToggleVisible?: (id: string) => void;
}) {
  const handleToggle = useCallback(() => {
    onToggleVisible?.(session.cid);
  }, [onToggleVisible, session.cid]);

  const colorfulChipStyle = CHIP_COLOR_TEMPLATES[colorKey];

  return (
    <Chip
      borderClassName={
        visible !== false
          ? colorfulChipStyle
          : 'border-solid border-gray-700 bg-gray-800'
      }
      key={session.cid}
      onRemove={onRemove ? () => onRemove(session.cid) : undefined}
    >
      <span className="font-bold text-sm pr-2">
        {showFullName
          ? session.modelName
          : truncateLongtail({
              value: session.modelName,
              startSubsetLength: 15,
              endSubsetLength: 15,
            })}
      </span>
      <span className="text-sm text-gray-300">
        {formatDateAndTime(session.createdAt)}
      </span>
      {visible !== undefined && onToggleVisible && (
        <VisibleOrNonVisibleChipIcon visible={visible} toggle={handleToggle} />
      )}
    </Chip>
  );
}

type EpochsTableProps = {
  projectId: string;
  rows: EpochRow[];
  fields: ModelFields<EpochRow>;
  getColorIndex: (sessionId: string) => number;
  sorter: Sorter<EpochRow>;
};

function EpochsTable({
  projectId,
  rows,
  fields,
  getColorIndex,
  sorter,
}: EpochsTableProps) {
  const filteredEpoch = useMemo(
    () =>
      rows.filter(
        (epoch) => !!(epoch as EpochRow).tags?.length || epoch.imported
      ),
    [rows]
  );

  const {
    environmentInfo: { clientStoragePrefixUrl },
  } = useEnvironmentInfo();

  const hoverActions = useMemo(
    () => [
      {
        title: 'Download uploaded model',
        icon: <CloudDownload2 className="mt-1" />,
        filter: (row: EpochRow) => !!row.uploadedModelPath,
        onSelect: async (row: EpochRow) => {
          const path = row.uploadedModelPath;
          const url = `${clientStoragePrefixUrl}/${path}`;
          if (path) {
            window.open(url, '_blank')?.focus();
          }
        },
      },
      {
        title: `Import uploaded model`,
        icon: <ImportIcon className="mt-1" />,
        filter: (row: EpochRow) => !!row.uploadedModelPath && !row.imported,
        onSelect: async (row: EpochRow) => {
          const epoch = row.epoch;
          const isOnnx = row.uploadedModelPath?.endsWith('.onnx');
          await api.importExternalModel({
            projectId,
            sessionId: row.sessionId,
            epoch,
            transformInputs: isOnnx,
          });
        },
      },
    ],
    [projectId, clientStoragePrefixUrl]
  );

  return (
    <Table
      className="overflow-auto"
      inline
      fields={fields}
      data={filteredEpoch}
      hoverActions={hoverActions}
      actionPosition="start"
      hoverActionsSnapToClass="left-[calc(100%+30px)]"
      bgRowClass={(row) =>
        `${
          GENTLE_BG_COLOR_TEMPLATES[getColorIndex((row as EpochRow).sessionId)]
        }  hover:bg-primary-925`
      }
      sorter={sorter}
    />
  );
}

type EpochsChartsProps = {
  metricsTypes: Map<string, 'number' | 'string' | 'image'>;
  epochs: EpochData[];
  getColorIndex: (sessionId: string) => number;
  sessions: Session[];
  className?: string;
  hiddenMetrics?: string[];
};

function EpochsCharts({
  metricsTypes,
  epochs,
  getColorIndex,
  sessions,
  className,
  hiddenMetrics,
}: EpochsChartsProps) {
  const sessionById = useMemo(
    () => new Map(sessions.map((session) => [session.cid, session])),
    [sessions]
  );

  const graphColorMap = useMemo(() => {
    const colorMap = Object.fromEntries(
      sessions.map((session) => {
        const sessionName = `${session.modelName} - ${formatDateAndTime(
          session.createdAt
        )}`;
        return [
          sessionName,
          BG_COLOR_TEMPLATES_HEX[getColorIndex(session.cid)] as string,
        ];
      })
    );
    return colorMap;
  }, [sessions, getColorIndex]);

  const chartsData = useMemo(() => {
    const hiddenMetricsSet = new Set(hiddenMetrics);
    const metricKeys = Array.from(metricsTypes.keys());
    const chartsData = metricKeys.flatMap((metricKey) => {
      const metricType = metricsTypes.get(metricKey) || '';
      if (metricType !== 'number' || hiddenMetricsSet.has(metricKey)) {
        return [];
      }
      let data = epochs.flatMap((epoch) => {
        const value = epoch.externalData?.metrics[metricKey]?.value;
        if (value === undefined) {
          return [];
        }
        const sessionId = epoch.sessionId;
        const session = sessionById.get(sessionId);
        const sessionName = session
          ? `${session.modelName} - ${formatDateAndTime(session.createdAt)}`
          : '';

        const item: GenericDataItem = {
          innerKey: sessionName,
          data: {
            epoch: epoch.epoch,
            [metricKey]: value,
            sessionName: sessionName,
          },
        };
        return item;
      });

      data = sortBy(data, (item) => item.data.epoch);
      const Comp = LineChart;
      const axisType = 'number';
      const chart = (
        <Comp
          graphData={{ data }}
          hiddenLabels={[]}
          hoverLabel={metricKey}
          showXAxisLine
          showYLabel
          xAxisDomain={undefined}
          yAxisDomain={undefined}
          colorMap={graphColorMap}
          chartRequestData={{
            xField: 'epoch',
            yField: metricKey,
            innerSplit: {
              distribution: DistributionType.Distinct,
              field: 'sessionName',
              order: OrderType.Asc,
              limit: null,
            },
          }}
          mapValue={(value) => value}
          axisType={axisType}
        />
      );

      return {
        metricKey,
        chart,
      };
    });

    return chartsData;
  }, [epochs, sessionById, metricsTypes, graphColorMap, hiddenMetrics]);

  return (
    <div
      className={clsx('grid p-4 grid-cols-2 gap-4 overflow-auto', className)}
    >
      {chartsData.map(({ metricKey, chart }) => (
        <div key={metricKey} className="h-60 p-2 bg-gray-900 rounded-lg">
          {chart}
        </div>
      ))}

      {chartsData.length === 0 && (
        <NoDataChart className="col-span-2" text="NO DATA FOUND" />
      )}
    </div>
  );
}

type SelectMetricsProps = {
  metricsTypes: Map<string, 'number' | 'string' | 'image'>;
  hiddenMetrics: string[];
  onHiddenMetricsChange: (hiddenMetrics: string[]) => void;
};

function SelectMetrics({
  metricsTypes,
  hiddenMetrics,
  onHiddenMetricsChange,
}: SelectMetricsProps) {
  const numericMetrics = useMemo(
    () =>
      Array.from(metricsTypes.entries())
        .filter(([_, type]) => type === 'number')
        .map(([key]) => key),
    [metricsTypes]
  );

  const hiddenMetricsCount = useMemo(
    () => hiddenMetrics.filter((metric) => metricsTypes.has(metric)).length,
    [hiddenMetrics, metricsTypes]
  );

  const filterLabelIcon = (
    <Tooltip
      title={
        hiddenMetricsCount
          ? `${hiddenMetricsCount} hidden(s)`
          : 'Filter metrics '
      }
    >
      <div>
        {hiddenMetricsCount ? (
          <Badge badgeContent={hiddenMetricsCount} color="primary">
            <SmallFilterFilled />
          </Badge>
        ) : (
          <SmallFilterEmpty />
        )}
      </div>
    </Tooltip>
  );

  return (
    <MultiSelectIconMenu
      iconWrapperClassName="p-2 h-8 w-8"
      icon={filterLabelIcon}
      invertedSelection
      noOptionsText="No metrics"
      value={hiddenMetrics}
      options={numericMetrics}
      onChange={onHiddenMetricsChange}
    />
  );
}
