Skip to content

fitting

fit_dataset(model, dataset, maxiter=10, chitol=1e-05)

Fit a model to TODs. This uses a modified Levenberg–Marquardt fitter with flat priors. This function is MPI aware.

Parameters:

Name Type Description Default
model Model

The model object that defines the model and grid we are fitting with.

required
dataset DataSet

The dataset to fit. The dataset.datavec.comm object is used to fit in an MPI aware way.

required
maxiter int

The maximum number of iterations to fit.

10
chitol float

The delta chisq to use as the convergence criteria.

1e-5

Returns:

Name Type Description
model Model

Model with the final set of fit parameters, errors, and chisq.

final_iter int

The number of iterations the fitter ran for.

delta_chisq float

The final delta chisq.

Source code in witch/fitting.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
@partial(jax.jit, static_argnums=(2, 3))
def fit_dataset(
    model: Model,
    dataset: DataSet,
    maxiter: int = 10,
    chitol: float = 1e-5,
) -> tuple[Model, int, float]:
    """
    Fit a model to TODs.
    This uses a modified Levenberg–Marquardt fitter with flat priors.
    This function is MPI aware.

    Parameters
    ----------
    model : Model
        The model object that defines the model and grid we are fitting with.
    dataset : DataSet
        The dataset to fit.
        The `dataset.datavec.comm` object is used to fit in an MPI aware way.
    maxiter : int, default: 10
        The maximum number of iterations to fit.
    chitol : float, default: 1e-5
        The delta chisq to use as the convergence criteria.

    Returns
    -------
    model : Model
        Model with the final set of fit parameters, errors, and chisq.
    final_iter : int
        The number of iterations the fitter ran for.
    delta_chisq : float
        The final delta chisq.
    """
    if dataset.mode not in ["tod", "map"]:
        raise ValueError("Invalid mode")
    zero = jnp.array(0.0)

    def _cond_func(val):
        i, delta_chisq, lmd, *_ = val
        iterbool = jax.lax.lt(i, maxiter)
        chisqbool = jax.lax.ge(delta_chisq, chitol) + jax.lax.gt(lmd, zero)
        return iterbool * chisqbool

    def _body_func(val):
        i, delta_chisq, lmd, model, curve, grad = val
        curve_use = curve.at[:].add(lmd * jnp.diag(jnp.diag(curve)))
        # Get the step
        step = jnp.dot(invscale(curve_use), grad)
        new_pars, to_fit = _prior_pars_fit(
            model.priors, model.pars.at[:].add(step), jnp.array(model.to_fit)
        )
        # Get errs
        errs = jnp.where(to_fit, jnp.sqrt(jnp.diag(invscale(curve_use))), 0)
        # Now lets get an updated model
        new_model = deepcopy(model).update(new_pars, model.errs, model.chisq)
        new_chisq, new_grad, new_curve = dataset.objective(
            new_model, dataset.datavec, dataset.mode, True, True, True
        )
        new_model = new_model.update(new_pars, errs, new_chisq)

        new_delta_chisq = model.chisq - new_model.chisq
        model, grad, curve, delta_chisq, lmd = jax.lax.cond(
            new_delta_chisq > 0,
            _success,
            _failure,
            model,
            new_model,
            grad,
            new_grad,
            curve,
            new_curve,
            delta_chisq,
            new_delta_chisq,
            lmd,
        )

        return (i + 1, delta_chisq, lmd, model, curve, grad)

    pars, _ = _prior_pars_fit(model.priors, model.pars, jnp.array(model.to_fit))
    model = model.update(pars, model.errs, model.chisq)
    chisq, grad, curve = dataset.objective(
        model, dataset.datavec, dataset.mode, True, True, True
    )
    model = model.update(pars, model.errs, chisq)
    i, delta_chisq, _, model, *_ = jax.lax.while_loop(
        _cond_func, _body_func, (0, jnp.inf, zero, model, curve, grad)
    )

    return model, i, delta_chisq

hmc(params, log_prob, log_prob_grad, num_steps, num_leaps, step_size, comm, key)

