import * as S from './PanelScatterPlot.styles';
import '../css/PanelScatterPlot.less';

import React, {useCallback, useState} from 'react';
import _, {isNumber} from 'lodash';
import {Tab, Checkbox} from 'semantic-ui-react';
import memoize from 'memoize-one';
import useResizeObserver from 'use-resize-observer';

import {
  XAxis,
  YAxis,
  FlexibleXYPlot,
  FlexibleWidthXYPlot,
  VerticalGridLines,
  HorizontalGridLines,
  MarkSeries,
  MarkSeriesPoint,
  VerticalRectSeries,
  Hint,
  LineSeries,
  LineSeriesPoint,
} from 'react-vis';
import '../../node_modules/react-vis/dist/style.css';
import * as urls from '../util/urls';
import SliderInput from './elements/SliderInput';
import LabeledOption from './elements/LabeledOption';
import * as ColorUtil from '../util/colors';
import BoxSelection, {BoxSelectionDrawArea} from './vis/BoxSelection';
import {toRunsDataQuery, RunWithRunsetInfo} from '../containers/RunsDataLoader';
import Input from './Input';
import ProjectFieldSelector from './ProjectFieldSelector';
import {PLASMA_GRADIENT_SPARSE_HEX} from '../util/colors';
import * as SM from '../util/selectionmanager';
import * as Run from '../util/runs';
import * as Filter from '../util/filters';
import * as Panels from '../util/panels';
import * as Query from '../util/queryts';
import {smooth} from '../util/math';
import {
  getPlotMargin,
  formatYAxis,
  getAxisStyleForFontSize,
  PlotFontSizeOrAuto,
} from '../util/plotHelpers';
import PanelError from './elements/PanelError';
import RangeInput from './elements/RangeInput';
import HelpPopup from './elements/HelpPopup';
import {useGatedValue} from '../state/hooks';
import PanelTitle from './elements/PanelTitle';
import * as ViewHooks from '../state/views/hooks';
import * as InteractStateActions from '../state/views/interactState/actions';
import * as InteractStateContext from '../state/views/interactState/context';
import * as GroupSelectionsActionsInternal from '../state/views/groupSelections/actionsInternal';
import {
  GradientPicker,
  DEFAULT_GRADIENT,
  gradientToRGBArray,
  GradientStop,
} from '../components/GradientPicker';
import {SelectionBoundsMenu} from './SelectionBoundsMenu';
import WandbLoader from './WandbLoader';
import {extent} from 'd3';
import {useParams, useHistory} from 'react-router';
import makeComp from '../util/profiler';
import {makeFetchProjectFieldOptions} from '../util/dropdownQueries';

import {useSampleAndQueryToTable} from './Export';

const PANEL_TYPE = 'Scatter Plot';

const INDEX_KEY = 'Index';
const DEFAULT_POINT_COLOR = '#3355bb';
const NOT_VISIBLE_COLOR = '#55555522';

export interface ScatterPlotConfig {
  xAxis?: string;
  yAxis?: string;
  zAxis?: string;
  chartTitle?: string; //
  showMaxYAxisLine?: boolean; //
  showMinYAxisLine?: boolean; //
  showAvgYAxisLine?: boolean; //
  yAxisLineSmoothingWeight?: number;
  color?: string;
  minColor?: string;
  maxColor?: string;
  xAxisLogScale?: boolean; //
  yAxisLogScale?: boolean; //
  zAxisLogScale?: boolean; //
  xAxisMin?: number; //
  xAxisMax?: number; //
  yAxisMin?: number; //
  yAxisMax?: number; //
  zAxisMin?: number; //
  zAxisMax?: number; //
  legendFields?: string[];
  customGradient?: GradientStop[];
  fontSize?: PlotFontSizeOrAuto;
}

type ScatterPlotProps = Panels.PanelProps<ScatterPlotConfig>;

interface ScatterDataPoint {
  x: Run.Value;
  y: Run.Value;
  run?: RunWithRunsetInfo;
  runName: string;
  uniqueId: string;
  groupKeys: Query.Grouping | undefined; // we need the grouping to get the right color
  colorIndex: // The z-value
  string | number | boolean | Run.BasicValue[] | Run.WBValue | null | undefined;
  color?: string | number | undefined; // The actual color displayed
  legend?: {
    [key: string]: string;
  };
  highlight?: boolean;
}

//  static type = 'Scatter Plot';
//  static pickerDisplayName = 'Scatter Plot';

const scatterPlotTransformQuery = (
  query: Query.Query,
  config: ScatterPlotConfig
) => {
  const result = toRunsDataQuery(query, undefined, {
    page: {size: 500},
  });
  const filterKeys: Run.Key[] = [];
  if (config.xAxis) {
    const key = Run.keyFromString(config.xAxis);
    if (key && (key.section === 'config' || key.section === 'summary')) {
      filterKeys.push(key);
    }
  }
  if (config.yAxis) {
    const key = Run.keyFromString(config.yAxis);
    if (key && (key.section === 'config' || key.section === 'summary')) {
      filterKeys.push(key);
    }
  }
  if (config.zAxis) {
    const key = Run.keyFromString(config.zAxis);
    if (key && (key.section === 'config' || key.section === 'summary')) {
      filterKeys.push(key);
    }
  }
  (config.legendFields || []).map(Run.keyFromString).forEach(key => {
    if (key != null) {
      filterKeys.push(key);
    }
  });

  let displayFields = filterKeys;
  // We need the grouping values to color properly
  if (query.runSets != null) {
    query.runSets.forEach((rs: any) => {
      if (rs.grouping) {
        displayFields = displayFields.concat(rs.grouping);
      }
    });
  }

  // we need the keys from selections to color the visible runs differently
  // We look for the selections in the first runset which is probably
  // something we should change.  query.selections appears to be undefined.
  if (query.runSets && query.runSets[0] && query.runSets[0].selections) {
    Filter.treeForEach(query.runSets[0].selections, (f: Filter.Filter) => {
      if (Filter.isIndividual(f)) {
        filterKeys.push(f.key);
      }
    });
  }

  result.configKeys = displayFields
    .filter(key => key && key.section === 'config')
    .map(key => key!.name);
  result.summaryKeys = displayFields
    .filter(key => key && key.section === 'summary')
    .map(key => key!.name);

  if (filterKeys.length > 0 && result.queries.length > 0) {
    result.queries[0].filters = Filter.And([
      result.queries[0].filters,
      ...filterKeys.map(key => ({
        key,
        op: '!=' as Filter.ValueOp,
        value: null as Run.Value,
      })),
    ]);
  }
  return result;
};

