/*
  This file is part of CDO. CDO is a collection of Operators to
  manipulate and analyse Climate model Data.

  Copyright (C) 2003-2019 Uwe Schulzweida, <uwe.schulzweida AT mpimet.mpg.de>
  See COPYING file for copying and redistribution conditions.

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; version 2 of the License.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
*/

#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#ifdef HAVE_LIBFFTW3
#include <fftw3.h>
#endif

#include <cdi.h>

#include "process_int.h"
#include "cdo_vlist.h"
#include "param_conversion.h"
#include "statistic.h"
#include "cdo_options.h"
#include "cimdOmp.h"

#define NALLOC_INC 1024

struct FourierMemory
{
  Varray<double> real;
  Varray<double> imag;
  Varray<double> work_r;
  Varray<double> work_i;
#ifdef HAVE_LIBFFTW3
  fftw_complex *in_fft;
  fftw_complex *out_fft;
  fftw_plan plan;
#endif
};

static void
fourier_fftw(int sign, int varID, int levelID, int nts, size_t gridsize, double missval, FieldVector3D &vars,
             std::vector<FourierMemory> &ompmem)
{
#ifdef HAVE_LIBFFTW3
  const double norm = 1. / std::sqrt(nts);

#ifdef _OPENMP
#pragma omp parallel for default(shared)
#endif
  for (size_t i = 0; i < gridsize; i++)
    {
      const auto ompthID = cdo_omp_get_thread_num();

      bool hasMissvals = false;
      for (int tsID = 0; tsID < nts; tsID++)
        {
          const auto real = vars[tsID][varID][levelID].vec[2 * i];
          const auto imag = vars[tsID][varID][levelID].vec[2 * i + 1];
          ompmem[ompthID].in_fft[tsID][0] = real;
          ompmem[ompthID].in_fft[tsID][1] = imag;
          if (DBL_IS_EQUAL(real, missval) || DBL_IS_EQUAL(imag, missval)) hasMissvals = true;
        }

      if (hasMissvals)
        {
          for (int tsID = 0; tsID < nts; tsID++)
            {
              vars[tsID][varID][levelID].vec[2 * i] = missval;
              vars[tsID][varID][levelID].vec[2 * i + 1] = missval;
            }
        }
      else
        {
          fftw_execute(ompmem[ompthID].plan);

          for (int tsID = 0; tsID < nts; tsID++)
            {
              vars[tsID][varID][levelID].vec[2 * i] = ompmem[ompthID].out_fft[tsID][0] * norm;
              vars[tsID][varID][levelID].vec[2 * i + 1] = ompmem[ompthID].out_fft[tsID][1] * norm;
            }
        }
    }
#endif
}

static void
fourier_intrinsic(int sign, int varID, int levelID, int nts, size_t gridsize, double missval, FieldVector3D &vars,
                  std::vector<FourierMemory> &ompmem)
{
  const bool isPower2 = ((nts & (nts - 1)) == 0);

#ifdef _OPENMP
#pragma omp parallel for default(shared)
#endif
  for (size_t i = 0; i < gridsize; i++)
    {
      const auto ompthID = cdo_omp_get_thread_num();

      bool hasMissvals = false;
      for (int tsID = 0; tsID < nts; tsID++)
        {
          const auto real = vars[tsID][varID][levelID].vec[2 * i];
          const auto imag = vars[tsID][varID][levelID].vec[2 * i + 1];
          ompmem[ompthID].real[tsID] = real;
          ompmem[ompthID].imag[tsID] = imag;
          if (DBL_IS_EQUAL(real, missval) || DBL_IS_EQUAL(imag, missval)) hasMissvals = true;
        }

      if (hasMissvals)
        {
          for (int tsID = 0; tsID < nts; tsID++)
            {
              vars[tsID][varID][levelID].vec[2 * i] = missval;
              vars[tsID][varID][levelID].vec[2 * i + 1] = missval;
            }
        }
      else
        {
          if (isPower2)  // nts is a power of 2
            fft(ompmem[ompthID].real.data(), ompmem[ompthID].imag.data(), nts, sign);
          else
            ft_r(ompmem[ompthID].real.data(), ompmem[ompthID].imag.data(), nts, sign, ompmem[ompthID].work_r.data(),
                 ompmem[ompthID].work_i.data());

          for (int tsID = 0; tsID < nts; tsID++)
            {
              vars[tsID][varID][levelID].vec[2 * i] = ompmem[ompthID].real[tsID];
              vars[tsID][varID][levelID].vec[2 * i + 1] = ompmem[ompthID].imag[tsID];
            }
        }
    }
}

