Skip to content

fitting

fit_tods(model, todvec, maxiter=10, chitol=1e-05)

Fit a model to TODs. This uses a modified Levenberg–Marquardt fitter with flat priors. This function is MPI aware.

Parameters:

Name Type Description Default
model Model

The model object that defines the model and grid we are fitting with.

required
todvec TODVec

The TODs to fit. The todvec.comm object is used to fit in an MPI aware way.

required
maxiter int

The maximum number of iterations to fit.

10
chitol float

The delta chisq to use as the convergence criteria.

1e-5

Returns:

Name Type Description
model Model

Model with the final set of fit parameters, errors, and chisq.

final_iter int

The number of iterations the fitter ran for.

delta_chisq float

The final delta chisq.

Source code in witch/fitting.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
@partial(jax.jit, static_argnums=(2, 3))
def fit_tods(
    model: Model, todvec: TODVec, maxiter: int = 10, chitol: float = 1e-5
) -> tuple[Model, int, float]:
    """
    Fit a model to TODs.
    This uses a modified Levenberg–Marquardt fitter with flat priors.
    This function is MPI aware.

    Parameters
    ----------
    model : Model
        The model object that defines the model and grid we are fitting with.
    todvec : TODVec
        The TODs to fit.
        The `todvec.comm` object is used to fit in an MPI aware way.
    maxiter : int, default: 10
        The maximum number of iterations to fit.
    chitol : float, default: 1e-5
        The delta chisq to use as the convergence criteria.

    Returns
    -------
    model : Model
        Model with the final set of fit parameters, errors, and chisq.
    final_iter : int
        The number of iterations the fitter ran for.
    delta_chisq : float
        The final delta chisq.
    """
    zero = jnp.array(0.0)

    def _cond_func(val):
        i, delta_chisq, lmd, *_ = val
        iterbool = jax.lax.lt(i, maxiter)
        chisqbool = jax.lax.ge(delta_chisq, chitol) + jax.lax.gt(lmd, zero)
        return iterbool * chisqbool

    def _body_func(val):
        i, delta_chisq, lmd, model, curve, grad = val
        curve_use = curve.at[:].add(lmd * jnp.diag(jnp.diag(curve)))
        # Get the step
        step = jnp.dot(invscale(curve_use), grad)
        new_pars, to_fit = _prior_pars_fit(
            model.priors, model.pars.at[:].add(step), jnp.array(model.to_fit)
        )
        # Get errs
        errs = jnp.where(to_fit, jnp.sqrt(jnp.diag(invscale(curve_use))), 0)
        # Now lets get an updated model
        new_model, new_grad, new_curve = objective(new_pars, model, todvec, errs)

        new_delta_chisq = model.chisq - new_model.chisq
        model, grad, curve, delta_chisq, lmd = jax.lax.cond(
            new_delta_chisq > 0,
            _success,
            _failure,
            model,
            new_model,
            grad,
            new_grad,
            curve,
            new_curve,
            delta_chisq,
            new_delta_chisq,
            lmd,
        )

        return (i + 1, delta_chisq, lmd, model, curve, grad)

    pars, _ = _prior_pars_fit(model.priors, model.pars, jnp.array(model.to_fit))
    model, grad, curve = objective(pars, model, todvec, model.errs)
    i, delta_chisq, _, model, *_ = jax.lax.while_loop(
        _cond_func, _body_func, (0, jnp.inf, zero, model, curve, grad)
    )

    return model, i, delta_chisq

invsafe(matrix, thresh=1e-14)

Safe SVD based psuedo-inversion of the matrix. This zeros out modes that are too small when inverting. Use with caution in cases where you really care about what the inverse is.

Parameters:

Name Type Description Default
matrix Array

The matrix to invert. Should be a (n, n) array.

required
thresh float

Threshold at which to zero out a mode.

1e-14

Returns:

Name Type Description
invmat Array

The inverted matrix. Same shape as matrix.

Source code in witch/fitting.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
@jax.jit
def invsafe(matrix: jax.Array, thresh: float = 1e-14) -> jax.Array:
    """
    Safe SVD based psuedo-inversion of the matrix.
    This zeros out modes that are too small when inverting.
    Use with caution in cases where you really care about what the inverse is.

    Parameters
    ----------
    matrix : jax.Array
        The matrix to invert.
        Should be a `(n, n)` array.
    thresh : float, default: 1e-14
        Threshold at which to zero out a mode.

    Returns
    -------
    invmat: jax.Array
        The inverted matrix.
        Same shape as `matrix`.
    """
    u, s, v = jnp.linalg.svd(matrix, False)
    s_inv = jnp.array(jnp.where(jnp.abs(s) < thresh * jnp.max(s), 0, 1 / s))

    return jnp.dot(jnp.transpose(v), jnp.dot(jnp.diag(s_inv), jnp.transpose(u)))

