import React, { useRef, useEffect } from "react";
import PropTypes from "prop-types";
import * as d3 from "d3";
import * as R from "ramda";

/*
selectedCohorts = [
  {
    patientIds: [...],
    sampleIds: [...],
    ...
  }
]
*/

const Chart = ({ selectedCohorts, saveCohort, matchPatientIds }) => {
  const ref = useRef(null);
  useEffect(() => {
    if (selectedCohorts && ref.current) {
      const svg = d3.select(ref.current);
      const height = 312;
      const width = 330;

      // Initialize some util functions

      const union = (cohorts) =>
        new Set(cohorts.map((c) => c.sampleIds).flat());
      const innerIntersect = (a, b) =>
        new Set([...a].filter((_a) => b.has(_a)));
      const allSampleIds = union(selectedCohorts);
      const intersect = (cohorts) =>
        cohorts
          .map((c) => new Set(c.sampleIds))
          .reduce(innerIntersect, allSampleIds);

      // Partition selected cohorts by samples
      const numCohorts = selectedCohorts.length;
      let data = [];
      let bubbles = [];
      if (selectedCohorts.length > 1) {
        // Create sequence 0 to n
        const originalSeq = R.times(R.identity, numCohorts);

        // Create all combinations of cohorts (by index)
        const f = (x) => R.xprod(originalSeq, x);
        let mapping = R.uniq(
          R.map(
            (l) => R.uniq(R.sort(R.subtract, R.flatten(l))),
            R.compose(...R.repeat(f, numCohorts - 1))(originalSeq)
          )
        );

        // Sort mapping by the number of cohorts in a partition (decending)
        mapping = R.reverse(R.sortBy((m) => m.length, mapping));

        // Assign samples to partitions
        let mappedIds = new Set();
        data = mapping.map((m) => {
          const newIds = intersect(R.props(m, selectedCohorts));
          const validIds = [...newIds].filter((id) => !mappedIds.has(id));
          mappedIds = new Set([...mappedIds, ...validIds]);
          return {
            sets: m,
            name: R.join(
              " & ",
              m.map((v) => selectedCohorts[v].name)
            ),
            cohortNames: m.map((v) => selectedCohorts[v].name),
            value: validIds.length,
            ids: validIds,
          };
        });

        // Ignore partitions without any samples
        data = R.filter((r) => r.value, data);

        // Sort partitions by number of samples (decending)
        data = R.reverse(R.sortBy((r) => r.value, data));

        // Create bubbles
        data.forEach((item, i) => {
          const exists = item.sets;
          const missing = R.difference(originalSeq, exists);
          bubbles = R.concat(
            bubbles,
            exists.map((idx) => ({
              x: idx,
              y: i,
              exists: true,
              cohortName: selectedCohorts[idx].name,
            }))
          );
          bubbles = R.concat(
            bubbles,
            missing.map((idx) => ({
              x: idx,
              y: i,
              exists: false,
              cohortName: selectedCohorts[idx].name,
            }))
          );
        });
      }

      const margin = { top: 30, right: 20, bottom: 30, left: 20 };

      const offset = 60;
      const padding = 10;

      const x = d3
        .scaleLinear()
        .domain([0, d3.max(data, (d) => d.value)])
        .range([margin.left + offset + padding, width - margin.right]);

      const x2 = d3
        .scaleLinear()
        .domain([0, numCohorts])
        .range([margin.left, margin.left + offset]);

      const x3 = d3
        .scalePoint()
        .domain([data.length > 1 ? "Cohorts" : ""])
        .range([
          margin.left,
          margin.left + offset - offset / (numCohorts + 0.01),
        ]);

      const x4 = d3
        .scalePoint()
        .domain([data.length > 1 ? "Sample Count" : ""])
        .range([margin.left + offset + padding, width - margin.right]);
      const y = d3
        .scaleBand()
        .domain(d3.range(data.length))
        .rangeRound([margin.top, height - margin.bottom])
        .padding(0.1);

      const xAxis = (g) =>
        g
          .attr("transform", `translate(0,${margin.top})`)
          .call(
            d3
              .axisBottom(x)
              .ticks(width / 80)
              .tickSize(height - margin.top - margin.bottom)
          )
          .call((_) =>
            _.selectAll(".tick:not(:first-of-type) line")
              .attr("stroke-opacity", 0.5)
              .attr("stroke-dasharray", "2,2")
          )
          .call((_) => _.select(".domain").remove());

      const xAxis3 = (g) =>
        g
          .attr("transform", `translate(0,${margin.top - 3})`)
          .call(d3.axisTop(x3).ticks().tickSizeInner(0))
          .call((_) => _.select(".domain").remove());

      const xAxis4 = (g) =>
        g
          .attr("transform", `translate(0,${margin.top - 3})`)
          .call(d3.axisTop(x4).ticks().tickSizeInner(0))
          .call((_) => _.select(".domain").remove());

      const tooltip = d3.select("#tooltip").join("div");
      tooltip.html("");

      // Display nothing if less than two partitions
      if (data < 2) {
        tooltip.html("No partitions to show.");
      }

      // Bars for each partition
      const bars = svg
        .selectAll("g.bars")
        .attr("fill", "steelblue")
        .selectAll("rect")
        .data(data)
        .join("rect")
        .attr("x", x(0))
        .attr("y", (d, i) => y(i))
        .attr("width", (d) => x(d.value) - x(0))
        .attr("height", y.bandwidth());

      // bars
      //   .transition()
      //   .delay((d, i) => i * 100 + 1200)
      //   .duration(600)
      //   .ease(d3.easeLinear);

      const showTooltip = (name) => {
        tooltip
          .style("opacity", 0)
          .style("font-size", `${Math.min(14, 1.8 * (width / name.length))}px`);
        tooltip.transition().duration(100).style("opacity", 1);
        tooltip.html(name);
      };

      bars
        .on("mouseover", (e, d) => {
          d3.select(e.target)
            .transition()
            .duration(100)
            .attr("fill", "darkorange");
          showTooltip(d.name);
        })
        .on("mouseout", (e) => {
          d3.select(e.target)
            .transition()
            .duration(100)
            .attr("fill", "steelblue");
        });

      bars.on("click", (e, d) => {
        const sampleIds = d.ids;
        const patientIds = matchPatientIds(sampleIds);
        saveCohort(patientIds, sampleIds, "Intersection", d.cohortNames);
      });

      // Bubbles to demonstrate the cohort relations
      const bubs = svg
        .selectAll("g.bubbles")
        .selectAll("circle")
        .data(bubbles)
        .join("circle")
        .attr("r", y.bandwidth() / 9)
        .attr("cx", (d) => x2(d.x))
        .attr("cy", (d) => y(d.y) + y.bandwidth() / 2)
        .attr("fill", "steelblue")
        .style("opacity", (d) => (d.exists ? 1 : 0.3));

      // bubs
      //   .transition()
      //   .delay((d, i) => i * 100 + 1200)
      //   .duration(600)
      //   .ease(d3.easeLinear);

      bubs.on("mouseover", (e, d) => {
        if (d.exists) {
          d3.select(e.target)
            .transition()
            .duration(100)
            .attr("fill", "darkorange");
          showTooltip(d.cohortName);
        }
      });

      bubs.on("mouseout", (e, d) => {
        if (d.exists) {
          d3.select(e.target)
            .transition()
            .duration(100)
            .attr("fill", "steelblue");
        }
      });

      // Labels for bars showing count
      svg
        .selectAll("g.labels")
        .attr("fill", "white")
        .attr("text-anchor", "end")
        .attr("font-family", "sans-serif")
        .attr("font-size", 12)
        .selectAll("text")
        .data(data)
        .join("text")
        .attr("x", (d) => x(d.value))
        .attr("y", (d, i) => y(i) + y.bandwidth() / 2)
        .attr("dy", "0.35em")
        .attr("dx", -4)
        .text((d) => d.value)
        .call((text) =>
          text
            .filter((d) => x(d.value) - x(0) < 20) // short bars
            .attr("dx", +4)
            .attr("fill", "black")
            .attr("text-anchor", "start")
        );

      svg.selectAll("g.x-axis").call(xAxis);
      svg.selectAll("g.x-axis-3").call(xAxis3);
      svg.selectAll("g.x-axis-4").call(xAxis4);
    }
  }, [selectedCohorts, saveCohort, matchPatientIds]);

  return (
    <div>
      <div
        style={{
          height: "28px",
          padding: "5px",
          borderRadius: "3px",
          backgroundColor: "rgba(138, 155, 168, 0.15)",
          textAlign: "center",
        }}
        id="tooltip"
      />
      <svg
        ref={ref}
        style={{
          height: "312",
          width: "330",
          marginRight: "0px",
          marginLeft: "0px",
        }}
      >
        <g className="axis x-axis" />
        <g className="axis x-axis-3" />
        <g className="axis x-axis-4" />
        <g className="bubbles" />
        <g className="bars" />
        <g className="labels" />
        <g className="circle" />
      </svg>
    </div>
  );
};

Chart.propTypes = {
  selectedCohorts: PropTypes.arrayOf(
    PropTypes.shape({
      id: PropTypes.string.isRequired,
      name: PropTypes.string.isRequired,
      patientIds: PropTypes.arrayOf(PropTypes.string).isRequired,
      sampleIds: PropTypes.arrayOf(PropTypes.string).isRequired,
    })
  ).isRequired,
  saveCohort: PropTypes.func.isRequired,
  matchPatientIds: PropTypes.func.isRequired,
};

export default Chart;