const useTableData = (pageQuery: Query.Query, config: ScatterPlotConfig) => {
  const query = scatterPlotTransformQuery(pageQuery, config);
  return useSampleAndQueryToTable(query, pageQuery, config);
};

const computeData = memoize(
  (
    filtered: RunWithRunsetInfo[],
    xAxis?: string,
    yAxis?: string,
    zAxis?: string,
    xAxisMin?: number,
    xAxisMax?: number,
    yAxisMin?: number,
    yAxisMax?: number,
    zAxisMin?: number,
    zAxisMax?: number,
    zAxisLogScale?: boolean,
    legendKeys?: string[],
    query?: Query.Query
  ) => {
    let zRange = [0, 1];
    const gradientData = [];
    let data = filtered
      .map((run, idx) => {
        const legend: {
          [key: string]: string;
        } = {};
        if (legendKeys) {
          legendKeys.forEach(key => {
            const val = getAxisValue(run, key, idx);
            legend[key] = Run.displayValue(val);
          });
        }
        const runset = query != null ? Run.lookupRunset(run, query) : undefined;
        const grouping = runset != null ? runset.grouping : undefined;
        const displayName =
          grouping != null && grouping.length > 0
            ? Run.groupedRunDisplayName(run, grouping)
            : run.displayName;

        const point: ScatterDataPoint = {
          x: getAxisValue(run, xAxis || '', idx),
          y: getAxisValue(run, yAxis || '', idx),
          run,
          runName: displayName,
          uniqueId: Run.uniqueId(run, grouping || []),
          groupKeys: grouping,
          colorIndex: zAxis
            ? zAxisLogScale
              ? Math.log(getAxisValue(run, zAxis, idx) as number)
              : getAxisValue(run, zAxis, idx)
            : null,
          legend,
        };
        return point;
      })
      .filter(point => point.x != null && point.y != null);

    const xVals = data
      .filter(point => point.x != null)
      .map(point => Number(point.x));
    let [xMin, xMax] = extent(xVals);
    if (xAxisMin != null || xAxisMax != null) {
      const minVal = xAxisMin != null ? xAxisMin : xMin || 0;
      const maxVal = xAxisMax != null ? xAxisMax : xMax || 0;
      xMin = minVal;
      xMax = maxVal;
      data = data.filter(point => {
        if (xAxisMin != null) {
          if (Number(point.x) < xAxisMin) {
            return false;
          }
        }
        if (xAxisMax != null) {
          if (Number(point.x) > xAxisMax) {
            return false;
          }
        }
        return true;
      });
    }
    const xDomain = xMin != null && xMax != null ? [xMin, xMax] : null;

    const yVals = data
      .filter(point => point.y != null)
      .map(point => Number(point.y));
    let [yMin, yMax] = extent(yVals);
    if (yAxisMin != null || yAxisMax != null) {
      const minVal = yAxisMin != null ? yAxisMin : yMin || 0;
      const maxVal = yAxisMax != null ? yAxisMax : yMax || 0;
      yMin = minVal;
      yMax = maxVal;
      data = data.filter(point => {
        if (yAxisMin != null) {
          if (Number(point.y) < yAxisMin) {
            return false;
          }
        }
        if (yAxisMax != null) {
          if (Number(point.y) > yAxisMax) {
            return false;
          }
        }
        return true;
      });
    }
    const yDomain = yMin != null && yMax != null ? [yMin, yMax] : null;

    if (zAxis) {
      const zVals = data
        .filter(point => point.colorIndex)
        .map(point => Number(point.colorIndex));
      if (zAxisMin != null || zAxisMax != null) {
        data = data.filter(point => {
          if (zAxisMin != null) {
            if (
              Number(point.colorIndex) <
              (zAxisLogScale ? Math.log(zAxisMin) : zAxisMin)
            ) {
              return false;
            }
          }
          if (zAxisMax != null) {
            if (
              Number(point.colorIndex) >
              (zAxisLogScale ? Math.log(zAxisMax) : zAxisMax)
            ) {
              return false;
            }
          }
          return true;
        });
      }
      let zMin = zAxisMin != null ? zAxisMin : _.min(zVals) || 0;
      let zMax = zAxisMax != null ? zAxisMax : _.max(zVals) || 0;

      if (zAxisLogScale) {
        zMin = Math.exp(zMin);
        zMax = Math.exp(zMax);
      }
      zRange = [zMin, zMax];
      const breaks = 50;
      const range = zMax - zMin;
      const step = range / breaks;
      for (let i = 0; i < breaks; i++) {
        const val = zMin + step * i;
        gradientData.push({
          x0: val,
          x: val + step,
          y0: 0,
          y: 1,
          fill: val,
          color: val,
        });
      }
    }

    return {data, gradientData, xDomain, yDomain, zRange};
  }
);

