import { useCallback, useEffect, useMemo, useRef } from "react";
import * as d3 from "d3";
import moment from "moment";
import { DEFAULT_COLOR_RANGE, DEFAULT_MARGIN } from "../shared/constants";
import { type Props } from "../types/LineChart";
import { useDimensions } from "../../../hooks/useDimensions";
import { useChartResizeTransition } from "../hooks/useChartResizeTransition";
import {
  isLogScale,
  hideTicksWithNoText,
  calculateChartDimensions,
  drawCenteringGroup,
  getColorScale,
  getTickCountBySize,
  APPEND_ONCE,
  getLogScaleDomain,
  textEllipsis,
} from "../shared/helpers";

import { widthAndHeightInvalid } from "../utils/utils";
import { getTicksFromValueTypes, TickValueType } from "../shared/public";

import "./lineChart.scss";

type FormattedData = {
  x: any;
  y: any;
};

type OffsetType = "top" | "right" | "bottom" | "left";
type TweenDashNodeType = (d3.BaseType | SVGPathElement)[] | (ArrayLike<d3.BaseType> | ArrayLike<SVGPathElement>);

// Prevents an issue where value constantly wants to resize within half a px of the previous
const conditionallySetOffsetProperty = (ele: HTMLElement, prop: OffsetType, value: number) => {
  const marginOfError = 1;
  const prevProperty = ele.style[prop];
  const prevValue = +prevProperty.substring(0, prevProperty.length - 2);

  const floor = prevValue - marginOfError;
  const ceiling = prevValue + marginOfError;
  if (value > ceiling || value < floor) {
    ele.style[prop] = `${value}px`;
  }
};

const getFormattedData = (xData: any[][], yData: any[][]) => {
  // In case only 1 date value
  if (xData.length && xData[0].length === 1 && xData[0][0] instanceof Date) {
    // Create a second date value
    const clonedFirstDate = moment(xData[0][0]);
    clonedFirstDate.startOf("day");
    const secondDate = moment(clonedFirstDate);
    secondDate.add(15, "minutes");
    let returnObj: FormattedData[][] = [];
    const mockXData = [clonedFirstDate.toDate(), secondDate.toDate()];
    for (let i = 0; i < xData.length; i++) {
      const mockYData = [yData[i][0], yData[i][0]];
      returnObj.push(FormatDataForUse(mockXData, mockYData));
    }
    return returnObj;
  }

  return xData.map((d, i) => FormatDataForUse(d, yData[i]));
};

function calculateTopTooltipOffsetScalar(input: number): number {
  // 23 is a bit of a magic number for now, but is line height + padding
  if (input < 23) {
    return 0.5;
  }

  const multiples = Math.floor(input / 23);
  return 1 / (multiples * 4); // Determine the decimal (1/4, 1/8, 1/12...)
}