void *
Fourier(void *process)
{
  int nrecs;
  int varID, levelID;
  int nalloc = 0;
  size_t nmiss;

  cdoInitialize(process);

  bool use_fftw = false;
  if (Options::Use_FFTW)
    {
#ifdef HAVE_LIBFFTW3
      if (Options::cdoVerbose) cdoPrint("Using fftw3 lib");
      use_fftw = true;
#else
      if (Options::cdoVerbose) cdoPrint("LIBFFTW3 support not compiled in!");
#endif
    }

  if (Options::cdoVerbose && !use_fftw) cdoPrint("Using intrinsic FFT function!");

  operatorInputArg("the sign of the exponent (-1 for normal or 1 for reverse transformation)!");
  const auto sign = parameter2int(cdoOperatorArgv(0));

  const auto streamID1 = cdoOpenRead(0);

  const auto vlistID1 = cdoStreamInqVlist(streamID1);
  const auto vlistID2 = vlistDuplicate(vlistID1);

  const auto taxisID1 = vlistInqTaxis(vlistID1);
  const auto taxisID2 = taxisDuplicate(taxisID1);
  vlistDefTaxis(vlistID2, taxisID2);

  const auto streamID2 = cdoOpenWrite(1);
  cdoDefVlist(streamID2, vlistID2);

  VarList varList;
  varListInit(varList, vlistID1);

  const auto nvars = vlistNvars(vlistID1);
  FieldVector3D vars;
  std::vector<int64_t> vdate;
  std::vector<int> vtime;

  int tsID = 0;
  while ((nrecs = cdoStreamInqTimestep(streamID1, tsID)))
    {
      if (tsID >= nalloc)
        {
          nalloc += NALLOC_INC;
          vdate.resize(nalloc);
          vtime.resize(nalloc);
          vars.resize(nalloc);
        }

      vdate[tsID] = taxisInqVdate(taxisID1);
      vtime[tsID] = taxisInqVtime(taxisID1);

      fieldsFromVlist(vlistID1, vars[tsID], FIELD_NONE);

      for (int recID = 0; recID < nrecs; recID++)
        {
          cdoInqRecord(streamID1, &varID, &levelID);
          const auto gridsize = varList[varID].gridsize;
          vars[tsID][varID][levelID].resize(2 * gridsize);
          cdoReadRecord(streamID1, vars[tsID][varID][levelID].vec.data(), &nmiss);
          vars[tsID][varID][levelID].nmiss = nmiss;
        }

      tsID++;
    }

  int nts = tsID;

  std::vector<FourierMemory> ompmem(Threading::ompNumThreads);

  if (use_fftw)
    {
#ifdef HAVE_LIBFFTW3
      for (int i = 0; i < Threading::ompNumThreads; i++)
        {
          ompmem[i].in_fft = fftw_alloc_complex(nts);
          ompmem[i].out_fft = fftw_alloc_complex(nts);
          ompmem[i].plan = fftw_plan_dft_1d(nts, ompmem[i].in_fft, ompmem[i].out_fft, sign, FFTW_ESTIMATE);
        }
      if (Options::cdoVerbose) fftw_print_plan(ompmem[0].plan);
#endif
    }
  else
    {
      const bool isPower2 = ((nts & (nts - 1)) == 0);
      for (int i = 0; i < Threading::ompNumThreads; i++)
        {
          ompmem[i].real.resize(nts);
          ompmem[i].imag.resize(nts);
          if (!isPower2) ompmem[i].work_r.resize(nts);
          if (!isPower2) ompmem[i].work_i.resize(nts);
        }
    }

  for (varID = 0; varID < nvars; varID++)
    {
      const auto missval = varList[varID].missval;
      const auto gridsize = varList[varID].gridsize;
      const auto nlevels = varList[varID].nlevels;
      for (levelID = 0; levelID < nlevels; levelID++)
        {
          if (use_fftw)
            fourier_fftw(sign, varID, levelID, nts, gridsize, missval, vars, ompmem);
          else
            fourier_intrinsic(sign, varID, levelID, nts, gridsize, missval, vars, ompmem);
        }
    }

#ifdef HAVE_LIBFFTW3
  if (use_fftw)
    {
      for (int i = 0; i < Threading::ompNumThreads; i++)
        {
          fftw_free(ompmem[i].in_fft);
          fftw_free(ompmem[i].out_fft);
          fftw_destroy_plan(ompmem[i].plan);
        }
      fftw_cleanup();
    }
#endif

  for (tsID = 0; tsID < nts; tsID++)
    {
      taxisDefVdate(taxisID2, vdate[tsID]);
      taxisDefVtime(taxisID2, vtime[tsID]);
      cdoDefTimestep(streamID2, tsID);

      for (varID = 0; varID < nvars; varID++)
        {
          const auto nlevels = varList[varID].nlevels;
          for (levelID = 0; levelID < nlevels; levelID++)
            {
              if (!vars[tsID][varID][levelID].empty())
                {
                  nmiss = vars[tsID][varID][levelID].nmiss;
                  cdoDefRecord(streamID2, varID, levelID);
                  cdoWriteRecord(streamID2, vars[tsID][varID][levelID].vec.data(), nmiss);
                }
            }
        }
    }

  cdoStreamClose(streamID2);
  cdoStreamClose(streamID1);

  cdoFinish();

  return nullptr;
}
