!-------------------------------------------------------------------------
module ConjGradMod
!-------------------------------------------------------------------------

  use SysParams,             only :  double, cglogFile, &
                                       zero, half, one, two
  use pwHamMod,              only :  hamInfoT
  use EigenStatesMod,        only :  EigenStatesT
  use GridMod,               only :  GridT, CalculateHpsi

  implicit none

  type conjGradT                                          ! defaults
    real(double)                  :: tolerance            = 1.d-5
    integer                       :: period               = 3
    logical                       :: usePreconditioning   = .false.
    logical                       :: checkConvergence     = .true.
    type(minVecT),      pointer   :: minVec, oldVec
  end type

  type minVecT
     real(double)                 :: cosine, sine
     complex(double),   pointer   :: StDesc(:)
     complex(double),   pointer   :: Phi(:)
     complex(double),   pointer   :: EPhi(:)
  end type minVecT

  contains


!------------------------------------------------------------------------
  subroutine conjGradInit(conjGrad, maxBasisSize)
!

    use TagHandlerMod,       only :  TagValue
    use utilsMod,            only :  errorStop

    implicit none

    type(conjGradT),    pointer   :: conjGrad
    integer,      intent(in)      :: maxBasisSize
    
    integer                       :: error


    ! Allocate conjGrad - - - - - - - - - - - - - - - - - - - - - - - - -

    allocate(conjGrad,                              &
             conjGrad%minVec,                       &
             conjGrad%minVec%StDesc( maxBasisSize), &
             conjGrad%minVec%Phi(    maxBasisSize), &
             conjGrad%minVec%EPhi(   maxBasisSize), &
             conjGrad%oldVec,                       &
             conjGrad%oldVec%StDesc( maxBasisSize), &
             conjGrad%oldVec%Phi(    maxBasisSize), &
             conjGrad%oldVec%EPhi(   maxBasisSize), stat = error )
    if (error /= 0) &
      call errorStop('@ConjGradInit, conjGrad allocation failed.')    


    ! Read values - - - - - - - - - - - - - - - - - - - - - - - - - - - -

    call TagValue('CGTolerance',         doubl = conjGrad%tolerance)
    call TagValue('CGIterationPeriod',   intgr = conjGrad%period)
    call TagValue('CGPreConditioner',    logic = conjGrad%usePreconditioning)
    call TagValue('CGConvergenceSwitch', logic = conjGrad%checkConvergence)


    ! randomize - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

    call random_seed


    ! open conjugate gradient log file  - - - - - - - - - - - - - - - - -

    open(unit = cglogFile, file = 'cglog.txt', status = 'replace', &
                          action = 'write', position = 'rewind')
    write(cglogFile, *) 'Eigenvalues:'

  end subroutine conjGradInit


!-------------------------------------------------------------------------
  subroutine conjGradDestroy(conjGrad)
!

    type(conjGradT),    pointer   :: conjGrad

    close(unit = cglogFile)
    deallocate(conjGrad%oldVec%EPhi,   &
               conjGrad%oldVec%Phi,    &
               conjGrad%oldVec%StDesc, &
               conjGrad%oldVec,        &
               conjGrad%minVec%EPhi,   &
               conjGrad%minVec%Phi,    &
               conjGrad%minVec%StDesc, &
               conjGrad%minVec,        &
               conjGrad)

  end subroutine conjGradDestroy


!-------------------------------------------------------------------------
  subroutine CGBandByBandMinimize(conjGrad, grid, hamInfo, eigenStates, kpt)
!

    type(conjGradT),    pointer   :: conjGrad
    type(GridT),        pointer   :: grid 
    type(hamInfoT),     pointer   :: hamInfo
    type(eigenStatesT), pointer   :: eigenStates
    integer,      intent(in)      :: kpt

    complex(double),    pointer   :: Hpsi(:), psi(:)
    integer                       :: iBand, cntr, oldband
    real(double)                  :: tolerance


    write(cglogFile, *) 'k:', kpt

    eigenStates%hSize = eigenStates%numBasisVectors(kpt)
    eigenStates%hDim  = eigenStates%maxBasisVectors

    cntr = 1
    iBand = 1
    tolerance = conjGrad%tolerance

    do while (iBand <= eigenStates%numStates)

      oldband = iBand

      call MinVecInit(conjGrad%minVec)
      call MinVecInit(conjGrad%oldVec)

      call EigenStatesInit(eigenStates, iBand, kpt)

      call EigenStatesOrth(eigenStates, iBand, kpt)

        psi => eigenStates%eigenVectors(:, iBand, kpt)
        Hpsi  => eigenStates%ePsi(:, kpt)
      call CalculateHpsi(grid, hamInfo, eigenStates, kpt, psi, Hpsi)

      ! Calculate the Ritz values
      eigenStates%eigenValues(iBand, kpt) &
        = dot_product(eigenStates%eigenVectors(:, iBand, kpt), &
                      eigenStates%ePsi(:, kpt))

      call CGMinimize(conjGrad, grid, hamInfo, eigenStates, iband, kpt, &
                      tolerance)

      iBand = iBand + 1

      !   If the calculation for the band has converged then iBand has
      ! increased by one; if so then reset counter to 1 and tolerance to
      ! tol; else increase counter, if the counter crosses 10(or some 
      ! user defined number), then reset counter to 1 and decrease 
      ! tolerance.

      if (oldband /= iBand) then
        cntr = 1                               ! reset counter
        tolerance = conjGrad%tolerance         ! reset tolerance
      else if (cntr == conjGrad%period) then
        cntr = 1                               ! reset counter
        tolerance = 100 * tolerance            ! increase tolerance
        write(cglogFile, *) &
                 'Failed to converge in', conjGrad%period, &
                 ' steps: Increasing tolerance to ', tolerance
        print *, 'Failed to converge in', conjGrad%period, &
                 ' steps: Increasing tolerance to ', tolerance
      else
        cntr = cntr + 1                        ! increment counter
      end if

    end do
 
  end subroutine CGBandByBandMinimize


