import React, { memo, useCallback, useMemo, useRef, useState } from 'react';
import PropTypes from 'prop-types';
import { Grid } from '@material-ui/core';
import {
  ComposedChart,
  CartesianGrid,
  XAxis,
  YAxis,
  Line,
  Scatter,
  Customized,
} from 'recharts';

import { DEFAULT_AXIS_INDEX } from 'constants/graphs';
import { palette } from 'common/theme';
import { RIGHT, LEFT, EOL } from 'constants/common';
import BaseTooltip from 'components/common/graphs/BaseTooltip';
import ProjectionLines from 'components/common/graphs/ProjectionLines';
import ColorRangeBar from 'components/common/graphs/ColorRangeBar';

import { useStyles } from './styles';

const CANVAS_OFFSET = 10;
const AXIS_WIDTH = 25;
const TICK_FONT_SIZE = 10;
const LABEL_LEFT_EDGE_OFFSET = 25;
const LABEL_RIGHT_EDGE_OFFSET = 90;

const GRAPH_MARGIN = {
  left: CANVAS_OFFSET,
  right: CANVAS_OFFSET,
  top: CANVAS_OFFSET,
  bottom: CANVAS_OFFSET,
};

const TOOLTIP_OFFSET = {
  x: 30,
  y: -40,
};

/**
 * Circle with custom hover radius.
 * @param {Number} hoverR
 * @param {Object} props
 * @return {jsx}
 */
const ScatterPlotCircle = ({ hoverR, ...props }) => {
  const classes = useStyles({ hoverR });

  return <circle className={classes.scatterPoint} {...props} />;
};

/**
 * Scatter chart component
 * @param {React.Ref} chartRef
 * @param {Number} canvasWidth
 * @param {Number} canvasHeight
 * @param {Object} axisLabels
 * @param {Array} ticks
 * @param {Array} values
 * @param {Array} domain
 * @param {{ x: number, y: number }} projectionPoint
 * @param {Function} onRegressLineEnter
 * @param {Function} onRegressLineLeave
 * @param {Function} onScatterEnter
 * @param {Function} onScatterLeave
 * @param {Object} margin
 * @return {jsx}
 */
const ScatterPlotChart = memo(
  ({
    chartRef,
    canvasWidth,
    canvasHeight,
    axisLabels,
    ticks,
    values,
    domain,
    projectionPoint,
    onRegressLineEnter,
    onRegressLineLeave,
    onScatterEnter,
    onScatterLeave,
    margin,
  }) => {
    const classes = useStyles();

    const marginX = margin.right + margin.left;
    const marginY = margin.top + margin.bottom;

    return (
      <ComposedChart
        ref={chartRef}
        width={canvasWidth + AXIS_WIDTH + marginX}
        height={canvasHeight + AXIS_WIDTH + marginY}
        margin={GRAPH_MARGIN}
        className={classes.graphContainer}
        data={values}
      >
        <XAxis
          key={0}
          type="number"
          height={AXIS_WIDTH}
          ticks={ticks}
          interval={0}
          domain={domain}
          dataKey="x"
          tick={{ fontSize: TICK_FONT_SIZE }}
          xAxisId={DEFAULT_AXIS_INDEX}
          label={{
            value: axisLabels.bottom,
            position: 'bottom',
            dy: -10,
          }}
        />
        <YAxis
          key={2}
          type="number"
          width={AXIS_WIDTH}
          ticks={ticks}
          interval={0}
          domain={domain}
          tick={{ fontSize: TICK_FONT_SIZE }}
          yAxisId={DEFAULT_AXIS_INDEX}
          label={{
            value: axisLabels.left,
            angle: -90,
            position: 'center',
            dx: -LABEL_LEFT_EDGE_OFFSET,
          }}
        />
        <XAxis
          key={1}
          type="number"
          orientation="top"
          tick={false}
          xAxisId={1}
          label={{
            value: axisLabels.top,
            position: 'center',
          }}
        />
        <YAxis
          key={3}
          width={1}
          type="number"
          orientation="right"
          tick={false}
          yAxisId={1}
          label={{
            value: axisLabels.right,
            angle: -90,
            position: 'center',
            dx: LABEL_RIGHT_EDGE_OFFSET,
          }}
        />
        <CartesianGrid
          strokeDasharray="3 3"
          shapeRendering="crispedges"
          key={4}
        />

        <Line
          key={5}
          connectNulls
          dataKey="bias"
          type="linear"
          stroke={palette.grey.main}
          dot={false}
          isAnimationActive={false}
          strokeDasharray="10 4"
        />

        {projectionPoint && (
          <Customized
            key={6}
            component={ProjectionLines}
            point={projectionPoint}
            isAnimationActive={false}
            stroke={projectionPoint.color}
            strokeWidth={0.5}
            strokeDasharray="10 4"
          />
        )}

        <Scatter
          key={7}
          shape={({ cx, cy, key, payload: { r, color, hoverR } }) =>
            !!(cx && cy) && (
              <ScatterPlotCircle
                cx={cx}
                cy={cy}
                r={r}
                hoverR={hoverR}
                fill={color}
                key={key}
              />
            )
          }
          dataKey="y"
          onMouseOver={onScatterEnter}
          onFocus={onScatterEnter}
          onMouseOut={onScatterLeave}
          onBlur={onScatterLeave}
        />

        <Line
          key={8}
          connectNulls
          dataKey="regress"
          type="linear"
          stroke={palette.black.main}
          className={classes.regressLine}
          dot={false}
          isAnimationActive={false}
          onMouseOver={onRegressLineEnter}
          onFocus={onRegressLineEnter}
          onMouseOut={onRegressLineLeave}
          onBlur={onRegressLineLeave}
        />
      </ComposedChart>
    );
  }
);

