import { ResultOf } from "@graphql-typed-document-node/core";
import { SettingsIcon } from "lucide-react";
import React, { useLayoutEffect, useMemo, useState } from "react";

import { FragmentType, MetricConfig, RuntimeConfigInput, ThresholdInput, gql, useFragment } from "@/apis/nannyml";
import { Dialog, DialogContent, DialogTrigger, alert } from "@/components/Dialog";
import { Table, TableBody, TableCell, TableHead, TableHeader, TableRow } from "@/components/Table";
import { Button } from "@/components/common/Button";
import { Checkbox } from "@/components/common/Checkbox/Checkbox";
import { CalculatorGroup } from "@/domains/monitoring";

import { ConceptDriftConfig } from "./ConceptDriftConfig";
import { CovariateShiftConfig } from "./CovariateShiftConfig";
import { DataQualityConfig } from "./DataQualityConfig";
import { DescriptiveStatisticsConfig } from "./DescriptiveStatisticsConfig";
import { PerformanceConfig } from "./PerformanceConfig";
import { PerformanceConfigComponentProps, RuntimeConfigComponentProps } from "./types";

const runtimeConfigDetailsFragment = gql(/* GraphQL */ `
  fragment MonitoringRuntimeConfigDetails on RuntimeConfig {
    ...PerformanceRuntimeConfig
    ...CovariateShiftRuntimeConfig
    ...ConceptDriftRuntimeConfig
    ...DescriptiveStatisticsRuntimeConfig
    ...DataQualityRuntimeConfig

    dataChunking {
      chunking
      enabled
      nrOfRows
    }
    performanceTypes {
      type
      enabled
    }
    performanceMetrics {
      metric
      ...MetricThresholdConfig
      estimated {
        enabled
      }
      realized {
        enabled
      }
      ... on BusinessValueMetricConfig {
        truePositiveWeight
        falsePositiveWeight
        trueNegativeWeight
        falseNegativeWeight
      }
    }
    univariateDriftMethods {
      method
      ...MetricThresholdConfig
      categorical {
        enabled
      }
      continuous {
        enabled
      }
      targets {
        enabled
      }
      predictions {
        enabled
      }
      predictedProbabilities {
        enabled
      }
    }
    multivariateDriftMethods {
      method
      enabled
      ...MetricThresholdConfig
    }
    dataQualityMetrics {
      metric
      ...MetricThresholdConfig
      categorical {
        enabled
      }
      continuous {
        enabled
      }
      targets {
        enabled
      }
      predictions {
        enabled
      }
      predictedProbabilities {
        enabled
      }
      normalize
    }
    conceptShiftMetrics {
      metric
      ...MetricThresholdConfig
      enabled
    }
    summaryStatsMetrics {
      __typename
      metric
      ...MetricThresholdConfig
      ... on SummaryStatsSimpleMetricConfig {
        enabled
      }
      ... on SummaryStatsColumnMetricConfig {
        categorical {
          enabled
        }
        continuous {
          enabled
        }
        targets {
          enabled
        }
        predictions {
          enabled
        }
        predictedProbabilities {
          enabled
        }
      }
    }
  }
`);

const metricThresholdConfigFragment = gql(/* GraphQL */ `
  fragment MetricThresholdConfig on MetricConfig {
    __typename
    threshold {
      ...ThresholdDetails
    }
    segmentThresholds {
      segment {
        id
      }
      threshold {
        ...ThresholdDetails
      }
    }
  }
`);

const thresholdDetailsFragment = gql(/* GraphQL */ `
  fragment ThresholdDetails on Threshold {
    __typename
    ... on ConstantThreshold {
      lower
      upper
    }
    ... on StandardDeviationThreshold {
      stdLowerMultiplier
      stdUpperMultiplier
    }
  }
`);

export const useRuntimeConfig = (configFragment: FragmentType<typeof runtimeConfigDetailsFragment>) => {
  const config = useFragment(runtimeConfigDetailsFragment, configFragment);
  return [config, useMemo(() => convertRuntimeConfigToInput(config), [config])] as const;
};

