/* Copyright (C) 2008 Xavier Pujol.

This file is part of the fplll Library.

The fplll Library is free software; you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
the Free Software Foundation; either version 2.1 of the License, or (at your
option) any later version.

The fplll Library is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
License for more details.

You should have received a copy of the GNU Lesser General Public License
along with the fplll Library; see the file COPYING.  If not, write to
the Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
MA 02111-1307, USA. */

#include "topenum.h"
#include "tools.h"

Enumerator::Enumerator(const FloatMatrix& mu, const FloatVect& rdiag,
  const FloatVect& targetcoord, const Float& maxVolume, int minLevel) :
mu(mu), rdiag(rdiag), targetcoord(targetcoord), kmin(minLevel)
{
  d = mu.GetNumRows();
  this->maxVolume.set(maxVolume);
  zeroVect(center, d);
  zeroVect(dist, d);
  zeroVect(x, d);
  dx.resize(d);
  ddx.resize(d);

  solveSVP = targetcoord.empty();
  svpInitNeeded = solveSVP;
  if (!solveSVP) {
    k = d - 1;
    center[k].set(targetcoord[k]);
    x[k].rnd(center[k]);
    dx[k].set(0.0);
    ddx[k].set(center[k] >= x[k] ? -1.0 : 1.0);
    kmax = d;
  }
}

bool Enumerator::enumNext(const Float& maxsqrlength) {
  Float newdist, newcenter, y, volume, rtmp1;
  bool notFound = true;

  if (svpInitNeeded) {
    for (k = d - 1; k > kmin; k--) {
      costEstimate(volume, maxsqrlength, rdiag, k - 1);
      if (volume <= maxVolume) break;
    }
    kmax = k;
    svpInitNeeded = false;
  }
  if (k >= d) return false;

  while (notFound) {
    TRACE("Level k=" << k << " dist_k=" << dist[k] << " x_k=" << x[k]);
    y.sub(center[k], x[k]);
    newdist.mul(y, y);
    newdist.mul(newdist, rdiag[k]);
    newdist.add(newdist, dist[k]);

    if (newdist <= maxsqrlength) {
      rtmp1.sub(maxsqrlength, newdist);
      costEstimate(volume, rtmp1, rdiag, k - 1);
      if (k > kmin && volume >= maxVolume) {
        k--;
        TRACE("  Go down, newdist=" << newdist);

        if (solveSVP)
          newcenter.set(0.0);
        else
          newcenter.set(targetcoord[k]);
        for (int j = d - 1; j > k; j--)
          newcenter.submul(x[j], mu(j, k));

        center[k].set(newcenter);
        dist[k].set(newdist);
        x[k].rnd(newcenter);
        dx[k].set(0.0);
        ddx[k].set(newcenter >= x[k] ? -1.0 : 1.0);
        continue;
      }
      subTree.resize(d - k);
      for (unsigned int j = 0; j < subTree.size(); j++)
        subTree[j].set(x[j + k]);
      TRACE("  SubTree approx_size=" << volume << " coord=" << subTree);
      notFound = false;
    }
    else {
      TRACE("  Go up");
      k++;
    }
    if (k < kmax) {
      ddx[k].neg(ddx[k]);
      dx[k].sub(ddx[k], dx[k]);
      x[k].add(x[k], dx[k]);
    }
    else {
      if (k >= d) break;
      kmax = k;
      rtmp1.set(1.0);
      x[k].add(x[k], rtmp1);
    }
    TRACE("  x[" << k << "]=" << x[k]);
  }
  return !notFound;
}
