// Credit to: https://bl.ocks.org/d3noob/5028304
import * as d3 from "d3";
import { type EntityType, type SankeyRelation } from "../../types/Sankey";

type LinkFunc = () => {
  (d: ComputedLink): string;
  curvature(): number;
  curvature(newCurvature: number): Sankey;
};

interface Sankey {
  // Accessor functions
  nodeWidth(): number;
  nodeWidth(newWidth: number): Sankey;
  nodePadding(): number;
  nodePadding(newPadding: number): Sankey;
  relations(): SankeyRelation[];
  relations(newRelations: SankeyRelation[]): Sankey;
  nodes(): ComputedNode[];
  nodes(newNodes: ComputedNode[]): Sankey;
  links(): ComputedLink[];
  links(newLinks: ComputedLink[]): Sankey;
  size(): [number, number];
  size(newSize: [number, number]): Sankey;
  layout: (iterations: number) => Sankey;
  relayout: () => Sankey;
  link: LinkFunc;
}

export type ComputedNode = {
  width: number;
  height: number;
  name: string;
  id?: number;
  order: number;
  /**
   * Combination of id and order since there can be duplicates of both
   */
  uniqueId: string;
  type: EntityType;
  sourceLinks: ComputedLink[];
  targetLinks: ComputedLink[];
  value: number;
  x: number;
  y: number;
};

export type ComputedLink = {
  thickness: number;
  // Source Y
  sourceYPosition: number;
  // Target Y
  targetYPosition: number;
  source: ComputedNode;
  target: ComputedNode;
  value: number;
};

const MIN_NODE_HEIGHT = 75;
const NODE_SPACING = 295; // Space between nodes of different columns
const LINK_SCALAR = 0.5; // Scale the links down so they don't fill the entire node

function GenerateUniqueNodeId(id: number, order: number) {
  return `${id} ${order}`;
}

function GenerateUniqueNodeIdFromNode(node: { id: number; order: number }) {
  return GenerateUniqueNodeId(node.id, node.order);
}

function GenerateUniqueLinkId(sourceUniqueId: string, targetUniqueId: string) {
  return `${sourceUniqueId} ${targetUniqueId}`;
}