/**
 * Tooltip for regress line
 * @param {Number} x
 * @param {Number} y
 * @param {String} yPosition
 * @param {Object} data
 * @return {jsx}
 */
const ScatterRegressTooltip = ({ x, y, yPosition, data }) => (
  <BaseTooltip
    x={x}
    y={y}
    xPosition={yPosition}
    offsetX={TOOLTIP_OFFSET.x}
    offsetY={TOOLTIP_OFFSET.y}
  >
    Linear regression
    {EOL}r<sup>2</sup> =&nbsp;{data.r2}
    {EOL}a =&nbsp;{data.slope}
    {EOL}b =&nbsp;{data.intercept}
  </BaseTooltip>
);

/**
 * Tooltip for point.
 * @param {Number} x
 * @param {Number} y
 * @param {String} yPosition
 * @param {Object} data
 * @param {Object} axisInfo
 * @return {jsx}
 */
const ScatterTooltip = ({ x, y, yPosition, data, axisInfo }) => (
  <BaseTooltip
    x={x}
    y={y}
    yPosition={yPosition}
    offsetX={TOOLTIP_OFFSET.x}
    offsetY={TOOLTIP_OFFSET.y}
  >
    <strong>{data.occurrence}% occurrence</strong>
    {EOL}
    Model:&nbsp;{data.x.toFixed(2)}
    {axisInfo.x.units}
    {EOL}
    Measure:&nbsp;{data.y.toFixed(2)}
    {axisInfo.y.units}
    {EOL}
    Error:&nbsp;{data.error.toFixed(2)}
    {axisInfo.error.units}
    {EOL}
  </BaseTooltip>
);

/**
 * Chart with scatter points and regression line.
 * @param {Object} commonData
 * @param {Object} margin
 * @param {Number} canvasWidth
 * @param {Number} canvasHeight
 * @return {jsx}
 */