const ScatterPlotGraphConfig: React.FC<ScatterPlotGraphProps> = makeComp(
  props => {
    const zAxisInUse =
      props.config.zAxis != null &&
      props.config.zAxis !== '' &&
      Run.keyStringDisplayName(props.config.zAxis) !== '';

    const fetchProjectFieldOptions = React.useMemo(
      () =>
        makeFetchProjectFieldOptions(
          props.pageQuery.entityName,
          props.pageQuery.projectName,
          {types: ['number'], extraKeys: [INDEX_KEY, ...Run.TIME_KEYS]}
        ),
      [props.pageQuery.entityName, props.pageQuery.projectName]
    );

    const dataTab = (
      <Tab.Pane as="div" className="form-grid">
        <LabeledOption
          label="X"
          helpText="Variable logged with wandb.config, wandb.summary or wandb.log to use for X axis"
          option={
            <S.StyledSearchableSelect
              options={fetchProjectFieldOptions}
              infiniteScroll
              value={props.config.xAxis}
              onSelect={v =>
                props.updateConfig({
                  ...props.config,
                  xAxis: v as string | undefined,
                })
              }></S.StyledSearchableSelect>
          }
        />
        <LabeledOption
          label="Y"
          helpText="Variable logged with wandb.config, wandb.summary or wandb.log to use for Y axis"
          option={
            <S.StyledSearchableSelect
              options={fetchProjectFieldOptions}
              infiniteScroll
              value={props.config.yAxis}
              onSelect={v =>
                props.updateConfig({
                  ...props.config,
                  yAxis: v as string | undefined,
                })
              }></S.StyledSearchableSelect>
          }
        />
        <LabeledOption
          label="Z"
          helpText="Variable logged with wandb.config, wandb.summary or wandb.log to use for Z axis, represented in color."
          option={
            <S.StyledSearchableSelect
              options={fetchProjectFieldOptions}
              infiniteScroll
              value={props.config.zAxis}
              onSelect={v =>
                props.updateConfig({
                  ...props.config,
                  zAxis: v as string | undefined,
                })
              }></S.StyledSearchableSelect>
          }
        />
        <LabeledOption
          label="X Axis"
          helpText="Minimum and maximum values for the x axis.  Disabled if x axis is not a number."
          option={
            <RangeInput
              disabled={
                props.config.xAxis == null ||
                Run.isTimeKeyString(props.config.xAxis)
              }
              onMinChange={newVal => {
                props.updateConfig({xAxisMin: newVal});
              }}
              onMaxChange={newVal => {
                props.updateConfig({xAxisMax: newVal});
              }}
              minValue={props.config.xAxisMin}
              maxValue={props.config.xAxisMax}
              log
              logValue={props.config.xAxisLogScale}
              onLogChange={() =>
                props.updateConfig({
                  ...props.config,
                  xAxisLogScale: !props.config.xAxisLogScale,
                })
              }
            />
          }
        />

        <LabeledOption
          label="Y Axis"
          helpText="Minimum and maximum values for the y axis.  Disabled if y axis is not a number."
          option={
            <div className="range-and-log">
              <RangeInput
                disabled={
                  props.config.yAxis == null ||
                  Run.isTimeKeyString(props.config.yAxis)
                }
                onMinChange={newVal => {
                  props.updateConfig({yAxisMin: newVal});
                }}
                onMaxChange={newVal => {
                  props.updateConfig({yAxisMax: newVal});
                }}
                minValue={props.config.yAxisMin}
                maxValue={props.config.yAxisMax}
                log
                logValue={props.config.yAxisLogScale}
                onLogChange={() =>
                  props.updateConfig({
                    ...props.config,
                    yAxisLogScale: !props.config.yAxisLogScale,
                  })
                }
              />
            </div>
          }
        />
        <LabeledOption
          label="Z Axis"
          helpText="Minimum and maximum values for the z axis.  Disabled if z axis is not a number."
          option={
            <RangeInput
              disabled={
                !zAxisInUse || Run.isTimeKeyString(props.config.zAxis || '')
              }
              onMinChange={newVal => {
                props.updateConfig({zAxisMin: newVal});
              }}
              onMaxChange={newVal => {
                props.updateConfig({zAxisMax: newVal});
              }}
              minValue={props.config.zAxisMin}
              maxValue={props.config.zAxisMax}
              log
              logValue={props.config.zAxisLogScale}
              onLogChange={() =>
                props.updateConfig({
                  ...props.config,
                  zAxisLogScale: !props.config.zAxisLogScale,
                })
              }
            />
          }
        />
      </Tab.Pane>
    );

    const labelTab = (
      <Tab.Pane as="div" className="form-grid">
        <LabeledOption
          label="Title"
          option={
            <Input
              placeholder={defaultTitle(props.config)}
              value={props.config.chartTitle || ''}
              onChange={(e, {value}) => {
                props.updateConfig({
                  chartTitle: value,
                });
              }}
            />
          }
        />

        <br />
        <p>Legend Fields</p>
        <ProjectFieldSelector
          className="legend"
          disabled={false}
          query={props.pageQuery}
          types={['string', 'number', 'boolean']}
          defaultKeys={[
            'run:displayName',
            'run:name',
            'run:createdAt',
            'run:userName',
          ]}
          fluid
          multi
          selection
          searchByKeyAndText
          value={props.config.legendFields || ['run:displayName']}
          setValue={value => props.updateConfig({legendFields: value})}
        />
      </Tab.Pane>
    );

    const annotationsTab = (
      <Tab.Pane as="div">
        <p>
          Options designed for plotting metrics over time and tracking the
          progress of a project.
        </p>
        <div className="time-plot-options">
          <Checkbox
            label={{
              children: (
                <>
                  Plot running maximum y values
                  <HelpPopup helpText="This shows a line that scans from left to right over the x axis and plots the running maximum y value.  It's especially useful when x axis is time and you want to plot your maximum metric value over time." />
                </>
              ),
            }}
            checked={props.config.showMaxYAxisLine || false}
            onChange={(e, data) =>
              props.updateConfig({
                ...props.config,
                showMaxYAxisLine: data.checked || false,
              })
            }
          />

          <Checkbox
            label={{
              children: (
                <>
                  Plot running minimum y values
                  <HelpPopup helpText="This shows a line that scans from left to right over the x axis and plots the running minimum y value.  It's especially useful when x axis is time and you want to plot your minimum metric value over time." />
                </>
              ),
            }}
            checked={props.config.showMinYAxisLine || false}
            onChange={(e, data) =>
              props.updateConfig({
                ...props.config,
                showMinYAxisLine: data.checked || false,
              })
            }
          />
          <Checkbox
            label={{
              children: (
                <>
                  Plot running average y values
                  <HelpPopup helpText="This shows a line that scans from left to right over the x axis and plots the running average y value.  The smoothing parameter will control how much averaging happens." />
                </>
              ),
            }}
            checked={props.config.showAvgYAxisLine || false}
            onChange={(e, data) =>
              props.updateConfig({
                ...props.config,
                showAvgYAxisLine: data.checked || false,
              })
            }
          />

          {props.config.showAvgYAxisLine && (
            <div className="smoothing-average">
              <p>Smoothing for running average</p>
              <SliderInput
                min={0}
                max={0.999}
                step={0.01}
                value={props.config.yAxisLineSmoothingWeight || 0}
                debounceTime={100}
                hasInput
                onChange={value =>
                  props.updateConfig({
                    yAxisLineSmoothingWeight: value,
                  })
                }
              />
            </div>
          )}
        </div>
      </Tab.Pane>
    );

    const settingsPanes = [
      {
        menuItem: 'Data',
        render: () => dataTab,
      },
      {
        menuItem: 'Labels',
        render: () => labelTab,
      },
      {
        menuItem: 'Annotations',
        render: () => annotationsTab,
      },
    ];

    if (props.config.zAxis && zAxisInUse) {
      const gradientTab = (
        <Tab.Pane as="div">
          {' '}
          <GradientPicker
            defaultGradient={
              props.config.customGradient
                ? {
                    type: 'customGradient',
                    gradient: props.config.customGradient,
                  }
                : DEFAULT_GRADIENT
            }
            setGradient={newGradient => {
              props.updateConfig({customGradient: newGradient});
            }}
          />
        </Tab.Pane>
      );
      settingsPanes.push({
        menuItem: 'Gradient',
        render: () => gradientTab,
      });
    }

    return (
      <div className="chart-modal">
        <div className="chart-preview">
          <ScatterPlotGraph {...props} loading={false} />
        </div>
        <div className="chart-settings">
          <Tab
            panes={settingsPanes}
            menu={{
              secondary: true,
              pointing: true,
              className: 'chart-settings-tab-menu',
            }}
          />
        </div>
      </div>
    );
  },
  {id: 'ScatterPlotGraphConfig'}
);

