import { useSuspenseQuery } from "@apollo/client";
import { ResultOf } from "@graphql-typed-document-node/core";
import _ from "lodash";
import * as Plotly from "plotly.js";
import { useMemo } from "react";
import Plot from "react-plotly.js";

import {
  getCenterTimestamp,
  getDateRange,
  getEndDate,
  getStartDate,
  getValueRange,
  padRange,
} from "@/adapters/monitoring";
import { AnalysisType, ResultPlotFragment, gql } from "@/apis/nannyml";
import { PlotConfig, PlotDataset, usePlotConfig } from "@/components/monitoring/PlotConfig";
import * as colors from "@/constants/colors";
import { PlotElements, PlotType } from "@/constants/enums";
import { WithAlert, WithThreshold, exceedsThreshold } from "@/domains/threshold";
import { alertDetectedLabels, calculatorLabels, getShortResultLabel } from "@/formatters/monitoring";
import { calculateResultThreshold } from "@/hooks/monitoring";
import { DateLike } from "@/lib/dateUtils";

import { selectDataset, getAxisName, getPlotLayout } from "./ResultPlot.utils";

const getTimeSeriesResultsQuery = gql(/* GraphQL */ `
  query GetTimeSeriesResults($modelId: Int!, $filter: [ModelResultsFilter!]) {
    monitoring_model(id: $modelId) {
      results(filter: $filter) {
        __typename
        id
        modelId
        ... on TimeSeriesResult {
          analysisType
          calculatorType
          metricName
          componentName
          columnName
          data {
            isAnalysis
            startTimestamp
            endTimestamp
            value
            lowerConfidenceBound
            upperConfidenceBound
            nrDataPoints
            samplingError
          }
          ...TimeSeriesResultThreshold
        }
      }
    }
  }
`);

type TimeSeriesResultDetails = Extract<
  NonNullable<ResultOf<typeof getTimeSeriesResultsQuery>["monitoring_model"]>["results"][number],
  { __typename: "TimeSeriesResult" }
>;
type TimeSeriesResultType = WithAlert<WithThreshold<TimeSeriesResultDetails>, "data">;
type TimeSeriesDataPointType = TimeSeriesResultType["data"][0];

const augmentResultWithThreshold = (results: TimeSeriesResultDetails[]): TimeSeriesResultType[] =>
  results.map((result) => {
    const threshold = calculateResultThreshold(result);
    return {
      ...result,
      threshold,
      data: result.data.map((dp) => ({ ...dp, hasAlert: exceedsThreshold(dp.value, threshold) })),
    };
  });

export const TimeSeriesResultPlot = ({
  dateRange,
  className,
  results,
  width,
}: {
  dateRange?: [DateLike, DateLike];
  className?: string;
  results: ResultPlotFragment[];
  width?: number;
}) => {
  const config = usePlotConfig();
  const modelId = results[0].modelId;
  if (results.some((r) => r.modelId !== modelId)) {
    throw new Error("All results must belong to the same model");
  }

  const {
    data: { monitoring_model: model },
  } = useSuspenseQuery(getTimeSeriesResultsQuery, {
    variables: {
      modelId: modelId,
      filter: results.map((r) => ({
        calculatorTypes: [r.calculatorType],
        metricNames: [r.metricName],
        componentNames: r.componentName === null ? null : [r.componentName],
        columnNames: r.columnName === null ? null : [r.columnName],
        segments: r.segment?.id ? [r.segment.id] : [null],
      })),
    },
  });

  const resultDetails = useMemo(
    () => augmentResultWithThreshold(model!.results as TimeSeriesResultDetails[]),
    [results]
  );

  // Cache data to prevent re-computation when date range changes
  const data = useMemo(() => getPlotData(resultDetails, config), [resultDetails, config]);

  return (
    <Plot
      className={className}
      data={data}
      layout={getTimeSeriesPlotLayout(resultDetails, config, dateRange, width)}
      config={{ displayModeBar: false }}
    />
  );
};

/**
 * Name generator for comparison plots
 * @param dataset Dataset the trace is generated from
 * @returns A name function that can be provided to `getPlotData` to generate trace names
 */