const convertRuntimeConfigToInput = (config: ResultOf<typeof runtimeConfigDetailsFragment>): RuntimeConfigInput => ({
  conceptShiftMetrics: config.conceptShiftMetrics.map((m) => ({
    metric: m.metric,
    enabled: m.enabled,
    ...convertMetricThresholdsToInput(m),
  })),
  dataChunking: config.dataChunking.map((m) => ({
    chunking: m.chunking,
    enabled: m.enabled,
    nrOfRows: m.nrOfRows,
  })),
  dataQualityMetrics: config.dataQualityMetrics.map((m) => ({
    metric: m.metric,
    enabledCategorical: m.categorical.enabled,
    enabledContinuous: m.continuous.enabled,
    enabledTargets: m.targets.enabled,
    enabledPredictions: m.predictions.enabled,
    enabledPredictedProbabilities: m.predictedProbabilities.enabled,
    normalize: m.normalize,
    ...convertMetricThresholdsToInput(m),
  })),
  multivariateDriftMethods: config.multivariateDriftMethods.map((m) => ({
    method: m.method,
    enabled: m.enabled,
    ...convertMetricThresholdsToInput(m),
  })),
  performanceMetrics: config.performanceMetrics.map((m) => ({
    metric: m.metric,
    enabledEstimated: m.estimated.enabled,
    enabledRealized: m.realized.enabled,
    businessValue:
      m.__typename === "BusinessValueMetricConfig"
        ? {
            truePositiveWeight: m.truePositiveWeight,
            falsePositiveWeight: m.falsePositiveWeight,
            trueNegativeWeight: m.trueNegativeWeight,
            falseNegativeWeight: m.falseNegativeWeight,
          }
        : undefined,
    ...convertMetricThresholdsToInput(m),
  })),
  performanceTypes: config.performanceTypes.map((t) => ({ type: t.type, enabled: t.enabled })),
  summaryStatsMetrics: config.summaryStatsMetrics.map((m) => ({
    metric: m.metric,
    ...convertMetricThresholdsToInput(m),
    ...(m.__typename === "SummaryStatsSimpleMetricConfig"
      ? { enabled: m.enabled }
      : {
          enabledCategorical: m.categorical.enabled,
          enabledContinuous: m.continuous.enabled,
          enabledTargets: m.targets.enabled,
          enabledPredictions: m.predictions.enabled,
          enabledPredictedProbabilities: m.predictedProbabilities.enabled,
        }),
  })),
  univariateDriftMethods: config.univariateDriftMethods.map((m) => ({
    method: m.method,
    enabledCategorical: m.categorical.enabled,
    enabledContinuous: m.continuous.enabled,
    enabledTargets: m.targets.enabled,
    enabledPredictions: m.predictions.enabled,
    enabledPredictedProbabilities: m.predictedProbabilities.enabled,
    ...convertMetricThresholdsToInput(m),
  })),
});

const convertMetricThresholdsToInput = (metricConfig: FragmentType<typeof metricThresholdConfigFragment>) => {
  const metric = useFragment(metricThresholdConfigFragment, metricConfig);
  return {
    threshold: convertThresholdToInput(metric.threshold),
    segmentThresholds: metric.segmentThresholds.map((s) => ({
      segmentId: s.segment.id,
      threshold: convertThresholdToInput(s.threshold),
    })),
  };
};

const convertThresholdToInput = (
  thresholdConfig: FragmentType<typeof thresholdDetailsFragment> | null
): ThresholdInput | null => {
  const threshold = useFragment(thresholdDetailsFragment, thresholdConfig);
  if (!threshold) {
    return null;
  } else if (threshold.__typename === "ConstantThreshold") {
    return { constant: { lower: threshold.lower, upper: threshold.upper } };
  } else if (threshold.__typename === "StandardDeviationThreshold") {
    return {
      standardDeviation: {
        stdLowerMultiplier: threshold.stdLowerMultiplier,
        stdUpperMultiplier: threshold.stdUpperMultiplier,
      },
    };
  } else {
    return null;
  }
};