const getAxisValue = (run: Run.Run, axisKeyString: string, index: number) => {
  if (axisKeyString === 'Index' || axisKeyString === 'run:Index') {
    return index;
  }
  const axisKey = Run.keyFromString(axisKeyString);
  const axisValue = axisKey != null ? Run.getValue(run, axisKey) : null;
  return Run.isTimeKeyString(axisKeyString)
    ? new Date(axisValue as string).getTime()
    : axisValue;
};

// default title cases
// "y and z v. x" if z-axis exists
// "y v. x" if no z-axis
// considered "z v. x" if no y-axis, but this doesn't seem to be possible from the UI anymore
const defaultTitle = (config: ScatterPlotConfig) => {
  const {xAxis, yAxis, zAxis} = config;
  if (xAxis == null || yAxis == null) {
    return '';
  }
  let defaultTitleStr;
  // if z axis enabled
  if (zAxis && zAxis !== '' && Run.keyStringDisplayName(zAxis) !== '') {
    // y and z v. x
    defaultTitleStr =
      Run.keyStringDisplayName(yAxis) +
      ' and ' +
      lowercaseSpecialColumnsInTitle(Run.keyStringDisplayName(zAxis)) +
      ' v. ' +
      lowercaseSpecialColumnsInTitle(Run.keyStringDisplayName(xAxis));
  } else {
    // y v. x
    defaultTitleStr =
      Run.keyStringDisplayName(yAxis) +
      ' v. ' +
      lowercaseSpecialColumnsInTitle(Run.keyStringDisplayName(xAxis));
  }
  return defaultTitleStr;
};

function lowercaseSpecialColumnsInTitle(columnName: string) {
  if (columnName === 'Created') {
    return 'created';
  } else if (columnName === 'End Time') {
    return 'end time';
  } else {
    return columnName;
  }
}