const getComparisonTraceNameGenerator =
  (dataset?: PlotDataset) => (result: Omit<TimeSeriesResultType, "data">, name: string) => {
    const datasetSuffix = dataset ? ` on ${dataset.toLowerCase()}` : "";
    if (result.analysisType === AnalysisType.EstimatedPerformance) {
      return `Estimated ${name}${datasetSuffix}`;
    } else if (result.analysisType === AnalysisType.RealizedPerformance) {
      return `Realized ${name}${datasetSuffix}`;
    } else {
      return `${name} (${calculatorLabels[result.calculatorType]}${datasetSuffix})`;
    }
  };

const regularTraceOptions: Record<PlotDataset | "any", TraceOptions> = {
  any: {
    nameFn: (_, name) => name,
    metricColor: colors.referenceLineColor,
    confidenceBandColor: colors.referenceConfidenceBandColor,
  },
  [PlotDataset.Reference]: {
    nameFn: (_, name) => `${name} (reference)`,
    metricColor: colors.referenceLineColor,
    confidenceBandColor: colors.referenceConfidenceBandColor,
  },
  [PlotDataset.Analysis]: {
    nameFn: (_, name) => `${name} (analysis)`,
    metricColor: colors.analysisLineColor,
    confidenceBandColor: colors.analysisConfidenceBandColor,
  },
};

const comparisonTraceOptions: Record<PlotDataset | "any", TraceOptions>[] = [
  {
    any: {
      nameFn: getComparisonTraceNameGenerator(),
      metricColor: colors.referenceLineColor,
    },
    [PlotDataset.Reference]: {
      nameFn: getComparisonTraceNameGenerator(PlotDataset.Reference),
      metricColor: colors.referenceLineColor,
    },
    [PlotDataset.Analysis]: {
      nameFn: getComparisonTraceNameGenerator(PlotDataset.Analysis),
      metricColor: colors.analysisLineColor,
    },
  },
  {
    any: {
      nameFn: getComparisonTraceNameGenerator(),
      metricColor: colors.comparisonLineColorReference,
    },
    [PlotDataset.Reference]: {
      nameFn: getComparisonTraceNameGenerator(PlotDataset.Reference),
      metricColor: colors.comparisonLineColorReference,
    },
    [PlotDataset.Analysis]: {
      nameFn: getComparisonTraceNameGenerator(PlotDataset.Analysis),
      metricColor: colors.comparisonLineColorAnalysis,
    },
  },
];

/**
 * Get plot traces for the given result
 * @param results The results to be plotted
 * @param config Plot configuration to be used
 * @returns Plotly traces
 */
const getPlotData = (results: TimeSeriesResultType[], config: PlotConfig): Partial<Plotly.PlotData>[] => {
  const getPlotTraces = plotTraceGenerators[config.type];
  const plotOptions = results.length > 1 ? comparisonTraceOptions : [regularTraceOptions];

  if (!getPlotTraces) {
    throw new Error(`Unsupported plot type: ${config.type}`);
  }

  // Get traces for all datasets
  const traces = results.flatMap((result, resultIdx) =>
    config.datasets.flatMap((dataset, idx) => {
      const data = selectDataset(result, dataset);
      if (!data.length) {
        return [];
      }

      // When using a single plot with line traces, add a connecting line between reference and analysis data
      if (
        !config.subplotPerDataset &&
        config.type === PlotType.Line &&
        config.datasets.includes(PlotDataset.Reference) &&
        config.datasets.includes(PlotDataset.Analysis) &&
        dataset === PlotDataset.Reference
      ) {
        data.push({
          ...selectDataset(result, PlotDataset.Analysis)[0],
          hasAlert: false, // Override alert status to avoid alert being shown on reference
        });
      }

      const options = plotOptions[resultIdx][config.subplotPerDataset ? "any" : dataset];
      const traces = getPlotTraces(result, data, config.elements, options);
      if (config.subplotPerDataset) {
        // If we're using subplots with thresholds, they need to be added to each subplot individually
        if (config.elements.includes(PlotElements.Thresholds)) {
          traces.push(...getThresholdPlotTraces(result, data, config.type));
        }

        // Move traces to a subplot
        traces.forEach((trace) => {
          trace.xaxis = getAxisName("x", idx);
          trace.yaxis = getAxisName("y", idx);
        });
      }
      return traces;
    })
  );

  // When not using subplots, thresholds are added separately to span all datasets
  if (!config.subplotPerDataset && config.elements.includes(PlotElements.Thresholds)) {
    traces.push(
      ...results.flatMap((result) =>
        getThresholdPlotTraces(
          result,
          config.datasets.flatMap((dataset) => selectDataset(result, dataset)),
          config.type
        )
      )
    );
  }

  // Disable duplicate legend entries
  const legendEntries = new Set();
  traces.forEach((trace) => {
    trace.showlegend &&= !legendEntries.has(trace.name);

    if (trace.showlegend) {
      legendEntries.add(trace.name);
    }
  });

  return traces;
};

