import React from 'react';
import { deviation, median, sum } from 'd3-array';

import Alert from '@mui/material/Alert';
import Button from '@mui/material/Button';
import Checkbox from '@mui/material/Checkbox';
import FormControl from '@mui/material/FormControl';
import FormControlLabel from '@mui/material/FormControlLabel';
import Grid from '@mui/material/Grid';
import MenuItem from '@mui/material/MenuItem';
import TextField from '@mui/material/TextField';

import Plot from 'react-plotly.js';

import { phaseSchleicherMarcus, cleanPhotometry } from '../util';

const axisLabels = {
  'm': 'm (mag)',
  'H11': 'H(1,1,α) (mag)',
  'Hy11': 'Hy(1,1,α) (mag)',
  'Hy110': 'Hy(1,1,0) (mag)',
  'rh': 'rh (au)',
  'signed rh': 'rh (au) (negative for pre-perihelion)',
  'delta': 'Delta (au)',
  'phase': 'Phase angle (deg)',
  'tmtp': 'T-Tp (days)',
  'afrho': 'Afρ (cm)',
  'afrho0': 'A(0°)fρ (cm)',
  'afrho0k': 'A(0°)fρ rh**k (cm)',
  'ostat': 'O stat',
  'dc': 'Centroid offset (arcsec)',
  'seeing': 'Seeing (arcsec)'
};

const mSun = {
  'g': -26.54,
  'r': -26.93,
  'i': -27.05,
};

const getAxisData = (config, axis, row) => {
  const c = ((row['catalog filter'] === 'g') ? -1 : 1) * (config[row['catalog filter']] || 0);
  let m = row[`m${config.aperture}`] + c;
  let data;
  const rh = Math.abs(row.rh);

  if (axis === 'm') {
    data = m;
  } else if (axis === 'H11') {
    data = m - 5 * Math.log10(rh * row.delta);
  } else if (axis.startsWith('Hy')) {
    // cometary absolute mag
    data = m + 2.5 * (config.activitySlope - 2) * Math.log10(rh);

    // fixed angular or linear
    if (config.aperture < 1000)
      data -= 2.5 * Math.log10(row.delta);
    else
      data -= 5 * Math.log10(row.delta);

    // phase correction?
    if (axis.endsWith('0'))
      data -= phaseSchleicherMarcus(row.phase);
  } else if (axis === 'rh') {
    data = rh;
  } else if (axis === 'signed rh') {
    data = Math.sign(row.tmtp) * rh;
  } else if (axis.startsWith('afrho')) {
    // rho and delta in cm
    const delta = row.delta * 14959787070000.0;
    const rho = (isFinite(config.aperture))
      ? config.aperture * 72527109.437993127 * row.delta
      : config.aperture.replace('k', '') * 1e8;
    data = 4 * delta ** 2 * rh ** 2 / rho * 10 ** (-0.4 * (m - mSun['r']));
    if (axis.startsWith('afrho0')) {
      data /= 10 ** (-0.4 * phaseSchleicherMarcus(row.phase));
    }
    if (axis.endsWith('k')) {
      data /= rh ** config.activitySlope;
    }
  } else if (axis === 'mjd') {
    data = row['obsjd'] - 2400000.5;
  } else {
    data = row[axis];
  }
  return data;
};

const getAxisUncertainties = (config, axis, row) => {
  let unc = 0;
  if (axis.startsWith('m') || axis.startsWith('H')) {
    unc = row[`merr${config.aperture}`];
  } else if (axis.startsWith('afrho')) {
    unc = row[`merr${config.aperture}`] / 1.0857 * getAxisData(config, axis, row);
  }
  return unc;
};

const splitPhotByFilter = (phot, config) => ([
  {
    name: 'g' + (config.g > 0 ? "-" : "+") + config.g.toString(),
    marker: {
      color: '#2ca02c',
      symbol: 'circle',
    },
    data: phot.filter((row) => (row['catalog filter'] === 'g'))
  },
  {
    name: 'r',
    marker: {
      color: '#ff7f0e',
      symbol: 'square',
    },
    data: phot.filter((row) => (row['catalog filter'] === 'r'))
  },
  {
    name: 'i' + (config.i < 0 ? "-" : "+") + config.i.toString(),
    marker: {
      color: '#d62728',
      symbol: 'triangle-up',
    },
    data: phot.filter((row) => (row['catalog filter'] === 'i'))
  }
]);

const addAxesData = (phot, config) => (phot.map(row => ({
  ...row,
  x: getAxisData(config, config.xaxis, row),
  y: getAxisData(config, config.yaxis, row),
  yunc: getAxisUncertainties(config, config.yaxis, row),
})));