!-------------------------------------------------------------------------
  subroutine CGMinimize(conjGrad, grid, hamInfo, eigenStates, iband, kpt, &
                        tolerance)

    type(conjGradT),    pointer   :: conjGrad
    type(GridT),        pointer   :: grid 
    type(hamInfoT),     pointer   :: hamInfo
    type(eigenStatesT), pointer   :: eigenStates
    integer,      intent(inout)   :: iband
    integer,      intent(in)      :: kpt
    real(double), intent(in)      :: tolerance

    integer                       :: iter
    real(double)                  :: residue, tolsq


    tolsq = 10.d0 * tolerance ** 2

    !  Iterate the Eigen Vector until convergence before moving over to
    !  the next higher band.
    EigenVectorLoop: do iter = 1, eigenStates%hSize

      !  Evaluate the Steepest Descent diection as the negative gradient
      !  In this case it is StDesc = -(H-lambda)|Psi>        
      call StDescentVec(eigenStates, conjGrad%minVec, iBand, kpt)

      !  The residue is the Steepst Descent Vector, so its square is 
      !  calculated, and the iterations are carried out until userdefined
      !  convergence is reached, OR 'HamSize' iterations have completed.
      residue = dot_product(conjGrad%minVec%StDesc, conjGrad%minVec%StDesc)

      if (residue <= tolsq) &
        exit EigenVectorLoop

      !  Orhtogonalize the steepst descent vector against all the LOWER
      !  Eigen Vectors.
      call LowerBandOrthogonalize(eigenStates, conjGrad%minVec%StDesc, &
                                  iBand, kpt)

      !  PreCondition the Steepest Descent Vector
      if (conjGrad%usePreconditioning) &
        call PreCondition(conjGrad%minVec, hamInfo, eigenStates, iBand, kpt)

      !  Orthogonalize the PreConditioned Vector against all bands, lower 
      !  than and including the current band.  In "NormalizeMod.f90"
      call BandOrthogonalize(eigenStates, conjGrad%minVec%StDesc, iBand, kpt)

      !  Evaluate the conjugate direction orthogonal to the current band 
      !  and normalized. (In "MinimizeMod.f90")
      call ConjDirection(eigenStates, conjGrad%minVec, conjGrad%oldVec, &
                         iBand, kpt)

      !  Do the line minimization by searching for the energy minimum,
      !  and update the band.
      call EnergyMinimize(grid, conjGrad%minVec, hamInfo, eigenStates, &
                          iBand, kpt)

    end do EigenVectorLoop

    ! Print the eigenvalue and residue
    write(cglogFile, 11) iBand, eigenStates%eigenValues(iBand, kpt), &
                        sqrt(residue), iter

11    format(1X,'Band:',I3,1X,'EigenValues:',E15.5,1X,'Norm Resed:',E15.5, &
             1X,'Iterations:',I3)

    if (residue > tolsq) &
      iBand = iBand - 1

  end subroutine CGMinimize


!-------------------------------------------------------------------------
  subroutine PreCondition(minVec, hamInfo, eigenStates, iBand, kpt)
