import { atom, useAtomValue } from "jotai";
import type { TickFormatter } from "@visx/axis";

import type { Data, Margin } from "../shared/types";
import { selectAtom } from "jotai/utils";
import { bucket, isTimeSeries, normalize } from "../internal/data";
import { prepareData } from "./util";
import { omit } from "lodash-es";
import { categorical } from "../internal/colors";
import { createChartMachineAtom } from "../machines/chart-machine";
import { getXTickFormat, getYTickFormat } from "../utils/tickFormat";
import { AreaChartProps } from "./area-chart";
import {
  BandScale,
  createLinearScale,
  createOrdinalScale,
  createScale,
  DiscreteColorScale,
  LinearScale,
  TimeScale,
} from "../scales";
import { getDateDomain, getNumericDomain } from "../domain";
import { heightAtom, widthAtom } from "../internal/atoms/dimensions";

interface AreaChartState {
  useData: () => Data[];
  useInnerHeight: () => number;
  useInnerWidth: () => number;
  useMargin: () => Margin;
  useXKey: () => string;
  useYKeys: () => string[];
  useXScale: () => TimeScale | LinearScale;
  useYScale: () => LinearScale;
  useYTickFormat: () => TickFormatter<number>;
  useXTickFormat: () => TickFormatter<Date> | TickFormatter<string>;
  useKeys: () => string[];
  useDisplayKeys: () => string[];
  useColorScale: () => DiscreteColorScale;
  useTooltipScale: () => BandScale;
  useTooltipData: () => Data[];
}

export const propsAtom = atom<AreaChartProps>({} as AreaChartProps); // This will be initialized by Provider

const state: AreaChartState = {
  useData: () => useAtomValue(displayDataAtom),
  useInnerHeight: () => useAtomValue(innerHeightAtom),
  useInnerWidth: () => useAtomValue(innerWidthAtom),
  useMargin: () => useAtomValue(marginAtom),
  useYKeys: () => useAtomValue(yKeysAtom),
  useXKey: () => useAtomValue(xKeyAtom),
  useXScale: () => useAtomValue(xScaleAtom),
  useYScale: () => useAtomValue(yScaleAtom),
  useYTickFormat: () => useAtomValue(yTickFormatterAtom),
  useXTickFormat: () => useAtomValue(xTickFormatterAtom),
  useKeys: () => useAtomValue(keysAtom),
  useDisplayKeys: () => useAtomValue(displayKeysAtom),
  useColorScale: () => useAtomValue(colorScaleAtom),
  useTooltipScale: () => useAtomValue(tooltipScaleAtom),
  useTooltipData: () => useAtomValue(tooltipDataAtom),
};

export default state;

const rawDataAtom = selectAtom(propsAtom, (config) => config.data);
const marginAtom = selectAtom(propsAtom, (config) => config.margin);
const xKeyAtom = selectAtom(propsAtom, (config) => config.xKey);
const yKeysAtom = selectAtom(propsAtom, (config) => config.yKeys);
const groupingKeyAtom = selectAtom(propsAtom, (config) => config.groupingKey);
const displayTypeAtom = selectAtom(propsAtom, (config) => config.displayType || "stacked");
const normalizeAtom = selectAtom(propsAtom, (config) => config.normalize || false);
const maximumCategoriesAtom = selectAtom(propsAtom, (config) => config.maximumCategories ?? 8);
const typeOverridesAtom = selectAtom(propsAtom, (config) => config.typeOverrides);

// *************************
// * DIMENSIONS
// *************************
const innerWidthAtom = atom((get) => {
  const width = get(widthAtom);
  const margin = get(marginAtom);
  return width - margin.left - margin.right;
});

const innerHeightAtom = atom((get) => {
  const height = get(heightAtom);
  const margin = get(marginAtom);
  return height - margin.top - margin.bottom;
});

// *************************
// * DATA
// *************************

const dataAtom = atom((get) => {
  const rawData = get(rawDataAtom);
  const xKey = get(xKeyAtom);
  const yKeys = get(yKeysAtom);
  const groupingValue = get(groupingKeyAtom);
  const typeOverrides = get(typeOverridesAtom);
  return prepareData(rawData, xKey, yKeys, groupingValue, typeOverrides);
});

