import * as S from './HomeQuickStart.styles';

import React, {FC, useCallback, useMemo, useState} from 'react';
import * as _ from 'lodash';
import makeComp from '../util/profiler';
import {WBIcon} from '@wandb/ui';
import {useViewer} from '../state/viewer/hooks';
import pytorchIcon from '../assets/icon-pytorch.svg';
import pytorchLightningIcon from '../assets/icon-pytorch-lightning.svg';
import kerasIcon from '../assets/icon-keras.svg';
import tensorflowIcon from '../assets/icon-tensorflow.svg';
import huggingFaceIcon from '../assets/icon-hugging-face.svg';
import xgboostIcon from '../assets/xgboost_logo.png';
import scikitIcon from '../assets/icon-scikit.svg';
import pythonIcon from '../assets/icon-python.svg';
import ColabButton from '../components/ColabButton';
import {Button, Dropdown} from 'semantic-ui-react';
import copy from 'copy-to-clipboard';
import {toast} from 'react-toastify';
import {useWindowSize} from '../util/window';
import history from '../util/history';
import {homeQuickStart} from '../util/urls';

const FRAMEWORKS = [
  'PyTorch',
  'PyTorch Lightning',
  'Keras',
  'TensorFlow',
  'Hugging Face',
  'XGBoost',
  'Scikit-learn',
  'Any Python script',
] as const;
export type Framework = typeof FRAMEWORKS[number];

const FRAMEWORK_BY_PATHNAME = new Map<string, Framework>();
FRAMEWORKS.forEach(f => {
  const pathname = frameworkToPathname(f);
  FRAMEWORK_BY_PATHNAME.set(pathname, f);
});

type FrameworkProps = {
  name: Framework;
  displayName?: string;
  img: string;
  alt: string;
  code: CodeProps[];
  colabs: ColabProps[];
};

type CodeProps = {
  header: string;
  code: string;
};

type ColabProps = {
  header: string;
  subheader: string;
  href: string;
};

type Match = {
  params: {
    framework?: Framework;
  };
};

type HomeQuickStartProps = {
  match: Match;
};