const formatTimestep = (data: any, index: any, scale: any, tickTotal: any) => {
  return Run.formatTimestamp(data);
};

const formatTickXAxis = (axis: any) => {
  return Run.isTimeKeyString(axis) ? formatTimestep : undefined;
};

const formatTickYAxis = (axis: any) => {
  return Run.isTimeKeyString(axis) ? formatTimestep : formatYAxis;
};

const formatHint = (props: ScatterPlotProps) => {
  return (point: ScatterDataPoint) => {
    const {config} = props;
    let legendVals: Array<{
      title: string;
      value: string;
    }> = [];
    if (point.legend) {
      legendVals = Object.keys(point.legend).map(key => {
        const value = point.legend ? point.legend[key] || '' : '';
        let displayValue: string;
        if (
          !isNaN(Number.parseFloat(value)) &&
          isNaN(Number.parseInt(value, 10))
        ) {
          // check if float and fix the precision
          displayValue = Number.parseFloat(value).toPrecision(4).toString();
        } else {
          displayValue = value;
        }
        return {
          title: Run.keyStringDisplayName(key),
          value: displayValue,
        };
      });
    }

    return [
      {
        title: 'Run',
        value: point.runName,
      },
    ]
      .concat(legendVals)
      .concat(
        _.compact(
          ['x', 'y', 'z'].map(axis => {
            const axisValue =
              axis === 'z'
                ? point.colorIndex
                : axis === 'x'
                ? point.x
                : axis === 'y'
                ? point.y
                : undefined;
            const axisKeyString =
              axis === 'x'
                ? config.xAxis
                : axis === 'y'
                ? config.yAxis
                : config.zAxis;
            return (
              axisValue != null && {
                title: Run.keyStringDisplayName(axisKeyString || ''),
                value: Run.isTimeKeyString(axisKeyString || '')
                  ? Run.formatTimestamp(axisValue as string)
                  : Run.displayValue(axisValue),
              }
            );
          })
        )
      );
  };
};

// Timestamps have long axis tick labels, so we need to add margin for them
const getUpperPlotMargin = (xAxis?: string, yAxis?: string, zAxis?: string) => {
  const xIsTime = xAxis && Run.isTimeKeyString(xAxis);
  const yIsTime = yAxis && Run.isTimeKeyString(yAxis);
  const zIsTime = zAxis && Run.isTimeKeyString(zAxis);
  if (!xIsTime && !yIsTime && !zIsTime) {
    return undefined;
  }
  const margin = {bottom: 30, left: 0, top: 5};
  if (xIsTime) {
    margin.top = 5;
    margin.bottom = 55;
    margin.left = 55;
  }
  if (yIsTime) {
    margin.left = 100;
  }
  if (zIsTime) {
    margin.top = 20;
  }
  return margin;
};

const calcRunningYAxisLineData = (
  goal: 'min' | 'max' | 'avg',
  data: ScatterDataPoint[],
  smoothingWeight?: number
) => {
  // Calculate a running best or worst for the case where scatterplot is a
  // metric over time.
  let runningVal: Run.Value = null;

  let lineData = data
    .sort((a, b) => {
      // sort the x axis
      if (a.x == null && b.x == null) {
        return 0;
      }
      if (a.x == null) {
        return 1;
      }
      if (b.x == null) {
        return -1;
      }
      if (isNumber(a.x) && isNumber(b.x)) {
        return a.x - b.x;
      } else {
        return a.x > b.x ? 1 : -1;
      }
    })
    .map(m => {
      if (runningVal === null) {
        runningVal = m.y;
      }
      if (m.y != null && runningVal != null) {
        if (goal === 'min' && m.y < runningVal) {
          runningVal = m.y;
        } else if (goal === 'max' && m.y > runningVal) {
          runningVal = m.y;
        } else if (goal === 'avg') {
          runningVal = m.y;
        }
      }
      return {x: m.x, y: runningVal};
    });

  if (goal === 'avg' && smoothingWeight != null) {
    const smoothedLine = smooth(
      lineData.map(p => p.y as number),
      [],
      smoothingWeight,
      'exponential'
    );
    lineData = lineData.map((p, i) => {
      return {
        x: p.x,
        y: smoothedLine[i],
      };
    });
  }

  return lineData;
};

const rangesFromBounds = (
  props: ScatterPlotProps,
  groupSelection: SM.GroupSelectionState
) => {
  // Gets the x,y and z selection ranges to display
  // as a gray rectangle over the scatterplot.

  let xSelect: SM.PanelSelectionConstraint = {};
  let ySelect: SM.PanelSelectionConstraint = {};
  let zSelect: SM.PanelSelectionConstraint = {};

  if (props.config.xAxis) {
    const key = Run.keyFromString(props.config.xAxis);
    if (key) {
      xSelect = SM.getConstraintsForKey(groupSelection, key);
    }
  }
  if (props.config.yAxis) {
    const key = Run.keyFromString(props.config.yAxis);
    if (key) {
      ySelect = SM.getConstraintsForKey(groupSelection, key);
    }
  }
  if (props.config.zAxis) {
    const key = Run.keyFromString(props.config.zAxis);
    if (key) {
      zSelect = SM.getConstraintsForKey(groupSelection, key);
    }
  }
  return {xSelect, ySelect, zSelect};
};

type ScatterPlotGraphProps = ScatterPlotProps & {
  boxSelection: BoxSelectionDrawArea | null;
  runHighlight: string | null;
  setBoxSelection(newBoxSelection: BoxSelectionDrawArea | null): void;
  setRunHighlight(newRunHighlight: string | null): void;
};