Runs Hamilonian Monte Carlo using a leapfrog integrator to approximate Hamilonian dynamics. This is a naive implementaion that will be replaced in the future.

The parallelism model employed here is different that most samplers where each task runs a subset of the chain, instead since the rest of WITCH employs a model where the data is distributed across tasks we do that here as well. In this model the chain evolves simultaneously in all tasks, but only rank 0 actually stores the chain.

Parameters:

Name Type Description Default
params Array

The initial parameters to start the chain at.

required
log_prob Callable[[Array], Array]

Function that returns the log probability of the model for a given set of params. This should take params as its first arguments, all other arguments should be fixed ahead of time (ie: using functools.partial).

required
log_prob_grad Callable[[Array], Array]

Function that returns the gradient log probability of the model for a given set of params. This should take params as its first arguments, all other arguments should be fixed ahead of time (ie: using functools.partial). The returned gradient should have shape (len(params),).

required
num_steps int

The number of steps to run the chain for.

required
num_leaps int

The number of leapfrog steps to run at each step of the chain.

required
step_size float

The step size to use. At each leapfrog step the parameters will evolve by step_size*momentum.

required
comm Intracomm

The MPI comm object to use.

required

Returns:

Name Type Description
chain Array

The chain of samples. Will have shape (num_steps, len(params)) in the rank 0 task. Note that on tasks with rank other than 0 the actual chain is not returned, instead a dummy array of size (0,) is returned.

Source code in witch/fitting.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
def hmc(
    params: jax.Array,
    log_prob: Callable[[jax.Array], jax.Array],
    log_prob_grad: Callable[[jax.Array], jax.Array],
    num_steps: int,
    num_leaps: int,
    step_size: float,
    comm: MPI.Intracomm,
    key: jax.Array,
) -> jax.Array:
    """
    Runs Hamilonian Monte Carlo using a leapfrog integrator to approximate Hamilonian dynamics.
    This is a naive implementaion that will be replaced in the future.

    The parallelism model employed here is different that most samplers where each task runs
    a subset of the chain, instead since the rest of WITCH employs a model where the data is
    distributed across tasks we do that here as well.
    In this model the chain evolves simultaneously in all tasks,
    but only rank 0 actually stores the chain.

    Parameters
    ----------
    params : jax.Array
        The initial parameters to start the chain at.
    log_prob : Callable[[jax.Array], jax.Array]
        Function that returns the log probability of the model
        for a given set of params. This should take `params` as its
        first arguments, all other arguments should be fixed ahead of time
        (ie: using `functools.partial`).
    log_prob_grad : Callable[[jax.Array], jax.Array]
        Function that returns the gradient log probability of the model
        for a given set of params. This should take `params` as its
        first arguments, all other arguments should be fixed ahead of time
        (ie: using `functools.partial`). The returned gradient should have
        shape `(len(params),)`.
    num_steps : int
        The number of steps to run the chain for.
    num_leaps : int
        The number of leapfrog steps to run at each step of the chain.
    step_size : float
        The step size to use.
        At each leapfrog step the parameters will evolve by `step_size`*`momentum`.
    comm : MPI.Intracomm
        The MPI comm object to use.

    Returns
    -------
    chain : jax.Array
        The chain of samples.
        Will have shape `(num_steps, len(params))` in the rank 0 task.
        Note that on tasks with rank other than 0 the actual
        chain is not returned, instead a dummy array of size `(0,)` is
        returned.
    """
    rank = comm.Get_rank()
    vnorm = jax.vmap(
        partial(jax.random.normal, shape=params[0].shape, dtype=params.dtype)
    )
    npar = len(params)
    ones = jnp.ones(npar, dtype=bool)

    @jax.jit
    def _leap(_, args):
        params, momentum = args
        momentum = momentum.at[:].add(0.5 * step_size * log_prob_grad(params))  # kick
        params = params.at[:].add(step_size * momentum)  # drift
        momentum = momentum.at[:].add(0.5 * step_size * log_prob_grad(params))  # kick

        return params, momentum

    @jax.jit
    def _sample(key, params):
        token = mpi4jax.barrier(comm=comm)
        key, token = mpi4jax.bcast(key, 0, comm=comm, token=token)
        key, uniform_key = jax.random.split(key, 2)

        # generate random momentum
        momentum = vnorm(jax.random.split(key, npar))
        new_params, new_momentum = jax.lax.fori_loop(
            0, num_leaps, _leap, (params, momentum)
        )

        # MH correction
        dpe = log_prob(new_params) - log_prob(params)
        dke = -0.5 * (jnp.sum(new_momentum**2) - jnp.sum(momentum**2))
        log_accept = dke + dpe
        accept_prob = jnp.minimum(jnp.exp(log_accept), 1)
        accept = jax.random.uniform(uniform_key) < accept_prob
        params = jax.lax.select(accept * ones, new_params, params)

        return key, params, accept_prob

    t0 = time.time()
    l_sample = _sample.lower(key, params)
    c_sample = l_sample.compile()
    t1 = time.time()
    if rank == 0:
        print(f"Compiled MC sample function in {t1-t0} s")

    chain = []
    accept_prob = []
    for _ in tqdm(range(num_steps), disable=(rank != 0)):
        key, params, prob = c_sample(key, params)
        if rank == 0:
            chain += [params]
            accept_prob += [prob]
    if rank == 0:
        chain = jnp.vstack(chain)
        accept_prob = jnp.array(accept_prob)
        print(f"Accepted {accept_prob.mean():.2%} of samples")
    else:
        chain = jnp.zeros(0)
    return chain

