import { atom, useAtomValue } from "jotai";
import { selectAtom } from "jotai/utils";
import { type PickD3Scale } from "@visx/scale";
import type { TickFormatter } from "@visx/axis";
import { compact, identity, keyBy, omit, reverse, sortBy } from "lodash-es";
import type { Data, Margin } from "../shared/types";
import { getXTickFormat, getYTickFormat } from "../utils/tickFormat";
import { normalize, bucket, aggregate, isTimeSeries, fillMissingDates } from "../internal/data";
import { categorical } from "../internal/colors";
import { createChartMachineAtom } from "../machines";
import type { BarChartProps } from "./bar-chart";
import {
  BandScale,
  createBandScale,
  createLinearScale,
  createLogScale,
  createOrdinalScale,
  createScale,
  DateBandScale,
} from "../scales";
import { getDateDomain, getNumericDomain, getStringDomain } from "../domain";
import { Domain } from "../domain/types";
import { heightAtom, widthAtom } from "../internal/atoms/dimensions";

interface BarChartState {
  useData: () => Data[];
  useLineData: () => Data[];
  useBarKeys: () => string[]; // keys pertaining to only bars (not line)
  useXKey: () => string;
  useInnerWidth: () => number;
  useXScale: () => BandScale | DateBandScale;
  useXGroupedScale: () => PickD3Scale<"band", number, string>;
  useXTickFormat: () => TickFormatter<string> | TickFormatter<Date>;
  useYScale: () => PickD3Scale<"linear" | "log", number, number>;
  useYScaleType: () => "linear" | "log" | "auto";
  useYLineScale: () => PickD3Scale<"linear", number, number> | undefined;
  useInnerHeight: () => number;
  useYKeysLeft: () => string[];
  useYKeysRight: () => string[] | undefined;
  useYTickFormat: () => TickFormatter<number>;
  useAllKeys: () => string[];
  useColorScale: () => PickD3Scale<"ordinal", string, string>;
  useMargin: () => Margin;
  useTooltipData: () => Data[];
}

const barChartState: BarChartState = {
  useData: () => useAtomValue(displayDataAtom),
  useLineData: () => useAtomValue(lineDataAtom),
  useBarKeys: () => useAtomValue(dataKeysAtom),
  useXKey: () => useAtomValue(xKeyAtom),
  useInnerWidth: () => useAtomValue(innerWidthAtom),
  useXScale: () => useAtomValue(xScaleAtom),
  useXGroupedScale: () => useAtomValue(xGroupedScaleAtom),
  useXTickFormat: () => useAtomValue(xTickFormatterAtom),
  useYScale: () => useAtomValue(yScaleAtom),
  useYScaleType: () => useAtomValue(yScaleTypeAtom),
  useYLineScale: () => useAtomValue(yScaleLineAtom),
  useInnerHeight: () => useAtomValue(innerHeightAtom),
  useYKeysLeft: () => useAtomValue(yKeysLeftAtom),
  useYKeysRight: () => useAtomValue(yKeysRightAtom),
  useYTickFormat: () => useAtomValue(yTickFormatterAtom),
  useAllKeys: () => useAtomValue(dataKeysAtom),
  useColorScale: () => useAtomValue(colorScaleAtom),
  useMargin: () => useAtomValue(marginAtom),
  useTooltipData: () => useAtomValue(tooltipDataAtom),
};

export default barChartState;

// *************************
// * CONFIG
// *************************
export const configAtom = atom<BarChartProps>({} as BarChartProps);

const xKeyAtom = selectAtom(configAtom, (v) => v.xKey);

const yKeysAtom = selectAtom(configAtom, (v) => v.yKeys);
const yKeyLineAtom = selectAtom(configAtom, (v) => v.yKeyLine);
const yAxisLinePositionAtom = selectAtom(configAtom, (v) => v.yAxisLinePosition ?? "right");

const rawDataAtom = selectAtom(configAtom, (v) => v.data);
const normalizeAtom = selectAtom(configAtom, (v) => v.normalize);
const displayTypeAtom = selectAtom(configAtom, (v) => v.displayType ?? "stacked");
const groupingKeyAtom = selectAtom(configAtom, (v) => v.groupingKey);
const yScaleTypeAtom = selectAtom(configAtom, (v) => v.yAxisScale ?? "linear");
const marginAtom = selectAtom(configAtom, (v) => v.margin);
const maximumCategoriesAtom = selectAtom(configAtom, (v) => v.maximumCategories);
const typeOverridesAtom = selectAtom(configAtom, (v) => v.typeOverrides);
const lineGapAtom = selectAtom(configAtom, (v) => v.lineGap ?? "none");
const xSortDirAtom = selectAtom(configAtom, (v) => v.xAxisSortDirection);
const xSortKeyAtom = selectAtom(configAtom, (v) => v.xAxisSortKey);