const renderBoxSelection = (
  props: ScatterPlotGraphProps,
  groupSelection: SM.GroupSelectionState,
  setGroupSelection: (state: SM.GroupSelectionState) => void,
  xAxis?: string,
  yAxis?: string,
  xSelect?: SM.PanelSelectionConstraint,
  ySelect?: SM.PanelSelectionConstraint
) => {
  if (
    props.boxSelection == null ||
    xAxis == null ||
    yAxis == null ||
    xSelect == null ||
    ySelect == null
  ) {
    return <></>;
  }

  return (
    <div
      className="selection-bounds-menu"
      style={{
        visibility: 'visible',
        transform: `translate(${props.boxSelection.right + 33}px, ${
          props.boxSelection.top + 5
        }px)`,
      }}>
      {
        <SelectionBoundsMenu
          clearSelections={() => {
            props.setBoxSelection(null);
            setGroupSelection(
              SM.clearSelections([xAxis, yAxis], groupSelection)
            );
          }}
          convertSelectionsToFilters={() => {
            if (props.convertSelectionsToFilters) {
              props.setBoxSelection(null);
              props.convertSelectionsToFilters({
                [xAxis]: xSelect,
                [yAxis]: ySelect,
              });
            }
          }}
        />
      }
    </div>
  );
};

