import classNames from 'classnames';
import Color from 'color';
import * as d3 from 'd3';
import _ from 'lodash';
import React, {
  FC,
  useCallback,
  useEffect,
  useMemo,
  useRef,
  useState,
} from 'react';
import {
  AreaSeries,
  Borders,
  FlexibleXYPlot,
  GradientDefs,
  HeatmapSeries,
  HorizontalGridLines,
  LineSeries,
  LineSeriesCanvas,
  LineSeriesPoint,
  LineSeriesProps,
  MarkSeries,
  MarkSeriesCanvas,
  MarkSeriesPoint,
  XAxis,
  YAxis,
} from 'react-vis';
import DelayedRender from '../../containers/DelayedRender';
import * as globals from '../../css/globals.styles';
import '../../css/LinePlot.less';
import * as InteractStateActions from '../../state/views/interactState/actions';
import * as InteractStateContext from '../../state/views/interactState/context';
import {getTextColor} from '../../util/colors';
import {useUnicornFeatureFlagEnabled} from '../../util/featureFlags';
import {
  convertTimestepToSeconds,
  formatXAxisNonTime,
  formatYAxis,
  getAxisStyleForFontSize,
  getPlotMargin,
  LegendPosition,
  Line,
  PlotFontSize,
  Timestep,
  XAxisType,
  YAxisType,
} from '../../util/plotHelpers';
import makeComp from '../../util/profiler';
import {truncateString} from '../../util/runhelpers';
import {makePropsAreEqual} from '../../util/shouldUpdate';
import {markToIcon} from '../LineStylePicker';
import FancyLegend from './FancyLegend';
import {HeatmapSeriesCanvas} from './HeatmapSeries';
import {DomainArea} from './highlight';
import * as S from './LinePlot.styles';
import LinePlotCrosshair from './LinePlotCrosshair';

export type Domain = [number, number];
export type DomainMaybe = [number | null, number | null];

interface LinePlotPlotProps {
  domRef: React.RefObject<HTMLDivElement>;
  xAxis: string;
  yScale: YAxisType;
  lines: Line[];
  xScale: 'linear' | 'log';
  yAxis: string;
  xDomain: Domain;
  yDomain: Domain;
  lastDrawLocation?: any;
  xAxisTitle?: string;
  yAxisTitle?: string;
  monotonic?: boolean;
  singleRun?: boolean;
  showLegend?: boolean;
  legendPosition: LegendPosition;
  fontSize: PlotFontSize;
  svg?: boolean;
}