// *************************
// * DIMENSIONS
// *************************
export const innerWidthAtom = atom((get) => {
  const width = get(widthAtom);
  const margin = get(marginAtom);
  return width - margin.left - margin.right;
});

export const innerHeightAtom = atom((get) => {
  const height = get(heightAtom);
  const margin = get(marginAtom);
  return height - margin.top - margin.bottom;
});

// *************************
// * DATA
// *************************

const dataBySortKey = atom((get) => {
  const rawData = get(rawDataAtom);
  const data = get(dataAtom);
  const xKey = get(xKeyAtom);
  const sortKey = get(xSortKeyAtom);
  if (!sortKey || sortKey === xKey) return data;
  const agged = aggregate({ data: rawData, xKey, yKeys: [sortKey] });
  return agged;
});

export const dataAtom = atom((get) => {
  const rawData = get(rawDataAtom);
  const xKey = get(xKeyAtom);
  const yKeys = get(yKeysAtom);
  const typeOverrides = get(typeOverridesAtom);
  const groupingKey = get(groupingKeyAtom);
  const data = aggregate({
    data: rawData,
    xKey,
    yKeys,
    groupingKey,
    types: typeOverrides,
  });
  if (isTimeSeries(data, xKey)) {
    return fillMissingDates(data, xKey);
  }
  return data;
});

export const lineDataAtom = atom((get) => {
  const yKeyLine = get(yKeyLineAtom);
  const machine = get(chartMachineAtom);
  const { filterKeys } = machine.context.legend;
  if (!yKeyLine || filterKeys.includes(yKeyLine)) return [];
  const xKey = get(xKeyAtom);
  const rawData = get(rawDataAtom);
  const data = aggregate({ data: rawData, xKey, yKeys: [yKeyLine] });
  const lineGap = get(lineGapAtom);
  if (isTimeSeries(data, xKey)) {
    const sorted = sortBy(data, xKey);
    if (lineGap !== "none") return fillMissingDates(sorted, xKey);
    return sorted;
  }
  return data;
});

export const limitedBucketedDataAtom = atom((get) => {
  const data = get(dataAtom);
  const xKey = get(xKeyAtom);
  const maxCategories = get(maximumCategoriesAtom);
  if (!maxCategories) return data;
  return bucket(data, xKey, maxCategories);
});

const filteredDataAtom = atom((get) => {
  const data = get(limitedBucketedDataAtom);
  const machine = get(chartMachineAtom);
  const { filterKeys } = machine.context.legend;
  if (filterKeys.length === 0) return data;
  return data.map((d) => omit(d, filterKeys)) as Data[];
});

const normalizedDataAtom = atom((get) => {
  const data = get(filteredDataAtom);
  const xKey = get(xKeyAtom);
  return normalize(data, xKey);
});

const displayDataAtom = atom((get) => {
  const shouldNormalize = get(normalizeAtom);
  if (shouldNormalize) return get(normalizedDataAtom);
  return get(filteredDataAtom);
});

// *************************
// * KEYS
// *************************

const dataKeysAtom = atom((get) => {
  const data = get(limitedBucketedDataAtom);
  const yKeyLine = get(yKeyLineAtom);
  const xKey = get(xKeyAtom);
  const excludedKeys = [xKey];
  const dataKeys = Object.keys(data[0]!).filter((k) => !excludedKeys.includes(k));
  return compact([...dataKeys, yKeyLine]);
});

const yKeysLeftAtom = atom((get) => {
  const displayBarKeys = get(dataKeysAtom);
  const lineKey = get(yKeyLineAtom);
  if (get(yAxisLinePositionAtom) === "left" && lineKey) {
    return [...displayBarKeys, lineKey];
  }
  return displayBarKeys;
});

const yKeysRightAtom = atom((get) => {
  const { yKeyLine, yAxisLinePosition } = get(configAtom);
  if (yAxisLinePosition === "right" && yKeyLine) {
    return [yKeyLine];
  }
});