type MetricColumnConfig = {
  enabledCategorical?: boolean | null;
  enabledContinuous?: boolean | null;
  enabledPredictedProbabilities?: boolean | null;
  enabledPredictions?: boolean | null;
  enabledTargets?: boolean | null;
};

const isEnabledColumnConfig = (metricConfig: MetricColumnConfig) =>
  metricConfig.enabledCategorical ||
  metricConfig.enabledContinuous ||
  metricConfig.enabledPredictedProbabilities ||
  metricConfig.enabledPredictions ||
  metricConfig.enabledTargets;

const disableColumnConfig = <T extends MetricColumnConfig>(metricConfig: T) => ({
  ...metricConfig,
  enabledCategorical: false,
  enabledContinuous: false,
  enabledPredictedProbabilities: false,
  enabledPredictions: false,
  enabledTargets: false,
});

type Configurator = {
  label: string;
  Component: React.FC<PerformanceConfigComponentProps<ResultOf<typeof runtimeConfigDetailsFragment>>>;
  isEnabled: (runtimeConfig: RuntimeConfigInput) => boolean;
  disable: (runtimeConfig: RuntimeConfigInput) => Partial<RuntimeConfigInput>;
  resetToDefault: (defaultConfig: RuntimeConfigInput) => Partial<RuntimeConfigInput>;
};

const configurators: Record<CalculatorGroup, Configurator> = {
  [CalculatorGroup.Performance]: {
    label: "Performance monitoring",
    Component: PerformanceConfig,
    isEnabled: (runtimeConfig) => runtimeConfig.performanceMetrics.some((m) => m.enabledEstimated || m.enabledRealized),
    disable: (runtimeConfig) => ({
      performanceMetrics: runtimeConfig.performanceMetrics.map((m) => ({
        ...m,
        enabledEstimated: false,
        enabledRealized: false,
      })),
    }),
    resetToDefault: ({ performanceMetrics }) => ({ performanceMetrics }),
  },
  [CalculatorGroup.CovariateShift]: {
    label: "Covariate shift detection",
    Component: CovariateShiftConfig,
    isEnabled: (runtimeConfig) =>
      runtimeConfig.univariateDriftMethods.some(isEnabledColumnConfig) ||
      runtimeConfig.multivariateDriftMethods.some((m) => m.enabled),
    disable: (runtimeConfig) => ({
      univariateDriftMethods: runtimeConfig.univariateDriftMethods.map(disableColumnConfig),
      multivariateDriftMethods: runtimeConfig.multivariateDriftMethods.map((m) => ({ ...m, enabled: false })),
    }),
    resetToDefault: ({ univariateDriftMethods, multivariateDriftMethods }) => ({
      univariateDriftMethods,
      multivariateDriftMethods,
    }),
  },
  [CalculatorGroup.ConceptDrift]: {
    label: "Concept drift detection",
    Component: ConceptDriftConfig,
    isEnabled: (runtimeConfig) => runtimeConfig.conceptShiftMetrics.some((m) => m.enabled),
    disable: (runtimeConfig) => ({
      conceptShiftMetrics: runtimeConfig.conceptShiftMetrics.map((m) => ({ ...m, enabled: false })),
    }),
    resetToDefault: ({ conceptShiftMetrics }) => ({ conceptShiftMetrics }),
  },
  [CalculatorGroup.DescriptiveStatistics]: {
    label: "Descriptive statistics",
    Component: DescriptiveStatisticsConfig,
    isEnabled: (runtimeConfig) => runtimeConfig.summaryStatsMetrics.some((m) => m.enabled || isEnabledColumnConfig(m)),
    disable: (runtimeConfig) => ({
      summaryStatsMetrics: runtimeConfig.summaryStatsMetrics.map((m) => ({
        ...disableColumnConfig(m),
        enabled: false,
      })),
    }),
    resetToDefault: ({ summaryStatsMetrics }) => ({ summaryStatsMetrics }),
  },
  [CalculatorGroup.DataQuality]: {
    label: "Data quality",
    Component: DataQualityConfig,
    isEnabled: (runtimeConfig) => runtimeConfig.dataQualityMetrics.some(isEnabledColumnConfig),
    disable: (runtimeConfig) => ({
      dataQualityMetrics: runtimeConfig.dataQualityMetrics.map(disableColumnConfig),
    }),
    resetToDefault: ({ dataQualityMetrics }) => ({ dataQualityMetrics }),
  },
};

