import { ResultOf } from "@graphql-typed-document-node/core";
import _ from "lodash";
import * as Plotly from "plotly.js";
import { useMemo } from "react";
import Plot from "react-plotly.js";

import { getCenterTimestamp, getEndDate, getStartDate } from "@/adapters/monitoring";
import { FragmentType, gql, useFragment } from "@/apis/nannyml";
import { PlotConfig, PlotDataset, usePlotConfig } from "@/components/monitoring/PlotConfig";
import * as colors from "@/constants/colors";
import { PlotElements, PlotType } from "@/constants/enums";
import { DateLike } from "@/lib/dateUtils";

import { getAxisName, getPlotLayout, selectDataset } from "./ResultPlot.utils";

export const valueCountDistributionResultDetails = gql(/* GraphQL */ `
  fragment ValueCountDistributionResultDetails on ValueCountDistributionResult {
    column: columnName
    data: chunks {
      isAnalysis
      startTimestamp
      endTimestamp
      nrDataPoints

      # Rename to work around name conflict
      catData: data {
        value
        density
        count
      }
    }
  }
`);

type ValueCountDistributionResultDetails = ResultOf<typeof valueCountDistributionResultDetails>;

export const ValueCountDistributionPlot = ({
  dateRange,
  className,
  result: resultFragment,
  width,
  alerts,
}: {
  dateRange?: [DateLike, DateLike];
  className?: string;
  result: FragmentType<typeof valueCountDistributionResultDetails>;
  alerts?: (Boolean | null)[];
  width?: number;
}) => {
  const result = useFragment(valueCountDistributionResultDetails, resultFragment);
  const config = usePlotConfig();

  // Cache data to prevent re-computation when date range changes
  const data = useMemo(() => getPlotData(result, config, alerts), [result, config, alerts]);

  return (
    <Plot
      className={className}
      data={data}
      layout={getValueCountDistributionPlotLayout(result, config, dateRange, width)}
      config={{ displayModeBar: false }}
    />
  );
};

/**
 * Get plot traces for the given result
 * @param result The results to be plotted
 * @param config Plot configuration to be used
 * @returns Plotly traces
 */
const getPlotData = (
  result: ValueCountDistributionResultDetails,
  config: PlotConfig,
  alerts?: (Boolean | null)[]
): Partial<Plotly.PlotData>[] => {
  if (config.type !== PlotType.Distribution) {
    throw new Error(`Unsupported plot type: ${config.type}`);
  }

  // Merge alert info into data
  result = {
    ...result,
    data: result.data.map((chunk, idx) => ({
      ...chunk,
      hasAlert: alerts?.[idx],
    })),
  };

  // Prepare color generator function so that values receive the same color across datasets
  const colorMap: Record<string, number[]> = {};
  const generateColor = (value: string, opacity: number): string => {
    if (!colorMap[value]) {
      colorMap[value] = colors.colorPalette[Object.keys(colorMap).length % colors.colorPalette.length];
    }

    return `rgba(${colorMap[value].concat(opacity).join(",")})`;
  };

  // Get traces for all datasets
  const traces = config.datasets.flatMap((dataset, idx) => {
    const data = selectDataset(result, dataset);
    if (!data.length) {
      return [];
    }

    const traceOption = traceOptions[config.subplotPerDataset ? "any" : dataset];
    const traces = getPlotTraces(result, data, config.elements, traceOption, generateColor);
    if (config.subplotPerDataset) {
      // Move traces to a subplot
      traces.forEach((trace) => {
        trace.xaxis = getAxisName("x", idx);
        trace.yaxis = getAxisName("y", idx);
      });
    }

    return traces;
  });

  // Disable duplicate legend entries
  const legendEntries = new Set();
  traces.forEach((trace) => {
    trace.showlegend &&= !legendEntries.has(trace.name);

    if (trace.showlegend) {
      legendEntries.add(trace.name);
    }
  });

  return traces;
};

type TraceOptions = {
  traceColor: string;
  fillColor: string;
  alertColor: string;
  opacity: number;
};

const traceOptions: Record<PlotDataset | "any", TraceOptions> = {
  any: {
    traceColor: colors.referenceLineColor,
    fillColor: colors.referenceConfidenceBandColor,
    alertColor: colors.alertColor,
    opacity: 0.9,
  },
  [PlotDataset.Analysis]: {
    traceColor: colors.analysisLineColor,
    fillColor: colors.analysisConfidenceBandColor,
    alertColor: colors.alertColor,
    opacity: 0.9,
  },
  [PlotDataset.Reference]: {
    traceColor: colors.referenceLineColor,
    fillColor: colors.referenceConfidenceBandColor,
    alertColor: colors.alertColor,
    opacity: 0.6,
  },
};