/**
 * Get plot layout for the given set of results
 * @param results The results to be plotted
 * @param config Plot configuration to be used
 * @param dateRange Optional date range to be used for the plot
 * @param width Optional width to be used for the plot
 * @returns Plotly layout
 */
const getTimeSeriesPlotLayout = (
  results: TimeSeriesResultType[],
  config: PlotConfig,
  dateRange?: [DateLike, DateLike],
  width?: number
): Partial<Plotly.Layout> => {
  // Use default layout as base
  const layout = getPlotLayout(results, config, dateRange, width);

  // When using subplots we want to align the y-axes
  let yRange: [number, number] | undefined = undefined;
  if (config.subplotPerDataset && config.datasets.length > 1) {
    const [minValues, maxValues] = _.zip(...results.map(getValueRange));
    yRange = padRange([_.min(minValues) ?? 0, _.max(maxValues) ?? 1]);
  }

  // Apply title and range for all y-axes
  const yAxes = Object.keys(layout).filter((key) => key.startsWith("yaxis"));
  return _.merge(
    layout,
    yAxes.reduce(
      (acc, axisKey) => ({
        ...acc,
        [axisKey]: {
          range: yRange,
          title: {
            text: getShortResultLabel(results[0]),
          },
        },
      }),
      {}
    )
  );
};

type TraceOptions = {
  confidenceBandColor?: string;
  metricColor?: string;
  nameFn?: (result: Omit<TimeSeriesResultType, "data">, name: string) => string;
};

/**
 * Generates step plot traces for a result
 * @param result Result to be plotted
 * @param data Data points to be plotted
 * @param plotElements Elements to be displayed in the plot
 * @param options Options for configuring the traces
 * @returns A list of Plotly traces
 */
