import React from 'react';
import { deviation, median, sum } from 'd3-array';

import Alert from '@mui/material/Alert';
import Button from '@mui/material/Button';
import Checkbox from '@mui/material/Checkbox';
import FormControl from '@mui/material/FormControl';
import FormControlLabel from '@mui/material/FormControlLabel';
import FormLabel from '@mui/material/FormLabel';
import Grid from '@mui/material/Grid';
import InputLabel from '@mui/material/InputLabel';
import MenuItem from '@mui/material/MenuItem';
import TextField from '@mui/material/TextField';

import Plot from 'react-plotly.js';

import { phaseSchleicherMarcus, cleanPhotometry } from '../utils';
import { maskPIDs } from '../services/firebase';
import { useLightcurveConfig } from '../services/cookies';

const axisLabels = {
  'm': 'm (mag)',
  'H11': 'H(1,1,α) (mag)',
  'Hy11': 'Hy(1,1,α) (mag)',
  'Hy110': 'Hy(1,1,0) (mag)',
  'rh': 'rh (au)',
  'signed rh': 'rh (au) (negative for pre-perihelion)',
  'delta': 'Delta (au)',
  'phase': 'Phase angle (deg)',
  'obsjd': 'Julian date (days)',
  'mjd': 'modified Julian date (days)',
  'tmtp': 'T-Tp (days)',
  'trueanomaly': 'True anomaly (°)',
  'afrho': 'Afρ (cm)',
  'afrho0': 'A(0°)fρ (cm)',
  'afrho0k': 'A(0°)fρ rh**k (cm)',
  'ostat': 'O stat',
  'centroid_offset': 'Centroid offset (arcsec)',
  'seeing': 'Seeing (arcsec)',
  'maglim': 'Mag. limit',
};

const mSun = {
  'zg': -26.54,
  'zr': -26.93,
  'zi': -27.05,
};

const getAxisData = (config, axis, row) => {
  // console.log(typeof config, axis, row);
  const c = ((row['filter'] === 'zg') ? -1 : 1) * (config[row['filter'][1]] || 0);
  let m = row.m[config.aperture] + c;
  let data;
  const rh = Math.abs(row.rh);

  if (axis === 'm') {
    data = m;
  } else if (axis === 'H11') {
    data = m - 5 * Math.log10(rh * row.delta);
  } else if (axis.startsWith('Hy')) {
    // cometary absolute mag
    data = m + 2.5 * (config.activitySlope - 2) * Math.log10(rh);

    // fixed angular or linear
    if (config.aperture < 1000)
      data -= 2.5 * Math.log10(row.delta);
    else
      data -= 5 * Math.log10(row.delta);

    // phase correction?
    if (axis.endsWith('0'))
      data -= phaseSchleicherMarcus(row.phase);
  } else if (axis === 'rh') {
    data = rh;
  } else if (axis === 'signed rh') {
    data = Math.sign(row.tmtp) * rh;
  } else if (axis.startsWith('afrho')) {
    // rho and delta in cm
    const delta = row.delta * 14959787070000.0;
    const rho = (config.aperture < 1000)
      ? config.aperture * 1.01 * 72527109.437993127 * row.delta
      : config.aperture * 1e5;
    data = 4 * delta ** 2 * rh ** 2 / rho * 10 ** (-0.4 * (m - mSun['zr']));
    if (axis.startsWith('afrho0')) {
      data /= 10 ** (-0.4 * phaseSchleicherMarcus(row.phase));
    }
    if (axis.endsWith('k')) {
      data /= rh ** config.activitySlope;
    }
  } else if (axis === 'mjd') {
    data = row['obsjd'] - 2400000.5;
  } else if (axis === 'trueanomaly') {
    data = (row.trueanomaly < 180) ? row.trueanomaly : (row.trueanomaly - 360);
  } else {
    data = row[axis];
  }
  return data;
};

const getAxisUncertainties = (config, axis, row) => {
  let unc = 0;
  if (axis.startsWith('m') || axis.startsWith('H')) {
    unc = row.merr[config.aperture];
  } else if (axis.startsWith('afrho')) {
    unc = row[config.aperture] / 1.0857 * getAxisData(config, axis, row);
  }
  return unc;
};

