import time
from typing import TYPE_CHECKING, Optional, Type
from .iteration_update import IterationUpdate
if TYPE_CHECKING:
from . import Context
[docs]
class ConvergenceCriteria:
'''
Abstract base class for determining convergence inside `iterate_ctx_se`. A
derived variant of this class will be instantiated by `iterate_ctx_se`
dependent upon its arguments. The default implementation is
`DefaultConvergenceCriteria`.
Parameters
----------
ctx : Context
The context being iterated.
JTol : float
The value of JTol passed to `iterate_ctx_se`.
popsTol : float
The value of popsTol passed to `iterate_ctx_se`.
rhoTol : float or None
The value of rhoTol passed to `iterate_ctx_se`.
'''
def __init__(self, ctx: 'Context', JTol: float, popsTol: float, rhoTol: Optional[float]):
raise NotImplementedError
[docs]
def is_converged(self, JUpdate: IterationUpdate, popsUpdate: IterationUpdate,
prdUpdate: Optional[IterationUpdate]) -> bool:
'''
This function takes the IterationUpdate objects from
`ctx.formal_sol_gamma_matrices` and `ctx.stat_equil` and optionally from
`ctx.prd_redistribute` (or None). Should return a bool indicated
whether the Context is sufficiently converged.
'''
raise NotImplementedError
[docs]
class DefaultConvergenceCriteria(ConvergenceCriteria):
'''
Default ConvergenceCriteria implementation. Usually sufficient for
statistical equilibrium problems, but you may occasionally need to override
this.
Parameters
----------
ctx : Context
The context being iterated.
JTol : float
The value of JTol passed to `iterate_ctx_se`.
popsTol : float
The value of popsTol passed to `iterate_ctx_se`.
rhoTol : float or None
The value of rhoTol passed to `iterate_ctx_se`.
'''
def __init__(self, ctx: 'Context', JTol: float, popsTol: float, rhoTol: Optional[float]):
self.ctx = ctx
self.JTol = JTol
self.popsTol = popsTol
self.rhoTol = rhoTol
[docs]
def is_converged(self, JUpdate: IterationUpdate, popsUpdate: IterationUpdate,
prdUpdate: Optional[IterationUpdate]) -> bool:
'''
Returns whether the context is converged.
'''
updates = [JUpdate, popsUpdate]
if prdUpdate is not None:
updates.append(prdUpdate)
terminate = True
for update in updates:
terminate = terminate and (update.dJMax < self.JTol)
terminate = terminate and (update.dPopsMax < self.popsTol)
if prdUpdate and self.rhoTol is not None:
terminate = terminate and (update.dRhoMax < self.rhoTol)
terminate = terminate and self.ctx.crswDone
return terminate
[docs]
def iterate_ctx_se(ctx: 'Context', Nscatter: int=3, NmaxIter: int=2000,
prd: bool=False, JTol: float=5e-3, popsTol: float=1e-3,
rhoTol: Optional[float]=None, prdIterTol: float=1e-2,
maxPrdSubIter: int=3, printInterval: float=0.2,
quiet: bool=False,
convergence: Optional[Type[ConvergenceCriteria]]=None,
returnFinalConvergence: bool=False):
'''
Iterate a configured Context towards statistical equilibrium solution.
Parameters
----------
ctx : Context
The context to iterate.
Nscatter : int, optional
The number of lambda iterations to perform for an initial estimate of J
(default: 3).
NmaxIter : int, optional
The maximum number of iterations (including Nscatter) to take (default:
2000).
prd: bool, optional
Whether to perform PRD subiterations to estimate rho for PRD lines
(default: False).
JTol: float, optional
The maximum relative change in J from one iteration to the next
(default: 5e-3).
popsTol: float, optional
The maximum relative change in an atomic population from one iteration
to the next (default: 1e-3).
rhoTol: float, optional
The maximum relative change in rho for a PRD line on the final
subiteration from one iteration to the next. If None, the change in rho
will not be considered in judging convergence (default: None).
prdIterTol: float, optional
The maximum relative change in rho for a PRD line below which PRD
subiterations will cease for this iteration (default: 1e-2).
maxPrdSubIter : int, optional
The maximum number of PRD subiterations to make, whether or not rho has
reached the tolerance of prdIterTol (which isn't necessary every
iteration). (Default: 3)
printInterval : float, optional
The interval between printing the update size information in seconds. A
value of 0.0 will print every iteration (default: 0.2).
quiet : bool, optional
Overrides any other print arguments and iterates silently if True.
(Default: False).
convergence : derived ConvergenceCriteria class, optional
The ConvergenceCriteria version to be used in determining convergence.
Will be instantiated by this function, and the `is_converged` method
will then be used. (Default: DefaultConvergenceCriteria).
returnFinalConvergence : bool, optional
Whether to return the IterationUpdate objects used in the final
convergence decision, if True, these will be returned in a list as the
second return value. (Default: False).
Returns
-------
it : int
The number of iterations taken.
finalIterationUpdates : List[IterationUpdate], optional
The final IterationUpdates computed, if requested by `returnFinalConvergence`.
'''
prevPrint = 0.0
printNow = True
alwaysPrint = (printInterval == 0.0)
startTime = time.time()
if convergence is None:
convergence = DefaultConvergenceCriteria
conv = convergence(ctx, JTol, popsTol, rhoTol)
for it in range(NmaxIter):
JUpdate : IterationUpdate = ctx.formal_sol_gamma_matrices()
if (not quiet and
(alwaysPrint or ((now := time.time()) >= prevPrint + printInterval))):
printNow = True
if not alwaysPrint:
prevPrint = now
if not quiet and printNow:
print(f'-- Iteration {it}:')
print(JUpdate.compact_representation())
if it < Nscatter:
if not quiet and printNow:
print(' (Lambda iterating background)')
# NOTE(cmo): reset print state
printNow = False
continue
popsUpdate : IterationUpdate = ctx.stat_equil()
if not quiet and printNow:
print(popsUpdate.compact_representation())
dRhoUpdate : Optional[IterationUpdate]
if prd:
dRhoUpdate = ctx.prd_redistribute(maxIter=maxPrdSubIter, tol=prdIterTol)
if not quiet and printNow and dRhoUpdate is not None:
print(dRhoUpdate.compact_representation())
else:
dRhoUpdate = None
terminate = conv.is_converged(JUpdate, popsUpdate, dRhoUpdate)
if terminate:
if not quiet:
endTime = time.time()
duration = endTime - startTime
line = '-' * 80
if printNow:
print('Final Iteration shown above.')
else:
print(line)
print(f'Final Iteration: {it}')
print(line)
print(JUpdate.compact_representation())
print(popsUpdate.compact_representation())
if prd and dRhoUpdate is not None:
print(dRhoUpdate.compact_representation())
print(line)
print(f'Context converged to statistical equilibrium in {it}'
f' iterations after {duration:.2f} s.')
print(line)
if returnFinalConvergence:
finalConvergence = [JUpdate, popsUpdate]
if prd and dRhoUpdate is not None:
finalConvergence.append(dRhoUpdate)
return it, finalConvergence
else:
return it
# NOTE(cmo): reset print state
printNow = False
else:
if not quiet:
line = '-' * 80
endTime = time.time()
duration = endTime - startTime
print(line)
print(f'Final Iteration: {it}')
print(line)
print(JUpdate.compact_representation())
print(popsUpdate.compact_representation())
if prd and dRhoUpdate is not None:
print(dRhoUpdate.compact_representation())
print(line)
print(f'Context FAILED to converge to statistical equilibrium after {it}'
f' iterations (took {duration:.2f} s).')
print(line)
if returnFinalConvergence:
finalConvergence = [JUpdate, popsUpdate]
if prd and dRhoUpdate is not None:
finalConvergence.append(dRhoUpdate)
return it, finalConvergence
else:
return it