const getStepPlotTraces = (
  result: TimeSeriesResultType,
  data: TimeSeriesDataPointType[],
  plotElements: PlotElements[],
  options?: TraceOptions
): Partial<Plotly.PlotData>[] => {
  const nameFn = options?.nameFn ?? ((_, name) => name);
  const metricColor = options?.metricColor ?? colors.referenceLineColor;
  const [hoverTemplate, getHoverTemplateData] = generatePlotHoverTemplate(result, metricColor);
  const timestamps = data.flatMap((dp) => [dp.startTimestamp, dp.endTimestamp]);
  const centerTimestamps = data.map((dp) => getCenterTimestamp(dp));

  let traces: Partial<Plotly.PlotData>[] = [
    {
      name: nameFn(result, "Metric"),
      legendgroup: "metric",
      mode: "lines",
      line: {
        color: metricColor,
        shape: "hv",
        width: 2,
      },
      marker: {
        color: metricColor,
        size: 5,
        symbol: "circle",
      },
      hoverinfo: "skip",
      showlegend: true,
      x: timestamps,
      y: data.flatMap((dp) => [dp.value!, dp.value!]),
      type: "scatter",
    },
    {
      name: nameFn(result, "Metric"),
      legendgroup: "metric",
      mode: "markers",
      line: {
        color: metricColor,
        shape: "hv",
        width: 2,
      },
      marker: {
        color: metricColor,
        size: 5,
        symbol: "circle",
      },
      hovertemplate: hoverTemplate,
      customdata: data.map(getHoverTemplateData),
      showlegend: false,
      x: centerTimestamps,
      y: data.map((dp) => dp.value!),
      type: "scatter",
    },
  ];

  if (plotElements.includes(PlotElements.ConfidenceBands) && data.some((dp) => dp.lowerConfidenceBound !== null)) {
    traces.push({
      name: nameFn(result, "Confidence band"),
      legendgroup: "confidence",
      mode: "lines",
      line: {
        color: "transparent",
        shape: "hv",
      },
      hoverinfo: "skip",
      showlegend: false,
      x: timestamps,
      y: data.flatMap((dp) => [dp.lowerConfidenceBound!, dp.lowerConfidenceBound!]),
      type: "scatter",
      connectgaps: true,
    });

    traces.push({
      name: nameFn(result, "Confidence band"),
      legendgroup: "confidence",
      mode: "lines",
      line: {
        color: "transparent",
        shape: "hv",
      },
      hoverinfo: "skip",
      fill: "tonexty",
      fillcolor: options?.confidenceBandColor ?? colors.referenceConfidenceBandColor,
      showlegend: true,
      x: timestamps,
      y: data.flatMap((dp) => [dp.upperConfidenceBound!, dp.upperConfidenceBound!]),
      type: "scatter",
      connectgaps: true,
    });
  }

  if (plotElements.includes(PlotElements.Alerts)) {
    const [x, y] = data.reduce<[number[], number[]]>(
      ([x, y], dp, idx) => {
        if (dp.hasAlert) {
          x.push(centerTimestamps[idx]);
          y.push(dp.value!);
        }
        return [x, y];
      },
      [[], []]
    );

    if (x.length) {
      traces.push({
        name: "Alert",
        mode: "markers",
        marker: {
          color: colors.thresholdLineColor,
          size: 8,
          symbol: "diamond",
        },
        hoverinfo: "skip",
        showlegend: true,
        x,
        y,
        type: "scatter",
      });
    }
  }

  return traces;
};

/**
 * Generates line plot traces for a result
 * @param result Result to be plotted
 * @param data Data points to be plotted
 * @param plotElements Elements to be displayed in the plot
 * @param options Options for configuring the plot
 * @returns A list of Plotly traces
 */
const getLinePlotTraces = (
  result: TimeSeriesResultType,
  data: TimeSeriesDataPointType[],
  plotElements: PlotElements[],
  options?: TraceOptions
): Partial<Plotly.PlotData>[] => {
  const nameFn = options?.nameFn ?? ((_, name) => name);
  const metricColor = options?.metricColor ?? colors.referenceLineColor;
  const [hoverTemplate, getHoverTemplateData] = generatePlotHoverTemplate(result, metricColor);
  const startTimestamps = data.map((dp) => dp.startTimestamp);

  let traces: Partial<Plotly.PlotData>[] = [
    {
      name: nameFn(result, "Metric"),
      legendgroup: "metric",
      mode: "lines+markers",
      line: {
        color: metricColor,
      },
      marker: {
        color: metricColor,
        size: 5,
        symbol: "circle",
      },
      hovertemplate: hoverTemplate,
      customdata: data.map(getHoverTemplateData),
      showlegend: true,
      x: startTimestamps,
      y: data.map((dp) => dp.value!),
      type: "scatter",
    },
  ];

  if (plotElements.includes(PlotElements.ConfidenceBands) && data.some((dp) => dp.lowerConfidenceBound !== null)) {
    traces.push(
      {
        name: nameFn(result, "Confidence band"),
        legendgroup: "confidence",
        mode: "lines",
        line: {
          color: "transparent",
        },
        hoverinfo: "skip",
        showlegend: false,
        x: startTimestamps,
        y: data.map((dp) => dp.lowerConfidenceBound!),
        type: "scatter",
        connectgaps: true,
      },
      {
        name: nameFn(result, "Confidence band"),
        legendgroup: "confidence",
        mode: "none",
        hoverinfo: "skip",
        fill: "tonexty",
        fillcolor: options?.confidenceBandColor ?? colors.referenceConfidenceBandColor,
        showlegend: true,
        x: startTimestamps,
        y: data.map((dp) => dp.upperConfidenceBound!),
        type: "scatter",
        connectgaps: true,
      }
    );
  }

  if (plotElements.includes(PlotElements.Alerts)) {
    const alertData = data.filter((dp) => dp.hasAlert);

    if (alertData.length) {
      traces.push({
        name: "Alert",
        mode: "markers",
        marker: {
          color: colors.thresholdLineColor,
          size: 8,
          symbol: "diamond",
        },
        hoverinfo: "skip",
        showlegend: true,
        x: alertData.map((dp) => dp.startTimestamp),
        y: alertData.map((dp) => dp.value!),
        type: "scatter",
      });
    }
  }

  return traces;
};