!

    type(minVecT),         pointer   :: minVec
    type(hamInfoT),        pointer   :: hamInfo
    type(eigenStatesT),    pointer   :: eigenStates
    integer,         intent(in)      :: iBand, kpt
    
    real(double)                     :: eKinIter, x, k
    integer                          :: i


    !   There is a preconditioning matrix K for every iBand, to calculate
    ! that we have to first find the quantity x, the ratio of the kinetic
    ! energy with the kinetic energy of that iteration. this is different
    ! for every band

    eKinIter = zero
    do i = 1, eigenStates%numBasisVectors(kpt)
      eKinIter = eKinIter &
        + hamInfo%kineticEnergyG(i) &
        * conjg(eigenStates%eigenVectors(i, iBand, kpt)) &
        * eigenStates%eigenVectors(i, iBand, kpt)
    end do

    ! Calculate x and k, then precondition  - - - - - - - - - - - - - - -

    do i = 1, eigenStates%numBasisVectors(kpt)

      x = hamInfo%kineticEnergyG(i) &
        / (eKinIter + 1.d-30)

      k = (27.d0 + 18.d0 * x + 12.d0 * x**2 + 8.d0 * x**3) &
        / (27.d0 + 18.d0 * x + 12.d0 * x**2 + 8.d0 * x**3 + 16.d0 * x**4)

      ! Update the StDescVector
      minVec%StDesc(i) = k * minVec%StDesc(i)

    end do

  end subroutine PreCondition


!-------------------------------------------------------------------------
  subroutine EnergyMinimize(grid, minVec, hamInfo, eigenStates, iBand, kpt)
!

    type(GridT),           pointer   :: grid
    type(minVecT),         pointer   :: minVec
    type(hamInfoT),        pointer   :: hamInfo
    type(eigenStatesT),    pointer   :: eigenStates
    integer,         intent(in)      :: iBand, kpt

    integer                          :: nSize
    real(double)                     :: e0, theta, ePhi
    complex(double)                  :: phiHPsi


    nSize = eigenStates%hSize

    e0 = cmplx(eigenStates%eigenValues(iBand, kpt))

    phiHPsi = dot_product(minVec%Phi, eigenStates%ePsi(:, kpt))

    call CalculateHpsi(grid, hamInfo, eigenStates, kpt, &
           minVec%Phi, &                               ! psi   (in)
           minVec%EPhi)                                ! Hpsi  (out)

    minVec%EPhi(eigenStates%hSize + 1:eigenStates%hDim) &
      = cmplx(zero)

    ePhi = dot_product(minVec%Phi, minVec%EPhi)

    theta = half * abs(atan((two * real(phiHPsi) &
          / (e0 - ePhi + 10.d-30))))

    minVec%sine   = sin(theta)
    minVec%cosine = cos(theta)

    eigenStates%eigenVectors(:, iBand, kpt) &
      = eigenStates%eigenVectors(:, iBand, kpt) * minVec%cosine &
      + minVec%Phi * minVec%sine

    ! Update the Ritz values
    eigenStates%ePsi(:, kpt) &
      = minVec%cosine * eigenStates%ePsi(:, kpt) &
      + minVec%sine * minVec%EPhi

    ! Store the RitzValue in eigenStates
    eigenStates%eigenValues(iBand, kpt) &
      = dot_product(eigenStates%eigenVectors(:, iBand, kpt), &
                    eigenStates%ePsi(:, kpt))

  end subroutine EnergyMinimize


!-------------------------------------------------------------------------
  subroutine EigenStatesInit(eigenStates, iband, kpt)
!

    type(EigenStatesT),           pointer   :: eigenStates
    integer,                intent(in)      :: iband, kpt

    real(double)                            :: rnd1, rnd2
    integer                                 :: i

    eigenStates%ePsi(:, kpt) = cmplx(zero)
    eigenStates%eigenVectors(:, iband, kpt) = cmplx(zero)
    do i = 1, eigenStates%numBasisVectors(kpt)
      call random_number(rnd1)
      call random_number(rnd2)
      eigenStates%eigenVectors(i, iband, kpt) &
        = cmplx(two * rnd1 - one, two * rnd2 - one)
    end do


  end subroutine EigenStatesInit


!----------------------------------------------------------------------------
  subroutine MinVecInit(minVec)
!

    type(minVecT),      pointer   :: minVec

    minVec%cosine = one
    minVec%sine   = zero
    minVec%StDesc = cmplx(zero)
    minVec%Phi    = cmplx(zero)
    minVec%EPhi   = cmplx(zero)

  end subroutine MinVecInit

!----------------------------------------------------------------------------
  subroutine StDescentVec(eigenStates, minVec, iBand, kpt)
!

    type(eigenStatesT),        pointer   :: eigenStates
    type(minVecT),             pointer   :: minVec
    integer,             intent(in)      :: iBand, kpt

    !  Evaluate the Steepest Descent direction for iBand
    minVec%StDesc &
      = eigenStates%eigenValues(iBand, kpt) &
      * eigenStates%eigenVectors(:, iBand, kpt) &
      - eigenStates%ePsi(:, kpt)

  end subroutine StDescentVec


!----------------------------------------------------------------------------
  subroutine ConjDirection(eigenStates, minVec, oldVec, iBand, kpt)
