import { extent } from "d3";
import { type PickD3Scale } from "@visx/scale";
import type { TickFormatter } from "@visx/axis";
import { useAtomValue } from "jotai";
import { atom } from "jotai";
import { selectAtom } from "jotai/utils";
import { aggregate, isCategoricalData, isTimeSeries } from "../internal/data";
import type { Data } from "../shared/types";
import { prepareData } from "./utils";
import { atomWithMachine } from "jotai-xstate";
import { createChartMachine } from "../machines";
import { identity, isNumber, reverse, sortBy, uniq } from "lodash-es";
import { ScatterChartProps } from "./scatter-chart";
import {
  BandScale,
  createBandScale,
  createLinearScale,
  createLogScale,
  createOrdinalScale,
  createSqrtScale,
  createTimeScale,
  LinearScale,
  LogScale,
  TimeScale,
} from "../scales";
import { formatDate } from "../internal/formatters";
import { formatNumber } from "../big-number/util";
import {
  categorical,
  CategoricalColorScale,
  ContinuousColorScale,
  createCategoricalScale,
  createContinuousScale,
} from "../colors";
import { getDateDomain, getNumericDomain, getStringDomain } from "../domain";
import { Domain } from "../domain/types";
import { heightAtom, widthAtom } from "../internal/atoms/dimensions";

interface ScatterChartState {
  useData: () => Data[];
  useXScale: () => LinearScale | BandScale | TimeScale;
  useYScale: () => LinearScale | LogScale | BandScale | TimeScale;
  useInnerWidth: () => number;
  useInnerHeight: () => number;
  useXKey: () => string;
  useRKey: () => string | undefined;
  useYKey: () => string;
  useRScale: () => PickD3Scale<"sqrt", number, number>;
  useColorScale: () => CategoricalColorScale | ContinuousColorScale;
  useXTickFormatter: () => TickFormatter<Date> | TickFormatter<number> | TickFormatter<string>;
  useYTickFormatter: () => TickFormatter<number> | TickFormatter<string> | TickFormatter<Date>;
}

export const configAtom = atom<ScatterChartProps>({} as ScatterChartProps); // This will be initialized by Provider

const state: ScatterChartState = {
  useData: () => useAtomValue(displayData),
  useXScale: () => useAtomValue(xScaleAtom),
  useYScale: () => useAtomValue(yScaleAtom),
  useInnerWidth: () => useAtomValue(innerWidthAtom),
  useInnerHeight: () => useAtomValue(innerHeightAtom),
  useRScale: () => useAtomValue(rScaleAtom),
  useColorScale: () => useAtomValue(colorScaleAtom),
  useXTickFormatter: () => useAtomValue(xTickFormatAtom),
  useYTickFormatter: () => useAtomValue(yTickFormatAtom),
  useXKey: () => useAtomValue(xKeyAtom),
  useRKey: () => useAtomValue(rKeyAtom),
  useYKey: () => useAtomValue(yKeyAtom),
};
export default state;

// *************************
// * CONFIG
// *************************
const marginAtom = selectAtom(configAtom, (config) => config.margin);
const rawDataAtom = selectAtom(configAtom, (config) => config.data);
const xKeyAtom = selectAtom(configAtom, (config) => config.xKey);
const yKeyAtom = selectAtom(configAtom, (config) => config.yKey);
const yScaleTypeConfigAtom = selectAtom(configAtom, (config) => config.yAxisScale);
const xScaleTypeConfigAtom = selectAtom(configAtom, (config) => config.xAxisScale);
const rKeyAtom = selectAtom(configAtom, (config) => config.rKey);
const rRangeAtom = selectAtom(configAtom, (config) => config.rRange ?? 30);
const colorKeyAtom = selectAtom(configAtom, (config) => config.colorKey);
const typeOverridesAtom = selectAtom(configAtom, (config) => config.typeOverrides);
const xSortDirAtom = selectAtom(configAtom, (config) => config.xAxisSortDirection);
const xSortKeyAtom = selectAtom(configAtom, (config) => config.xAxisSortKey);