/**
 * Map of plot types to their corresponding trace generators
 */
const plotTraceGenerators: Partial<Record<PlotType, typeof getStepPlotTraces>> = {
  [PlotType.Step]: getStepPlotTraces,
  [PlotType.Line]: getLinePlotTraces,
};

/**
 * Generates plot traces for a result's thresholds
 * @param result The result to be plotted
 * @param data Data points for which the threshold is to be determined
 * @param plotType The type of plot to be generated
 * @returns A list of Plotly traces
 */
const getThresholdPlotTraces = (
  result: TimeSeriesResultType,
  data: TimeSeriesDataPointType[],
  plotType: PlotType
): Partial<Plotly.PlotData>[] => {
  if (!data.length) {
    return [];
  }

  const traces: Partial<Plotly.PlotData>[] = [];
  const x = getDateRange(data, plotType);

  if (result.threshold.lower !== null) {
    traces.push({
      name: "Threshold",
      legendgroup: "thresholds",
      mode: "lines",
      line: {
        color: colors.thresholdLineColor,
        dash: "dash",
        width: 1,
      },
      hoverinfo: "skip",
      showlegend: true,
      x: x,
      y: [result.threshold.lower, result.threshold.lower],
      type: "scatter",
    });
  }

  if (result.threshold.upper !== null) {
    traces.push({
      name: "Threshold",
      legendgroup: "thresholds",
      mode: "lines",
      line: {
        color: colors.thresholdLineColor,
        dash: "dash",
        width: 1,
      },
      hoverinfo: "skip",
      showlegend: result.threshold.lower === null,
      x: x,
      y: [result.threshold.upper, result.threshold.upper],
      type: "scatter",
    });
  }

  return traces;
};

/**
 * Generates Plotly hover template for a result
 * @param result Result to generate hover information for
 * @returns A template string and a function to generate hover `customdata` input
 */
const generatePlotHoverTemplate = (
  result: TimeSeriesResultType,
  metricColor: string
): [string, (dataPoint: TimeSeriesDataPointType) => (string | number)[]] => [
  `<b style="color:${metricColor}">%{customdata[1]}</b> &nbsp; &nbsp; %{customdata[2]}<br />
    Chunk: <b>%{customdata[3]} - %{customdata[4]}</b> (%{customdata[5]} rows)<br />
    %{customdata[0]}: <b>%{customdata[6]:.4f}</b><br />
    Confidence band: <b>± %{customdata[7]}</b>
    <extra></extra>`,
  (dataPoint) => [
    getShortResultLabel(result),
    dataPoint.isAnalysis ? "Analysis" : "Reference",
    dataPoint.hasAlert
      ? `<b style="color:${colors.alertColor}">⚠️ ${alertDetectedLabels[result.calculatorType]}</b>`
      : "",
    getStartDate(dataPoint).toLocaleDateString(),
    getEndDate(dataPoint).toLocaleDateString(),
    dataPoint.nrDataPoints,
    dataPoint.value!,
    dataPoint.samplingError?.toFixed(4) ?? "undefined",
  ],
];