const LinePlotPlotComp: FC<LinePlotPlotProps> = makeComp(
  ({
    domRef,
    xAxis,
    yScale,
    lines,
    xScale,
    yAxis,
    xDomain,
    yDomain,
    xAxisTitle,
    yAxisTitle,
    singleRun,
    showLegend = false,
    legendPosition,
    fontSize,
    svg = false,
  }: LinePlotPlotProps) => {
    const highlightRun = InteractStateContext.useInteractState(
      interactState => interactState.highlight['run:name']
    );

    let xType: XAxisType = 'linear';
    if (xAxis === 'Wall Time') {
      xType = 'time';
    } else if (xScale === 'log') {
      xType = 'log';
    }

    // Swap to canvas or SVG based on rendering mode
    const HeatmapSeriesComponent = svg ? HeatmapSeries : HeatmapSeriesCanvas;

    const LineSeriesComponent: React.FC<LineSeriesProps> = svg
      ? (LineSeries as any)
      : LineSeriesCanvas;

    const xAxisTickTotal = xType === 'log' ? 3 : 5;
    const yAxisTickTotal = yScale === 'log' ? 2 : 5;
    const horizontalGridLinesTickTotal = yScale === 'log' ? 3 : 5;
    const fontStyles = getAxisStyleForFontSize(fontSize);

    return (
      <FlexibleXYPlot
        className={'line-plot-rv'}
        key={showLegend + legendPosition + fontSize} // needed to force rerender when showLegend or legendPosition change
        ref={domRef as any}
        margin={getPlotMargin({
          axisKeys: {xAxis, yAxis},
          axisDomain: {yAxis: yDomain},
          axisType: {yAxis: yScale},
          tickTotal: {yAxis: yAxisTickTotal},
          fontSize,
        })}
        xDomain={xDomain}
        yDomain={yDomain}
        yType={yScale}
        xType={xType}
        padding={{top: 10}}>
        <HorizontalGridLines
          style={{strokeWidth: 0.5}}
          tickTotal={horizontalGridLinesTickTotal}
        />
        {
          lines
            .sort((aLine, bLine) => {
              // put the highlighted lines in front
              const aLineHighlighted = highlightRun === aLine.uniqueId ? 1 : 0;
              const bLineHighlighted = highlightRun === bLine.uniqueId ? 1 : 0;
              return bLineHighlighted - aLineHighlighted;
            })
            .map((line, i) => {
              let strokeDashArray: any =
                line.mark === 'dashed'
                  ? [4, 2]
                  : line.mark === 'dotted'
                  ? [1, 1]
                  : line.mark === 'dotdash'
                  ? [5, 2, 1, 2]
                  : line.mark === 'dotdotdash'
                  ? [5, 2, 1, 2, 1, 2]
                  : undefined;
              strokeDashArray = svg
                ? strokeDashArray?.join(',')
                : strokeDashArray;

              const highlighted =
                lines.length > 1 &&
                !singleRun &&
                line.run &&
                highlightRun === line.uniqueId;

              const highlightedAreaColor = Color(line.color)
                .alpha(0.3)
                .rgb()
                .string();
              const highlightedAreaStrokeColor = Color(line.color)
                .darken(1)
                .rgb()
                .string();
              const transparentAreaStrokeColor = Color(line.color)
                .alpha(0)
                .rgb()
                .string();

              const strokeWidth = (line.lineWidth ?? 1) * 1.5;
              const highlightedStrokeWidth = strokeWidth + 1.5;

              return line.type === 'area' ? (
                <AreaSeries
                  key={i}
                  // color is only for shaded area
                  color={highlighted ? highlightedAreaColor : line.color}
                  opacity={1}
                  data={line.data}
                  getNull={d => d.y !== null}
                  // stroke is the outline around shaded area
                  stroke={
                    highlighted
                      ? highlightedAreaStrokeColor
                      : transparentAreaStrokeColor
                  }
                />
              ) : line.type === 'scatter' ? (
                <MarkSeries
                  key={i}
                  color={line.color}
                  data={line.data as MarkSeriesPoint[]}
                  getNull={d => d.y !== null}
                  size={2}
                />
              ) : line.type === 'heatmap' ? (
                <HeatmapSeriesComponent
                  key={i}
                  colorRange={['#f0f0f0', line.color]}
                  data={line.data}
                  size={2}
                />
              ) : line.type === 'points' || line.mark === 'points' ? (
                <MarkSeriesCanvas
                  key={i}
                  color={line.color}
                  data={line.data as LineSeriesPoint[]}
                  size={2}
                />
              ) : (
                <LineSeriesComponent
                  key={i}
                  color={line.color}
                  data={line.data as LineSeriesPoint[]}
                  getNull={d => d.y !== null}
                  size={2}
                  strokeDasharray={strokeDashArray}
                  strokeStyle={'solid'}
                  strokeWidth={
                    highlighted ? highlightedStrokeWidth : strokeWidth
                  }
                />
              );
            })
            .reverse() // reverse so the top line is the first line
        }
        {lines.map((line, i) => {
          if (
            line.type === 'area' ||
            line.type === 'scatter' ||
            line.type === 'heatmap'
          ) {
            return null;
          }

          // add a dot mark to the end of each line
          return (
            <MarkSeries
              key={i}
              color={line.color}
              data={[line.data[line.data.length - 1] as MarkSeriesPoint]}
              getNull={d => d.y !== null}
              size={2}
            />
          );
        })}
        <GradientDefs>
          <linearGradient id="leftFadeGradient" x1="0" x2="1" y1="0" y2="0">
            <stop offset="0%" stopColor="white" stopOpacity={1} />
            <stop offset="60%" stopColor="white" stopOpacity={1} />
            <stop offset="100%" stopColor="white" stopOpacity={0} />
          </linearGradient>
          <linearGradient id="rightFadeGradient" x1="1" x2="0" y1="0" y2="0">
            <stop offset="0%" stopColor="white" stopOpacity={1} />
            <stop offset="100%" stopColor="white" stopOpacity={0} />
          </linearGradient>
          <linearGradient id="bottomFadeGradient" x1="0" x2="0" y1="1" y2="0">
            <stop offset="0%" stopColor="white" stopOpacity={1} />
            <stop offset="40%" stopColor="white" stopOpacity={1} />
            <stop offset="100%" stopColor="white" stopOpacity={0} />
          </linearGradient>
          <linearGradient id="topFadeGradient" x1="0" x2="0" y1="0" y2="1">
            <stop offset="0%" stopColor="white" stopOpacity={1} />
            <stop offset="100%" stopColor="white" stopOpacity={0} />
          </linearGradient>
        </GradientDefs>
        <Borders
          className="plot-border"
          style={{
            left: {fill: 'url(#leftFadeGradient)', opacity: 1},
            right: {fill: 'url(#rightFadeGradient)', opacity: 1},
            bottom: {fill: 'url(#bottomFadeGradient)', opacity: 1},
            top: {fill: 'url(#topFadeGradient)', opacity: 1},
          }}
        />

        <XAxis
          // This is the line that marks the x axis
          tickTotal={0}
          on0
        />

        <XAxis
          // This is the legend
          title={xAxisTitle || truncateString(xAxis)}
          style={{
            ...fontStyles,
            line: {strokeWidth: 0},
          }}
          tickFormat={formatXAxisNonTime(xType, xDomain[0], xDomain[1])}
          // React vis (or maybe d3) doesn't seem to respect the ticktotal
          // well in the case of log scale, so smaller vale is safer.  See WB-4525
          tickTotal={xAxisTickTotal}
        />

        <XAxis
          // These are the little tick marks
          width={0}
          tickTotal={xAxisTickTotal}
          tickFormat={tick => ''}
        />

        <YAxis
          title={yAxisTitle}
          tickFormat={tick => formatYAxis(tick)}
          tickTotal={yAxisTickTotal}
          style={{
            ...fontStyles,
            line: {stroke: 'none'},
          }}
        />
      </FlexibleXYPlot>
    );
  },
  {
    id: 'LinePlotPlotComp',
    memo: makePropsAreEqual({
      name: 'LinePlotPlot',
      deep: ['yDomain', 'xDomain'],
      ignoreFunctions: true,
      debug: false,
      verbose: false,
    }),
  }
);

