import { ResultOf } from "@graphql-typed-document-node/core";
import _ from "lodash";
import React from "react";

import {
  FragmentType,
  MultivariateDriftMethod,
  MultivariateDriftMethodConfigInput,
  ProblemType,
  UnivariateDriftMethod,
  UnivariateDriftMethodConfigInput,
  gql,
  useFragment,
} from "@/apis/nannyml";
import { Table, TableBody, TableCaption, TableCell, TableHead, TableHeader, TableRow } from "@/components/Table";
import { distanceMethodLabels, multivariateDriftMethodLabels, statisticalMethodLabels } from "@/formatters/monitoring";
import { cn } from "@/lib/utils";

import { MetricEnableToggle } from "./MetricEnableToggle";
import { ThresholdConfigCells, ThresholdConfigHeaderCells } from "./ThresholdConfig";
import { RuntimeConfigComponentProps, SegmentRuntimeConfigComponentProps } from "./types";

const univariateLabelGroups = {
  "Distance measures": distanceMethodLabels,
  "Statistical measures": statisticalMethodLabels,
};

const covariateShiftRuntimeConfigFragment = gql(/* GraphQL */ `
  fragment CovariateShiftRuntimeConfig on RuntimeConfig {
    multivariateDriftMethods {
      ...IsSupportedConfig
      method
      lowerValueLimit
      upperValueLimit
    }
    univariateDriftMethods {
      method
      lowerValueLimit
      upperValueLimit
      categorical {
        ...IsSupportedConfig
      }
      continuous {
        ...IsSupportedConfig
      }
      targets {
        ...IsSupportedConfig
      }
      predictions {
        ...IsSupportedConfig
      }
      predictedProbabilities {
        ...IsSupportedConfig
      }
    }
  }
`);

const useCovariateShiftEdit = ({
  config: configFragment,
  value: { univariateDriftMethods, multivariateDriftMethods },
  onValueChange,
}: RuntimeConfigComponentProps<FragmentType<typeof covariateShiftRuntimeConfigFragment>>) => {
  const config = useFragment(covariateShiftRuntimeConfigFragment, configFragment);

  const univariateMethods = _.keyBy(univariateDriftMethods, "method");
  const multivariateMethods = _.keyBy(multivariateDriftMethods, "method");
  const configMultivariateMethods = _.keyBy(config.multivariateDriftMethods, "method");
  const configUnivariateMethods = _.keyBy(config.univariateDriftMethods, "method");

  const onUnivariateMethodChange =
    (method: string | UnivariateDriftMethod) => (change: Partial<UnivariateDriftMethodConfigInput>) => {
      onValueChange({
        univariateDriftMethods: (univariateDriftMethods ?? []).map((m) =>
          m.method === method ? { ...m, ...change } : m
        ),
      });
    };
  const onMultivariateMethodChange =
    (method: MultivariateDriftMethod) => (change: Partial<MultivariateDriftMethodConfigInput>) => {
      onValueChange({
        multivariateDriftMethods: (multivariateDriftMethods ?? []).map((m) =>
          m.method === method ? { ...m, ...change } : m
        ),
      });
    };

  return {
    univariateMethods,
    multivariateMethods,
    configMultivariateMethods,
    configUnivariateMethods,
    onUnivariateMethodChange,
    onMultivariateMethodChange,
  };
};