/**
 * Generates traces for a value count distribution result
 * @param result Result to be plotted
 * @param data Data points to be plotted
 * @param plotElements Elements to be displayed in the plot
 * @param options Options for configuring the plot
 * @returns A list of Plotly traces
 */
const getPlotTraces = (
  result: Omit<ValueCountDistributionResultDetails, "data">,
  data: (ValueCountDistributionResultDetails["data"][0] & { hasAlert?: Boolean })[],
  plotElements: PlotElements[],
  options: TraceOptions,
  generateColor: (value: string, opacity: number) => string
): Partial<Plotly.PlotData>[] => {
  const [hoverTemplate, getHoverTemplateData] = generatePlotHoverTemplate(result, options.traceColor);

  const traceMap: { [value: string]: any } = {};
  data.forEach((chunk) => {
    chunk.catData.forEach((dp) => {
      const trace = (traceMap[dp.value] =
        traceMap[dp.value] ??
        ({
          name: dp.value,
          legendgroup: dp.value,
          marker: {
            color: generateColor(dp.value, options.opacity),
          },
          hovertemplate: hoverTemplate,
          customdata: [],
          showlegend: true,
          x: [],
          y: [],
          type: "bar",
        } as Partial<Plotly.PlotData>));

      trace.customdata.push(getHoverTemplateData(chunk, dp));
      trace.x.push(getCenterTimestamp(chunk));
      trace.y.push(dp.density);
    });
  });

  const traces = Object.values(traceMap);

  if (plotElements.includes(PlotElements.Alerts) && data.some((chunk) => chunk.hasAlert)) {
    const alertTimestamps = data.flatMap((chunk) => (chunk.hasAlert ? [getCenterTimestamp(chunk)] : []));
    if (alertTimestamps.length) {
      // For the alerts to align with the data bars, it needs at least 2 data points with distinct X values. Adding
      // two hidden dummy datapoints to accomplish that
      traces.push({
        name: "Alerts",
        legendgroup: "alerts",
        marker: {
          color: "transparent",
          line: {
            color: options.alertColor,
            width: Array(alertTimestamps.length).fill(4).concat(0, 0),
          },
        },
        hoverinfo: "skip",
        showlegend: true,
        x: alertTimestamps.concat(getCenterTimestamp(data[0]), getCenterTimestamp(data[1])),
        y: Array(alertTimestamps.length + 2).fill(1),
        base: 0,
        type: "bar",
      });
    }
  }

  return traces;
};

/**
 * Generates Plotly hover template for a result
 * @param result Result to generate hover information for
 * @returns A template string and a function to generate hover `customdata` input
 */
const generatePlotHoverTemplate = (
  result: Pick<ValueCountDistributionResultDetails, "column">,
  metricColor: string
): [
  string,
  (
    chunk: ValueCountDistributionResultDetails["data"][0] & { hasAlert?: Boolean },
    dp: ValueCountDistributionResultDetails["data"][0]["catData"][0]
  ) => Plotly.Datum[]
] => [
  `<b style="color:${metricColor}">%{customdata[1]}</b> &nbsp; &nbsp; %{customdata[2]}<br />
    Chunk: <b>%{customdata[3]} - %{customdata[4]}</b> (%{customdata[5]} rows)<br />
    %{customdata[0]}: <b>%{customdata[7]}%</b> has value %{customdata[6]} (%{customdata[8]} instances)<br />
    <extra></extra>`,
  (chunk, dp) => [
    result.column,
    chunk.isAnalysis ? "Analysis" : "Reference",
    chunk.hasAlert ? `<b style="color:${colors.alertColor}">⚠️ Drift detected</b>` : "",
    getStartDate(chunk).toLocaleDateString(),
    getEndDate(chunk).toLocaleDateString(),
    chunk.nrDataPoints,
    dp.value,
    (dp.density * 100).toFixed(2),
    dp.count,
  ],
];

/**
 * Get plot layout for the given result
 * @param result The results to be plotted
 * @param config Plot configuration to be used
 * @param dateRange Optional date range to be used for the plot
 * @param width Optional width to be used for the plot
 * @returns Plotly layout
 */
const getValueCountDistributionPlotLayout = (
  result: ValueCountDistributionResultDetails,
  config: PlotConfig,
  dateRange?: [DateLike, DateLike],
  width?: number
): Partial<Plotly.Layout> => {
  // Use default layout as base
  const layout = getPlotLayout([result], config, dateRange, width);

  // Apply title and range for all y-axes
  const yAxes = Object.keys(layout).filter((key) => key.startsWith("yaxis"));
  return _.merge(
    layout,
    yAxes.reduce(
      (acc, axisKey) => ({
        ...acc,
        [axisKey]: {
          tickformat: ",.0%",
          title: {
            text: result.column,
          },
        },
      }),
      { barmode: "stack" }
    )
  );
};