invsafe(matrix, thresh=1e-14)

Safe SVD based psuedo-inversion of the matrix. This zeros out modes that are too small when inverting. Use with caution in cases where you really care about what the inverse is.

Parameters:

Name Type Description Default
matrix Array

The matrix to invert. Should be a (n, n) array.

required
thresh float

Threshold at which to zero out a mode.

1e-14

Returns:

Name Type Description
invmat Array

The inverted matrix. Same shape as matrix.

Source code in witch/fitting.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
@jax.jit
def invsafe(matrix: jax.Array, thresh: float = 1e-14) -> jax.Array:
    """
    Safe SVD based psuedo-inversion of the matrix.
    This zeros out modes that are too small when inverting.
    Use with caution in cases where you really care about what the inverse is.

    Parameters
    ----------
    matrix : jax.Array
        The matrix to invert.
        Should be a `(n, n)` array.
    thresh : float, default: 1e-14
        Threshold at which to zero out a mode.

    Returns
    -------
    invmat: jax.Array
        The inverted matrix.
        Same shape as `matrix`.
    """
    u, s, v = jnp.linalg.svd(matrix, False)
    s_inv = jnp.array(jnp.where(jnp.abs(s) < thresh * jnp.max(s), 0, 1 / s))

    return jnp.dot(jnp.transpose(v), jnp.dot(jnp.diag(s_inv), jnp.transpose(u)))

invscale(matrix, thresh=1e-14)

Invert and rescale a matrix by the diagonal. This uses invsafe for the inversion.

Parameters:

Name Type Description Default
Parameters
required
matrix Array

The matrix to invert and sxane. Should be a (n, n) array.

required
thresh float

Threshold for invsafe. See that function for more info.

1e-14

Returns:

Name Type Description
invmat Array

The inverted and rescaled matrix. Same shape as matrix.

Source code in witch/fitting.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
@jax.jit
def invscale(matrix: jax.Array, thresh: float = 1e-14) -> jax.Array:
    """
    Invert and rescale a matrix by the diagonal.
    This uses `invsafe` for the inversion.

    Parameters
    ----------
    Parameters
    ----------
    matrix : jax.Array
        The matrix to invert and sxane.
        Should be a `(n, n)` array.
    thresh : float, default: 1e-14
        Threshold for `invsafe`.
        See that function for more info.

    Returns
    -------
    invmat: jax.Array
        The inverted and rescaled matrix.
        Same shape as `matrix`.
    """
    diag = jnp.diag(matrix)
    vec = jnp.array(jnp.where(diag != 0, 1.0 / jnp.sqrt(jnp.abs(diag)), 1e-10))
    mm = jnp.outer(vec, vec)

    return mm * invsafe(mm * matrix, thresh)