export const CovariateShiftConfig = (
  props: RuntimeConfigComponentProps<FragmentType<typeof covariateShiftRuntimeConfigFragment>>
) => {
  const {
    configMultivariateMethods,
    configUnivariateMethods,
    multivariateMethods,
    univariateMethods,
    onUnivariateMethodChange,
    onMultivariateMethodChange,
  } = useCovariateShiftEdit(props);

  return (
    <div className={cn("flex flex-col gap-8", props.className)}>
      <Table>
        <UnivariateDriftCaption />
        <TableHeader>
          <TableRow>
            <TableHead>Method</TableHead>
            <UnivariateDriftEnableHeaderCells problemType={props.problemType} />
            <ThresholdConfigHeaderCells />
          </TableRow>
        </TableHeader>
        <TableBody>
          {Object.entries(univariateLabelGroups).map(([title, labels]) => (
            <React.Fragment key={title}>
              <HeaderRow>{title}</HeaderRow>
              {Object.entries(labels).map(([method, label]) => (
                <TableRow key={method}>
                  <CovariateShiftMethodCell method={label} />
                  <UnivariateDriftEnableCells
                    problemType={props.problemType}
                    config={configUnivariateMethods[method as UnivariateDriftMethod]}
                    value={univariateMethods[method as UnivariateDriftMethod]}
                    onValueChange={onUnivariateMethodChange(method)}
                  />
                  <UnivariateDriftThresholdCells
                    config={configUnivariateMethods[method as UnivariateDriftMethod]}
                    value={univariateMethods[method as UnivariateDriftMethod]}
                    onValueChange={onUnivariateMethodChange(method)}
                  />
                </TableRow>
              ))}
            </React.Fragment>
          ))}
        </TableBody>
      </Table>
      <Table className="w-fit">
        <MultivariateDriftCaption />
        <TableHeader>
          <TableRow>
            <TableHead>Method</TableHead>
            <MultivariateDriftEnableHeaderCells />
            <ThresholdConfigHeaderCells />
          </TableRow>
        </TableHeader>
        <TableBody>
          {Object.values(MultivariateDriftMethod).map((method) => (
            <TableRow key={method}>
              <CovariateShiftMethodCell method={multivariateDriftMethodLabels[method]} />
              <MultivariateDriftEnableCells
                config={configMultivariateMethods[method]}
                value={multivariateMethods[method]}
                onValueChange={onMultivariateMethodChange(method)}
              />
              <MultivariateDriftThresholdCells
                config={configMultivariateMethods[method]}
                value={multivariateMethods[method]}
                onValueChange={onMultivariateMethodChange(method)}
              />
            </TableRow>
          ))}
        </TableBody>
      </Table>
    </div>
  );
};

const HeaderRow = ({ children }: { children: React.ReactNode }) => (
  <TableRow>
    <TableCell colSpan={8} className="text-xs font-semibold uppercase bg-slate-700/50">
      {children}
    </TableCell>
  </TableRow>
);

export const CovariateShiftEnableConfig = (
  props: RuntimeConfigComponentProps<FragmentType<typeof covariateShiftRuntimeConfigFragment>>
) => {
  const {
    configMultivariateMethods,
    configUnivariateMethods,
    multivariateMethods,
    univariateMethods,
    onUnivariateMethodChange,
    onMultivariateMethodChange,
  } = useCovariateShiftEdit(props);

  return (
    <div className={cn("flex flex-col gap-8", props.className)}>
      <Table>
        <UnivariateDriftCaption />
        <TableHeader>
          <TableRow>
            <TableHead>Method</TableHead>
            <UnivariateDriftEnableHeaderCells problemType={props.problemType} />
          </TableRow>
        </TableHeader>
        <TableBody>
          {Object.entries(univariateLabelGroups).map(([title, labels]) => (
            <React.Fragment key={title}>
              <HeaderRow>{title}</HeaderRow>
              {Object.entries(labels).map(([method, label]) => (
                <TableRow key={method}>
                  <CovariateShiftMethodCell method={label} />
                  <UnivariateDriftEnableCells
                    problemType={props.problemType}
                    config={configUnivariateMethods[method as UnivariateDriftMethod]}
                    value={univariateMethods[method as UnivariateDriftMethod]}
                    onValueChange={onUnivariateMethodChange(method)}
                  />
                </TableRow>
              ))}
            </React.Fragment>
          ))}
        </TableBody>
      </Table>
      <Table className="w-fit">
        <MultivariateDriftCaption className="whitespace-nowrap" />
        <TableHeader>
          <TableRow>
            <TableHead>Method</TableHead>
            <MultivariateDriftEnableHeaderCells />
          </TableRow>
        </TableHeader>
        <TableBody>
          {Object.values(MultivariateDriftMethod).map((method) => (
            <TableRow key={method}>
              <CovariateShiftMethodCell method={multivariateDriftMethodLabels[method]} />
              <MultivariateDriftEnableCells
                key={method}
                config={configMultivariateMethods[method as MultivariateDriftMethod]}
                value={multivariateMethods[method as MultivariateDriftMethod]}
                onValueChange={onMultivariateMethodChange(method)}
              />
            </TableRow>
          ))}
        </TableBody>
      </Table>
    </div>
  );
};