const getYAxisParameters = (config, phot) => {
  let type = null;
  let range;
  if (config.yaxis.startsWith("afrho")) {
    range = [
      Math.log10(Math.max(0.1, Math.min(...phot.map(row => row.y)) / 2)),
      Math.log10(Math.max(...phot.map(row => row.y)) * 2)
    ];
    type = "log";
  } else if (config.yaxis.startsWith('m') || config.yaxis.startsWith('H')) {
    range = [
      Math.max(...phot.map(row => row.y)) + 0.5,
      Math.min(...phot.map(row => row.y)) - 0.5
    ];
  } else {
    range = [
      Math.min(...phot.map(row => row.y)) / 1.2,
      Math.max(...phot.map(row => row.y)) * 1.2
    ];
  }
  return { type: type, range: range };
}

function difference(a, b) {
  return {
    value: a.value - b.value,
    err: Math.sqrt(a.err ** 2 + b.err ** 2)
  }
}

function weightedMean(rows) {
  const weights = rows.map(row => row.err ** -2);
  const sumWeights = sum(weights);
  return {
    value: sum(rows.map((row, index) => row.value * weights[index])) / sumWeights,
    err: sumWeights ** -0.5
  }
}

function sigmaClippedMean(rows) {
  if (rows.length === 0) {
    return {
      value: null,
      err: null
    }
  } else if (rows.length === 1) {
    return rows[0];
  } else if (rows.length == 2) {
    return weightedMean(rows);
  } else {
    const stdev = deviation(rows.map(row => row.value));
    const med = median(rows.map(row => row.value));
    const inliers = rows.filter(row => Math.abs(row.value - med) / stdev <= 2.5);
    return weightedMean(inliers);
  }
}

function findMagByFilter(rows, filter) {
  return rows.filter(row => (row.filter === filter) && row['m5'] && (row['merr5'] < 0.2))
    .map(row => ({ value: row['m5'], err: row['merr5'] }))
}

function estimateColors(data, setPlotMessage, config, setConfig) {
  if (data) {
    // group by date?
    const dates = new Set(data.map((row => row.date.substring(0, 10))));
    const gmrs = [];
    const rmis = [];
    dates.forEach(date => {
      const obs = data.filter(row => row.date.startsWith(date));
      const g = sigmaClippedMean(findMagByFilter(obs, 'gp'));
      const r = sigmaClippedMean(findMagByFilter(obs, 'rp'));
      const i = sigmaClippedMean(findMagByFilter(obs, 'ip'));

      if (g.value && r.value) {
        gmrs.push(difference(g, r));
      }
      if (r.value && i.value) {
        rmis.push(difference(r, i));
      }
    });

    const gmr = weightedMean(gmrs);
    const rmi = weightedMean(rmis);
    const gmr0 = Math.min(Math.max(gmr.value.toFixed(2), 0.3), 0.7);
    const rmi0 = Math.min(Math.max(rmi.value.toFixed(2), 0.14), 0.25);
    if (gmrs.length || rmis.length) {
      setConfig({
        ...config,
        'g': gmrs.length ? gmr0 : 0.55,
        'i': rmis.length ? rmi0 : 0.24
      });
      const messages = [];
      if (gmrs.length)
        messages.push(`g-r = ${gmr.value.toFixed(2)} ± ${gmr.err.toFixed(2)} mag from ${gmrs.length} nights`);
      if (rmis.length)
        messages.push(`r-i = ${rmi.value.toFixed(2)} ± ${gmr.err.toFixed(2)} mag from ${rmis.length} nights`);

      setPlotMessage({ severity: 'success', text: messages.join(', ') + '.' });
    } else {
      setPlotMessage({ severity: 'error', text: 'No appropriate color pairs.' });
    }
  } else {
    setPlotMessage({ severity: 'error', text: 'No data available for color estimate.' });
  }
}

function PlotControl(props) {
  return (
    <Grid item xs={12} sm={6} lg={4} xl={2}>
      <FormControl sx={{ p: 2, width: '100%' }} {...props} />
    </Grid>
  );
}

// g: g-r color, i: r-i color
const defaultConfig = {
  aperture: 5,
  g: 0.55,
  i: 0.25,
  activitySlope: -2,
  xaxis: 'tmtp',
  yaxis: 'm'
};

