import React, {useRef} from 'react';
import _ from 'lodash';
import {Table} from 'semantic-ui-react';
import makeComp from '../util/profiler';

const ROW_HEIGHT = 51;
const SPACING = 50;

function xPos(col: number) {
  return (col + 0.5) * SPACING;
}
function yPos(row: number) {
  return (row + 0.5) * ROW_HEIGHT;
}

type Pair = [number, number];

type Line = Pair[];

type Lines = Line[];

interface LinesWithInfo {
  width: number;
  dots: number[];
  lines: Lines;
}

interface GraphPanelProps {
  graph: any;
}

const GraphPanel: React.FC<GraphPanelProps> = makeComp(
  ({graph}) => {
    const memoGraphRef = useRef(null);
    const memoLinesRef = useRef<LinesWithInfo | null>(null);

    const graphToLines = (g: any) => {
      if (_.isEqual(g, memoGraphRef.current) && memoLinesRef.current != null) {
        return memoLinesRef.current;
      }

      const outEdges: {[id: string]: number[][]} = {};
      for (const node of g.nodes) {
        outEdges[node.id] = [];
      }
      for (const edge of g.edges) {
        outEdges[edge[0]].push(edge);
      }

      const dots = [];
      const lines: Lines = [];
      let active: Array<{index: number; edge: Pair}> = [];
      let width = 0;

      for (const node of g.nodes) {
        const ins = [];
        const nextActive: typeof active = [];
        const curLines: Line = [];
        for (const i in active) {
          if (!active.hasOwnProperty(i)) {
            continue;
          }
          const {index, edge} = active[i];
          if (edge[1] !== node.id) {
            curLines.push([index, nextActive.length]);
            nextActive.push({
              index: nextActive.length,
              edge,
            });
          } else {
            ins.push(index);
          }
        }
        const curIndex = nextActive.length;
        width = Math.max(width, curIndex + 1);
        for (const index of ins) {
          curLines.push([index, curIndex]);
        }
        for (const edge of outEdges[node.id]) {
          nextActive.push({
            index: curIndex,
            edge: edge as Pair,
          });
        }

        active = nextActive;
        lines.push(curLines);
        dots.push(curIndex);
      }

      // Reorder indexes to minimize line crossing
      for (let pass = 0; pass < 2; pass++) {
        const forward = pass % 2 === 0;
        for (let index = 1; index < lines.length; index++) {
          const i = forward ? index : lines.length - index - 1;
          const indices: {[i: number]: number} = {};
          const edges: number[][] = [];
          for (let j = 0; j < width; j++) {
            indices[j] = j;
            edges.push([]);
          }
          if (forward) {
            for (const edge of lines[i]) {
              edges[edge[1]].push(edge[0]);
            }
          } else {
            for (const edge of lines[i + 1]) {
              edges[edge[0]].push(edge[1]);
            }
          }

          // eslint-disable-next-line
          for (const e of edges) {
            let swaps = 0;
            for (let j = 0; j < edges.length; j++) {
              for (let k = 0; k < edges.length; k++) {
                if (j === k) {
                  continue;
                }
                let costPre = 0;
                let costPost = 0;
                for (const edge of edges[j]) {
                  costPre += (edge - indices[j]) * (edge - indices[j]);
                  costPost += (edge - indices[k]) * (edge - indices[k]);
                }
                for (const edge of edges[k]) {
                  costPost += (edge - indices[j]) * (edge - indices[j]);
                  costPre += (edge - indices[k]) * (edge - indices[k]);
                }
                if (costPost < costPre) {
                  const ji = indices[j];
                  const ki = indices[k];
                  indices[j] = ki;
                  indices[k] = ji;
                  swaps++;
                }
              }
            }
            if (swaps === 0) {
              break;
            }
          }

          for (const j in lines[i]) {
            if (lines[i].hasOwnProperty(j)) {
              lines[i][j][1] = indices[lines[i][j][1]];
            }
          }
          dots[i] = indices[dots[i]];
          if (i < lines.length - 1) {
            for (const j in lines[i + 1]) {
              if (lines[i + 1].hasOwnProperty(j)) {
                lines[i + 1][j][0] = indices[lines[i + 1][j][0]];
              }
            }
          }
        }
      }

      memoGraphRef.current = g;
      memoLinesRef.current = {
        width,
        dots,
        lines,
      };

      return memoLinesRef.current;
    };

    const pathForLine = (i: number, l: Pair) => {
      if (l[0] === l[1]) {
        return ['M', xPos(l[0]), yPos(i - 1), 'L', xPos(l[1]), yPos(i)].join(
          ' '
        );
      }

      const xOff = l[0] < l[1] ? 0.5 : -0.5;

      return [
        'M',
        xPos(l[0]),
        yPos(i - 1),
        'Q',
        xPos(l[0]),
        yPos(i - 0.5),
        ',',
        xPos(l[0] + xOff),
        yPos(i - 0.5),
        'L',
        xPos(l[1] - xOff),
        yPos(i - 0.5),
        'Q',
        xPos(l[1]),
        yPos(i - 0.5),
        ',',
        xPos(l[1]),
        yPos(i),
      ].join(' ');
    };

    const renderGraph = () => {
      const {width, dots, lines} = graphToLines(graph);

      return (
        <Table.Cell
          rowSpan={graph.nodes.length}
          style={{
            borderRight: '1px solid rgba(34,36,38,.1)',
            width: `${width * SPACING}px`,
            maxWidth: `${width * SPACING}px`,
            padding: 0,
          }}>
          <svg
            style={{
              width: `${width * SPACING}px`,
              height: `${lines.length * ROW_HEIGHT}px`,
            }}>
            {lines.map((ls, i) => (
              <React.Fragment key={i}>
                {ls.map(l => (
                  <path
                    key={`${i}-${l[0]}-${l[1]}`}
                    stroke="black"
                    fill="transparent"
                    d={pathForLine(i, l)}
                  />
                ))}
                <circle cx={xPos(dots[i])} cy={yPos(i)} r={5} fill="black" />
              </React.Fragment>
            ))}
          </svg>
        </Table.Cell>
      );
    };

    const displayGraph = !!graph.nodes && !!graph.edges;

    return (
      <div className="panel-graph">
        <Table celled singleLine columns={4}>
          <Table.Header>
            <Table.Row>
              {displayGraph && <Table.HeaderCell />}
              <Table.HeaderCell>Name</Table.HeaderCell>
              <Table.HeaderCell>Type</Table.HeaderCell>
              <Table.HeaderCell># Parameters</Table.HeaderCell>
              <Table.HeaderCell>Output Shape</Table.HeaderCell>
            </Table.Row>
          </Table.Header>
          <Table.Body>
            {graph.nodes.map((node: any, i: number) => {
              const name = node.name;
              const className = node.class_name;
              const numParameters: string = _.isArray(node.num_parameters)
                ? node.num_parameters.join(', ')
                : String(node.num_parameters);
              const outputShape: string = _.isArray(node.output_shape)
                ? node.output_shape
                    .map((param: any) => (param ? param : 'None'))
                    .join(', ')
                : '';
              return (
                <Table.Row key={'graphtable' + i}>
                  {i === 0 && displayGraph && renderGraph()}
                  <Table.Cell title={name} style={{borderLeft: 0}}>
                    {name}
                  </Table.Cell>
                  <Table.Cell title={className}>{className}</Table.Cell>
                  <Table.Cell title={numParameters}>{numParameters}</Table.Cell>
                  <Table.Cell title={outputShape}>{outputShape}</Table.Cell>
                </Table.Row>
              );
            })}
          </Table.Body>
        </Table>
      </div>
    );
  },
  {id: 'GraphPanel'}
);

export default GraphPanel;