const ScatterPlot = ({
  commonData,
  canvasWidth,
  canvasHeight,
  margin = GRAPH_MARGIN,
}) => {
  const classes = useStyles();
  const chartRef = useRef(null);
  const wrapperRef = useRef(null);
  const [scatterTooltip, setScatterTooltip] = useState(null);
  const [regressTooltip, setRegressTooltip] = useState(null);

  const {
    ticks,
    axisLabels,
    values,
    axisInfo,
    lineRegress,
    errorTicks,
  } = commonData;
  const domain = useMemo(() => [ticks[0], ticks[ticks.length - 1]], [ticks]);

  const onScatterEnter = useCallback(
    ({ tooltipPosition, payload }) => {
      const { container: chartContainer } = chartRef.current;
      const { left: chartX } = chartContainer.getBoundingClientRect();
      const { left: wrapperX } = wrapperRef.current.getBoundingClientRect();

      setScatterTooltip({
        payload,
        x: tooltipPosition.x + (chartX - wrapperX),
        y: tooltipPosition.y,
      });
    },
    [chartRef, wrapperRef, setScatterTooltip]
  );

  const onScatterLeave = useCallback(() => setScatterTooltip(null), [
    setScatterTooltip,
  ]);

  const onRegressLineEnter = useCallback(
    ({ points }) => {
      const { container: chartContainer } = chartRef.current;
      const {
        width: chartWidth,
        height: chartHeights,
        left: chartX,
        top: chartY,
      } = chartContainer.getBoundingClientRect();
      const {
        left: wrapperX,
        top: wrapperY,
      } = wrapperRef.current.getBoundingClientRect();

      const [{ payload: startPoint }, { payload: endPoint }] = points.filter(
        ({ payload }) => payload.regress !== undefined
      );
      const [minValue, maxValue] = domain;

      const regressLineCenter =
        (endPoint.regress - startPoint.regress) / 2 + startPoint.regress;

      setRegressTooltip({
        x: chartX - wrapperX + chartWidth / 2,
        y: chartY - wrapperY + chartHeights / 2,
        yPosition: (minValue + maxValue) / 2 > regressLineCenter ? LEFT : RIGHT,
      });
    },
    [chartRef, wrapperRef, setRegressTooltip, domain]
  );

  const onRegressLineLeave = useCallback(() => setRegressTooltip(null), [
    setRegressTooltip,
  ]);

  return (
    <Grid
      ref={wrapperRef}
      justifyContent="center"
      container
      className={classes.graphWrapper}
    >
      <Grid item>
        <ScatterPlotChart
          ticks={ticks}
          domain={domain}
          margin={margin}
          values={values}
          chartRef={chartRef}
          axisLabels={axisLabels}
          canvasWidth={canvasWidth}
          canvasHeight={canvasHeight}
          onScatterEnter={onScatterEnter}
          onScatterLeave={onScatterLeave}
          onRegressLineEnter={onRegressLineEnter}
          onRegressLineLeave={onRegressLineLeave}
          projectionPoint={scatterTooltip?.payload}
        />
      </Grid>
      <Grid item className={classes.colorBarContainer}>
        <ColorRangeBar
          width={30}
          className={classes.colorBar}
          ticks={errorTicks}
          height={canvasHeight - TICK_FONT_SIZE - (margin.top + margin.bottom)}
        />
      </Grid>
      {scatterTooltip && (
        <ScatterTooltip
          x={scatterTooltip.x}
          y={scatterTooltip.y}
          yPosition={scatterTooltip.yPosition}
          data={scatterTooltip.payload}
          axisInfo={axisInfo}
        />
      )}
      {regressTooltip && (
        <ScatterRegressTooltip
          x={regressTooltip.x}
          y={regressTooltip.y}
          xPosition={regressTooltip.yPosition}
          data={lineRegress}
        />
      )}
    </Grid>
  );
};

ScatterPlot.defaultProps = {
  canvasWidth: PropTypes.number,
  canvasHeight: PropTypes.number,
  commonData: PropTypes.shape({
    axisLabels: PropTypes.shape({
      top: PropTypes.string,
      left: PropTypes.string,
      bottom: PropTypes.string,
      right: PropTypes.string,
    }).isRequired,
    ticks: PropTypes.arrayOf(PropTypes.number),
    errorTicks: PropTypes.arrayOf(PropTypes.number),
    axisInfo: PropTypes.shape({
      error: PropTypes.shape({
        units: PropTypes.string,
      }),
      y: PropTypes.shape({
        units: PropTypes.string,
      }),
      x: PropTypes.shape({
        units: PropTypes.string,
      }),
    }).isRequired,
    lineRegress: PropTypes.shape({
      slope: PropTypes.string,
      r2: PropTypes.string,
      intercept: PropTypes.string,
    }).isRequired,
  }),
};

ScatterPlot.defaultProps = {
  canvasWidth: 400,
  canvasHeight: 400,
};

export default ScatterPlot;