// *************************
// * DOMAINS
// *************************

const yDomainAtom = atom((get) => {
  const data = get(displayDataAtom);
  const lineData = get(lineDataAtom);
  const linePosition = get(yAxisLinePositionAtom);
  let allData = data;
  const xKey = get(xKeyAtom);
  const allKeys = get(dataKeysAtom);
  const displayType = get(displayTypeAtom);
  const shouldNormalize = get(normalizeAtom);
  if (shouldNormalize) return [0, 1] as [number, number];
  const { min, max } = getNumericDomain(allData, allKeys, {
    stacked: displayType === "stacked" ? { key: xKey } : undefined,
  });
  if (lineData.length && linePosition === "left") {
    const { min: lineMin, max: lineMax } = getNumericDomain(lineData, [get(yKeyLineAtom)!]);
    return [Math.min(min, lineMin), Math.max(max, lineMax)] as [number, number];
  }
  return [min, max] as [number, number];
});

const yLineDomainAtom = atom((get) => {
  const { data, xKey, yKeyLine: yKey, yAxisLinePosition, normalize } = get(configAtom);
  if (!yKey || yAxisLinePosition === "left") return undefined;
  if (normalize) return [0, 1] as [number, number];
  // Lines for bars is always considered stacked
  const { min, max } = getNumericDomain(data, [yKey], { stacked: { key: xKey } });
  return [min, max] as [number, number];
});

const xDomainAtom = atom((get) => {
  const data = get(dataBySortKey);
  const xKey = get(xKeyAtom);
  const sortDir = get(xSortDirAtom);
  const sortKey = get(xSortKeyAtom) ?? xKey;
  let domain: Domain<Date> | Domain<string>;
  if (isTimeSeries(data, xKey)) {
    domain = getDateDomain(data, [xKey]);
  } else {
    domain = getStringDomain(data, [xKey]);
  }
  if (sortDir === "none") return domain;
  let sortedValues = sortBy(data, sortKey).map((d) => d[xKey]);
  if (sortDir === "desc") sortedValues = reverse(sortedValues);
  return { ...domain, values: sortedValues as string[] | Date[] };
});

// *************************
// * FORMATTERS
// *************************
const xTickFormatterAtom = atom((get) => {
  const xKey = get(xKeyAtom);
  const data = get(displayDataAtom);
  if (!data.length) return identity;
  return getXTickFormat(data, xKey);
});

const yTickFormatterAtom = atom((get) => {
  return getYTickFormat(get(yDomainAtom));
});

// *************************
// * SCALES
// *************************
const xScaleAtom = atom((get) => {
  const { values } = get(xDomainAtom);
  const width = get(innerWidthAtom);
  return createScale(values, [0, width], "band") as BandScale | DateBandScale;
});

const xGroupedScaleAtom = atom((get) => {
  const barKeys = get(dataKeysAtom);
  const xScale = get(xScaleAtom);
  return createBandScale(barKeys, [0, xScale.bandwidth()], { padding: 0.1 });
});

const yScaleAtom = atom((get) => {
  const domain = get(yDomainAtom);
  const height = get(innerHeightAtom);
  const yScale = get(yScaleTypeAtom);
  if (yScale === "log") {
    return createLogScale([1, domain[1]], [height, 0]);
  }
  return createLinearScale(domain, [height, 0]);
});

const yScaleLineAtom = atom((get) => {
  const domain = get(yLineDomainAtom);
  const { yAxisLinePosition } = get(configAtom);
  if (yAxisLinePosition === "left") return get(yScaleAtom);
  if (!domain) return undefined;
  const height = get(innerHeightAtom);
  return createLinearScale(domain, [height, 0]);
});

const colorScaleAtom = atom((get) => {
  const keys = get(dataKeysAtom);
  const { colorMap } = get(configAtom);
  return createOrdinalScale(
    keys,
    keys.map((k, i) => (colorMap?.[k] ?? categorical[i % categorical.length]) as string),
  );
});

const tooltipDataAtom = atom((get) => {
  const data = get(displayDataAtom);
  const xKey = get(xKeyAtom);
  const lineData = get(lineDataAtom);
  if (!lineData) return data;
  const keyedLineData = keyBy(lineData, xKey);
  return data.map((d) => ({ ...d, ...keyedLineData[d[xKey] as string] }));
});

export const chartMachineAtom = createChartMachineAtom(dataKeysAtom);