const splitPhotByFilter = (phot, config) => ([
  {
    name: 'g' + (config.g > 0 ? "-" : "+") + config.g.toString(),
    marker: {
      color: '#2ca02c',
      symbol: 'circle',
    },
    data: phot.filter((row) => (row['filter'] === 'zg'))
  },
  {
    name: 'r',
    marker: {
      color: '#ff7f0e',
      symbol: 'square',
    },
    data: phot.filter((row) => (row['filter'] === 'zr'))
  },
  {
    name: 'i' + (config.i < 0 ? "-" : "+") + config.i.toString(),
    marker: {
      color: '#d62728',
      symbol: 'triangle-up',
    },
    data: phot.filter((row) => (row['filter'] === 'zi'))
  }
]);

const addAxesData = (phot, config) => (phot.map(row => ({
  ...row,
  x: getAxisData(config, config.xaxis, row),
  y: getAxisData(config, config.yaxis, row),
  yunc: getAxisUncertainties(config, config.yaxis, row),
})));

const getYAxisParameters = (config, phot) => {
  let type = null;
  let range;
  if (config.yaxis.startsWith("afrho")) {
    range = [
      Math.log10(Math.max(0.1, Math.min(...phot.map(row => row.y)) / 2)),
      Math.log10(Math.max(...phot.map(row => row.y)) * 2)
    ];
    type = "log";
  } else if (config.yaxis.startsWith('m') || config.yaxis.startsWith('H')) {
    range = [
      Math.max(...phot.map(row => row.y)) + 0.5,
      Math.min(...phot.map(row => row.y)) - 0.5
    ];
  } else {
    range = [
      Math.min(...phot.map(row => row.y)) / 1.2,
      Math.max(...phot.map(row => row.y)) * 1.2
    ];
  }
  return { type: type, range: range };
}

function difference(a, b) {
  return {
    value: a.value - b.value,
    err: Math.sqrt(a.err ** 2 + b.err ** 2)
  }
}

function weightedMean(rows) {
  const weights = rows.map(row => row.err ** -2);
  const sumWeights = sum(weights);
  return {
    value: sum(rows.map((row, index) => row.value * weights[index])) / sumWeights,
    err: sumWeights ** -0.5
  }
}

function sigmaClippedMean(rows) {
  if (rows.length === 0) {
    return {
      value: null,
      err: null
    }
  } else if (rows.length === 1) {
    return rows[0];
  } else if (rows.length == 2) {
    return weightedMean(rows);
  } else {
    const stdev = deviation(rows.map(row => row.value));
    const med = median(rows.map(row => row.value));
    const inliers = rows.filter(row => Math.abs(row.value - med) / stdev <= 2.5);
    return weightedMean(inliers);
  }
}

function findMagByFilter(rows, filter, config) {
  return rows.filter(row => (row.filter === filter) && row.m[config.aperture] && (row.merr[config.aperture] < 0.2))
    .map(row => ({ value: row.m[config.aperture], err: row.merr[config.aperture] }))
}

function estimateColors(data, setPlotMessage, config, setConfig) {
  if (data) {
    // group by date
    const dates = new Set(data.map((row => row.obsdate.substring(0, 10))));
    const gmrs = [];
    const rmis = [];
    dates.forEach(date => {
      const obs = data.filter(row => row.obsdate.startsWith(date));
      const g = sigmaClippedMean(findMagByFilter(obs, 'zg', config));
      const r = sigmaClippedMean(findMagByFilter(obs, 'zr', config));
      const i = sigmaClippedMean(findMagByFilter(obs, 'zi', config));

      if (g.value && r.value) {
        gmrs.push(difference(g, r));
      }
      if (r.value && i.value) {
        rmis.push(difference(r, i));
      }
    });

    const gmr = weightedMean(gmrs);
    const rmi = weightedMean(rmis);
    const gmr0 = Math.min(Math.max(gmr.value.toFixed(2), 0.3), 0.7);
    const rmi0 = Math.min(Math.max(rmi.value.toFixed(2), 0.14), 0.25);
    if (gmrs.length || rmis.length) {
      setConfig({
        ...config,
        'g': gmrs.length ? gmr0 : defaultConfig.g,
        'i': rmis.length ? rmi0 : defaultConfig.i
      });
      const messages = [];
      if (gmrs.length)
        messages.push(`g-r = ${gmr.value.toFixed(2)} ± ${gmr.err.toFixed(2)} mag from ${gmrs.length} nights`);
      if (rmis.length)
        messages.push(`r-i = ${rmi.value.toFixed(2)} ± ${gmr.err.toFixed(2)} mag from ${rmis.length} nights`);

      setPlotMessage({ severity: 'success', text: messages.join(', ') + '.' });
    } else {
      setPlotMessage({ severity: 'error', text: 'No appropriate color pairs.' });
    }
  } else {
    setPlotMessage({ severity: 'error', text: 'No data available for color estimate.' });
  }
}