const yScaleTypeAtom = atom((get) => {
  const data = get(dataAtom);
  const yScaleTypeConfig = get(yScaleTypeConfigAtom);
  if (yScaleTypeConfig) return yScaleTypeConfig;
  return determineScale(data, get(yKeyAtom));
});

const xScaleTypeAtom = atom((get) => {
  const data = get(dataAtom);
  const xScaleTypeConfig = get(xScaleTypeConfigAtom);
  if (xScaleTypeConfig) return xScaleTypeConfig;
  return determineScale(data, get(xKeyAtom));
});

function determineScale(data: Data[], key: string) {
  const values = data.map((d) => d[key]);
  if (isTimeSeries(data, key)) {
    return "time";
  }
  if (isNumber(values[0])) {
    return "linear";
  }
  return "band";
}

// *************************
// * DIMENSIONS
// *************************
export const innerWidthAtom = atom((get) => {
  const width = get(widthAtom);
  const margin = get(marginAtom);
  return width - margin.left - margin.right;
});

export const innerHeightAtom = atom((get) => {
  const height = get(heightAtom);
  const margin = get(marginAtom);
  return height - margin.top - margin.bottom;
});

// *************************
// * DATA
// *************************

const dataAtom = atom((get) => {
  const rawData = get(rawDataAtom);
  const typeOverrides = get(typeOverridesAtom);
  const data = prepareData(rawData, typeOverrides);
  return data;
});

const dataBySortKey = atom((get) => {
  const rawData = get(rawDataAtom);
  const data = get(dataAtom);
  const xKey = get(xKeyAtom);
  const sortKey = get(xSortKeyAtom);
  if (!sortKey || sortKey === xKey) return data;
  const agged = aggregate({ data: rawData, xKey, yKeys: [sortKey] });
  return agged;
});

const displayData = atom((get) => {
  const data = get(dataAtom);
  const key = get(colorKeyAtom) ?? get(yKeyAtom);
  const filterKeys = get(filterKeysAtom);
  return data.filter((d) => !filterKeys.includes(d[key] as string));
});

// *************************
// * FORMATTERS
// *************************
const xTickFormatAtom = atom((get) => {
  const data = get(dataAtom);
  const xKey = get(xKeyAtom);
  const { min, max } = get(xDomainAtom);
  const sample = data[0]![xKey]!;
  if (sample instanceof Date) {
    return formatDate(min as Date, max as Date);
  }
  if (isNumber(sample)) {
    return formatNumber;
  }
  return identity;
});

const yTickFormatAtom = atom((get) => {
  const data = get(dataAtom);
  const yKey = get(yKeyAtom);
  const { min, max } = get(yDomainAtom);
  const sample = data[0]![yKey]!;
  if (sample instanceof Date) {
    return formatDate(min as Date, max as Date);
  }
  if (isNumber(sample)) {
    return formatNumber;
  }
  return identity;
});

// *************************
// * DOMAINS
// *************************
const yDomainAtom = atom((get) => {
  const data = get(dataAtom);
  const yKey = get(yKeyAtom);
  const sample = data[0]![yKey]!;
  if (isTimeSeries(data, yKey)) {
    return getDateDomain(data, [yKey]);
  }
  if (isNumber(sample)) {
    return getNumericDomain(data, [yKey]);
  }
  return getStringDomain(data, [yKey]);
});

const xDomainAtom = atom((get) => {
  const data = get(dataBySortKey);
  const xKey = get(xKeyAtom);
  const sortDir = get(xSortDirAtom);
  const sortKey = get(xSortKeyAtom) ?? xKey;
  const sample = data[0]![xKey]!;
  let domain: Domain<Date> | Domain<number> | Domain<string>;
  if (isTimeSeries(data, xKey)) {
    domain = getDateDomain(data, [xKey]);
  } else if (isNumber(sample)) {
    domain = getNumericDomain(data, [xKey], { nice: false });
  } else {
    domain = getStringDomain(data, [xKey]);
  }
  if (sortDir === "none") return domain;
  let sortedValues = sortBy(data, sortKey).map((d) => d[xKey]);
  if (sortDir === "desc") sortedValues = reverse(sortedValues);
  return { ...domain, values: sortedValues as string[] | Date[] };
});