const HomeQuickStart: FC<HomeQuickStartProps> = makeComp(
  ({match}) => {
    const {width: viewportWidth} = useWindowSize();

    const viewer = useViewer();
    const userFirstName = viewer?.name.trim().split(` `)[0];
    const userEntity = viewer?.entity;

    const [userApiKeys, hasApiKey] = useMemo(() => {
      const apiKeys = _.compact(
        viewer?.apiKeys?.edges.map(e => e.node.name) ?? []
      );
      return [apiKeys, apiKeys.length > 0];
    }, [viewer]);

    const [framework, setFramework] = useState<Framework>(
      pathnameToFramework(match)
    );

    const onClickFramework = useCallback(
      (f: Framework) => {
        const pathname = frameworkToPathname(f);
        setFramework(f);
        history.push(homeQuickStart() + '/' + pathname);
        window.analytics.track(`Quickstart: framework selected`, {
          username: viewer?.username,
          framework: f,
        });
      },
      [viewer]
    );

    const frameworks: FrameworkProps[] = [
      {
        name: 'PyTorch',
        img: pytorchIcon,
        alt: 'pytorch',
        code: [
          {
            header: 'From the command line, install and log in to wandb',
            code: 'pip install wandb\nwandb login',
          },
          {
            header: 'At the top of your training script, start a new run',
            code: `import wandb\n\nwandb.init(project="my-test-project"${
              userEntity ? `, entity="${userEntity}"` : ''
            })`,
          },
          {
            header: 'Capture a dictionary of hyperparameters with config',
            code: 'wandb.config = {\n  "learning_rate": 0.001,\n  "epochs": 100,\n  "batch_size": 128\n}',
          },
          {
            header:
              'Log metrics inside your training loop to visualize model performance',
            code: 'wandb.log({"loss": loss})\n\n# Optional\nwandb.watch(model)',
          },
        ],
        colabs: [
          {
            header: 'Simple PyTorch Integration',
            subheader:
              'Add W&B tracking, versioning, and collaboration to your PyTorch code',
            href: 'https://wandb.me/pytorch-colab',
          },
          {
            header: 'Profiling PyTorch Code',
            subheader:
              'Incorporate PyTorch profiling and W&B logging in your script',
            href: 'https://wandb.me/trace-colab',
          },
        ],
      },
      {
        name: 'PyTorch Lightning',
        img: pytorchLightningIcon,
        alt: 'pytorch-lightning',
        code: [
          {
            header: 'From the command line, install and log in to wandb',
            code: 'pip install wandb\nwandb login',
          },
          {
            header:
              'In your script, create a WandbLogger and name your new W&B project',
            code: 'from pytorch_lightning.loggers import WandbLogger\nfrom pytorch_lightning import Trainer\n\nwandb_logger = WandbLogger(project="my-test-project")',
          },
          {
            header: 'Add your WandbLogger to your Pytorch Lightning Trainer',
            code: 'trainer = Trainer(logger=wandb_logger)',
          },
          {
            header:
              'Logging begins automatically when you start training your Trainer',
            code: 'trainer.fit(model, datamodule)',
          },
        ],
        colabs: [
          {
            header: 'PyTorch Lightning Integration',
            subheader:
              'Supercharge your training with PyTorch Lightning and W&B',
            href: 'https://wandb.me/lit-colab',
          },
        ],
      },
      {
        name: 'Keras',
        img: kerasIcon,
        alt: 'keras',
        code: [
          {
            header: 'From the command line, install and log in to wandb',
            code: 'pip install wandb\nwandb login',
          },
          {
            header: 'At the top of your training script, start a new run',
            code: `import wandb\nfrom wandb.keras import WandbCallback\n\nwandb.init(project="my-test-project"${
              userEntity ? `, entity="${userEntity}"` : ''
            })`,
          },
          {
            header: 'Save model inputs and hyperparameters',
            code: 'wandb.config = {\n  "learning_rate": 0.001,\n  "epochs": 100,\n  "batch_size": 128\n}\n\n# ... Define a model',
          },
          {
            header: 'Log layer dimensions and metrics over time',
            code: 'model.fit(X_train, y_train, validation_data=(X_test, y_test), callbacks=[WandbCallback()])',
          },
        ],
        colabs: [
          {
            header: 'Simple Keras Integration',
            subheader:
              'Add W&B tracking, versioning, and collaboration to your Keras code',
            href: 'https://wandb.me/keras-colab',
          },
        ],
      },
      {
        name: 'TensorFlow',
        img: tensorflowIcon,
        alt: 'tensorflow',
        code: [
          {
            header: 'From the command line, install and log in to wandb',
            code: 'pip install wandb\nwandb login',
          },
          {
            header: 'At the top of your training script, start a new run',
            code: `import wandb\nimport tensorflow as tf\n\nwandb.init(project="my-test-project"${
              userEntity ? `, entity="${userEntity}"` : ''
            })`,
          },
          {
            header: 'Save model inputs and hyperparameters',
            code: 'wandb.config = {\n  "learning_rate": 0.001,\n  "epochs": 100,\n  "batch_size": 128\n}\n\n# ... Define a model',
          },
          {
            header: 'Log metrics over time to visualize performance',
            code: 'with tf.Session() as sess:\n  # ...\n  wandb.tensorflow.log(tf.summary.merge_all())',
          },
        ],
        colabs: [
          {
            header: 'Simple TensorFlow Integration',
            subheader:
              'Add W&B tracking, versioning, and collaboration to your TensorFlow pipeline',
            href: 'https://wandb.me/tf-colab',
          },
        ],
      },
      {
        name: 'Hugging Face',
        img: huggingFaceIcon,
        alt: 'hugging-face',
        code: [
          {
            header: 'From the command line, install and log in to wandb',
            code: 'pip install wandb\nwandb login',
          },
          {
            header: 'At the top of your training script, start a new run',
            code: `import wandb\nfrom transformers import TrainingArguments, Trainer\n\nwandb.init(project="my-test-project"${
              userEntity ? `, entity="${userEntity}"` : ''
            })`,
          },
          {
            header: 'Add wandb in your Hugging Face "TrainingArguments"',
            code: 'args = TrainingArguments(..., report_to="wandb")',
          },
          {
            header:
              'Logging begins automatically when you start training your Trainer',
            code: 'trainer = Trainer(..., args=args)\ntrainer.train()',
          },
        ],
        colabs: [
          {
            header: 'Hugging Face Integration',
            subheader: 'Optimize Hugging Face models with W&B',
            href: 'https://wandb.me/hf',
          },
        ],
      },
      {
        name: 'XGBoost',
        img: xgboostIcon,
        alt: 'xgboost',
        code: [
          {
            header: 'From the command line, install and log in to wandb',
            code: 'pip install wandb\nwandb login',
          },
          {
            header: 'At the top of your training script, start a new run',
            code: `import wandb\nimport xgboost as xgb\nfrom wandb.xgboost import wandb_callback\n\nwandb.init(project="my-test-project"${
              userEntity ? `, entity="${userEntity}"` : ''
            })`,
          },
          {
            header: 'Add the callback',
            code: 'bst = xgb.train(param, xg_train, num_round, watchlist, callbacks=[wandb_callback()])',
          },
          {
            header: 'Get predictions',
            code: 'pred = bst.predict(xg_test)',
          },
        ],
        colabs: [
          {
            header: 'Interpretable Credit Scorecards with XGBoost',
            subheader:
              'Enable experiment tracking, versioning, and optimization in XGBoost with W&B',
            href: 'https://wandb.me/xgboost-colab',
          },
        ],
      },
      {
        name: 'Scikit-learn',
        img: scikitIcon,
        alt: 'scikit',
        code: [
          {
            header: 'From the command line, install and log in to wandb',
            code: 'pip install wandb\nwandb login',
          },
          {
            header: 'At the top of your training script, start a new run',
            code: `import wandb\n\nwandb.init(project="my-test-project"${
              userEntity ? `, entity="${userEntity}"` : ''
            })`,
          },
          {
            header: 'Log classifier visualizations',
            code: 'wandb.sklearn.plot_classifier(clf, X_train, X_test, y_train, y_test, y_pred, y_probas, labels, model_name="SVC", feature_names=None)',
          },
          {
            header: 'Log regression visualizations',
            code: 'wandb.sklearn.plot_regressor(reg, X_train, X_test, y_train, y_test,  model_name="Ridge")',
          },
          {
            header: 'Log clustering visualizations',
            code: 'wandb.sklearn.plot_clusterer(kmeans, X_train, cluster_labels, labels=None, model_name="KMeans")',
          },
        ],
        colabs: [
          {
            header: 'Simple Scikit Integration',
            subheader:
              'Add W&B tracking, versioning, and collaboration to your Scikit code',
            href: 'https://wandb.me/scikit-colab',
          },
        ],
      },
      {
        name: 'Any Python script',
        displayName: 'any Python script',
        img: pythonIcon,
        alt: 'python-script',
        code: [
          {
            header: 'From the command line, install and log in to wandb',
            code: 'pip install wandb\nwandb login',
          },
          {
            header: 'At the top of your training script, start a new run',
            code: `wandb.init(project="my-test-project"${
              userEntity ? `, entity="${userEntity}"` : ''
            })`,
          },
          {
            header: 'Save model inputs and hyperparameters',
            code: 'wandb.config = {\n  "learning_rate": 0.001,\n  "epochs": 100,\n  "batch_size": 128\n}',
          },
          {
            header: 'Log gradients and model parameters',
            code: 'wandb.log({"loss": loss})',
          },
        ],
        colabs: [
          {
            header: 'Intro to Weights & Biases',
            subheader:
              'Using W&B for machine learning experiment tracking, dataset versioning, and project collaboration',
            href: 'https://wandb.me/intro',
          },
        ],
      },
    ];

    const frameworksMap = useMemo(() => {
      const fwMap = new Map<Framework, FrameworkProps>();
      frameworks.forEach(f => fwMap.set(f.name, f));
      return fwMap;
    }, [frameworks]);

    const currentFramework: FrameworkProps = useMemo(() => {
      return frameworksMap.get(framework) ?? frameworks[0];
    }, [framework, frameworks, frameworksMap]);

    const copyAllCode = useCallback(() => {
      const allCode = currentFramework.code
        // don't include command line block
        .slice(1)
        .map(c => c.code)
        .join('\n');
      copy(allCode);
      toast('Copied to clipboard.');
      window.analytics.track('Quickstart: all code copied', {
        username: viewer?.username,
        framework: currentFramework.name,
      });
    }, [currentFramework.code, currentFramework.name, viewer]);

    return (
      <S.QuickStartContainer width={viewportWidth} data-test="home-quickstart">
        <S.FrameworkColumn width={viewportWidth}>
          <h1>Welcome{userFirstName ? `, ${userFirstName}` : ''}!</h1>
          <h2>Choose a framework to see integration guide and colab</h2>
          <FrameworkSelector
            frameworks={frameworks}
            framework={framework}
            onClickFramework={onClickFramework}
          />
        </S.FrameworkColumn>
        <S.QuickStartContainerRight width={viewportWidth}>
          <S.IntegrationColumn width={viewportWidth}>
            <h3>Quickstart</h3>
            <h1>Visualize training in your own ML project</h1>
            <h2>
              <img src={currentFramework.img} alt={currentFramework.alt} />
              Integrate W&B with{' '}
              {currentFramework.displayName ?? currentFramework.name}
            </h2>
            {/* Split out command line block from python blocks */}
            {
              <div
                key={currentFramework.name + currentFramework.code[0].header}>
                <h1>#&emsp;{currentFramework.code[0].header}</h1>
                <QuickStartCode
                  code={currentFramework.code[0].code}
                  blockNumber={0}
                  framework={currentFramework.name}
                />
              </div>
            }
            {/* Add API key section if logged in */}
            {hasApiKey ? (
              <div key={currentFramework.name + 'api-key'}>
                <h1>
                  #&emsp;Copy this key and paste it into your command line when
                  asked to authorize your account
                </h1>
                <ApiKeySection apiKey={userApiKeys[0]} />
              </div>
            ) : (
              <></>
            )}
            {/* Render python code blocks */}
            {currentFramework.code.slice(1).map((c, i) => (
              <div key={currentFramework.name + c.header}>
                <h1>#&emsp;{c.header}</h1>
                <QuickStartCode
                  code={c.code}
                  // increment block index by 1 to account for the command line block
                  blockNumber={i + 1}
                  framework={currentFramework.name}
                />
              </div>
            ))}
            <Button primary onClick={copyAllCode} content="Copy All" />
          </S.IntegrationColumn>
          <S.ColabColumn width={viewportWidth}>
            <div>Try it out in Colab</div>
            <S.ColabContainer width={viewportWidth}>
              {currentFramework.colabs.map(c => (
                <ColabLink colab={c} framework={currentFramework.name} />
              ))}
            </S.ColabContainer>
          </S.ColabColumn>
        </S.QuickStartContainerRight>
      </S.QuickStartContainer>
    );
  },
  {id: 'HomeQuickStart', memo: true}
);

