import '../../css/LinePlot.less';

import _ from 'lodash';
import React, {useRef, useMemo} from 'react';
import {Sparklines, SparklinesBars, SparklinesLine} from 'react-sparklines';
import {
  Crosshair,
  FlexibleXYPlot,
  MarkSeries,
  MarkSeriesPoint,
} from 'react-vis';

import {notEmpty} from '@wandb/cg/browser/utils/obj';
import {Line, getPlotMargin, PlotFontSize} from '../../util/plotHelpers';
import Highlight, {BrushEndHandler} from './highlight';
import {markToIcon} from '../LineStylePicker';
import {legendTemplateInsertCrosshairValues} from '../../util/legend';

import * as InteractStateContext from '../../state/views/interactState/context';
import * as InteractStateActions from '../../state/views/interactState/actions';
import makeComp from '../../util/profiler';
import {getTextColor} from '../../util/colors';

interface LinePlotCrosshairProps {
  lines: Line[];
  xDomain: number[];
  yDomain: number[];
  xScale: 'linear' | 'log';
  yScale: 'linear' | 'log';
  xAxis: string;
  yAxis: string;
  onMouseDown: React.MouseEventHandler;
  onMouseUp: React.MouseEventHandler;
  lastDrawLocation?: any;
  onBrushEnd: BrushEndHandler;
  hideCrosshair: boolean;
  parentWidth: number;
  singleRun: boolean;
  fontSize: PlotFontSize;
}

function crossHairValuesFromLines(
  lines: Line[],
  highlightX: number,
  xMin: number,
  xMax: number
) {
  let clippedHighlightX = highlightX;
  if (clippedHighlightX < xMin) {
    clippedHighlightX = xMin;
  }
  if (clippedHighlightX > xMax) {
    clippedHighlightX = xMax;
  }
  return lines
    .map(line => {
      // Binary-search
      let pointIndex = _.sortedIndexBy(
        line.data,
        {x: clippedHighlightX, y: 0},
        pt => pt.x
      );
      if (pointIndex < 0 || pointIndex > line.data.length) {
        // We'll filter this out after this map().
        return null;
      }
      let point = null;
      if (pointIndex === line.data.length) {
        if (highlightX! < line.data[line.data.length - 1].x + 1) {
          point = line.data[line.data.length - 1];
        } else {
          return null;
        }
      } else {
        // Show the point we're closest to
        const nextPoint = line.data[pointIndex];
        const prevPointIndex = Math.max(pointIndex - 1, 0);
        const prevPoint = line.data[prevPointIndex];
        point = prevPoint;
        if (
          prevPoint.x < xMin ||
          (nextPoint.x - highlightX! < highlightX! - prevPoint.x &&
            nextPoint.x <= xMax)
        ) {
          // use the point to the right of the cursor
          point = nextPoint;
        } else {
          // use the point to the left of the cursor
          pointIndex = prevPointIndex;
        }
      }

      const minmaxPoint =
        line.minmaxLine != null ? line.minmaxLine.data[pointIndex] : undefined;
      const stddevPoint =
        line.stddevLine != null ? line.stddevLine.data[pointIndex] : undefined;

      return {
        title: line.title || '',
        color: line.color,
        x: point.x,
        y: point.y,
        mark: line.mark,
        name: line.run ? line.run.name : undefined,
        uniqueId: line.uniqueId,
        minmax: minmaxPoint ? [minmaxPoint.y0, minmaxPoint.y] : undefined,
        stddev: stddevPoint ? stddevPoint.y - point.y : undefined,
        percent: point.legendData?.percent,
        total: point.legendData?.total,
        value: point.legendData?.y ?? point.y,
        original: point.legendData?.original,
      };
    })
    .filter(notEmpty);
}