function PlotControl(props) {
  return (
    <Grid item xs={12} sm={6} lg={2}>
      <FormControl sx={{ p: 2, width: '100%' }} {...props} />
    </Grid>
  );
}

const setToMasked = (user, objid, data, status) => {
  const pids = (data || []).map((point) => point.pid);
  maskPIDs(user.uid, objid, pids, status);
};

function PhotometryWarnings({ rows, sx }) {
  const warnings = [];

  if (rows.find(row => row.flags.faint))
    warnings.push('An apparent magnitude (5 pix radius) is fainter than ZTF pipeline 5-sigma limit.');

  if (rows.find(row => row.flags.centroid))
    warnings.push('A centroid is more than 1" outside the 3-sigma ephemeris uncertainty.');

  const message = (warnings.length > 0)
    ? `Caution: ${warnings.join(" ")}`
    : "Current stack photometry is nominal.";
  const severity = (warnings.length > 0) ? "warning" : "success";

  return (
    <Alert severity={severity} sx={sx}>
      {message}
    </Alert>
  );
}

/** Intersection of two lists of photometry */
function intersection(a, b) {
  const _intersection = [];
  const setA = new Set(a.map(row => row.pid));
  for (let row of b) {
    if (setA.has(row.pid)) {
      _intersection.push(row);
    }
  }
  return _intersection;
}

// g: g-r color, i: r-i color
const defaultConfig = {
  aperture: 5,
  g: 0.53,
  i: 0.20,
  activitySlope: 0,
  xaxis: 'tmtp',
  yaxis: 'm',
  plotAllPhotometry: false,
};