const LinePlotPlot = makeComp(
  (props: LinePlotPlotProps) => {
    return <LinePlotPlotComp {...props} />;
  },
  {id: 'LinePlotPlot'}
);

interface LinePlotProps {
  entityName?: string;
  projectName?: string;
  xDomain?: DomainMaybe;
  yDomain?: DomainMaybe;
  lines: Line[];
  timestep?: Timestep;
  xScale?: 'linear' | 'log';
  yScale?: 'linear' | 'log';
  showLegend?: boolean;
  disableRunLinks?: boolean;
  xAxis: string; // Name of the xAxis
  yAxis: string; // Name of the yAxis
  xAxisTitle?: string; // optional string for exactly this xaxis title displayed in graph
  yAxisTitle?: string; // optional string for exactly this yaxis title displayed in graph
  legendPrefix?: React.Component | JSX.Element;
  legendPosition?: LegendPosition;
  fontSize?: PlotFontSize;
  ignoreOutliers?: boolean;
  zooming?: boolean; // is zooming happening
  parentWidth: number; // width of parent element for crosshair truncating
  singleRun: boolean; // is inside a single run page (turns off highlighting)
  svg?: boolean;
  overrideLineTitles?: {
    [key: string]: string;
  };
  overrideColors?: {
    [key: string]: {
      color: string;
      transparentColor: string;
    };
  };
  overrideMark?: {
    [key: string]: string;
  };
  setYScale?(yScale: string): void;
  zoomCallback?(xAxisMin?: number, xAxisMax?: number): void;
}