const ScatterPlotGraph: React.FC<ScatterPlotGraphProps> = makeComp(
  props => {
    const {xAxis, yAxis, zAxis} = props.config;
    const runSets = ViewHooks.useParts(props.runSetRefs);
    const groupSelections = ViewHooks.useParts(
      runSets.map(runSet => runSet.groupSelectionsRef)
    );
    const setGroupSelection = ViewHooks.useViewAction(
      runSets[0].groupSelectionsRef,
      GroupSelectionsActionsInternal.setGroupSelections
    );
    const history = useHistory();
    const {entityName, projectName} = useParams<{
      entityName: string;
      projectName: string;
    }>();

    const elementRef = React.useRef<HTMLDivElement>(null);
    const {height = 1} = useResizeObserver<HTMLDivElement>({
      ref: elementRef,
    });

    const {xSelect, ySelect} = rangesFromBounds(props, groupSelections[0]);

    let errorMessage;

    if (!xAxis || !yAxis) {
      errorMessage =
        'Select metrics for the X Axis and Y Axis to visualize data in this scatter plot.';
    }

    if (props.data.filtered.length === 0) {
      errorMessage = `Select runs that logged ${xAxis} and ${yAxis} to visualize data in this line chart.`;
    }

    const {data, gradientData, xDomain, yDomain, zRange} = computeData(
      props.data.filtered,
      xAxis,
      yAxis,
      zAxis,
      props.config.xAxisMin,
      props.config.xAxisMax,
      props.config.yAxisMin,
      props.config.yAxisMax,
      props.config.zAxisMin,
      props.config.zAxisMax,
      props.config.zAxisLogScale,
      props.config.legendFields || [],
      props.pageQuery
    );

    const xScale = props.config.xAxisLogScale ? 'log' : 'linear';
    const yScale = props.config.yAxisLogScale ? 'log' : 'linear';

    const highlight = _.find(
      data,
      point => point.uniqueId === props.runHighlight
    );

    const totalRuns = props.data.totalRuns;
    const maxRuns = props.data.limit;

    const zAxisInUse =
      zAxis != null && zAxis !== '' && Run.keyStringDisplayName(zAxis) !== '';

    let gradientColorRange: string[] = [];
    let gradientColorDomain: number[] = [];

    // React vis maps the values in color domain to color range and interpolates
    // color values that are not explicitly in color domain
    let colorDomain: number[] = [];
    let colorRange: string[] = [];
    let zDomain: number[] = [];

    if (zAxisInUse) {
      const toHex = (value: number) => {
        value = Math.max(0, Math.min(255, Math.round(value) || 0));
        return (value < 16 ? '0' : '') + value.toString(16);
      };

      gradientColorRange =
        props.config.customGradient != null
          ? gradientToRGBArray(props.config.customGradient, 10).map(rgbArr => {
              return (
                '#' +
                toHex(rgbArr[0]) +
                toHex(rgbArr[1]) +
                toHex(rgbArr[2]) +
                'ff'
              );
            })
          : PLASMA_GRADIENT_SPARSE_HEX;

      const zDomainStep =
        (zRange[1] - zRange[0]) / (gradientColorRange.length - 1);
      zDomain = _.map(
        gradientColorRange,
        (v, i) => zRange[0] + i * zDomainStep
      );
      gradientColorDomain = props.config.zAxisLogScale
        ? zDomain.map(z => Math.log(z))
        : zDomain;

      // we set a special color value for the non-selected points
      const notVisibleValue = (_.last(zDomain) || 0) + 1;

      colorRange = _.concat(gradientColorRange, [NOT_VISIBLE_COLOR]);
      colorDomain = _.concat(gradientColorDomain, [notVisibleValue]);
      if (zAxisInUse) {
        data.forEach(p => {
          if (p.run) {
            const checkboxStates = groupSelections.map(
              groupSelection =>
                p.run && SM.getCheckedState(groupSelection, p.run, 1)
            );
            if (checkboxStates.indexOf('checked') !== -1) {
              if (typeof p.colorIndex === 'number') {
                p.color = p.colorIndex;
              } else if (p.colorIndex != null) {
                p.color = p.colorIndex.toString();
              }
            } else {
              p.color = notVisibleValue;
            }
          }
        });
      }
    } else {
      data.forEach(p => {
        if (p.run) {
          const checkboxStates = groupSelections.map(
            groupSelection =>
              p.run && SM.getCheckedState(groupSelection, p.run, 1)
          );
          if (checkboxStates.indexOf('checked') !== -1) {
            p.color = ColorUtil.runColor(
              p.run,
              p.groupKeys || [],
              props.customRunColors
            );
          } else {
            p.color = NOT_VISIBLE_COLOR;
          }
        }
      });
      colorDomain = [0, 1];
      colorRange = [
        NOT_VISIBLE_COLOR,
        props.config.color || DEFAULT_POINT_COLOR,
      ];
    }

    const fontSize = Panels.getFontSize(
      props.config.fontSize ?? 'auto',
      height
    );
    const fontStyles = getAxisStyleForFontSize(fontSize);
    return (
      <div className="chart" ref={elementRef}>
        <PanelTitle
          title={getTitleFromConfig(props.config)}
          searchQuery={props.searchQuery}
          fontSize={fontSize}
        />
        <div className="chart-content">
          {props.data.loading && <WandbLoader />}
          {!props.data.loading && errorMessage != null ? (
            <PanelError message={errorMessage} />
          ) : (
            <>
              {xAxis != null && yAxis != null && (
                <div className={'scatter-plot-container'}>
                  {renderBoxSelection(
                    props,
                    groupSelections[0],
                    setGroupSelection,
                    xAxis,
                    yAxis,
                    xSelect,
                    ySelect
                  )}

                  {maxRuns && totalRuns > maxRuns && (
                    <div style={{float: 'right', marginRight: 15}}>
                      <span style={{...fontStyles, fontStyle: 'italic'}}>
                        Showing first {maxRuns} runs{' '}
                      </span>
                    </div>
                  )}

                  {zAxis && zAxisInUse && (
                    // The chart above the graph showing the zrange
                    <div>
                      <FlexibleWidthXYPlot
                        height={Run.isTimeKeyString(zAxis) ? 70 : 55}
                        margin={getUpperPlotMargin(zAxis, undefined, zAxis)}
                        xType={props.config.zAxisLogScale ? 'log' : 'linear'}
                        yDomain={[0, 1]}>
                        <XAxis
                          title={Run.keyStringDisplayName(zAxis)}
                          tickFormat={formatTickXAxis(zAxis) as any}
                          tickLabelAngle={
                            Run.isTimeKeyString(zAxis) ? -30 : undefined
                          }
                          style={fontStyles}
                        />
                        <VerticalRectSeries
                          colorDomain={gradientColorDomain}
                          colorRange={gradientColorRange}
                          fillDomain={gradientColorDomain}
                          fillRange={gradientColorRange}
                          data={gradientData}
                        />
                      </FlexibleWidthXYPlot>
                    </div>
                  )}
                  <div style={{flexGrow: 1, height: '0px', userSelect: 'none'}}>
                    <FlexibleXYPlot
                      margin={getPlotMargin({
                        axisKeys: {xAxis, yAxis, zAxis},
                        axisDomain: {
                          yAxis: yDomain ?? undefined,
                        },
                        axisType: {yAxis: yScale},
                        fontSize: Panels.getFontSize(
                          props.config.fontSize ?? 'auto',
                          height
                        ),
                      })}
                      xType={xScale}
                      yType={yScale}
                      xDomain={xDomain}
                      yDomain={yDomain}>
                      <VerticalGridLines />
                      <HorizontalGridLines />
                      <BoxSelection
                        xSelect={xSelect}
                        ySelect={ySelect}
                        onSelectChange={(
                          newXSelect,
                          newYSelect,
                          newBoxSelection
                        ) => {
                          // Index is a special option the just sorts the runs in order
                          // it's not really a key - so selection wont work
                          // this behavior is a little weird right now but better than crashing.
                          const xAxisKey =
                            xAxis !== 'Index' ? Run.keyFromString(xAxis) : null;
                          const yAxisKey =
                            yAxis !== 'Index' ? Run.keyFromString(yAxis) : null;
                          let selection = _.cloneDeep(groupSelections[0]);

                          if (
                            newXSelect.low == null &&
                            newXSelect.high == null &&
                            newYSelect.low == null &&
                            newYSelect.high == null
                          ) {
                            props.setBoxSelection(null);
                            selection = SM.clearSelections(
                              [xAxis, yAxis],
                              groupSelections[0]
                            );
                            setGroupSelection(selection);
                          } else {
                            const panelSelections: SM.PanelSelections = {};
                            if (xAxisKey) {
                              panelSelections[xAxis] = newXSelect;
                            }
                            if (yAxisKey) {
                              panelSelections[yAxis] = newYSelect;
                            }
                            props.setBoxSelection(newBoxSelection);
                            selection = SM.setSelectionsToBounds(
                              panelSelections,
                              groupSelections[0]
                            );
                            setGroupSelection(selection);
                          }
                        }}
                      />
                      <XAxis
                        title={Run.keyStringDisplayName(xAxis)}
                        style={fontStyles}
                        tickFormat={formatTickXAxis(xAxis) as any}
                        tickLabelAngle={
                          Run.isTimeKeyString(xAxis) ? -30 : undefined
                        }
                      />
                      <YAxis
                        title={Run.keyStringDisplayName(yAxis)}
                        tickFormat={formatTickYAxis(yAxis) as any}
                        style={fontStyles}
                      />
                      {props.config.showAvgYAxisLine && (
                        <LineSeries
                          data={
                            calcRunningYAxisLineData(
                              'avg',
                              data,
                              props.config.yAxisLineSmoothingWeight || 0
                            ) as LineSeriesPoint[]
                          }
                        />
                      )}
                      {props.config.showMinYAxisLine && (
                        <LineSeries
                          data={
                            calcRunningYAxisLineData(
                              'min',
                              data
                            ) as LineSeriesPoint[]
                          }
                        />
                      )}
                      {props.config.showMaxYAxisLine && (
                        <LineSeries
                          data={
                            calcRunningYAxisLineData(
                              'max',
                              data
                            ) as LineSeriesPoint[]
                          }
                        />
                      )}
                      <MarkSeries
                        colorDomain={zAxisInUse ? colorDomain : undefined}
                        colorRange={zAxisInUse ? colorRange : undefined}
                        colorType={zAxisInUse ? undefined : 'literal'}
                        data={data as MarkSeriesPoint[]}
                        size={4}
                        onValueMouseOver={value => {
                          props.setRunHighlight(value.uniqueId);
                          // setHighlight('run:name', value.uniqueId);
                        }}
                        onValueMouseOut={value => {
                          props.setRunHighlight(null);
                          // setHighlight('run:name', undefined);
                        }}
                        onValueClick={value => {
                          if (
                            entityName != null &&
                            projectName != null &&
                            value.run.name != null
                          ) {
                            history.push(
                              urls.run({
                                entityName,
                                projectName,
                                name: value.run.name,
                              })
                            );
                          }
                        }}
                      />

                      {highlight && (
                        <Hint
                          value={highlight}
                          format={formatHint(props) as any}
                        />
                      )}
                    </FlexibleXYPlot>
                  </div>
                </div>
              )}
            </>
          )}
        </div>
      </div>
    );
  },
  {id: 'ScatterPlotGraph', memo: true}
);