invscale(matrix, thresh=1e-14)

Invert and rescale a matrix by the diagonal. This uses invsafe for the inversion.

Parameters:

Name Type Description Default
Parameters
required
matrix Array

The matrix to invert and sxane. Should be a (n, n) array.

required
thresh float

Threshold for invsafe. See that function for more info.

1e-14

Returns:

Name Type Description
invmat Array

The inverted and rescaled matrix. Same shape as matrix.

Source code in witch/fitting.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
@jax.jit
def invscale(matrix: jax.Array, thresh: float = 1e-14) -> jax.Array:
    """
    Invert and rescale a matrix by the diagonal.
    This uses `invsafe` for the inversion.

    Parameters
    ----------
    Parameters
    ----------
    matrix : jax.Array
        The matrix to invert and sxane.
        Should be a `(n, n)` array.
    thresh : float, default: 1e-14
        Threshold for `invsafe`.
        See that function for more info.

    Returns
    -------
    invmat: jax.Array
        The inverted and rescaled matrix.
        Same shape as `matrix`.
    """
    diag = jnp.diag(matrix)
    vec = jnp.array(jnp.where(diag != 0, 1.0 / jnp.sqrt(jnp.abs(diag)), 1e-10))
    mm = jnp.outer(vec, vec)

    return mm * invsafe(mm * matrix, thresh)

objective(pars, model, todvec, errs)

Objective function to minimize when fitting. This is also responsible for updating our model with the current guess. This is an MPI aware function.

Parameters:

Name Type Description Default
pars Array

New parameters for our model.

required
model Model

The model object we are using to fit.

required
todvec TODVec

The TODs to fit against. This is what we use to compute our fit residuals.

required
errs Array

The error on pars, used to update the model state.

required

Returns:

Name Type Description
new_model Model

An updated model object. This contains the newly computed chisq for the input pars. Also is updated with input pars and errs.

grad Array

The gradient of the paramaters at there current values. This is a (npar,) array.

curve Array

The curvature of the parameter space at the current values. This is a (npar, npar) array.

Source code in witch/fitting.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
@jax.jit
def objective(
    pars: jax.Array, model: Model, todvec: TODVec, errs: jax.Array
) -> tuple[Model, jax.Array, jax.Array]:
    """
    Objective function to minimize when fitting.
    This is also responsible for updating our model with the current guess.
    This is an MPI aware function.

    Parameters
    ----------
    pars : jax.Array
        New parameters for our model.
    model : Model
        The model object we are using to fit.
    todvec : TODVec
        The TODs to fit against.
        This is what we use to compute our fit residuals.
    errs : jax.Array
        The error on `pars`, used to update the model state.

    Returns
    -------
    new_model : Model
        An updated model object.
        This contains the newly computed `chisq` for the input `pars`.
        Also is updated with input `pars` and `errs`.
    grad : jax.Array
        The gradient of the paramaters at there current values.
        This is a `(npar,)` array.
    curve : jax.Array
        The curvature of the parameter space at the current values.
        This is a `(npar, npar)` array.
    """
    npar = len(pars)
    new_model = model.update(pars, errs, model.chisq)
    chisq = jnp.array(0)
    grad = jnp.zeros(npar)
    curve = jnp.zeros((npar, npar))
    for tod in todvec:
        x = tod.x * wu.rad_to_arcsec
        y = tod.y * wu.rad_to_arcsec

        pred_tod, grad_tod = new_model.to_tod_grad(x, y)
        ndet, nsamp = tod.shape

        resid = tod.data - pred_tod
        resid_filt = tod.noise.apply_noise(resid)
        chisq += jnp.sum(resid * resid_filt)

        grad_filt = jnp.zeros_like(grad_tod)
        for i in range(npar):
            grad_filt = grad_filt.at[i].set(tod.noise.apply_noise(grad_tod.at[i].get()))
        grad_filt = jnp.reshape(grad_filt, (npar, ndet * nsamp))
        grad_tod = jnp.reshape(grad_tod, (npar, ndet * nsamp))
        resid = jnp.reshape(resid, (ndet * nsamp,))

        grad = grad.at[:].add(jnp.dot(grad_filt, jnp.transpose(resid)))
        curve = curve.at[:].add(jnp.dot(grad_filt, jnp.transpose(grad_tod)))

    chisq, _ = mpi4jax.allreduce(chisq, MPI.SUM, comm=todvec.comm)
    grad, _ = mpi4jax.allreduce(grad, MPI.SUM, comm=todvec.comm)
    curve, _ = mpi4jax.allreduce(curve, MPI.SUM, comm=todvec.comm)

    new_model = new_model.update(pars, errs, chisq)

    return new_model, grad, curve