type CrosshairFlagContentProps = {
  truncateLegend: boolean;
  enabledLines: Line[];
  crosshairValues: MarkSeriesPoint[];
  highlightRun?: string;
  xAxis: string;
  maxLength?: number;
};
const CrosshairFlagContent = makeComp(
  (props: CrosshairFlagContentProps) => {
    const {
      crosshairValues,
      enabledLines,
      highlightRun,
      xAxis,
      truncateLegend,
      maxLength,
    } = props;

    const sparkLines =
      enabledLines.length > 0 && enabledLines[0].type === 'heatmap';

    return (
      <div
        style={{
          color: '#333',
          border: '1px solid #eee',
          padding: '8px 12px',
          background: 'white',
          whiteSpace: 'nowrap',
          lineHeight: '120%',
          position: 'relative',
          fontFamily: 'Inconsolata, monospace',
          width: sparkLines ? 200 : undefined,
        }}>
        {sparkLines ? (
          <Sparklines
            data={crosshairValues[0].values}
            style={{width: '100%', height: '100%'}}>
            <SparklinesLine color={crosshairValues[0].color} />
            <SparklinesBars
              barWidth={8}
              style={{
                fill: crosshairValues[0].color,
                fillOpacity: 0.75,
              }}
            />
          </Sparklines>
        ) : (
          <div>
            {crosshairValues
              .sort((a, b) => (b.y as number) - (a.y as number))
              .map((point, i) => (
                <div
                  key={point.title + ' ' + i}
                  className={
                    point.uniqueId === highlightRun &&
                    crosshairValues != null && // no need to show highlight if only one run
                    crosshairValues.length > 1
                      ? 'highlighted'
                      : undefined
                  }>
                  <span
                    style={{
                      display: 'inline-block',
                      color: getTextColor(point.color), // Tooltip uses text color when available
                    }}>
                    {markToIcon(point.mark || 'solid')}
                    <span>
                      {legendTemplateInsertCrosshairValues(
                        point.name + ' ' + i,
                        point.title,
                        true,
                        {
                          x: {xAxis, val: point.x},
                          y: point.value,
                          mean: point.value,
                          min:
                            point.minmax && point.minmax.length >= 2
                              ? point.minmax[0]
                              : undefined,
                          max:
                            point.minmax && point.minmax.length >= 2
                              ? point?.minmax[1]
                              : undefined,
                          stddev: point.stddev,
                          percent: point.percent,
                          total: point.total,
                          original: point.original,
                        },
                        'line',
                        truncateLegend,
                        maxLength || 10
                      )}
                    </span>
                  </span>
                </div>
              ))}
          </div>
        )}
      </div>
    );
  },
  {
    id: 'CrosshairFlagContent',
  }
);