type FrameworkSelectorProps = {
  frameworks: FrameworkProps[];
  framework: Framework;
  onClickFramework: (f: Framework) => void;
};

const FrameworkSelector: FC<FrameworkSelectorProps> = makeComp(
  ({frameworks, framework, onClickFramework}) => {
    const {width: viewportWidth} = useWindowSize();

    return (
      <>
        {frameworks.map(f => (
          <S.Framework
            width={viewportWidth}
            key={f.name}
            active={f.name === framework}
            onClick={() => onClickFramework(f.name)}>
            <img src={f.img} alt={f.alt} />
            {f.name}
          </S.Framework>
        ))}
        <S.FrameworkDropdown text={framework} width={viewportWidth}>
          <Dropdown.Menu>
            {frameworks.map(f => (
              <Dropdown.Item
                key={f.name}
                text={f.name}
                onClick={() => onClickFramework(f.name)}
              />
            ))}
          </Dropdown.Menu>
        </S.FrameworkDropdown>
      </>
    );
  },
  {
    id: 'FrameworkSelector',
    memo: true,
  }
);

type ColabLinkProps = {
  colab: ColabProps;
  framework: Framework;
};

const ColabLink: FC<ColabLinkProps> = makeComp(
  ({colab, framework}) => {
    const viewer = useViewer();
    const {width: viewportWidth} = useWindowSize();

    const onClickColab = useCallback(() => {
      window.analytics.track('Quickstart: colab clicked', {
        username: viewer?.username,
        framework,
        colab: colab.href,
      });
    }, [viewer, colab.href, framework]);

    return (
      <S.Colab
        key={colab.header}
        href={colab.href}
        onClick={onClickColab}
        width={viewportWidth}>
        <div>
          <WBIcon name="launch" />
        </div>
        <div>
          <h1>{colab.header}</h1>
          <h2>{colab.subheader}</h2>
          <ColabButton />
        </div>
      </S.Colab>
    );
  },
  {id: 'ColabLink', memo: true}
);