export const d3Sankey = function () {
  let nodeWidth = 24,
    nodePadding = 8,
    size = [1, 1],
    relations: SankeyRelation[] = [],
    nodes: ComputedNode[] = [],
    links: ComputedLink[] = [];

  const sankey: Sankey = {
    // @ts-ignore
    nodeWidth: function (newWidth?: number) {
      if (!arguments.length || !newWidth) return nodeWidth;
      nodeWidth = newWidth;
      return sankey;
    },
    // @ts-ignore
    nodePadding: function (newPadding?: number) {
      if (!newPadding) return nodePadding;

      nodePadding = newPadding;
      return sankey;
    },
    // @ts-ignore
    relations: function (newRelations?: SankeyRelation[]) {
      if (!arguments.length || !newRelations) return relations;
      relations = newRelations;
      return sankey;
    },
    // @ts-ignore
    nodes: function (newNodes?: ComputedNode[]) {
      if (!arguments.length || !newNodes) return nodes;
      nodes = newNodes;
      return sankey;
    },
    // @ts-ignore
    links: function (newLinks?: ComputedLink[]) {
      if (!arguments.length || !newLinks) return links;
      links = newLinks;
      return sankey;
    },
    // @ts-ignore
    size: function (newSize?: [number, number]) {
      if (!arguments.length || !newSize) return [size[0], size[1]];
      size = newSize;
      return sankey;
    },
    layout: function (iterations) {
      computeNodes();
      computeLinks();
      computeNodeLinks();
      computeNodeValues();
      computeNodeBreadths();
      computeNodeDepths(iterations);
      computeLinkDepths();
      return sankey;
    },
    relayout: () => {
      computeLinkDepths();
      return sankey;
    },
    // @ts-ignore
    link: () => {
      let curvature = 0.5;

      function link(d: ComputedLink) {
        let x0 = d.source.x + d.source.width,
          x1 = d.target.x,
          xi = d3.interpolateNumber(x0, x1),
          x2 = xi(curvature),
          x3 = xi(1 - curvature),
          y0 = d.source.y + d.sourceYPosition + d.thickness / 2,
          y1 = d.target.y + d.targetYPosition + d.thickness / 2;
        return `M${  x0  },${  y0  }C${  x2  },${  y0  } ${  x3  },${  y1  } ${  x1  },${  y1}`;
      }

      link.curvature = function (newCurvature?: number) {
        if (!arguments.length || !newCurvature) return curvature;
        curvature = newCurvature;
        return link;
      };

      return link;
    },
  };

  // Only run this if we're dealing with v2
  function computeNodes() {
    if (relations.length === 0) return;

    type RelationSubset = {
      id: SankeyRelation["ItemId"];
      title: SankeyRelation["ItemTitle"];
      type: SankeyRelation["ItemType"];
      order: SankeyRelation["ItemOrder"];
    };
    const insertedNodes = new Set();
    const testAndInsert = (relation: RelationSubset) => {
      const nodeId = GenerateUniqueNodeIdFromNode(relation);
      if (insertedNodes.has(nodeId)) return;
      nodes.push({
        width: nodeWidth,
        height: MIN_NODE_HEIGHT,
        name: relation.title,
        id: relation.id,
        order: relation.order,
        uniqueId: GenerateUniqueNodeIdFromNode(relation),
        type: relation.type,
        sourceLinks: [],
        targetLinks: [],
        value: 0,
        x: 0,
        y: 0,
      });
      insertedNodes.add(nodeId);
    };
    nodes = [];
    relations.forEach((relation) => {
      testAndInsert({
        id: relation.ItemId,
        title: relation.ItemTitle,
        type: relation.ItemType,
        order: relation.ItemOrder,
      });
      testAndInsert({
        id: relation.NextItemId,
        title: relation.NextItemTitle,
        type: relation.NextItemType,
        order: relation.NextItemOrder,
      });
    });
  }

  // Create links array from the relations
  // Also store computed value if it exists
  function computeLinks() {
    // Only run this if we're dealing with v2
    if (relations.length === 0) return;
    links = [];
    const insertedLinks = new Set();
    for (let relation of relations) {
      const currentUniqueId = GenerateUniqueNodeId(relation.ItemId, relation.ItemOrder);
      const nextUniqueId = GenerateUniqueNodeId(relation.NextItemId, relation.NextItemOrder);
      // A self pointing node (ie stores its own value)
      // Don't create a link, but store the value
      if (currentUniqueId === nextUniqueId) {
        nodes.filter((node) => node.uniqueId === currentUniqueId)[0].value = relation.UsersCompleted;
      }
      // Otherwise, make a new link
      else {
        const linkId = GenerateUniqueLinkId(currentUniqueId, nextUniqueId);
        // Will this be a duplicate link?
        if (insertedLinks.has(linkId)) {
          continue;
        }
        let sourceNode = nodes.find((node) => node.uniqueId === currentUniqueId);
        let targetNode = nodes.find((node) => node.uniqueId === nextUniqueId);
        links.push({
          thickness: 1,
          sourceYPosition: 0,
          targetYPosition: 0,
          source: sourceNode!,
          target: targetNode!,
          value: relation.UsersCompleted,
        });
        insertedLinks.add(linkId);
      }
    }
  }

  // Populate the sourceLinks and targetLinks for each node.
  // Also, if the source and target are not objects, assume they are indices.
  function computeNodeLinks() {
    nodes.forEach(function (node) {
      node.sourceLinks = [];
      node.targetLinks = [];
    });
    const filterFunc = (node: ComputedNode, other: ComputedNode) => node.uniqueId === other.uniqueId;
    links.forEach(function (link) {
      let source = link.source,
        target = link.target;
      // Nodes is not guaranteed to have all nodes
      if (!source || !target) {
        return;
      }

      {
        let foundNode = nodes.filter((node) => filterFunc(node, source));
        source = link.source = foundNode[0];
      }
      {
        let foundNode = nodes.filter((node) => filterFunc(node, target));
        target = link.target = foundNode[0];
      }
      source.sourceLinks.push(link);
      target.targetLinks.push(link);
    });
  }

  // Compute the value (size) of each node by summing the associated links.
  function computeNodeValues() {
    nodes.forEach((node) => {
      node.value = node.value || Math.max(d3.sum(node.sourceLinks, value), d3.sum(node.targetLinks, value));
    });
  }

  // Iteratively assign the breadth (x-position) for each node.
  // Nodes are assigned the maximum breadth of incoming neighbors plus one;
  // nodes with no incoming links are assigned breadth zero, while
  // nodes with no outgoing links are assigned the maximum breadth.
  function computeNodeBreadths() {
    let remainingNodes = nodes,
      nextNodes: ComputedNode[],
      x = 0;

    while (remainingNodes.length) {
      nextNodes = [];
      for (let node of remainingNodes) {
        node.x = x;
        node.width = nodeWidth;
        for (let link of node.sourceLinks) {
          nextNodes.push(link.target);
        }
      }
      remainingNodes = nextNodes;
      ++x;
    }

    moveSinksRight(x);

    // This spaces nodes evenly according to their depth
    scaleNodeBreadths(NODE_SPACING);
    // Below is dynamic spacing based on the width of the sankey diagram
    // scaleNodeBreadths((size[0] - nodeWidth) / (x - 1));
  }

  function moveSinksRight(x: number) {
    nodes.forEach((node) => {
      if (!node.sourceLinks.length) {
        node.x = x - 1;
      }
    });
  }

  function scaleNodeBreadths(kx: number) {
    nodes.forEach((node) => {
      node.x *= kx;
    });
  }

  function computeNodeDepths(iterations: number) {
    let nodesByBreadth: ComputedNode[][] = d3
      .nest<ComputedNode, number>()
      .key((d) => d.x.toString())
      .sortKeys(d3.ascending)
      .entries(nodes)
      .map((d) => d.values);

    initializeNodeDepth(nodesByBreadth);
    resolveCollisions(nodesByBreadth);
    for (let alpha = 1; iterations > 0; --iterations) {
      alpha *= 0.99;
      relaxRightToLeft(alpha, nodesByBreadth);
      resolveCollisions(nodesByBreadth);
      relaxLeftToRight(alpha, nodesByBreadth);
      resolveCollisions(nodesByBreadth);
    }
  }

  /**
   * Relax left-to-right helpers
   */
  function initializeNodeDepth(nodesByBreadth: ComputedNode[][]) {
    let ky = d3.min(nodesByBreadth, (nodeGroup) => {
      return (
        (size[1] - (nodeGroup.length - 1) * nodePadding) /
        // Prevent divide by 0 errors
        atLeastOne(d3.sum(nodeGroup, value))
      );
    }) as number;

    nodesByBreadth.forEach((nodeGroup) => {
      nodeGroup.forEach((node, i) => {
        node.y = i;
        node.height = Math.max(node.value * ky, MIN_NODE_HEIGHT);
      });
    });

    links.forEach((link) => {
      link.thickness = atLeastOne(+link.value * ky * LINK_SCALAR);
    });
  }

  function relaxLeftToRight(alpha: number, nodesByBreadth: ComputedNode[][]) {
    nodesByBreadth.forEach((nodeGroup) => {
      nodeGroup.forEach((node) => {
        if (node.targetLinks.length) {
          let y = d3.sum(node.targetLinks, weightedSource) / atLeastOne(d3.sum(node.targetLinks, value));
          node.y += (y - center(node)) * alpha;
        }
      });
    });

    function weightedSource(link: ComputedLink) {
      return center(link.source) * +link.value;
    }
  }

  function relaxRightToLeft(alpha: number, nodesbyBreadth: ComputedNode[][]) {
    nodesbyBreadth
      .slice()
      .reverse()
      .forEach(function (nodeGroup) {
        nodeGroup.forEach(function (node) {
          if (node.sourceLinks.length) {
            let y = d3.sum(node.sourceLinks, weightedTarget) / atLeastOne(d3.sum(node.sourceLinks, value));
            node.y += (y - center(node)) * alpha;
          }
        });
      });

    function weightedTarget(link: ComputedLink) {
      return center(link.target) * +link.value;
    }
  }

  function resolveCollisions(nodesbyBreadth: ComputedNode[][]) {
    nodesbyBreadth.forEach((nodeGroup) => {
      let node = nodeGroup[0],
        dy = 0,
        y0 = 0,
        n = nodeGroup.length;

      // Push any overlapping nodes down.
      nodeGroup.sort(ascendingDepth);
      for (let i = 0; i < n; ++i) {
        node = nodeGroup[i];
        dy = y0 - node.y;
        if (dy > 0) node.y += dy;
        y0 = node.y + node.height + nodePadding;
      }

      // If the bottommost node goes outside the bounds, push it back up.
      dy = y0 - nodePadding - size[1];
      if (dy > 0) {
        y0 = node.y -= dy;

        // Push any overlapping nodes back up.
        for (let i = n - 2; i >= 0; --i) {
          node = nodeGroup[i];
          dy = node.y + node.height + nodePadding - y0;
          if (dy > 0) node.y -= dy;
          y0 = node.y;
        }
      }
    });
  }

  function ascendingDepth(a: ComputedNode, b: ComputedNode) {
    return a.y - b.y;
  }
  /**
   * End Relax left-to-right helpers
   */

  function computeLinkDepths() {
    nodes.forEach((node) => {
      node.sourceLinks.sort(ascendingTargetDepth);
      node.targetLinks.sort(ascendingSourceDepth);
    });
    nodes.forEach((node) => {
      const totalSourceThickness = node.sourceLinks.reduce(thicknessSum, 0);
      const totalTargetThickness = node.targetLinks.reduce(thicknessSum, 0);
      const nodeMidpoint = node.height / 2;
      let sy = nodeMidpoint - totalSourceThickness / 2,
        ty = nodeMidpoint - totalTargetThickness / 2;

      node.sourceLinks.forEach((link) => {
        link.sourceYPosition = sy;
        sy += link.thickness;
      });
      node.targetLinks.forEach((link) => {
        link.targetYPosition = ty;
        ty += link.thickness;
      });
    });

    function ascendingSourceDepth(a: ComputedLink, b: ComputedLink) {
      return a.source.y - b.source.y;
    }

    function ascendingTargetDepth(a: ComputedLink, b: ComputedLink) {
      return a.target.y - b.target.y;
    }

    function thicknessSum(prev: number, curr: ComputedLink) {
      return prev + curr.thickness;
    }
  }

  function center(node: ComputedNode) {
    return node.y + node.height / 2;
  }

  function value(link: ComputedLink | ComputedNode) {
    return +link.value;
  }

  return sankey;
};

function atLeastOne(input: number) {
  return Math.max(input, 1);
}