const LinePlotCrosshair = makeComp(
  (props: LinePlotCrosshairProps) => {
    const {
      xAxis,
      xDomain,
      yAxis,
      yDomain,
      xScale,
      yScale,
      onMouseDown,
      lines,
      onMouseUp,
      onBrushEnd,
      hideCrosshair,
      parentWidth,
      singleRun,
    } = props;
    let crosshairValues: MarkSeriesPoint[] | null = null;
    const yAxisTickTotal = yScale === 'log' ? 2 : 5;

    const crosshairFlagRef = useRef<HTMLDivElement>(null);

    // trying to estimate the most characters that will fit in the crosshair box without
    // overflowing
    const CROSSHAIR_CHAR_WIDTH = 10;
    const CROSSHAIR_MAX_CHAR_COUNT = 15;
    const crosshairContentMaxLength = Math.max(
      10,
      parentWidth / CROSSHAIR_CHAR_WIDTH - CROSSHAIR_MAX_CHAR_COUNT
    );

    // const [highlightRun, setHighlightRun] = React.useState(null as string | null);

    const [domRef, [highlightX, highlightRun]] =
      InteractStateContext.useInteractStateWhenOnScreen(interactState => [
        interactState.highlight[xAxis] as number | undefined,
        interactState.highlight['run:name'] as string,
      ]);

    const setHighlights = InteractStateContext.useInteractStateAction(
      InteractStateActions.setHighlights
    );

    const enabledLines = useMemo(
      () => lines.filter(line => !line.aux),
      [lines]
    );

    const {highlights, crosshairCount} = useMemo(() => {
      // False line is a straight line that fits within the chart data. We render
      // it transparently and just use it for it's onNearestX callback.
      // BUT IT NO LONGER EXISTS, HA
      // so this code may be out of place now--consider refactoring
      if (enabledLines.length > 0) {
        if (enabledLines[0].type === 'heatmap') {
          // pre-compute highlights for histogram
          const hls: any = {};
          enabledLines[0].data.forEach(row => {
            hls[row.x] = hls[row.x] || [
              {
                x: row.x,
                title: 'histogram',
                color: enabledLines[0].color,
                values: [],
              },
            ];
            hls[row.x][0].values.push(row.color);
          });
          return {
            highlights: hls,
            crosshairCount: Object.keys(hls).length,
          };
        } else {
          return {
            highlights: null,
            crosshairCount: Math.max(
              ...enabledLines.map(line => line.data.length)
            ),
          };
        }
      }
      return {highlights: null, crosshairCount: 0};
    }, [enabledLines]);

    if (highlightX != null) {
      if (highlights == null) {
        crosshairValues = crossHairValuesFromLines(
          enabledLines,
          highlightX,
          xDomain[0],
          xDomain[1]
        );
      } else {
        const xVals = Object.keys(highlights).map(s => parseInt(s, 10));
        if (xVals.length > 0) {
          const closest = xVals.reduce((prev, curr) =>
            Math.abs(curr - highlightX) < Math.abs(prev - highlightX)
              ? curr
              : prev
          );
          crosshairValues = highlights[closest];
        }
      }
    }

    // willChange: transform forces the crosshair into it's own rendering layer
    // which means the plot behind won't be re-rendered when the crosshair moves.
    // This improved FPS from 2 to >60 on large plots.
    return (
      <div style={{width: '100%', height: '100%'}} ref={domRef}>
        <FlexibleXYPlot
          style={{willChange: 'transform'}}
          margin={getPlotMargin({
            axisKeys: {xAxis, yAxis},
            axisDomain: {
              yAxis: yDomain,
            },
            axisType: {yAxis: yScale},
            tickTotal: {yAxis: yAxisTickTotal},
            fontSize: props.fontSize,
          })}
          animation={false}
          xDomain={xDomain}
          yDomain={yDomain}
          xType={xScale}
          yType={yScale}
          onMouseLeave={() => {
            setHighlights([
              {axis: xAxis, value: undefined},
              {axis: 'run:name', value: undefined},
            ]);
            // setHighlightRun(null);
          }}
          onMouseDown={onMouseDown}
          dontCheckIfEmpty>
          <Highlight
            stepCount={crosshairCount}
            onMouseUp={onMouseUp}
            onBrushEnd={
              ((area, zoomedXAxis, zoomedYAxis) => {
                onBrushEnd(area, zoomedXAxis, zoomedYAxis);
              }) as BrushEndHandler
            }
            onMouseMoveWithXY={(mouseX: number, mouseY: number) => {
              const newHighlights: Array<{
                axis: string;
                value: string | number | undefined;
              }> = [{axis: xAxis, value: mouseX}];

              // Compute the next crosshair values now, and then find the
              // the line whose crosshair point y-value is nearest to
              // mouseY. This is a little strange but it mostly works.
              const newCrosshairValues = crossHairValuesFromLines(
                enabledLines,
                mouseX,
                xDomain[0],
                xDomain[1]
              ).sort((a, b) => (a.y as number) - (b.y as number));
              if (newCrosshairValues.length > 0) {
                let lineIndex = _.sortedIndexBy(
                  newCrosshairValues,
                  {y: mouseY} as any,
                  pt => pt.y
                );
                if (lineIndex === newCrosshairValues.length) {
                  lineIndex = lineIndex - 1;
                } else if (lineIndex !== 0 && newCrosshairValues.length > 1) {
                  lineIndex =
                    newCrosshairValues[lineIndex].y - mouseY <
                    mouseY - newCrosshairValues[lineIndex - 1].y
                      ? lineIndex
                      : lineIndex - 1;
                }
                newHighlights.push({
                  axis: 'run:name',
                  value: newCrosshairValues[lineIndex].uniqueId,
                });
                // setHighlightRun(newCrosshairValues[lineIndex].title || null);
              }
              setHighlights(newHighlights);
            }}
          />
          {!hideCrosshair &&
            crosshairValues &&
            crosshairValues.length > 0 &&
            crosshairValues[0].values == null && (
              <MarkSeries size={2} colorType="literal" data={crosshairValues} />
            )}
          {!hideCrosshair && crosshairValues && crosshairValues.length > 0 && (
            <Crosshair
              values={
                highlightX != null
                  ? [
                      crosshairValues.reduce((closest, pt) =>
                        Math.abs((pt.x as number) - highlightX!) <
                        Math.abs((closest.x as number) - highlightX!)
                          ? pt
                          : closest
                      ),
                    ]
                  : undefined
              }>
              {(() => {
                const truncatedCrosshairFlagContent = (
                  <CrosshairFlagContent
                    xAxis={xAxis}
                    crosshairValues={crosshairValues}
                    enabledLines={enabledLines}
                    highlightRun={singleRun ? undefined : highlightRun}
                    truncateLegend={true}
                    maxLength={
                      crosshairContentMaxLength
                    }></CrosshairFlagContent>
                );

                const longCrosshairFlagContent = (
                  <CrosshairFlagContent
                    xAxis={xAxis}
                    crosshairValues={crosshairValues}
                    enabledLines={enabledLines}
                    highlightRun={singleRun ? undefined : highlightRun}
                    truncateLegend={false}></CrosshairFlagContent>
                );

                /* crosshairFlagContent is duplicated as part of a hack to
                prevent the flag from getting cut off by the panel boundaries 
                the div with name line-plot-flag-escaping has display: none overridden elsewhere
                only for the highlighted panel. For the other panels this div doesn't appear. */
                return (
                  <>
                    <div className="line-plot-flag" ref={crosshairFlagRef}>
                      {truncatedCrosshairFlagContent}
                    </div>
                    <div
                      className="line-plot-flag-escaping"
                      style={{
                        position: 'fixed',
                        display: 'none',
                        marginTop: crosshairFlagRef.current
                          ? -crosshairFlagRef.current.clientHeight
                          : 0,
                      }}>
                      {longCrosshairFlagContent}
                    </div>
                  </>
                );
              })()}
            </Crosshair>
          )}
        </FlexibleXYPlot>
      </div>
    );
  },
  {id: 'LinePlotCrosshair'}
);

export default LinePlotCrosshair;