const PanelScatterPlot: React.FC<ScatterPlotProps> = makeComp(
  props => {
    const [boxSelection, setBoxSelection] =
      useState<BoxSelectionDrawArea | null>(null);

    const [runHighlight, setInternalRunHighlight] = useState<string | null>(
      null
    );

    const setHighlight = InteractStateContext.useInteractStateAction(
      InteractStateActions.setHighlight
    );

    const setRunHighlight = useCallback(
      (newRunHighlight: string | null) => {
        if (newRunHighlight == null) {
          setInternalRunHighlight(null);
          setHighlight('run:name', undefined);
        } else {
          setInternalRunHighlight(newRunHighlight);
          setHighlight('run:name', newRunHighlight);
        }
      },
      [setInternalRunHighlight, setHighlight]
    );

    // we don't update props when loading
    const savedProps = useGatedValue(props, p => {
      return !props.loading;
    });

    // we pass through the props.groupSelection to make the rectangle
    // feel more responsive, it we passed saveProps.groupSelection the
    // rectangle wouldn't move until the page load
    const graphProps = {
      ...savedProps,
      // loading: props.loading,
      runSetRefs: props.runSetRefs, // we need to always update runSetRefs to avoid crashing
      boxSelection,
      runHighlight,
      setBoxSelection,
      setRunHighlight,
    };

    if (props.configMode) {
      return (
        <div>
          <ScatterPlotGraphConfig {...graphProps} />
        </div>
      );
    } else {
      return <ScatterPlotGraph {...graphProps} />;
    }
  },
  {id: 'PanelScatterPlot'}
);

export default PanelScatterPlot;

const configSpec = {
  chartTitle: {
    editor: 'string' as const,
    displayName: 'Chart title',
  },
  xAxisMin: {editor: 'number' as const, displayName: 'X min'},
  xAxisMax: {editor: 'number' as const, displayName: 'X max'},
  xAxisLogScale: {
    editor: 'checkbox' as const,
    displayName: 'X log scale',
    default: false,
  },
  yAxisMin: {editor: 'number' as const, displayName: 'Y min'},
  yAxisMax: {editor: 'number' as const, displayName: 'Y max'},
  yAxisLogScale: {
    editor: 'checkbox' as const,
    displayName: 'Y log scale',
    default: false,
  },
  zAxisMin: {editor: 'number' as const, displayName: 'Z min'},
  zAxisMax: {editor: 'number' as const, displayName: 'Z max'},
  zAxisLogScale: {
    editor: 'checkbox' as const,
    displayName: 'Z log scale',
    default: false,
  },
  customGradient: {
    editor: 'gradient' as const,
    displayName: 'Gradient',
    default: [
      {offset: 0, color: '#900000'},
      {offset: 50, color: '#D64F04'},
      {offset: 100, color: '#FFE600'},
    ],
  },
  showMaxYAxisLine: {
    editor: 'checkbox' as const,
    displayName: 'Plot max y',
    default: false,
  },
  showMinYAxisLine: {
    editor: 'checkbox' as const,
    displayName: 'Plot min y',
    default: false,
  },
  showAvgYAxisLine: {
    editor: 'checkbox' as const,
    displayName: 'Plot avg y',
    default: false,
  },
};

function getTitleFromConfig(config: ScatterPlotConfig): string {
  return config.chartTitle || defaultTitle(config);
}

export const Spec: Panels.PanelSpec<typeof PANEL_TYPE, ScatterPlotConfig> = {
  type: PANEL_TYPE,
  Component: PanelScatterPlot,
  getTitleFromConfig,
  exportable: {
    image: true,
    csv: true,
    api: true,
  },
  transformQuery: scatterPlotTransformQuery,
  configSpec,
  useTableData,
  icon: 'panel-scatter-plot',
};