const LinePlot: FC<LinePlotProps> = makeComp(
  ({
    xDomain: userXDomain,
    yDomain: userYDomain,
    lines,
    timestep,
    xScale = 'linear',
    yScale = 'linear',
    showLegend,
    disableRunLinks,
    xAxis,
    yAxis,
    xAxisTitle,
    yAxisTitle,
    legendPrefix,
    ignoreOutliers,
    parentWidth,
    singleRun,
    svg,
    legendPosition = 'north',
    fontSize = 'small',
    zoomCallback,
  }: LinePlotProps) => {
    const domRef = useRef<HTMLDivElement | null>(null);
    const [hideCrosshair, setHideCrosshair] = useState(false);
    const [lastDrawLocation, setLastDrawLocation] = useState<DomainArea | null>(
      null
    );
    const [, setRecentResize] = useState(true);
    const [zoomedXAxis, setZoomedXAxis] = useState(false);
    const [zoomedYAxis, setZoomedYAxis] = useState(false);

    const unicornCrosshair = useUnicornFeatureFlagEnabled();

    const setHighlightRun = InteractStateContext.useInteractStateAction(
      InteractStateActions.setHighlight
    );

    const onWindowResize = useCallback(() => {
      setRecentResize(true);
    }, []);

    useEffect(() => {
      window.addEventListener('resize', onWindowResize);
      return () => {
        window.removeEventListener('resize', onWindowResize);
      };
    }, [onWindowResize]);

    const filteredLines: Line[] = lines.map(line => {
      const dataWithFinitePoints = line.data.filter(
        pt => _.isFinite(pt.x) && _.isFinite(pt.y)
      );
      const dataWithPositiveX = dataWithFinitePoints.filter(pt => pt.x > 0);
      return {
        ...line,
        data: xScale === 'log' ? dataWithPositiveX : dataWithFinitePoints,
      };
    });
    // NOTE: remove performs a mutation
    const hiddenLines = _.remove(filteredLines, line => line.data.length === 0);

    const isHeatmap = filteredLines[0]?.type === 'heatmap';

    const [calcXDomain, calcYDomain] = useMemo(() => {
      let retXDomain = userXDomain != null ? [...userXDomain] : [null, null];
      let retYDomain = userYDomain != null ? [...userYDomain] : [null, null];
      if (zoomedXAxis && lastDrawLocation != null) {
        retXDomain = [lastDrawLocation.left, lastDrawLocation.right];
      }
      if (zoomedYAxis && lastDrawLocation != null) {
        retYDomain = [lastDrawLocation.bottom, lastDrawLocation.top];
      }

      if (
        retXDomain[0] != null &&
        retXDomain[1] != null &&
        retYDomain[0] != null &&
        retYDomain[1] != null
      ) {
        return [retXDomain, retYDomain];
      }

      if (filteredLines.length === 0) {
        return [retXDomain, retYDomain];
      }

      const filteredPoints = _.flatMap(filteredLines, line => {
        let points = line.data;
        if (zoomedXAxis && lastDrawLocation != null) {
          points = points.filter(
            point =>
              point.x >= lastDrawLocation.left &&
              point.x <= lastDrawLocation.right
          );
        } else if (userXDomain != null) {
          points = points.filter(
            point =>
              (userXDomain[0] == null || point.x >= userXDomain[0]) &&
              (userXDomain[1] == null || point.x <= userXDomain[1])
          );
        }
        return points;
      });

      const [xMinRaw, xMaxRaw] = d3.extent(
        filteredPoints.map(point => point.x)
      );
      let xMin = xMinRaw ?? 0;
      let xMax = xMaxRaw ?? 0;

      let minYPoints = filteredPoints.map(point =>
        point.y0 != null ? Math.min(point.y0, point.y) : point.y
      );
      let maxYPoints = filteredPoints.map(point =>
        point.y0 != null ? Math.max(point.y0, point.y) : point.y
      );

      // TODO(aswanberg): Consider adding padding like tensorboard does.
      if (ignoreOutliers) {
        minYPoints = _.sortBy(minYPoints);
        maxYPoints = _.sortBy(maxYPoints);

        const lowerBound = d3.quantile(minYPoints, 0.05) || -Infinity;
        const upperBound = d3.quantile(maxYPoints, 0.95) || Infinity;

        minYPoints = minYPoints.filter(yVal => yVal >= lowerBound);
        maxYPoints = maxYPoints.filter(yVal => yVal <= upperBound);
      }

      let yMin = _.min(minYPoints) ?? 0;
      let yMax = _.max(maxYPoints) ?? 0;

      const xRange = xMax - xMin;
      const yRange = yMax - yMin;

      if (xRange === 0) {
        if (isHeatmap) {
          // ridiculous hacky fix to single step heatmap
          xMin = -0.49998;
          xMax = 0.49999;
        } else {
          xMin -= Math.abs(xMin);
          xMax += xMax === 0 ? 1 : Math.abs(xMax);
        }
      }

      // arbitrary epsilon to account for artifacts when smoothing constant lines
      if (yRange < 1e-8) {
        yMin -= yMax === 0 ? 2 : Math.abs(yMin);
        yMax += yMax === 0 ? 2 : Math.abs(yMax);
      }

      if (yScale !== 'log') {
        // smart guess for when the x axis should be y=0
        const smartSnapFactor = zoomedXAxis ? 0.5 : 2;
        if (yMin > 0) {
          if (yMin < yRange * smartSnapFactor) {
            yMin = 0;
          }
        }
        if (yMax < 0) {
          if (-yMax < yRange * smartSnapFactor) {
            yMax = 0;
          }
        }
      }

      retXDomain[0] = retXDomain[0] ?? xMin;
      retXDomain[1] = retXDomain[1] ?? xMax;
      retYDomain[0] = retYDomain[0] ?? yMin;
      retYDomain[1] = retYDomain[1] ?? yMax;

      return [retXDomain, retYDomain];
    }, [
      yScale,
      filteredLines,
      userXDomain,
      userYDomain,
      lastDrawLocation,
      ignoreOutliers,
      zoomedXAxis,
      zoomedYAxis,
      isHeatmap,
    ]);

    // After calculating the domain, let's filter the lines a bit again. We want to avoid
    // a bug (https://github.com/wandb/core/issues/2323) where lines aren't drawn when
    // there are extreme outliers present. So let's clamp points to an offset of the domain.
    // We'll still use the unclamped points in the crosshair so as to not obviously present "wrong" data.
    const clampedFilteredLines = useMemo(() => {
      if (calcYDomain[0] == null || calcYDomain[1] == null) {
        return filteredLines;
      }

      const yDomainLowerBound = calcYDomain[0] - 100000;
      const yDomainUpperBound = calcYDomain[1] + 100000;
      return filteredLines.map(line => ({
        ...line,
        data: line.data.map(pt => ({
          ...pt,
          y: _.clamp(pt.y, yDomainLowerBound, yDomainUpperBound),
        })),
      }));
    }, [filteredLines, calcYDomain]);

    // if user passed in multiple heatmaps only display the first one and nothing else
    const linesOrHeatmap = useMemo(() => {
      const heatmap = clampedFilteredLines.find(
        line => line.type === 'heatmap'
      );
      if (heatmap == null) {
        return clampedFilteredLines;
      }
      return [heatmap];
    }, [clampedFilteredLines]);

    if (filteredLines.length === 0) {
      return null;
    }

    if (
      calcXDomain[0] == null ||
      calcXDomain[1] == null ||
      calcYDomain[0] == null ||
      calcYDomain[1] == null
    ) {
      return <S.InvalidDataAlert>Unable to process data.</S.InvalidDataAlert>;
    }
    const notNullXDomain = calcXDomain as Domain;
    const notNullYDomain = calcYDomain as Domain;

    if (
      dataOutOfBounds(notNullXDomain, userXDomain) ||
      dataOutOfBounds(notNullYDomain, userYDomain)
    ) {
      return (
        <S.InvalidDataAlert>
          Data out of bounds. Try adjusting the axis range.
        </S.InvalidDataAlert>
      );
    }

    const legendOrientation =
      legendPosition === 'east' || legendPosition === 'west'
        ? 'vertical'
        : 'horizontal';

    return (
      <S.LinePlot
        className="line-plot"
        ref={domRef}
        legendPosition={legendPosition}>
        {showLegend && (
          <S.LinePlotLegend
            className="line-plot-legend"
            fontSize={fontSize}
            orientation={legendOrientation}>
            {legendPrefix}
            {lines
              .filter(line => !line.aux)
              .map((line, i) => {
                const hidden: boolean =
                  _.find(hiddenLines, l => l.uniqueId === line.uniqueId) !=
                  null;
                const color = !hidden ? line.color : globals.gray500;

                return (
                  <S.LegendEntry
                    orientation={legendOrientation}
                    className={classNames(
                      'legend-entry',
                      hidden && 'line-plot-title-disabled'
                    )}
                    key={'line' + i}>
                    <span
                      key={'line-plot-color' + i}
                      className="line-plot-color"
                      style={{display: 'inline', color}}>
                      {markToIcon(line.mark ?? 'solid')}
                    </span>
                    <span
                      className="line-plot-title"
                      onMouseOver={() => {
                        if (!singleRun) {
                          setHighlightRun('run:name', line.uniqueId);
                        }
                      }}
                      onMouseOut={() => {
                        if (!singleRun) {
                          setHighlightRun('run:name', undefined);
                        }
                      }}
                      style={{color: getTextColor(color)}}
                      key={line.title}>
                      {line.fancyTitle &&
                        FancyLegend({
                          ...line.fancyTitle,
                          disableRunLink: disableRunLinks,
                        })}
                    </span>
                  </S.LegendEntry>
                );
              })}
          </S.LinePlotLegend>
        )}

        <DelayedRender>
          <div
            className={classNames(
              'plot-container',
              'line-plot-container',
              lastDrawLocation?.classes,
              {unicorn: unicornCrosshair}
            )}>
            <LinePlotPlot
              xAxis={xAxis}
              yAxis={yAxis}
              xAxisTitle={xAxisTitle}
              yAxisTitle={yAxisTitle}
              xDomain={notNullXDomain}
              yDomain={notNullYDomain}
              xScale={xScale}
              yScale={yScale}
              lines={linesOrHeatmap}
              lastDrawLocation={lastDrawLocation}
              domRef={domRef}
              singleRun={singleRun}
              legendPosition={legendPosition}
              fontSize={fontSize}
              showLegend={showLegend}
              svg={svg}
            />
            <div className={'line-plot-crosshair'}>
              <LinePlotCrosshair
                xDomain={
                  isHeatmap
                    ? [notNullXDomain[0] - 0.5, notNullXDomain[1] + 0.5]
                    : notNullXDomain
                }
                yDomain={notNullYDomain}
                xScale={xScale}
                yScale={yScale}
                xAxis={xAxis}
                yAxis={yAxis}
                lines={filteredLines}
                lastDrawLocation={lastDrawLocation}
                parentWidth={parentWidth}
                singleRun={singleRun}
                fontSize={fontSize}
                onMouseDown={() => {
                  setHideCrosshair(true);
                  setRecentResize(false);
                }}
                onMouseUp={() => setHideCrosshair(false)}
                hideCrosshair={hideCrosshair}
                onBrushEnd={(area, newZoomedXAxis, newZoomedYAxis) => {
                  setLastDrawLocation(area);
                  setZoomedXAxis(newZoomedXAxis);
                  setZoomedYAxis(newZoomedYAxis);
                  if (zoomCallback != null) {
                    let left = area?.left;
                    let right = area?.right;
                    if (timestep != null) {
                      if (left != null) {
                        left = convertTimestepToSeconds(left, timestep);
                      }
                      if (right != null) {
                        right = convertTimestepToSeconds(right, timestep);
                      }
                    }
                    zoomCallback(left, right);
                  }
                }}
              />
            </div>
          </div>
        </DelayedRender>
      </S.LinePlot>
    );
  },
  {
    id: 'LinePlot',
    memo: makePropsAreEqual({
      name: 'LinePlot',
      deep: ['yDomain', 'xDomain'],
      ignoreFunctions: true,
      debug: false,
      verbose: false,
    }),
  }
);

export default LinePlot;

function dataOutOfBounds(calcDomain: Domain, userDomain?: DomainMaybe) {
  if (userDomain == null || calcDomain == null) {
    return false;
  }

  const [boundsMin, boundsMax] = userDomain;
  const [dataMin, dataMax] = calcDomain;

  const boundsMaxSmallerThanDataMin =
    boundsMax != null && dataMin != null && boundsMax < dataMin;
  const boundsMinBiggerThanDataMax =
    boundsMin != null && dataMax != null && boundsMin > dataMax;

  return boundsMaxSmallerThanDataMin || boundsMinBiggerThanDataMax;
}