run_mcmc(model, dataset, num_steps=5000, num_leaps=10, step_size=0.02, sample_which=-1)

Run MCMC using the emcee package to estimate the posterior for our model. Currently this function only support flat priors, but more will be supported down the line. In order to ensure accuracy of the noise model used, it is reccomended that you run at least one round of fit_tods followed by noise reestimation before this function.

This is MPI aware. Eventually this will be replaced with something more jaxy.

Parameters:

Name Type Description Default
model Model

The model to run MCMC on. We expect that all parameters in this model have priors defined.

required
dataset DataSet

The dataset to compute the model posterior with. The dataset.datavec.comm object is used to fit in an MPI aware way.

required
num_steps int

The number of steps to run MCMC for.

5000
num_leaps int

The number of leapfrog steps to take at each sample.

10
step_size float

The step size to use in the leapfrog algorithm. This should be tuned to get an acceptance fraction of ~.65.

0.02
default float

The step size to use in the leapfrog algorithm. This should be tuned to get an acceptance fraction of ~.65.

0.02
sample_which int

Sets which parameters to sample. If this is >= 0 then we will sample which ever parameters were fit in that round of fitting. If this is -1 then we will sample which ever parameters were fit in the last round of fitting. If this is -2 then any parameters that were ever fit will be sampled. If this is <= -3 or >= model.n_rounds then all parameters are sampled.

-1,

Returns:

Name Type Description
model Model

The model with MCMC estimated parameters and errors. The parameters are estimated as the mean of the samples. The errors are estimated as the standard deviation. This also has the chi-squared of the estimated parameters.

flat_samples Array

Array of samples from running MCMC.