type QuickStartCodeProps = {
  code: string;
  blockNumber: number;
  framework: Framework;
};

const QuickStartCode: FC<QuickStartCodeProps> = makeComp(
  ({code, blockNumber, framework}) => {
    const viewer = useViewer();
    const {width: viewportWidth} = useWindowSize();

    const copyCode = useCallback(() => {
      copy(code);
      toast('Copied to clipboard.');
      window.analytics.track('Quickstart: code block copied', {
        username: viewer?.username,
        framework,
        codeBlock: blockNumber + 1,
      });
    }, [code, blockNumber, framework, viewer]);

    return (
      <S.QuickStartCode width={viewportWidth}>
        <div>
          <div onClick={copyCode}>
            <WBIcon name="copy" />
            Copy
          </div>
        </div>
        <div>{code}</div>
      </S.QuickStartCode>
    );
  },
  {id: 'QuickStartCode', memo: true}
);

type ApiKeySectionProps = {
  apiKey: string;
};

const ApiKeySection: FC<ApiKeySectionProps> = makeComp(
  ({apiKey}) => {
    const viewer = useViewer();

    const copyApiKey = useCallback(() => {
      copy(apiKey);
      toast('Copied to clipboard.');
      window.analytics.track('Quickstart: API key copied', {
        username: viewer?.username,
      });
    }, [apiKey, viewer]);

    return (
      <S.ApiKeySection>
        <div>{apiKey}</div>
        <WBIcon name="copy" onClick={copyApiKey} />
      </S.ApiKeySection>
    );
  },
  {id: 'ApiKeySection', memo: true}
);

function pathnameToFramework(match: Match): Framework {
  const framework = FRAMEWORK_BY_PATHNAME.get(match.params.framework ?? '');
  // default to pytorch if framework is invalid or not provided in url
  if (framework == null) {
    history.push(homeQuickStart() + '/pytorch');
    return 'PyTorch';
  }
  return framework;
}

function frameworkToPathname(f: Framework) {
  // special case
  if (f === 'Any Python script') {
    return 'python-script';
  }
  return f.replace(' ', '-').toLowerCase();
}

export default HomeQuickStart;