export default function Lightcurve({
  target, targetPhot, stacks, stackIndex, stackClusters, stackNavigation, selectedStackIndices, setSelectedStackIndices
}) {
  const [config, setConfig] = React.useState(defaultConfig);
  const [plotMessage, setPlotMessage] = React.useState({});
  const [shapes, setShapes] = React.useState([]);
  const [annotations, setAnnotations] = React.useState(true);

  const photometry = addAxesData(cleanPhotometry(targetPhot, config), config);
  const binnedPhotometry = addAxesData(cleanPhotometry(stacks, config), config);

  const stack = (stackIndex < stacks.length) ? stacks[stackIndex] : null;
  const stackPrefix = stack && stacks[stackIndex]['stack prefix'];

  React.useEffect(() => {
    if (stack && !stackClusters.data.stacks[stackPrefix]) {
      setPlotMessage({
        severity: "error",
        text: "Bad stack file name prefix: " + stackPrefix + ".  Please report the error."
      })
    }
  }, [stack, stackClusters]);

  const currentStackFiles = (stack && stackClusters.data.stacks[stackPrefix])
    ? stackClusters.data.stacks[stackPrefix]
    : [];
  const currentStackPhotometry = stack
    ? photometry.filter((row) => currentStackFiles.includes(row.file))
    : [];
  const currentStackBinnedPhotometry = stack
    ? binnedPhotometry.filter((row) => row['stack prefix'] === stackPrefix)
    : [];

  const gri = splitPhotByFilter(binnedPhotometry, config);

  const data = [
    {
      name: 'Individual exposures',
      x: photometry.map(row => row.x),
      y: photometry.map(row => row.y),
      mode: 'markers',
      type: 'scatter',
      marker: {
        size: 4,
        color: 'gray',
        symbol: 'circle',
        opacity: 0.6
      },
      unselected: {
        marker: {
          opacity: 0.6
        }
      }
    },
    ...gri.map(phot => (
      {
        name: phot.name,
        x: phot.data.map(row => row.x),
        y: phot.data.map(row => row.y),
        error_y: {
          type: 'data',
          array: phot.data.map(row => row.yunc),
          visible: true,
          thickness: 0.5,
        },
        mode: 'markers',
        type: 'scatter',
        marker: {
          opacity: 0.5,
          ...phot.marker
        },
        unselected: {
          opacity: 0.5,
          ...phot.marker
        }
      }
    ))
  ];
  if (annotations) {
    data.push(
      {
        name: 'Current exposures',
        x: currentStackPhotometry.map(row => row.x),
        y: currentStackPhotometry.map(row => row.y),
        mode: 'markers',
        type: 'scatter',
        marker: {
          size: 12,
          color: 'black',
          symbol: 'square-open'
        },
        unselected: {
          marker: {
            opacity: 1
          }
        }
      },
      {
        name: 'Current stack',
        x: currentStackBinnedPhotometry.map(row => row.x),
        y: currentStackBinnedPhotometry.map(row => row.y),
        mode: 'markers',
        type: 'scatter',
        marker: {
          size: 12,
          color: 'red',
          symbol: 'circle-open'
        },
        unselected: {
          marker: {
            opacity: 1
          }
        }
      }
    );
  }

  const layout = {
    title: { text: target },
    xaxis: {
      title: axisLabels[config.xaxis],
    },
    yaxis: {
      title: axisLabels[config.yaxis],
      ...getYAxisParameters(config, photometry)
    },
    shapes: shapes,
    legend: {
      orientation: "h"
    },
    uirevision: true,
    autosize: true,
  };

  const selectPoint = (event) => {
    // after clicking on a binned point, update the StackViewer
    const { curveNumber, pointIndex } = event.points[0];

    if (curveNumber === 0) {
      // individual points are always curve number 0      
      const stackFiles = stackClusters.data.files[photometry[pointIndex].file];
      if (stackFiles.length > 0) {
        const stackPrefix = stackFiles[0].slice(0, -9);
        stackNavigation.viewByStackPrefix(stackPrefix);
        return;
      }

      // stack file was not found
      setPlotMessage({
        severity: "error",
        text: "No corresponding stack to view.",
      });
    } else if (curveNumber < 4) {
      // binned points are always curve numbers 1, 2, 3
      const points = gri[curveNumber - 1];
      stackNavigation.viewByStackPrefix(points.data[pointIndex]["stack prefix"]);
    }
  };

  const selectPoints = (selected) => {
    let selectedStackIndices = [];
    let shapes = [];
    if (selected && selected.points.length) {
      // get all stack indices from curve numbers 1, 2, 3
      const points = selected.points
        .filter(point => (point.curveNumber > 0) && (point.curveNumber < 4));
      if (points.length) {
        selectedStackIndices = points
          .map(point => gri[point.curveNumber - 1].data[point.pointIndex])
          .sort((a, b) => a.date.localeCompare(b.date))
          .map(stack => stack.index);
        shapes = [{
          type: 'rect',
          x0: selected.range.x[0],
          x1: selected.range.x[1],
          y0: selected.range.y[0],
          y1: selected.range.y[1],
          fillcolor: 'red',
          opacity: 0.15,
          line: {
            width: 1,
            color: 'red',
            opacity: 0.3,
          }
        }];
      }
    }
    setSelectedStackIndices(selectedStackIndices);
    setShapes(shapes);
  }

  return (
    <>
      <Grid container>
        <Grid item xs={12} sx={{ height: "65vh" }}>
          <Plot
            data={data}
            layout={layout}
            onClick={selectPoint}
            onSelected={selectPoints}
            revision={config.plotRevision}
            useResizeHandler={true}
            style={{ width: '100%', height: '100%' }}
          />
        </Grid>
        <Grid item xs={12}>
          <Alert severity={plotMessage.severity} sx={{ display: !plotMessage.text && 'none', my: 1 }}>
            {plotMessage.text}
          </Alert>
        </Grid>
        <PlotControl>
          <TextField
            id="g"
            label="g-r (mag)"
            helperText="Solar: 0.39 mag"
            type="number"
            InputProps={{
              inputProps: {
                step: 0.01
              }
            }}
            value={config.g}
            onChange={(event) => {
              setConfig({
                ...config,
                g: event.target.value
              });
              event.preventDefault();
            }
            }
          />
        </PlotControl>
        <PlotControl>
          <TextField
            id="i"
            label="r-i (mag)"
            helperText="Solar: 0.12 mag"
            type="number"
            InputProps={{
              inputProps: {
                step: 0.01
              }
            }}
            value={config.i}
            onChange={(event) => {
              setConfig({
                ...config,
                i: event.target.value
              });
              event.preventDefault();
            }
            }
          />
        </PlotControl>
        <PlotControl>
          <TextField
            select
            id="aperture-selection"
            label="Aperture radius"
            value={config.aperture}
            onChange={(event) => {
              setConfig({
                ...config,
                aperture: event.target.value
              })
            }}
          >
            <MenuItem value={2}>2"</MenuItem>
            <MenuItem value={5}>5"</MenuItem>
            <MenuItem value={10}>10"</MenuItem>
            <MenuItem value={12}>12"</MenuItem>
            <MenuItem value={20}>20"</MenuItem>
            <MenuItem value="5k">5000 km</MenuItem>
            <MenuItem value="10k">10000 km</MenuItem>
            <MenuItem value="20k">20000 km</MenuItem>
          </TextField>
        </PlotControl>
        <PlotControl>
          <TextField
            select
            label="y-axis"
            helperText="H = absolute magnitude; Hy = cometary absolute magnitude"
            id="y-axis-selection"
            value={config.yaxis}
            onChange={(event) => {
              setConfig({
                ...config,
                yaxis: event.target.value
              })
            }}
          >
            {Object.entries(axisLabels).map(([key, label]) => <MenuItem key={key} value={key}>{label}</MenuItem>)}
          </TextField>
        </PlotControl>
        <PlotControl>
          <TextField
            select
            label="x-axis"
            id="x-axis-selection"
            value={config.xaxis}
            onChange={(event) => {
              setConfig({
                ...config,
                xaxis: event.target.value
              })
            }}
          >
            {Object.entries(axisLabels).map(([key, label]) => <MenuItem key={key} value={key}>{label}</MenuItem>)}
          </TextField>
        </PlotControl>
        <PlotControl>
          <TextField
            id="magnitude-slope"
            label="Cometary activity slope (k)"
            helperText="Activity ~ rh^k"
            type="number"
            InputProps={{
              inputProps: {
                step: 1
              }
            }}
            value={config.activitySlope}
            onChange={(event) => {
              setConfig({
                ...config,
                activitySlope: event.target.value
              })
            }}
          />
        </PlotControl>
        <PlotControl>
          <Button
            color="primary"
            variant="outlined"
            onClick={() => estimateColors(
              selectedStackIndices.length
                ? binnedPhotometry.filter((row, i) => selectedStackIndices.includes(i))
                : binnedPhotometry,
              setPlotMessage,
              config,
              setConfig
            )}
          >
            Estimate colors
          </Button>
        </PlotControl>
        <PlotControl>
          <Button
            color="primary"
            variant="outlined"
            disabled={selectedStackIndices.length === 0}
            onClick={() => {
              setSelectedStackIndices([]);
              setShapes([]);
            }}
          >
            Clear selection
          </Button>
        </PlotControl>
        <PlotControl>
          <FormControlLabel
            control={
              <Checkbox checked={annotations} onChange={(event) => setAnnotations(event.target.checked)} />
            }
            label="Plot annotations" />
        </PlotControl>
      </Grid>
    </>
  );
}