Source code in witch/fitting.py
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
def run_mcmc(
    model: Model,
    dataset: DataSet,
    num_steps: int = 5000,
    num_leaps: int = 10,
    step_size: float = 0.02,
    sample_which: int = -1,
) -> tuple[Model, jax.Array]:
    """
    Run MCMC using the `emcee` package to estimate the posterior for our model.
    Currently this function only support flat priors, but more will be supported
    down the line. In order to ensure accuracy of the noise model used, it is
    reccomended that you run at least one round of `fit_tods` followed by noise
    reestimation before this function.

    This is MPI aware.
    Eventually this will be replaced with something more jaxy.

    Parameters
    ----------
    model : Model
        The model to run MCMC on.
        We expect that all parameters in this model have priors defined.
    dataset : DataSet
        The dataset to compute the model posterior with.
        The `dataset.datavec.comm` object is used to fit in an MPI aware way.
    num_steps : int, default: 5000
        The number of steps to run MCMC for.
    num_leaps: int, default: 10
        The number of leapfrog steps to take at each sample.
    step_size, default: .02
        The step size to use in the leapfrog algorithm.
        This should be tuned to get an acceptance fraction of ~.65.
    sample_which : int, default: -1,
        Sets which parameters to sample.
        If this is >= 0 then we will sample which ever parameters were
        fit in that round of fitting.
        If this is -1 then we will sample which ever parameters were fit
        in the last round of fitting.
        If this is -2 then any parameters that were ever fit will be sampled.
        If this is <= -3 or >= `model.n_rounds` then all parameters are sampled.

    Returns
    -------
    model : Model
        The model with MCMC estimated parameters and errors.
        The parameters are estimated as the mean of the samples.
        The errors are estimated as the standard deviation.
        This also has the chi-squared of the estimated parameters.
    flat_samples : jax.Array
        Array of samples from running MCMC.

    """
    token = mpi4jax.barrier(comm=dataset.datavec.comm)
    rank = dataset.datavec.comm.Get_rank()

    if sample_which >= 0 and sample_which < model.n_rounds:
        model.cur_round = sample_which
        to_fit = model.to_fit
    elif sample_which == -1:
        to_fit = model.to_fit
    elif sample_which == -2:
        to_fit = model.to_fit_ever
    else:
        to_fit = jnp.ones_like(model.to_fit_ever, dtype=bool)
    to_fit = jnp.array(to_fit)
    model = model.add_round(jnp.array(to_fit))

    init_pars = jnp.array(model.pars)
    init_errs = jnp.zeros_like(model.pars)
    final_pars = init_pars.copy()
    final_errs = init_errs.copy()

    prior_l, prior_u = model.priors
    scale = (jnp.abs(prior_l) + jnp.abs(prior_u)) / 2.0
    scale = jnp.where(scale == 0, 1, scale)
    init_pars = init_pars.at[:].multiply(1.0 / scale)
    npar = jnp.sum(to_fit)

    def _is_inf(pars, model):
        _ = (pars, model)
        return -1 * jnp.inf

    def _not_inf(pars, model):
        pars, _ = mpi4jax.bcast(pars, 0, comm=dataset.datavec.comm)
        temp_model = model.update(pars, init_errs, model.chisq)
        chisq, *_ = dataset.objective(
            temp_model, dataset.datavec, dataset.mode, True, False, False
        )
        log_like = -0.5 * chisq
        return log_like

    @jax.jit
    def _log_prob(pars, model=model, init_pars=init_pars):
        full_pars = init_pars.at[to_fit].set(pars)
        full_pars = full_pars.at[:].multiply(scale)
        _, in_bounds = _prior_pars_fit(model.priors, full_pars, jnp.array(model.to_fit))
        log_prior = jnp.sum(
            jnp.where(in_bounds.at[model.to_fit].get(), 0, -1 * jnp.inf)
        )
        return jax.lax.cond(
            jnp.isfinite(log_prior),
            _not_inf,
            _is_inf,
            full_pars,
            model,
        )

    def _is_inf_grad(pars, model, scale):
        _ = (pars, model, scale)
        return jnp.inf * jnp.ones(npar)

    def _not_inf_grad(pars, model, scale):
        pars, _ = mpi4jax.bcast(pars, 0, comm=dataset.datavec.comm)
        temp_model = model.update(pars, init_errs, model.chisq)
        _, grad, _ = dataset.objective(
            temp_model, dataset.datavec, dataset.mode, False, True, False
        )
        grad = grad.at[:].multiply(scale)
        return grad.at[to_fit].get().ravel()

    @jax.jit
    def _log_prob_grad(pars, model=model, init_pars=init_pars):
        full_pars = init_pars.at[to_fit].set(pars)
        full_pars = full_pars.at[:].multiply(scale)
        _, in_bounds = _prior_pars_fit(model.priors, full_pars, jnp.array(model.to_fit))
        log_prior = jnp.sum(jnp.where(in_bounds.at[to_fit].get(), 0, -1 * jnp.inf))
        return jax.lax.cond(
            jnp.isfinite(log_prior),
            _not_inf_grad,
            _is_inf_grad,
            full_pars,
            model,
            scale,
        )

    key = jax.random.PRNGKey(0)
    key, token = mpi4jax.bcast(key, 0, comm=dataset.datavec.comm, token=token)
    chain = hmc(
        init_pars.at[to_fit].get().ravel(),
        _log_prob,
        _log_prob_grad,
        num_steps=num_steps,
        num_leaps=num_leaps,
        step_size=step_size,
        comm=dataset.datavec.comm,
        key=key,
    )
    flat_samples = chain.at[:].multiply(scale.at[to_fit].get())
    if rank == 0:
        final_pars = final_pars.at[to_fit].set(jnp.median(flat_samples, axis=0).ravel())
        final_errs = final_errs.at[to_fit].set(jnp.std(flat_samples, axis=0).ravel())
    final_pars, token = mpi4jax.bcast(
        final_pars, 0, comm=dataset.datavec.comm, token=token
    )
    final_errs, _ = mpi4jax.bcast(final_errs, 0, comm=dataset.datavec.comm, token=token)
    model = model.update(
        final_pars.block_until_ready(), final_errs.block_until_ready(), model.chisq
    )
    chisq, *_ = dataset.objective(
        model, dataset.datavec, dataset.mode, True, False, False
    )
    model = model.update(final_pars, final_errs, chisq)

    return model, flat_samples