const limitedBucketedDataAtom = atom((get) => {
  const data = get(dataAtom);
  const xKey = get(xKeyAtom);
  const maximumCategories = get(maximumCategoriesAtom);
  if (!maximumCategories) return data;
  return bucket(data, xKey, maximumCategories);
});

const filteredDataAtom = atom((get) => {
  const data = get(limitedBucketedDataAtom);
  const machine = get(chartMachineAtom);
  const { filterKeys } = machine.context.legend;
  if (filterKeys.length === 0) return data;
  return data.map((d) => omit(d, filterKeys));
});

const displayDataAtom = atom((get) => {
  const data = get(filteredDataAtom);
  const xKey = get(xKeyAtom);
  if (get(normalizeAtom)) return normalize(data, xKey);
  return data;
});

// *************************
// * KEYS
// *************************

const keysAtom = atom((get) => {
  const data = get(limitedBucketedDataAtom);
  if (data.length === 0) return [];
  const xKey = get(xKeyAtom);
  return Object.keys(data[0]!).filter((key) => key !== xKey);
});

const displayKeysAtom = atom((get) => {
  const data = get(displayDataAtom);
  if (data.length === 0) return [];
  const xKey = get(xKeyAtom);
  return Object.keys(data[0]!).filter((k) => k !== xKey);
});

const allKeysAtom = atom((get) => {
  const data = get(limitedBucketedDataAtom);
  const xKey = get(xKeyAtom);
  if (data.length === 0) return [];
  const keys = Object.keys(data[0]!).filter((k) => k !== xKey);
  return keys;
});

// *************************
// * DOMAINS
// *************************

const yDomainAtom = atom((get) => {
  const data = get(displayDataAtom);
  const displayKeys = get(displayKeysAtom);
  const xKey = get(xKeyAtom);
  const displayType = get(displayTypeAtom);
  const { min, max } = getNumericDomain(data, displayKeys, {
    stacked: displayType === "stacked" ? { key: xKey } : undefined,
  });
  return [min, max] as [number, number];
});

// *************************
// * FORMATTERS
// *************************
export const yTickFormatterAtom = atom((get) => {
  return getYTickFormat(get(yDomainAtom));
});

export const xTickFormatterAtom = atom((get) => {
  return getXTickFormat(get(displayDataAtom), get(xKeyAtom));
});

// *************************
// * SCALES
// *************************
const xScaleAtom = atom((get) => {
  const data = get(displayDataAtom);
  const xKey = get(xKeyAtom);
  const width = get(innerWidthAtom);
  if (isTimeSeries(data, xKey)) {
    const { min, max } = getDateDomain(data, [xKey]);
    return createScale([min, max], [0, width], "time") as TimeScale;
  }
  const { min, max } = getNumericDomain(data, [xKey], { nice: false });
  return createScale([min, max], [0, width], "linear") as LinearScale;
});

const tooltipScaleAtom = atom((get) => {
  const data = get(displayDataAtom);
  const xKey = get(xKeyAtom);
  const width = get(innerWidthAtom);
  if (isTimeSeries(data, xKey)) {
    const { values } = getDateDomain(data, [xKey]);
    return createScale(values, [0, width], "band") as BandScale;
  }
  const { values } = getNumericDomain(data, [xKey], { nice: false });
  return createScale(values, [0, width], "band", { padding: 0 }) as BandScale;
});

export const yScaleAtom = atom((get) => {
  const domain = get(yDomainAtom);
  const height = get(innerHeightAtom);
  return createLinearScale(domain, [height, 0]);
});
const colorScaleAtom = atom((get) => {
  const keys = get(allKeysAtom);
  const { colorMap } = get(propsAtom);
  return createOrdinalScale(
    keys,
    keys.map((k, i) => (colorMap?.[k] ?? categorical[i % categorical.length]) as string),
  );
});

export const tooltipDataAtom = atom((get) => {
  const data = get(displayDataAtom);
  return data;
});

// *************************
// * Machine
// *************************
const chartMachineAtom = createChartMachineAtom(keysAtom);
export { chartMachineAtom };