export const RuntimeConfig = ({
  problemType,
  kpm,
  value,
  onValueChange,
  onKpmChange,
  config: configFragment,
}: PerformanceConfigComponentProps<FragmentType<typeof runtimeConfigDetailsFragment>>) => {
  const [config, defaultConfig] = useRuntimeConfig(configFragment);
  const [history, setHistory] = useState<RuntimeConfigInput[]>([]);

  // Initialize runtime config with default values
  useLayoutEffect(() => {
    if (!value) {
      onValueChange(defaultConfig);
    }
    if (!kpm) {
      onKpmChange({ metric: config.performanceMetrics.find((m) => m.estimated.enabled || m.realized.enabled)!.metric });
    }
  }, []);

  // Track history of config changes
  if (value && value !== history.at(-1)) {
    setHistory((history) => history.concat(value));
  }

  // Nothing to render if there is no value
  if (!value) {
    return null;
  }

  const onCalculatorGroupEnabledChange = (configurator: Configurator, enabled: boolean) => {
    if (!enabled) {
      onValueChange(configurator.disable(value));
    } else {
      // Find the last config where the calculator was enabled
      for (let i = history.length - 1; i >= 0; i--) {
        if (configurator.isEnabled(history[i])) {
          onValueChange(history[i]);
          return;
        }
      }

      alert({
        title: "Cannot enable",
        message: "This calculator is not supported. Check the configuration screen for more details.",
        variant: "info",
      });
    }
  };

  return (
    <Table className="w-fit">
      <TableHeader>
        <TableRow>
          <TableHead>Calculators</TableHead>
          <TableHead>Enabled</TableHead>
          <TableHead />
        </TableRow>
      </TableHeader>
      <TableBody>
        {Object.entries(configurators).map(([group, configurator]) => (
          <TableRow key={group}>
            <TableCell>{configurator.label}</TableCell>
            <TableCell className="text-center">
              <Checkbox
                checked={configurator.isEnabled(value)}
                onCheckedChange={(enabled) => onCalculatorGroupEnabledChange(configurator, Boolean(enabled))}
              />
            </TableCell>
            <TableCell>
              <Dialog>
                <DialogTrigger asChild>
                  <Button cva={{ size: "mediumLong" }} className="flex items-center gap-2">
                    <SettingsIcon size={16} />
                    Configure
                  </Button>
                </DialogTrigger>
                <DialogContent className="text-pale max-w-full w-fit overflow-auto max-h-full">
                  <h3 className="mb-6 text-lg text-center font-semibold">{configurator.label}</h3>
                  <configurator.Component
                    problemType={problemType}
                    kpm={kpm}
                    config={config}
                    value={value}
                    onValueChange={onValueChange}
                    onKpmChange={onKpmChange}
                  />
                </DialogContent>
              </Dialog>
            </TableCell>
          </TableRow>
        ))}
      </TableBody>
    </Table>
  );
};

export const RuntimeConfigSummary = ({ runtimeConfig }: { runtimeConfig: RuntimeConfigInput }) => (
  <Table className="w-fit">
    <TableHeader>
      <TableRow>
        <TableHead>Calculators</TableHead>
        <TableHead className="text-center">Enabled</TableHead>
      </TableRow>
    </TableHeader>
    <TableBody>
      {Object.entries(configurators).map(([group, configurator]) => (
        <TableRow key={group}>
          <TableCell>{configurator.label}</TableCell>
          <TableCell className="text-center">
            <Checkbox checked={configurator.isEnabled(runtimeConfig)} aria-readonly disabled={true} />
          </TableCell>
        </TableRow>
      ))}
    </TableBody>
  </Table>
);