export const CovariateShiftThresholdConfig = (
  props: SegmentRuntimeConfigComponentProps<FragmentType<typeof covariateShiftRuntimeConfigFragment>>
) => {
  const {
    configMultivariateMethods,
    configUnivariateMethods,
    multivariateMethods,
    univariateMethods,
    onUnivariateMethodChange,
    onMultivariateMethodChange,
  } = useCovariateShiftEdit(props);

  return (
    <div className={cn("flex flex-col gap-8", props.className)}>
      <Table>
        <UnivariateDriftCaption />
        <TableHeader>
          <TableRow>
            <TableHead>Method</TableHead>
            <ThresholdConfigHeaderCells useSegments />
          </TableRow>
        </TableHeader>
        <TableBody>
          {Object.entries(univariateLabelGroups).map(([title, labels]) => (
            <React.Fragment key={title}>
              <HeaderRow>{title}</HeaderRow>
              {Object.entries(labels).map(([method, label]) => (
                <TableRow key={method}>
                  <CovariateShiftMethodCell method={label} />
                  <UnivariateDriftThresholdCells
                    config={configUnivariateMethods[method as UnivariateDriftMethod]}
                    value={univariateMethods[method as UnivariateDriftMethod]}
                    onValueChange={onUnivariateMethodChange(method)}
                    segmentId={props.segmentId}
                  />
                </TableRow>
              ))}
            </React.Fragment>
          ))}
        </TableBody>
      </Table>
      <Table>
        <MultivariateDriftCaption />
        <TableHeader>
          <TableRow>
            <TableHead>Method</TableHead>
            <ThresholdConfigHeaderCells useSegments />
          </TableRow>
        </TableHeader>
        <TableBody>
          {Object.values(MultivariateDriftMethod).map((method) => (
            <TableRow key={method}>
              <CovariateShiftMethodCell method={multivariateDriftMethodLabels[method]} />
              <MultivariateDriftThresholdCells
                key={method}
                config={configMultivariateMethods[method as MultivariateDriftMethod]}
                value={multivariateMethods[method as MultivariateDriftMethod]}
                onValueChange={onMultivariateMethodChange(method)}
                segmentId={props.segmentId}
              />
            </TableRow>
          ))}
        </TableBody>
      </Table>
    </div>
  );
};

const CovariateShiftMethodCell = ({ method }: { method: string }) => (
  <TableCell className="whitespace-nowrap">{method}</TableCell>
);

const UnivariateDriftCaption = () => (
  <TableCaption className="caption-top mt-0 mb-4 text-left">
    <h3 className="text-lg">Univariate drift</h3>
    <span className="text-gray-400">
      Drift detection methods that use each model feature individually to detect change.
    </span>
  </TableCaption>
);

const MultivariateDriftCaption = ({ className }: { className?: string }) => (
  <TableCaption className={cn("caption-top mt-0 mb-4 text-left", className)}>
    <h3 className="text-lg">Multivariate drift</h3>
    <span className="text-gray-400">Drift detection methods that use all model features to detect change.</span>
  </TableCaption>
);