export default function Lightcurve({
  target, objid, photometry, stacks, stackIndex, stackNavigation,
  setSelectedStackIndices, highlight, setAperture, user
}) {
  const [config, setConfig] = useLightcurveConfig(target, defaultConfig);
  const [plotMessage, setPlotMessage] = React.useState({});
  const [selectedPhotometry, setSelectedPhotometry] = React.useState([]);
  const [shapes, setShapes] = React.useState([]);

  React.useEffect(() => {
    setPlotMessage({});
  }, [target]);

  const allPhot = addAxesData(photometry, config);
  const phot = cleanPhotometry(allPhot, config);

  const stack = (stackIndex < stacks.length) ? stacks[stackIndex] : null;

  const currentStackAllPhotometry = stack
    ? allPhot.filter((row) => row.stackid === stack.stackid)
    : [];

  const currentStackPhotometry = stack
    ? phot.filter((row) => row.stackid === stack.stackid)
    : [];

  const currentSelectedStackPhotometry = selectedPhotometry
    ? intersection(currentStackPhotometry, selectedPhotometry)
    : currentStackPhotometry;

  const gri = splitPhotByFilter(phot, config);

  const data = [
    ...gri.map(phot => (
      {
        name: phot.name,
        x: phot.data.map(row => row.x),
        y: phot.data.map(row => row.y),
        error_y: {
          type: 'data',
          array: phot.data.map(row => row.yunc),
          visible: true,
          thickness: 0.5,
        },
        mode: 'markers',
        type: 'scatter',
        marker: {
          opacity: 0.5,
          ...phot.marker
        },
        unselected: {
          opacity: 0.5,
          ...phot.marker
        }
      }
    )),
    {
      name: 'Current stack',
      x: [...currentStackPhotometry.map(row => row.x),
      ...(config.yaxis === 'm' ? currentStackAllPhotometry.map(row => row.x) : [])],
      y: [...currentStackPhotometry.map(row => row.y),
      ...(config.yaxis === 'm' ? currentStackAllPhotometry.map(row => row.vmag) : [])],
      mode: 'markers',
      type: 'scatter',
      marker: {
        size: 12,
        color: 'black',
        symbol: 'square-open'
      },
      unselected: {
        marker: {
          opacity: 1
        }
      }
    },
    ...highlight(phot)
  ];

  if (config.yaxis === 'm') {
    data.push({
      name: 'V (JPL)',
      x: allPhot.map(row => row.x),
      y: allPhot.map(row => row.vmag),
      mode: 'markers',
      type: 'scatter',
      marker: {
        opacity: 0.25,
        color: 'black',
        symbol: 'cross'
      },
      unselected: {
        opacity: 0.25,
        color: 'black',
      }
    });
    data.push({
      name: 'ZTF maglim',
      x: allPhot.map(row => row.x),
      y: allPhot.map(row => row.maglim),
      mode: 'markers',
      type: 'scatter',
      visible: 'legendonly',
      marker: {
        opacity: 0.25,
        color: 'black',
        symbol: 'arrow-bar-down'
      }
    });
  }

  const layout = {
    title: { text: target },
    xaxis: {
      title: axisLabels[config.xaxis],
    },
    yaxis: {
      title: axisLabels[config.yaxis],
      ...getYAxisParameters(config, phot)
    },
    shapes: shapes,
    legend: {
      orientation: "h"
    },
    uirevision: true,
    autosize: true,
  };

  const selectPoint = (selected) => {
    // after clicking on a binned point, update the StackViewer
    const { curveNumber, pointIndex } = selected.points[0];

    if (curveNumber < 3) {
      // stack points are always curve numbers 0, 1, 2
      const points = gri[curveNumber];
      stackNavigation.viewByBasename(points.data[pointIndex].basename);
    }
  };

  const selectPoints = (selected) => {
    let selectedStackIndices = [];
    let selectedPhotometry = [];
    let shapes = [];

    if (selected && selected.points.length) {
      // get all stack indices from curve numbers 0, 1, 2
      const points = selected.points
        .filter(point => (point.curveNumber < 3));
      if (points.length) {
        selectedPhotometry = points
          .map(point => gri[point.curveNumber].data[point.pointIndex])
          .sort((a, b) => a.obsdate.localeCompare(b.obsdate));

        const stackids = new Set(selectedPhotometry.map(phot => phot.stackid));

        selectedStackIndices = Array.from(stackids).map(
          stackid =>
            stacks.find(
              (stack) => stack.stackid === stackid
            ).index
        );

        shapes = [{
          type: 'rect',
          x0: selected.range.x[0],
          x1: selected.range.x[1],
          y0: selected.range.y[0],
          y1: selected.range.y[1],
          fillcolor: 'red',
          opacity: 0.15,
          line: {
            width: 1,
            color: 'red',
            opacity: 0.3,
          }
        }];
      }
    }
    setSelectedStackIndices(selectedStackIndices);
    setSelectedPhotometry(selectedPhotometry);
    setShapes(shapes);
  }

  return (
    <>
      <Grid container>
        <Grid item xs={12} sx={{ height: "65vh" }}>
          <Plot
            data={data}
            layout={layout}
            onClick={selectPoint}
            onSelected={selectPoints}
            revision={config.plotRevision}
            useResizeHandler={true}
            style={{ width: '100%', height: '100%' }}
          />
        </Grid>
        <Grid item xs={6}>
          <InputLabel>
            {selectedPhotometry.length} selected point{(selectedPhotometry.length === 1) ? "" : "s"}
          </InputLabel>
          {user.isReviewer &&
            <>
              <Button onClick={() => setToMasked(user, objid, selectedPhotometry, true)} disabled={selectedPhotometry.length === 0}>Mask</Button>
              <Button onClick={() => setToMasked(user, objid, selectedPhotometry, false)} disabled={selectedPhotometry.length === 0}>Unmask</Button>
            </>
          }
          <Button onClick={() => {
            setSelectedStackIndices([]);
            setSelectedPhotometry([]);
            setShapes([]);
          }} disabled={selectedPhotometry.length === 0}>Clear selection</Button>
        </Grid>
        <Grid item xs={6}>
          {(selectedPhotometry.length > 0)
            ? <>
              <InputLabel>
                {currentSelectedStackPhotometry.length} selected point{(currentSelectedStackPhotometry.length === 1) ? "" : "s"} in current stack
              </InputLabel>
              {user.isReviewer &&
                <>
                  <Button onClick={() => setToMasked(user, objid, currentSelectedStackPhotometry, true)} disabled={currentSelectedStackPhotometry.length === 0}>Mask</Button>
                  <Button onClick={() => setToMasked(user, objid, currentSelectedStackPhotometry, false)} disabled={currentSelectedStackPhotometry.length === 0}>Unmask</Button>
                </>
              }
            </>
            : <>
              <InputLabel>
                {currentStackPhotometry.length} point{(currentStackPhotometry.length === 1) ? "" : "s"} in current stack
              </InputLabel>
              {user.isReviewer &&
                <>
                  <Button onClick={() => setToMasked(user, objid, currentStackPhotometry, true)} disabled={currentStackPhotometry.length === 0}>Mask</Button>
                  <Button onClick={() => setToMasked(user, objid, currentStackPhotometry, false)} disabled={currentStackPhotometry.length === 0}>Unmask</Button>
                </>
              }
            </>
          }
        </Grid>
        <Grid item xs={12}>
          {user.isReviewer &&
            <FormLabel>Nominally consider 5-pix radius photometry for masking data.  If this target is best considered with a smaller aperture, note it above.</FormLabel>
          }
          <PhotometryWarnings rows={currentStackPhotometry} config={config} sx={{ my: 1 }} />
        </Grid>
        <Grid item xs={12}>
          <Alert severity={plotMessage.severity} sx={{ display: !plotMessage.text && 'none', my: 1 }}>
            {plotMessage.text}
          </Alert>
        </Grid>
        <PlotControl>
          <TextField
            id="g"
            label="g-r (mag)"
            helperText="Solar: 0.39 mag"
            type="number"
            InputProps={{
              inputProps: {
                step: 0.01
              }
            }}
            value={config.g}
            onChange={(event) => {
              setConfig({
                ...config,
                g: event.target.value
              });
              event.preventDefault();
            }
            }
          />
        </PlotControl>
        <PlotControl>
          <TextField
            id="i"
            label="r-i (mag)"
            helperText="Solar: 0.12 mag"
            type="number"
            InputProps={{
              inputProps: {
                step: 0.01
              }
            }}
            value={config.i}
            onChange={(event) => {
              setConfig({
                ...config,
                i: event.target.value
              });
              event.preventDefault();
            }
            }
          />
        </PlotControl>
        <PlotControl>
          <TextField
            select
            label="y-axis"
            helperText="H = absolute magnitude; Hy = cometary absolute magnitude (with Δ correction)"
            id="y-axis-selection"
            value={config.yaxis}
            onChange={(event) => {
              setConfig({
                ...config,
                yaxis: event.target.value
              })
            }}
          >
            {Object.entries(axisLabels).map(([key, label]) => <MenuItem key={key} value={key}>{label}</MenuItem>)}
          </TextField>
        </PlotControl>
        <PlotControl>
          <TextField
            select
            label="x-axis"
            id="x-axis-selection"
            value={config.xaxis}
            onChange={(event) => {
              setConfig({
                ...config,
                xaxis: event.target.value
              })
            }}
          >
            {Object.entries(axisLabels).map(([key, label]) => <MenuItem key={key} value={key}>{label}</MenuItem>)}
          </TextField>
        </PlotControl>
        <PlotControl>
          <TextField
            select
            id="aperture-selection"
            label="Aperture radius"
            helperText='1 pix = 1.01"'
            value={config.aperture}
            onChange={(event) => {
              setConfig({
                ...config,
                aperture: event.target.value
              });
              setAperture(event.target.value);
            }}
          >
            <MenuItem value={2}>2 pix</MenuItem>
            <MenuItem value={3}>3 pix</MenuItem>
            <MenuItem value={4}>4 pix</MenuItem>
            <MenuItem value={5}>5 pix</MenuItem>
            <MenuItem value={7}>7 pix</MenuItem>
            <MenuItem value={11}>11 pix</MenuItem>
            <MenuItem value={15}>15 pix</MenuItem>
            <MenuItem value={20}>20 pix</MenuItem>
            <MenuItem value={5000}>5000 km</MenuItem>
            <MenuItem value={10000}>10000 km</MenuItem>
            <MenuItem value={15000}>15000 km</MenuItem>
            <MenuItem value={20000}>20000 km</MenuItem>
            <MenuItem value={30000}>30000 km</MenuItem>
            <MenuItem value={40000}>40000 km</MenuItem>
          </TextField>
        </PlotControl>
        <PlotControl>
          <TextField
            id="magnitude-slope"
            label="Cometary activity slope (k)"
            helperText="Activity ~ rh^k"
            type="number"
            InputProps={{
              inputProps: {
                step: 1
              }
            }}
            value={config.activitySlope}
            onChange={(event) => {
              setConfig({
                ...config,
                activitySlope: event.target.value
              })
            }}
          />
        </PlotControl>
        <PlotControl>
          <Button
            color="primary"
            variant="outlined"
            onClick={() => estimateColors(
              selectedPhotometry.length ? selectedPhotometry : phot,
              setPlotMessage,
              config,
              setConfig
            )}
          >
            Estimate colors
          </Button>
        </PlotControl>
        <PlotControl>
          <FormControlLabel
            control={
              <Checkbox
                checked={!config.plotAllPhotometry}
                onChange={() => setConfig({ ...config, plotAllPhotometry: !config.plotAllPhotometry })}
              />
            }
            label="Mask σ>0.36 mag" />
        </PlotControl>
      </Grid>
    </>
  );
}