export const LineChart: React.FC<Props> = ({
  margins = DEFAULT_MARGIN,
  xData,
  yData,
  dashedYData,
  xScaleRatio = d3.scaleLinear,
  yScaleRatio = d3.scaleLinear,
  xFormatterFunc = (d) => d,
  yFormatterFunc,
  yExtraTickFormat = null,
  xTickValueType,
  yTickValueType,
  colorRange = DEFAULT_COLOR_RANGE,
  onClick,
  transitionDuration = 1000,
  tooltipFormatter = (_, y) => (y.toLocaleString ? y.toLocaleString() : y),
  yBlackBoxFormat = (_, y) => (y.toLocaleString ? y.toLocaleString() : y),
}) => {
  const hoveredPoint = useRef<any[]>([]);
  const onClickRef = useRef(onClick);
  const svgRef = useRef<SVGSVGElement>(null);
  const tooltipRef = useRef<HTMLDivElement>();
  const [dim, measuredRef] = useDimensions();

  const [transitionDurationRef, resizeRef] = useChartResizeTransition(transitionDuration);

  const { totalWidth, totalHeight, innerWidth, innerHeight } = calculateChartDimensions(dim, margins);

  /**
   * Creating d3 functions/values necessary to draw the graph
   */
  const formattedDataForLine: FormattedData[][] = useMemo(() => {
    if (xData.length !== yData.length) {
      throw new Error("Data objects must have the same number of elements.");
    }

    return getFormattedData(xData, yData);
  }, [xData, yData]);

  const formattedDashDataForLine: FormattedData[][] = useMemo(() => {
    if (!dashedYData) {
      return [];
    }
    if (xData.length !== dashedYData.length) {
      throw new Error("Data objects must have the same number of elements.");
    }

    return getFormattedData(xData, dashedYData);
  }, [dashedYData, xData]);

  const xDomain = useMemo(
    () =>
      isLogScale(xScaleRatio)
        ? getLogScaleDomain(d3.max(formattedDataForLine[0], (d) => d.x))
        : d3.extent(formattedDataForLine[0], (d) => d.x),
    [formattedDataForLine, xScaleRatio],
  );

  const yDomain = useMemo(() => {
    const flattenedYData = formattedDataForLine.flat().map((e) => e.y);
    const flattenedYDashedData = formattedDashDataForLine.flat().map((e) => e.y) || [0];

    let largestYValue = d3.max([...flattenedYData, ...flattenedYDashedData]);
    if (isLogScale(yScaleRatio)) {
      return getLogScaleDomain(largestYValue);
    } else {
      return [0, largestYValue];
    }
  }, [formattedDashDataForLine, formattedDataForLine, yScaleRatio]);

  const xScale = useMemo(
    () => xScaleRatio().domain(xDomain).range([0, innerWidth]),
    [xScaleRatio, xDomain, innerWidth],
  );
  const yScale = useMemo(
    () => yScaleRatio().clamp(true).domain(yDomain).range([innerHeight, 0]).nice(),
    [yScaleRatio, yDomain, innerHeight],
  );

  const colorScale = useMemo(
    () => getColorScale(colorRange, formattedDataForLine.length),
    [colorRange, formattedDataForLine],
  );

  const xAxis = useMemo(() => {
    const tickValues = getTicksFromValueTypes(xScale, xTickValueType);

    const tickCount = getTickCountBySize(innerWidth, xDomain);
    const ticks = Math.min(formattedDataForLine[0].length, tickCount);

    return d3
      .axisBottom(xScale)
      .ticks(ticks)
      .tickFormat(xFormatterFunc)
      .tickValues(tickValues)
      .tickSizeInner(-innerHeight)
      .tickSizeOuter(-innerHeight)
      .tickPadding(7);
  }, [xScale, xTickValueType, innerWidth, xDomain, formattedDataForLine, xFormatterFunc, innerHeight]);

  const yAxis = useMemo(() => {
    const tickValues = getTicksFromValueTypes(yScale, yTickValueType, 5);

    // Remove "decimal point" if integers only
    const optionalTickFormat =
      yTickValueType === TickValueType.IntegersOnly ? (d: string) => (+d).toLocaleString() : null;

    return (
      d3
        .axisLeft(yScale)
        .ticks(5, yFormatterFunc)
        .tickSize(-innerWidth)
        .tickPadding(7)
        .tickValues(tickValues)
        // @ts-ignore
        .tickFormat(yExtraTickFormat ?? optionalTickFormat)
    );
  }, [yScale, yTickValueType, yFormatterFunc, innerWidth, yExtraTickFormat]);

  const lineGenerator = useMemo(
    () =>
      d3
        .line<FormattedData>()
        .x((d) => xScale(d.x))
        .y((d) => yScale(d.y)),
    [xScale, yScale],
  );

  /**
   * End d3 functions
   */

  const hideTooltips = useCallback(() => {
    const svg = d3.select(svgRef.current);
    svg.selectAll(".dot").style("display", "none");
    hoveredPoint.current = [];
    svg.selectAll(".dashedLine").style("visibility", "hidden");
    svg.selectAll(".leftTooltip").style("visibility", "hidden");
    tooltipRef.current!.style.display = "none";
  }, []);

  // Mount tooltip
  useEffect(() => {
    const tooltip = document.createElement("div");
    tooltip.setAttribute("class", "line-chart-tooltip");
    tooltip.setAttribute("data-testid", "line chart tooltip");
    tooltip.style.display = "none";
    const tooltipPointer = document.createElement("div");
    tooltipPointer.setAttribute("class", "line-chart-tooltipPointer");
    tooltip.appendChild(tooltipPointer);

    const tooltipText = document.createElement("div");
    tooltipText.setAttribute("class", "line-chart-tooltipText");
    tooltip.appendChild(tooltipText);
    document.body.appendChild(tooltip);
    tooltipRef.current = tooltip;

    document.addEventListener("scroll", hideTooltips, true);

    return () => {
      document.body.removeChild(tooltip);
      document.removeEventListener("scroll", hideTooltips);
    };
  }, [hideTooltips]);

  const preflightChecksFailed = useMemo(
    () => widthAndHeightInvalid(innerWidth, innerHeight),
    [innerWidth, innerHeight],
  );

  useEffect(() => {
    onClickRef.current = onClick;
  }, [onClick]);

  // Centering chart / axes
  useEffect(() => {
    if (!svgRef.current || preflightChecksFailed) return;

    const svg = d3.select(svgRef.current);
    const mainArea = drawCenteringGroup(svg, margins.left, margins.top, transitionDurationRef.current);

    const axisTween = (_this: SVGGElement) => {
      return function () {
        const textSelection = d3.select(_this).selectAll("text");
        hideTicksWithNoText(textSelection);
      };
    };

    const xAxisTextEllipsis = (_data: unknown, index: number, elements: ArrayLike<SVGTextElement>) => {
      const e: [SVGTextElement, SVGTextElement, SVGTextElement] = [
        elements[index - 1],
        elements[index],
        elements[index + 1],
      ];
      textEllipsis(elements[index], 10, -8, e);
    };

    // Add X grid lines with labels
    const xAxisGroup = mainArea
      .selectAll<SVGGElement, boolean[]>(".x-axis-group")
      .data(APPEND_ONCE)
      .join(
        (enter) => {
          enter.append("g").attr("transform", `translate(0, ${innerHeight})`).attr("class", "x-axis-group").call(xAxis);
          hideTicksWithNoText(enter.selectAll("text") as any);
          return enter;
        },
        (update) => {
          update
            .transition()
            .duration(transitionDurationRef.current)
            .attr("transform", `translate(0, ${innerHeight})`)
            .tween("hide x ticks", function () {
              const _this = this as unknown as SVGGElement;
              return axisTween(_this);
            })
            .tween("text ellipsis", function () {
              const innerText = d3.select(this).selectAll<SVGTextElement, unknown>("text");

              return function () {
                innerText.each(xAxisTextEllipsis);
              };
            })
            .call(xAxis);
          return update;
        },
      );
    xAxisGroup.selectAll(".domain").attr("class", "domain axisLine");
    xAxisGroup.selectAll("line").attr("class", "axisLine");
    xAxisGroup.selectAll("text").attr("class", "axisText").attr("dy", "13");

    // Add y grid lines with labels
    const yAxisGroup = mainArea
      .selectAll<SVGGElement, boolean[]>(".y-axis-group")
      .data(APPEND_ONCE)
      .join(
        (enter) => {
          enter.append("g").attr("class", "y-axis-group").call(yAxis);
          // If scale log is the default, this hides the ticks
          hideTicksWithNoText(enter.selectAll("text") as any);
          return enter;
        },
        (update) => {
          update
            .transition()
            .duration(transitionDurationRef.current)
            .tween("hide y ticks", function () {
              const _this = this as unknown as SVGGElement;
              return axisTween(_this);
            })
            .call(yAxis);
          return update;
        },
      );

    // For some reason, yAxisGroup is selecting all "line" children
    mainArea.selectAll(".y-axis-group .domain").remove();
    yAxisGroup.selectAll("line").attr("class", "axisLine");
    yAxisGroup.selectAll("text").attr("class", "axisText").attr("dx", "3");
  }, [preflightChecksFailed, margins.left, margins.top, innerHeight, transitionDurationRef, xAxis, yAxis]);

  // Lines
  useEffect(() => {
    if (!svgRef.current || preflightChecksFailed) return;

    const svg = d3.select(svgRef.current);
    const mainArea = svg.select(".centering-group");

    // For initial line transitions
    function tweenDash(_data: FormattedData[], index: number, nodes: TweenDashNodeType) {
      // This should only be required by tests, as in prod nodes is a list of path elements
      const getTotalLengthIsDefined = nodes?.[index] && (nodes[index] as SVGPathElement).getTotalLength;
      const l = getTotalLengthIsDefined ? (nodes[index] as SVGPathElement).getTotalLength() : 1;
      const i = d3.interpolateString(`0,${l}`, `${l},${l}`);

      return function (t: number) {
        return i(t);
      };
    }

    function initialTransition(
      path: d3.Selection<SVGPathElement, FormattedData[], d3.BaseType | SVGGElement, unknown>,
    ) {
      path
        .transition()
        .duration(transitionDurationRef.current)
        .ease(d3.easeLinear)
        .attrTween("stroke-dasharray", tweenDash);
    }
    // End line transition helpers

    // Draw the lines
    mainArea
      .selectAll(".path")
      .data(formattedDataForLine)
      .join(
        (enter) =>
          // @ts-ignore
          enter
            .append("path")
            .attr("class", "path")
            .attr("stroke", (_, i) => colorScale(i)!)
            .attr("d", lineGenerator)
            .attr("stroke-dasharray", "0,1")
            .call(initialTransition),
        (update) => {
          update
            .transition()
            .duration(transitionDurationRef.current)
            .attr("d", lineGenerator)
            .attr("stroke-dasharray", null);
          return update;
        },
      );

    mainArea
      .selectAll(".dashedPath")
      .data(formattedDashDataForLine)
      .join(
        (enter) =>
          enter
            .append("path")
            .attr("class", "dashedPath")
            .attr("stroke", (_, i) => colorScale(i)!)
            .attr("fill", "none")
            .attr("d", lineGenerator)
            .attr("stroke-dasharray", "6, 6"),
        (update) => update.attr("d", lineGenerator),
      );
  }, [
    colorScale,
    formattedDashDataForLine,
    formattedDataForLine,
    lineGenerator,
    preflightChecksFailed,
    transitionDurationRef,
  ]);

  // Tooltips / hover states
  useEffect(() => {
    if (!svgRef.current || preflightChecksFailed) return;

    const svg = d3.select(svgRef.current);
    const mainArea = svg.select(".centering-group");

    // Dashed line that also marks current selection
    const LEFT_TOOLTIP_OFFSET = -15; // Distance from left axis
    const LEFT_TOOLTIP_HEIGHT = 20;
    const dashedLine = mainArea
      .selectAll(".dashedLine")
      .data(APPEND_ONCE)
      .join((enter) => enter.append("line").attr("class", "dashedLine"));
    const leftTooltip = mainArea
      .selectAll(".leftTooltip")
      .data(APPEND_ONCE)
      .join((enter) => enter.append("g").attr("class", "leftTooltip"));
    const leftTooltipRect = leftTooltip
      .selectAll(".leftTooltipRect")
      .data(APPEND_ONCE)
      .join(
        (enter) =>
          enter
            .append("rect")
            .attr("class", "leftTooltipRect")
            .attr("height", LEFT_TOOLTIP_HEIGHT)
            .attr("x", LEFT_TOOLTIP_OFFSET)
            .attr("y", -LEFT_TOOLTIP_HEIGHT / 2)
            .attr("rx", "5"),
        (update) =>
          update
            .attr("height", LEFT_TOOLTIP_HEIGHT)
            .attr("x", LEFT_TOOLTIP_OFFSET)
            .attr("y", -LEFT_TOOLTIP_HEIGHT / 2),
      );
    const leftTooltipText = leftTooltip
      .selectAll(".leftTooltipText")
      .data(APPEND_ONCE)
      .join((enter) => enter.append("text").attr("class", "leftTooltipText").attr("y", "0"));

    // Hover states
    const hoverDot = mainArea
      .selectAll<SVGGElement, null>(".dot")
      .data(APPEND_ONCE)
      .join((enter) => enter.append("g").attr("class", "dot").style("display", "none"));

    hoverDot
      .selectAll(".outerCircle")
      .data(APPEND_ONCE)
      .join((enter) => enter.append("circle").attr("r", 8).attr("class", "outerCircle"));
    hoverDot
      .selectAll(".innerCircle")
      .data(APPEND_ONCE)
      .join((enter) => enter.append("circle").attr("r", 3).attr("class", "innerCircle"));

    function pointerMoved(...args: unknown[]) {
      if (!tooltipRef.current) return;
      hoverDot.style("display", "block");
      tooltipRef.current.style.display = "block";
      dashedLine.style("visibility", "visible");
      leftTooltip.style("visibility", "visible");

      const index = args[1] as number;
      const groups = args[2] as d3.ContainerElement[];

      const [xm, ym] = d3.mouse(groups[index]);

      const xMeasures = formattedDataForLine[0].map((d) => Math.hypot(xScale(d.x) + margins.left - xm));
      const lowestxMeasure = Math.min(...xMeasures);
      const lowestxIndex = xMeasures.indexOf(lowestxMeasure);
      const closestX = formattedDataForLine[0][lowestxIndex];

      let indexOfClosestPoint = formattedDataForLine[0].findIndex((v) => v === closestX);

      const availableYs = formattedDataForLine.map((d) => d[indexOfClosestPoint].y);
      const availableDashedYs = formattedDashDataForLine.map((d) => d[indexOfClosestPoint].y);
      const yMeasures = availableYs.map((y) => Math.hypot(yScale(y) + margins.top - ym));
      const lowestyMeasure = Math.min(...yMeasures);
      const lowestyIndex = yMeasures.indexOf(lowestyMeasure);
      const closestY = availableYs[lowestyIndex];
      const correspondingDashedY = availableDashedYs?.[lowestyIndex];

      if (closestY == null) return;

      hoverDot.attr("transform", `translate(${xScale(closestX.x)}, ${yScale(closestY)})`);
      hoveredPoint.current = [closestX.x, closestY];

      dashedLine.attr("x1", 0).attr("x2", innerWidth).attr("y1", yScale(closestY)).attr("y2", yScale(closestY));

      leftTooltip.attr("transform", `translate(${0}, ${yScale(closestY)})`);
      leftTooltipText.text(yBlackBoxFormat(closestX.x, closestY, lowestyIndex));

      const leftTooltipPadding = 10;
      const leftWidth = (leftTooltipText.node() as SVGTextElement).getBoundingClientRect().width;
      const rectWidth = leftWidth + leftTooltipPadding;
      leftTooltipRect.attr("width", rectWidth).attr("x", LEFT_TOOLTIP_OFFSET - rectWidth);
      leftTooltipText.attr("x", LEFT_TOOLTIP_OFFSET - rectWidth / 2);

      const bound = hoverDot.node()?.getBoundingClientRect() as DOMRect;
      const tooltip = tooltipRef.current;
      const tooltipPointer = tooltip.firstChild as HTMLDivElement;
      const tooltipText = tooltip.lastChild as HTMLDivElement;
      tooltipText.innerText = tooltipFormatter(closestX.x, closestY, lowestyIndex, correspondingDashedY);
      const textWidth = tooltip.clientWidth;

      const topPadding = 45;
      const normalOneLineHeightPlusPadding = 33;
      const heightOffset = tooltip.clientHeight - normalOneLineHeightPlusPadding;

      const leftOffset = bound.x + window.scrollX + bound.width / 2 - textWidth / 2;
      const topOffset = bound.y + window.scrollY - topPadding - heightOffset;

      conditionallySetOffsetProperty(tooltip, "left", leftOffset);
      conditionallySetOffsetProperty(tooltip, "top", topOffset);

      const tooltipPointerLeftOffset = tooltip.clientWidth / 2 - tooltipPointer.clientWidth / 2;
      const tooltipPointerTopOffset =
        tooltip.clientHeight - tooltip.clientHeight * calculateTopTooltipOffsetScalar(heightOffset);
      const tooltipPointerTopPadding = 3;
      conditionallySetOffsetProperty(tooltipPointer, "left", tooltipPointerLeftOffset);
      conditionallySetOffsetProperty(tooltipPointer, "top", tooltipPointerTopOffset + tooltipPointerTopPadding);
    }

    // Binding event listeners
    svg
      .on("pointermove", pointerMoved)
      .on("pointerleave", hideTooltips)
      .on("click", () => onClickRef.current?.(hoveredPoint.current[0], hoveredPoint.current[1]));

    return () => {
      svg.on("pointermove", null).on("pointerleave", null).on("click", null);
    };
  }, [
    formattedDashDataForLine,
    formattedDataForLine,
    hideTooltips,
    innerWidth,
    margins.left,
    margins.top,
    preflightChecksFailed,
    tooltipFormatter,
    xScale,
    yBlackBoxFormat,
    yScale,
  ]);

  return (
    <div className="line-chart-container" ref={measuredRef} data-testid="line-chart">
      <div ref={resizeRef} style={{ width: "100%", height: "100%" }}>
        <svg viewBox={`0 0 ${totalWidth} ${totalHeight}`} ref={svgRef} preserveAspectRatio="none" />
      </div>
    </div>
  );
};

const FormatDataForUse = (x: any[], y: any[]) => {
  return x.map((d, i) => ({
    x: d,
    y: y[i],
  }));
};