!

    type(eigenStatesT),        pointer   :: eigenStates
    type(minVecT),             pointer   :: minVec, oldVec
    integer,             intent(in)      :: iBand, kpt

    real(double)                         :: Gamma, numer, denom

    !  Calculate Gamma, the phi is adjusted such that in the first 
    !  iteration Gamma becomes inconsequential
    numer = dot_product((minVec%StDesc - oldVec%StDesc), minVec%StDesc)
    denom = dot_product((minVec%StDesc - oldVec%StDesc), oldVec%Phi)

    !  Gamma = 0.d0 makes it the steepest descent method
    Gamma = -numer / (denom + 10.d-40)

    minVec%Phi = minVec%StDesc + Gamma * oldVec%Phi

    !  Orthogonalize Phi against the state-vector of the same band 'iBand'
    call SelfBandOrthogonalize(eigenStates, minVec%Phi, iBand, kpt)

    !  Normalize Phi
    call VecNormalize(minVec%Phi)

    ! Roll the current values into oldVec
    oldVec%cosine = minVec%cosine
    oldVec%sine   = minVec%sine
    oldVec%StDesc = minVec%StDesc
    oldVec%Phi    = minVec%Phi
    oldVec%EPhi   = minVec%EPhi

  end subroutine ConjDirection


!--------------------------------------------------------------------------
  subroutine EigenStatesOrth(eigenStates, iBand, kpt)
!
!   This Subroutine orthonormalizes the starting vector for iBand against
!   all the lower bands.
!

    type(eigenStatesT),         pointer   :: eigenStates
    integer,              intent(in)      :: iBand, kpt

    !  Orthogonalize to all the LOWER bands
    call LowerBandOrthogonalize(eigenStates, &
           eigenStates%eigenVectors(:, iBand, kpt), iBand, kpt)

    !  Normalize the state-vector
    call VecNormalize(eigenStates%eigenvectors(:, iBand, kpt))

  end subroutine EigenStatesOrth


!--------------------------------------------------------------------------
  subroutine LowerBandOrthogonalize(eigenStates, vector, iBand, kpt)
!
!   This Subroutine Orthogonalizes the vector corresponding to iBand
!   abainst all the bands (in eigenStates) below iBand.
!

    type(eigenStatesT),         pointer   :: eigenStates
    complex(double),      intent(inout)   :: vector(:)
    integer,              intent(in)      :: iBand, kpt
    
    integer                               :: jBand
    complex(double)                       :: orthfac

    do jBand = 1, iBand - 1

      !  Calculate the orthogonalising factor for the j th band wrt 
      !  the i th band.
      orthfac = dot_product(eigenStates%eigenVectors(:, jBand, kpt), &
                            vector)

      !  Calculate the Orthogonalising vector between i and j bands
      vector = vector &
             - orthfac * eigenStates%eigenvectors(:, jband, kpt)

    end do
    
    vector(eigenStates%hSize + 1:eigenStates%hDim) = cmplx(zero, zero)

  end subroutine LowerBandOrthogonalize


!--------------------------------------------------------------------------
  subroutine VecNormalize(vector)
!
!   This Subroutine normalizes vector
!

    complex(double),      intent(inout)   :: vector(:)

    vector = vector / (sqrt(dot_product(vector, vector)) + 1.d-32)

  end subroutine VecNormalize


!--------------------------------------------------------------------------
  subroutine SelfBandOrthogonalize(eigenStates, vector, iBand, kpt)
!
!   This Subroutine orthogonalizes the vector corresponding to iBand,
!   against the state-vector iBand contained in eigenStates.
!

    type(eigenStatesT),         pointer   :: eigenStates
    complex(double),      intent(inout)   :: vector(:)
    integer,              intent(in)      :: iBand, kpt

    complex(double)                       :: orthfac

    orthfac = dot_product(eigenStates%eigenVectors(:, iBand, kpt), &
                          vector)

    vector = vector - orthfac * eigenStates%eigenvectors(:, iBand, kpt)
       
    vector(eigenStates%hSize + 1:eigenStates%hDim) = cmplx(zero, zero)

  end subroutine SelfBandOrthogonalize


!--------------------------------------------------------------------------
  subroutine BandOrthogonalize(eigenStates, vector, iBand, kpt)
!
!   This Subroutine orthogonalizes the vector corresponding to iBand
!   against all the bands below and including iBand.
! 

    type(eigenStatesT),         pointer   :: eigenStates
    complex(double),      intent(inout)   :: vector(:)
    integer,              intent(in)      :: iBand, kpt

    !  First Orthogonalize against iBand
    call SelfBandOrthogonalize(eigenStates, vector, iBand, kpt)

    !  Next Orthogonalize that against all the lower bands
    call LowerBandOrthogonalize(eigenStates, vector, iBand, kpt)

  end subroutine BandOrthogonalize

end module ConjGradMod