const UnivariateDriftEnableHeaderCells = ({ problemType }: { problemType: ProblemType }) => (
  <>
    <TableHead className="text-center">Target</TableHead>
    <TableHead className="text-center">Predictions</TableHead>
    {problemType !== ProblemType.Regression && <TableHead className="text-center">Predicted probabilities</TableHead>}
    <TableHead className="text-center">Continuous features</TableHead>
    <TableHead className="text-center">Categorical features</TableHead>
  </>
);

const UnivariateDriftEnableCells = ({
  problemType,
  config,
  value,
  onValueChange,
}: {
  problemType: ProblemType;
  config: ResultOf<typeof covariateShiftRuntimeConfigFragment>["univariateDriftMethods"][number];
  value: UnivariateDriftMethodConfigInput;
  onValueChange: (change: Partial<UnivariateDriftMethodConfigInput>) => void;
}) => (
  <>
    <TableCell>
      <MetricEnableToggle
        config={config.targets}
        value={value.enabledTargets ?? false}
        onValueChange={(enabledTargets) => onValueChange({ enabledTargets })}
      />
    </TableCell>
    <TableCell>
      <MetricEnableToggle
        config={config.predictions}
        value={value.enabledPredictions ?? false}
        onValueChange={(enabledPredictions) => onValueChange({ enabledPredictions })}
      />
    </TableCell>
    {problemType !== ProblemType.Regression && (
      <TableCell>
        <MetricEnableToggle
          config={config.predictedProbabilities}
          value={value.enabledPredictedProbabilities ?? false}
          onValueChange={(enabledPredictedProbabilities) => onValueChange({ enabledPredictedProbabilities })}
        />
      </TableCell>
    )}
    <TableCell>
      <MetricEnableToggle
        config={config.continuous}
        value={value.enabledContinuous ?? false}
        onValueChange={(enabledContinuous) => onValueChange({ enabledContinuous })}
      />
    </TableCell>
    <TableCell>
      <MetricEnableToggle
        config={config.categorical}
        value={value.enabledCategorical ?? false}
        onValueChange={(enabledCategorical) => onValueChange({ enabledCategorical })}
      />
    </TableCell>
  </>
);

const UnivariateDriftThresholdCells = (props: {
  config: ResultOf<typeof covariateShiftRuntimeConfigFragment>["univariateDriftMethods"][number];
  segmentId?: number;
  value: UnivariateDriftMethodConfigInput;
  onValueChange: (change: Partial<UnivariateDriftMethodConfigInput>) => void;
}) => (
  <ThresholdConfigCells
    disabled={
      !props.value.enabledCategorical &&
      !props.value.enabledContinuous &&
      !props.value.enabledPredictedProbabilities &&
      !props.value.enabledPredictions &&
      !props.value.enabledTargets
    }
    {...props}
  />
);

const MultivariateDriftEnableHeaderCells = () => <TableHead className="text-center">Enabled</TableHead>;

const MultivariateDriftEnableCells = ({
  config,
  value,
  onValueChange,
}: {
  config: ResultOf<typeof covariateShiftRuntimeConfigFragment>["multivariateDriftMethods"][number];
  value: MultivariateDriftMethodConfigInput;
  onValueChange: (change: Partial<MultivariateDriftMethodConfigInput>) => void;
}) => (
  <TableCell>
    <MetricEnableToggle config={config} value={value.enabled} onValueChange={(enabled) => onValueChange({ enabled })} />
  </TableCell>
);

const MultivariateDriftThresholdCells = (props: {
  segmentId?: number;
  config: ResultOf<typeof covariateShiftRuntimeConfigFragment>["multivariateDriftMethods"][number];
  value: MultivariateDriftMethodConfigInput;
  onValueChange: (change: Partial<MultivariateDriftMethodConfigInput>) => void;
}) => <ThresholdConfigCells disabled={!props.value.enabled} {...props} />;