// *************************
// * SCALES
// *************************
const xScaleAtom = atom((get) => {
  const { values, min, max } = get(xDomainAtom);
  const width = get(innerWidthAtom);
  const scaleType = get(xScaleTypeAtom);
  const sortDir = get(xSortDirAtom);
  const sortedDomain = sortDir == "desc" ? [max, min] : [min, max];
  if (scaleType === "band" || (isCategoricalData(get(dataAtom), get(xKeyAtom)) && scaleType === "auto")) {
    return createBandScale(values as string[], [0, width], { paddingInner: 1, paddingOuter: 0.1 });
  }
  if (scaleType === "time" || (isTimeSeries(get(dataAtom), get(xKeyAtom)) && scaleType === "auto")) {
    return createTimeScale(sortedDomain as [Date, Date], [0, width]);
  }
  if (scaleType === "log") {
    return createLogScale([min, max] as [number, number], [0, width]);
  }
  return createLinearScale(sortedDomain as [number, number], [0, width]);
});

const yScaleAtom = atom((get) => {
  const { min, max, values } = get(yDomainAtom);
  const height = get(innerHeightAtom);
  const scaleType = get(yScaleTypeAtom);
  const data = get(dataAtom);
  const yKey = get(yKeyAtom);
  if (scaleType === "band" || (isCategoricalData(data, yKey) && scaleType === "auto")) {
    return createBandScale(values as string[], [height, 0], { paddingInner: 1, paddingOuter: 0.1 });
  }
  if (scaleType === "time" || (isTimeSeries(data, yKey) && scaleType === "auto")) {
    return createTimeScale([min, max] as [Date, Date], [height, 0]);
  }
  if (scaleType === "log") {
    return createLogScale([min, max] as [number, number], [height, 0]);
  }
  return createLinearScale([min, max] as [number, number], [height, 0]);
});

const rScaleAtom = atom((get) => {
  const rKey = get(rKeyAtom);
  const rRange = get(rRangeAtom);
  if (!rKey) return createSqrtScale([0, 1], [3, Math.min(rRange, 30)]);
  const rValues = get(dataAtom).map((d) => d[rKey]);
  const domain = extent(rValues as number[]) as [number, number];
  return createSqrtScale(domain, [3, Math.min(rRange, 30)]);
});

const colorScaleAtom = atom((get) => {
  const data = get(dataAtom);
  const colorKey = get(colorKeyAtom);
  const yKey = get(yKeyAtom);
  const { colorMap } = get(configAtom);
  if (!colorKey) {
    if (!colorMap) return createCategoricalScale([yKey]);
    return createOrdinalScale([yKey], [(colorMap[yKey] ?? categorical[0]) as string]);
  }
  const colorValues = uniq(data.map((d) => d[colorKey] as string | number));
  if (isNumber(colorValues[0])) return createContinuousScale(extent(colorValues as number[]) as [number, number]);
  return createOrdinalScale(
    colorValues as string[],
    colorValues.map((k, i) => (colorMap?.[k] ?? categorical[i % categorical.length]) as string),
  );
});

const keysAtom = atom((get) => {
  const colorKey = get(colorKeyAtom) ?? get(rKeyAtom) ?? get(yKeyAtom);
  const keys = uniq(get(dataAtom).map((d) => d[colorKey]));
  return keys as string[];
});

export const chartMachineAtom = atomWithMachine(
  (get) => {
    const keys = get(keysAtom);
    return createChartMachine(keys);
  },
  { devTools: true },
);

const filterKeysAtom = atom((get) => {
  const current = get(chartMachineAtom);
  return current.context.legend.filterKeys;
});
