import React, { useRef, useEffect } from "react";
import * as d3 from "d3";
import PropTypes from "prop-types";

const WaterfallChart = ({ data, width = 1100, height = 800 }) => {
  const svgRef = useRef();

  useEffect(() => {
    d3.select(svgRef.current).selectAll("*").remove();

    const margin = { top: 60, right: 20, bottom: 80, left: 80 };
    const chartWidth = width - margin.left - margin.right;
    const chartHeight = height - margin.top - margin.bottom;

    const groupedData = d3
      .rollups(
        data,
        (v) => ({
          label: v[0].label,
          value: d3.sum(v, (d) => d.value),
          type: v[0].type,
        }),
        (d) => d.label
      )
      .map(([, value]) => value);

    const types = Array.from(new Set(groupedData.map((d) => d.type)));
    let cumulativeTotal = 0;

    const waterfallData = groupedData.map((d, i) => {
      const start = cumulativeTotal;
      cumulativeTotal += d.value;
      return {
        ...d,
        start,
        end: cumulativeTotal,
      };
    });

    const typeTotals = types.map((type) => ({
      type,
      value: d3.sum(
        groupedData.filter((d) => d.type === type && d.value > 0),
        (d) => d.value
      ),
    }));

    const totalBarData = typeTotals
      .sort((a, b) => a.value - b.value)
      .reduce((acc, d, i) => {
        const previousEnd = acc.length ? acc[acc.length - 1].end : 0;
        acc.push({
          label: "Total",
          type: d.type,
          value: d.value,
          start: previousEnd,
          end: previousEnd + d.value,
        });
        return acc;
      }, []);

    const extendedData = [...waterfallData, ...totalBarData];

    if (d3.max(extendedData, (d) => d.end) === 0) {
      d3.select(svgRef.current)
        .append("text")
        .attr("x", chartWidth / 2)
        .attr("y", chartHeight / 2)
        .attr("text-anchor", "middle")
        .attr("font-size", "16px")
        .attr("fill", "#666")
        .text("No data to display");
      return;
    }

    const svg = d3
      .select(svgRef.current)
      .attr("viewBox", `0 0 ${width} ${height}`)
      .attr("preserveAspectRatio", "xMinYMin meet")
      .append("g")
      .attr("transform", `translate(${margin.left},${margin.top})`);

    const x = d3
      .scaleBand()
      .domain(extendedData.map((d) => d.label))
      .range([0, chartWidth])
      .padding(0.3);

    const y = d3
      .scaleLinear()
      .domain([
        d3.min(extendedData, (d) => d.start),
        d3.max(extendedData, (d) => d.end) * 1.2,
      ])
      .nice()
      .range([chartHeight, 0]);

    const colorScale = d3
      .scaleOrdinal()
      .domain(types)
      .range(d3.schemeCategory10);

    svg
      .append("g")
      .attr("transform", `translate(0,${chartHeight})`)
      .call(d3.axisBottom(x).tickSize(0))
      .selectAll(".tick text")
      .attr("text-anchor", "middle")
      .style("font-size", "14px")
      .attr("dy", "1em")
      .each(function (d) {
        const text = d3.select(this);
        const labelText = d.startsWith("ESRS") ? d.replace("ESRS ", "") : d; // Remove "ESRS" prefix
        const maxChars = 6;
        if (labelText.length > maxChars) {
          const truncated = labelText.slice(0, maxChars) + "…";
          text.text(truncated);
          text.append("title").text(labelText);
        } else {
          text.text(labelText);
        }
      });

    const tooltip = d3
      .select("body")
      .append("div")
      .attr("class", "tooltip")
      .style("position", "absolute")
      .style("visibility", "hidden")
      .style("background", "rgba(0, 0, 0, 0.7)")
      .style("color", "white")
      .style("padding", "5px 10px")
      .style("border-radius", "4px")
      .style("box-shadow", "0 0 10px rgba(0,0,0,0.1)");

    svg
      .selectAll(".bar")
      .data(waterfallData)
      .enter()
      .append("rect")
      .attr("class", "bar")
      .attr("x", (d) => x(d.label))
      .attr("y", (d) => y(Math.max(d.start, d.end)))
      .attr("width", x.bandwidth())
      .attr("height", (d) => Math.abs(y(d.start) - y(d.end)))
      .attr("fill", (d) => colorScale(d.type))
      .on("mouseover", (event, d) => {
        tooltip
          .style("visibility", "visible")
          .html(`Label: ${d.label}<br>Value: ${d.value}`);
      })
      .on("mousemove", (event) => {
        tooltip
          .style("top", `${event.pageY - 10}px`)
          .style("left", `${event.pageX + 10}px`);
      })
      .on("mouseout", () => {
        tooltip.style("visibility", "hidden");
      });

    svg
      .selectAll(".connector")
      .data(waterfallData.slice(0, -1))
      .enter()
      .append("line")
      .attr("class", "connector")
      .attr("x1", (d) => x(d.label) + x.bandwidth() / 2)
      .attr("x2", (d, i) => x(waterfallData[i + 1].label) + x.bandwidth() / 2)
      .attr("y1", (d) => y(d.end))
      .attr("y2", (d, i) => y(waterfallData[i + 1].start))
      .attr("stroke", "#999")
      .attr("stroke-width", 1);

    // Render the total bar with type divisions
    svg
      .selectAll(".total-bar")
      .data(totalBarData)
      .enter()
      .append("rect")
      .attr("class", "total-bar")
      .attr("x", x("Total"))
      .attr("y", (d) => y(d.end))
      .attr("width", x.bandwidth())
      .attr("height", (d) => Math.abs(y(d.start) - y(d.end)))
      .attr("fill", (d) => colorScale(d.type))
      .on("mouseover", (event, d) => {
        tooltip
          .style("visibility", "visible")
          .html(
            `<strong>Type:</strong> ${d.type}<br><strong>Value:</strong> ${d.value}`
          );
      })
      .on("mousemove", (event) => {
        tooltip
          .style("top", `${event.pageY - 10}px`)
          .style("left", `${event.pageX + 10}px`);
      })
      .on("mouseout", () => {
        tooltip.style("visibility", "hidden");
      });

    // Add labels for the total bar
    svg
      .selectAll(".total-label")
      .data(totalBarData.filter((d) => d.value !== 0)) // Filter out 0 values
      .enter()
      .append("text")
      .attr("class", "total-label")
      .attr("x", x("Total") + x.bandwidth() / 2)
      .attr("y", (d) => y(d.start + d.value / 3)) // Center the text in the segment
      .attr("text-anchor", "middle")
      .attr("fill", "white") // Adjust text color based on segment size
      .attr("font-weight", "bold")
      .text((d) => d.value);
    svg
      .selectAll(".bar-label")
      .data(waterfallData)
      .enter()
      .append("text")
      .attr("class", "bar-label")
      .attr("x", (d) => x(d.label) + x.bandwidth() / 2)
      .attr("y", (d) => y(d.end) - 5)
      .attr("text-anchor", "middle")
      .attr("fill", "black")
      .text((d) => d.value);
  }, [data, height, width]);

  return <svg ref={svgRef} style={{ width: "100%", height: "100%" }}></svg>;
};

WaterfallChart.propTypes = {
  data: PropTypes.arrayOf(
    PropTypes.shape({
      label: PropTypes.string,
      value: PropTypes.number,
      type: PropTypes.string,
    })
  ).isRequired,
  width: PropTypes.number,
  height: PropTypes.number,
};

export default WaterfallChart;
