Skip to content

fitting

fit_dataset(model, dataset, maxiter=10, chitol=1e-05, mode='tod')

Fit a model to TODs. This uses a modified Levenberg–Marquardt fitter with flat priors. This function is MPI aware.

Parameters:

Name Type Description Default
model Model

The model object that defines the model and grid we are fitting with.

required
dataset TODVec | SolutionSet

The data to fit. The dataset.comm object is used to fit in an MPI aware way.

required
maxiter int

The maximum number of iterations to fit.

10
chitol float

The delta chisq to use as the convergence criteria.

1e-5
mode str

The type of data we compile this function for. Should be either "tod" or "map".

"tod"

Returns:

Name Type Description
model Model

Model with the final set of fit parameters, errors, and chisq.

final_iter int

The number of iterations the fitter ran for.

delta_chisq float

The final delta chisq.

Source code in witch/fitting.py
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
@partial(jax.jit, static_argnums=(2, 3, 4))
def fit_dataset(
    model: Model,
    dataset: TODVec | SolutionSet,
    maxiter: int = 10,
    chitol: float = 1e-5,
    mode: str = "tod",
) -> tuple[Model, int, float]:
    """
    Fit a model to TODs.
    This uses a modified Levenberg–Marquardt fitter with flat priors.
    This function is MPI aware.

    Parameters
    ----------
    model : Model
        The model object that defines the model and grid we are fitting with.
    dataset : TODVec | SolutionSet
        The data to fit.
        The `dataset.comm` object is used to fit in an MPI aware way.
    maxiter : int, default: 10
        The maximum number of iterations to fit.
    chitol : float, default: 1e-5
        The delta chisq to use as the convergence criteria.
    mode : str, default: "tod"
        The type of data we compile this function for.
        Should be either "tod" or "map".

    Returns
    -------
    model : Model
        Model with the final set of fit parameters, errors, and chisq.
    final_iter : int
        The number of iterations the fitter ran for.
    delta_chisq : float
        The final delta chisq.
    """
    if mode not in ["tod", "map"]:
        raise ValueError("Invalid mode")
    zero = jnp.array(0.0)

    def _cond_func(val):
        i, delta_chisq, lmd, *_ = val
        iterbool = jax.lax.lt(i, maxiter)
        chisqbool = jax.lax.ge(delta_chisq, chitol) + jax.lax.gt(lmd, zero)
        return iterbool * chisqbool

    def _body_func(val):
        i, delta_chisq, lmd, model, curve, grad = val
        curve_use = curve.at[:].add(lmd * jnp.diag(jnp.diag(curve)))
        # Get the step
        step = jnp.dot(invscale(curve_use), grad)
        new_pars, to_fit = _prior_pars_fit(
            model.priors, model.pars.at[:].add(step), jnp.array(model.to_fit)
        )
        # Get errs
        errs = jnp.where(to_fit, jnp.sqrt(jnp.diag(invscale(curve_use))), 0)
        # Now lets get an updated model
        new_model, new_grad, new_curve = objective(new_pars, model, dataset, errs, mode)

        new_delta_chisq = model.chisq - new_model.chisq
        model, grad, curve, delta_chisq, lmd = jax.lax.cond(
            new_delta_chisq > 0,
            _success,
            _failure,
            model,
            new_model,
            grad,
            new_grad,
            curve,
            new_curve,
            delta_chisq,
            new_delta_chisq,
            lmd,
        )

        return (i + 1, delta_chisq, lmd, model, curve, grad)

    pars, _ = _prior_pars_fit(model.priors, model.pars, jnp.array(model.to_fit))
    model, grad, curve = objective(pars, model, dataset, model.errs, mode)
    i, delta_chisq, _, model, *_ = jax.lax.while_loop(
        _cond_func, _body_func, (0, jnp.inf, zero, model, curve, grad)
    )

    return model, i, delta_chisq

get_chisq(model, dataset, mode='tod')

Get the chi-squared of a model given data. This is an MPI aware function.

Parameters:

Name Type Description Default
model Model

The model object we are using to fit.

required
dataset TODVec | SolutionSet

The data to fit against. This is what we use to compute our fit residuals.

required
mode str

The type of data we compile this function for. Should be either "tod" or "map".

"tod"

Returns:

Name Type Description
chisq Array

The chi-squared of the model.

Source code in witch/fitting.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
@partial(jax.jit, static_argnames=("mode",))
def get_chisq(
    model: Model, dataset: TODVec | SolutionSet, mode: str = "tod"
) -> jax.Array:
    """
    Get the chi-squared of a model given data.
    This is an MPI aware function.

    Parameters
    ----------
    model : Model
        The model object we are using to fit.
    dataset : TODVec | SolutionSet
        The data to fit against.
        This is what we use to compute our fit residuals.
    mode : str, default: "tod"
        The type of data we compile this function for.
        Should be either "tod" or "map".

    Returns
    -------
    chisq : jax.Array
        The chi-squared of the model.
    """
    if mode not in ["tod", "map"]:
        raise ValueError("Invalid mode")
    token = mpi4jax.barrier(comm=dataset.comm)
    if mode not in ["tod", "map"]:
        raise ValueError("Invalid mode")
    chisq = jnp.array(0)
    for data in dataset:
        if mode == "tod":
            x = data.x * wu.rad_to_arcsec
            y = data.y * wu.rad_to_arcsec
            pred_dat = model.to_tod(x, y)
        else:
            x, y = data.xy
            pred_dat = model.to_map(x * wu.rad_to_arcsec, y * wu.rad_to_arcsec)

        resid = data.data - pred_dat
        resid_filt = data.noise.apply_noise(resid)
        chisq += jnp.sum(resid * resid_filt)

    chisq, token = mpi4jax.allreduce(chisq, MPI.SUM, comm=dataset.comm, token=token)
    _ = mpi4jax.barrier(comm=dataset.comm, token=token)

    return chisq

get_grad(model, dataset, mode='tod')

Get the gradient of chi-squared of a model given a set of TODs. This is an MPI aware function.

Parameters:

Name Type Description Default
model Model

The model object we are using to fit.

required
dataset TODVec | SolutionSet

The data to fit against. This is what we use to compute our fit residuals.

required
mode str

The type of data we compile this function for. Should be either "tod" or "map".

"tod"

Returns:

Name Type Description
grad Array

The gradient of the parameters at there current values. This is a (npar,) array.

Source code in witch/fitting.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
@partial(jax.jit, static_argnames=("mode",))
def get_grad(model: Model, dataset: TODVec, mode: str = "tod") -> jax.Array:
    """
    Get the gradient of chi-squared of a model given a set of TODs.
    This is an MPI aware function.

    Parameters
    ----------
    model : Model
        The model object we are using to fit.
    dataset : TODVec | SolutionSet
        The data to fit against.
        This is what we use to compute our fit residuals.
    mode : str, default: "tod"
        The type of data we compile this function for.
        Should be either "tod" or "map".

    Returns
    -------
    grad : jax.Array
        The gradient of the parameters at there current values.
        This is a `(npar,)` array.
    """
    if mode not in ["tod", "map"]:
        raise ValueError("Invalid mode")
    token = mpi4jax.barrier(comm=dataset.comm)
    npar = len(model.pars)
    grad = jnp.zeros(npar)
    for data in dataset:
        if mode == "tod":
            x = data.x * wu.rad_to_arcsec
            y = data.y * wu.rad_to_arcsec
            pred_dat, grad_dat = model.to_tod_grad(x, y)
        else:
            x, y = data.xy
            pred_dat, grad_dat = model.to_map_grad(
                x * wu.rad_to_arcsec, y * wu.rad_to_arcsec
            )

        resid = data.data - pred_dat

        grad_filt = jnp.zeros_like(grad_dat)
        for i in range(npar):
            grad_filt = grad_filt.at[i].set(
                data.noise.apply_noise(grad_dat.at[i].get())
            )
        grad_filt = jnp.reshape(grad_filt, (npar, -1))
        grad_dat = jnp.reshape(grad_dat, (npar, -1))
        resid = resid.ravel()

        grad = grad.at[:].add(jnp.dot(grad_filt, jnp.transpose(resid)))

    grad, token = mpi4jax.allreduce(grad, MPI.SUM, comm=dataset.comm, token=token)
    _ = mpi4jax.barrier(comm=dataset.comm, token=token)

    return grad

hmc(params, log_prob, log_prob_grad, num_steps, num_leaps, step_size, comm, key)

Runs Hamilonian Monte Carlo using a leapfrog integrator to approximate Hamilonian dynamics. This is a naive implementaion that will be replaced in the future.

The parallelism model employed here is different that most samplers where each task runs a subset of the chain, instead since the rest of WITCH employs a model where the data is distributed across tasks we do that here as well. In this model the chain evolves simultaneously in all tasks, but only rank 0 actually stores the chain.

Parameters:

Name Type Description Default
params Array

The initial parameters to start the chain at.

required
log_prob Callable[[Array], Array]

Function that returns the log probability of the model for a given set of params. This should take params as its first arguments, all other arguments should be fixed ahead of time (ie: using functools.partial).

required
log_prob_grad Callable[[Array], Array]

Function that returns the gradient log probability of the model for a given set of params. This should take params as its first arguments, all other arguments should be fixed ahead of time (ie: using functools.partial). The returned gradient should have shape (len(params),).

required
num_steps int

The number of steps to run the chain for.

required
num_leaps int

The number of leapfrog steps to run at each step of the chain.

required
step_size float

The step size to use. At each leapfrog step the parameters will evolve by step_size*momentum.

required
comm Intracomm

The MPI comm object to use.

required

Returns:

Name Type Description
chain Array

The chain of samples. Will have shape (num_steps, len(params)) in the rank 0 task. Note that on tasks with rank other than 0 the actual chain is not returned, instead a dummy array of size (0,) is returned.

Source code in witch/fitting.py
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
def hmc(
    params: jax.Array,
    log_prob: Callable[[jax.Array], jax.Array],
    log_prob_grad: Callable[[jax.Array], jax.Array],
    num_steps: int,
    num_leaps: int,
    step_size: float,
    comm: MPI.Intracomm,
    key: jax.Array,
) -> jax.Array:
    """
    Runs Hamilonian Monte Carlo using a leapfrog integrator to approximate Hamilonian dynamics.
    This is a naive implementaion that will be replaced in the future.

    The parallelism model employed here is different that most samplers where each task runs
    a subset of the chain, instead since the rest of WITCH employs a model where the data is
    distributed across tasks we do that here as well.
    In this model the chain evolves simultaneously in all tasks,
    but only rank 0 actually stores the chain.

    Parameters
    ----------
    params : jax.Array
        The initial parameters to start the chain at.
    log_prob : Callable[[jax.Array], jax.Array]
        Function that returns the log probability of the model
        for a given set of params. This should take `params` as its
        first arguments, all other arguments should be fixed ahead of time
        (ie: using `functools.partial`).
    log_prob_grad : Callable[[jax.Array], jax.Array]
        Function that returns the gradient log probability of the model
        for a given set of params. This should take `params` as its
        first arguments, all other arguments should be fixed ahead of time
        (ie: using `functools.partial`). The returned gradient should have
        shape `(len(params),)`.
    num_steps : int
        The number of steps to run the chain for.
    num_leaps : int
        The number of leapfrog steps to run at each step of the chain.
    step_size : float
        The step size to use.
        At each leapfrog step the parameters will evolve by `step_size`*`momentum`.
    comm : MPI.Intracomm
        The MPI comm object to use.

    Returns
    -------
    chain : jax.Array
        The chain of samples.
        Will have shape `(num_steps, len(params))` in the rank 0 task.
        Note that on tasks with rank other than 0 the actual
        chain is not returned, instead a dummy array of size `(0,)` is
        returned.
    """
    rank = comm.Get_rank()
    vnorm = jax.vmap(
        partial(jax.random.normal, shape=params[0].shape, dtype=params.dtype)
    )
    npar = len(params)
    ones = jnp.ones(npar, dtype=bool)

    @jax.jit
    def _leap(_, args):
        params, momentum = args
        momentum = momentum.at[:].add(0.5 * step_size * log_prob_grad(params))  # kick
        params = params.at[:].add(step_size * momentum)  # drift
        momentum = momentum.at[:].add(0.5 * step_size * log_prob_grad(params))  # kick

        return params, momentum

    @jax.jit
    def _sample(key, params):
        token = mpi4jax.barrier(comm=comm)
        key, token = mpi4jax.bcast(key, 0, comm=comm, token=token)
        key, uniform_key = jax.random.split(key, 2)

        # generate random momentum
        momentum = vnorm(jax.random.split(key, npar))
        new_params, new_momentum = jax.lax.fori_loop(
            0, num_leaps, _leap, (params, momentum)
        )

        # MH correction
        dpe = log_prob(new_params) - log_prob(params)
        dke = -0.5 * (jnp.sum(new_momentum**2) - jnp.sum(momentum**2))
        log_accept = dke + dpe
        accept_prob = jnp.minimum(jnp.exp(log_accept), 1)
        accept = jax.random.uniform(uniform_key) < accept_prob
        params = jax.lax.select(accept * ones, new_params, params)

        return key, params, accept_prob

    t0 = time.time()
    l_sample = _sample.lower(key, params)
    c_sample = l_sample.compile()
    t1 = time.time()
    if rank == 0:
        print(f"Compiled MC sample function in {t1-t0} s")

    chain = []
    accept_prob = []
    for _ in tqdm(range(num_steps), disable=(rank != 0)):
        key, params, prob = c_sample(key, params)
        if rank == 0:
            chain += [params]
            accept_prob += [prob]
    if rank == 0:
        chain = jnp.vstack(chain)
        accept_prob = jnp.array(accept_prob)
        print(f"Accepted {accept_prob.mean():.2%} of samples")
    else:
        chain = jnp.zeros(0)
    return chain

invsafe(matrix, thresh=1e-14)

Safe SVD based psuedo-inversion of the matrix. This zeros out modes that are too small when inverting. Use with caution in cases where you really care about what the inverse is.

Parameters:

Name Type Description Default
matrix Array

The matrix to invert. Should be a (n, n) array.

required
thresh float

Threshold at which to zero out a mode.

1e-14

Returns:

Name Type Description
invmat Array

The inverted matrix. Same shape as matrix.

Source code in witch/fitting.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
@jax.jit
def invsafe(matrix: jax.Array, thresh: float = 1e-14) -> jax.Array:
    """
    Safe SVD based psuedo-inversion of the matrix.
    This zeros out modes that are too small when inverting.
    Use with caution in cases where you really care about what the inverse is.

    Parameters
    ----------
    matrix : jax.Array
        The matrix to invert.
        Should be a `(n, n)` array.
    thresh : float, default: 1e-14
        Threshold at which to zero out a mode.

    Returns
    -------
    invmat: jax.Array
        The inverted matrix.
        Same shape as `matrix`.
    """
    u, s, v = jnp.linalg.svd(matrix, False)
    s_inv = jnp.array(jnp.where(jnp.abs(s) < thresh * jnp.max(s), 0, 1 / s))

    return jnp.dot(jnp.transpose(v), jnp.dot(jnp.diag(s_inv), jnp.transpose(u)))

invscale(matrix, thresh=1e-14)

Invert and rescale a matrix by the diagonal. This uses invsafe for the inversion.

Parameters:

Name Type Description Default
Parameters
required
matrix Array

The matrix to invert and sxane. Should be a (n, n) array.

required
thresh float

Threshold for invsafe. See that function for more info.

1e-14

Returns:

Name Type Description
invmat Array

The inverted and rescaled matrix. Same shape as matrix.

Source code in witch/fitting.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
@jax.jit
def invscale(matrix: jax.Array, thresh: float = 1e-14) -> jax.Array:
    """
    Invert and rescale a matrix by the diagonal.
    This uses `invsafe` for the inversion.

    Parameters
    ----------
    Parameters
    ----------
    matrix : jax.Array
        The matrix to invert and sxane.
        Should be a `(n, n)` array.
    thresh : float, default: 1e-14
        Threshold for `invsafe`.
        See that function for more info.

    Returns
    -------
    invmat: jax.Array
        The inverted and rescaled matrix.
        Same shape as `matrix`.
    """
    diag = jnp.diag(matrix)
    vec = jnp.array(jnp.where(diag != 0, 1.0 / jnp.sqrt(jnp.abs(diag)), 1e-10))
    mm = jnp.outer(vec, vec)

    return mm * invsafe(mm * matrix, thresh)

objective(pars, model, dataset, errs, mode='tod')

Objective function to minimize when fitting. This is also responsible for updating our model with the current guess. This is an MPI aware function.

Parameters:

Name Type Description Default
pars Array

New parameters for our model.

required
model Model

The model object we are using to fit.

required
dataset TODVec | SolutionSet

The data to fit against. This is what we use to compute our fit residuals.

required
errs Array

The error on pars, used to update the model state.

required
mode str

The type of data we compile this function for. Should be either "tod" or "map".

"tod"

Returns:

Name Type Description
new_model Model

An updated model object. This contains the newly computed chisq for the input pars. Also is updated with input pars and errs.

grad Array

The gradient of the parameters at there current values. This is a (npar,) array.

curve Array

The curvature of the parameter space at the current values. This is a (npar, npar) array.

Source code in witch/fitting.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
@partial(jax.jit, static_argnames=("mode",))
def objective(
    pars: jax.Array,
    model: Model,
    dataset: TODVec | SolutionSet,
    errs: jax.Array,
    mode: str = "tod",
) -> tuple[Model, jax.Array, jax.Array]:
    """
    Objective function to minimize when fitting.
    This is also responsible for updating our model with the current guess.
    This is an MPI aware function.

    Parameters
    ----------
    pars : jax.Array
        New parameters for our model.
    model : Model
        The model object we are using to fit.
    dataset: TODVec | SolutionSet
        The data to fit against.
        This is what we use to compute our fit residuals.
    errs : jax.Array
        The error on `pars`, used to update the model state.
    mode : str, default: "tod"
        The type of data we compile this function for.
        Should be either "tod" or "map".

    Returns
    -------
    new_model : Model
        An updated model object.
        This contains the newly computed `chisq` for the input `pars`.
        Also is updated with input `pars` and `errs`.
    grad : jax.Array
        The gradient of the parameters at there current values.
        This is a `(npar,)` array.
    curve : jax.Array
        The curvature of the parameter space at the current values.
        This is a `(npar, npar)` array.
    """
    if mode not in ["tod", "map"]:
        raise ValueError("Invalid mode")
    npar = len(pars)
    new_model = model.update(pars, errs, model.chisq)
    chisq = jnp.array(0)
    grad = jnp.zeros(npar)
    curve = jnp.zeros((npar, npar))
    for data in dataset:
        if mode == "tod":
            x = data.x * wu.rad_to_arcsec
            y = data.y * wu.rad_to_arcsec
            pred_dat, grad_dat = new_model.to_tod_grad(x, y)
        else:
            x, y = data.xy
            pred_dat, grad_dat = new_model.to_map_grad(
                x * wu.rad_to_arcsec, y * wu.rad_to_arcsec
            )

        resid = data.data - pred_dat
        resid_filt = data.noise.apply_noise(resid)
        chisq += jnp.sum(resid * resid_filt)

        grad_filt = jnp.zeros_like(grad_dat)
        for i in range(npar):
            grad_filt = grad_filt.at[i].set(
                data.noise.apply_noise(grad_dat.at[i].get())
            )
        grad_filt = jnp.reshape(grad_filt, (npar, -1))
        grad_dat = jnp.reshape(grad_dat, (npar, -1))
        resid = resid.ravel()

        grad = grad.at[:].add(jnp.dot(grad_filt, jnp.transpose(resid)))
        curve = curve.at[:].add(jnp.dot(grad_filt, jnp.transpose(grad_dat)))

    chisq, token = mpi4jax.allreduce(chisq, MPI.SUM, comm=dataset.comm)
    grad, token = mpi4jax.allreduce(grad, MPI.SUM, comm=dataset.comm, token=token)
    curve, _ = mpi4jax.allreduce(curve, MPI.SUM, comm=dataset.comm, token=token)

    new_model = new_model.update(pars, errs, chisq)

    return new_model, grad, curve

run_mcmc(model, dataset, num_steps=5000, num_leaps=10, step_size=0.02, sample_which=-1, mode='tod')

Run MCMC using the emcee package to estimate the posterior for our model. Currently this function only support flat priors, but more will be supported down the line. In order to ensure accuracy of the noise model used, it is reccomended that you run at least one round of fit_tods followed by noise reestimation before this function.

This is MPI aware. Eventually this will be replaced with something more jaxy.

Parameters:

Name Type Description Default
model Model

The model to run MCMC on. We expect that all parameters in this model have priors defined.

required
dataset TODVec | SolutionSet

The data to compute the likelihood of the model with.

required
num_steps int

The number of steps to run MCMC for.

5000
num_leaps int

The number of leapfrog steps to take at each sample.

10
step_size float

The step size to use in the leapfrog algorithm. This should be tuned to get an acceptance fraction of ~.65.

0.02
default float

The step size to use in the leapfrog algorithm. This should be tuned to get an acceptance fraction of ~.65.

0.02
sample_which int

Sets which parameters to sample. If this is >= 0 then we will sample which ever parameters were fit in that round of fitting. If this is -1 then we will sample which ever parameters were fit in the last round of fitting. If this is -2 then any parameters that were ever fit will be sampled. If this is <= -3 or >= model.n_rounds then all parameters are sampled.

-1,
mode str

The type of data to run this function on. Should be either "tod" or "map".

"tod"

Returns:

Name Type Description
model Model

The model with MCMC estimated parameters and errors. The parameters are estimated as the mean of the samples. The errors are estimated as the standard deviation. This also has the chi-squared of the estimated parameters.

flat_samples Array

Array of samples from running MCMC.

Source code in witch/fitting.py
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
def run_mcmc(
    model: Model,
    dataset: TODVec | SolutionSet,
    num_steps: int = 5000,
    num_leaps: int = 10,
    step_size: float = 0.02,
    sample_which: int = -1,
    mode: str = "tod",
) -> tuple[Model, jax.Array]:
    """
    Run MCMC using the `emcee` package to estimate the posterior for our model.
    Currently this function only support flat priors, but more will be supported
    down the line. In order to ensure accuracy of the noise model used, it is
    reccomended that you run at least one round of `fit_tods` followed by noise
    reestimation before this function.

    This is MPI aware.
    Eventually this will be replaced with something more jaxy.

    Parameters
    ----------
    model : Model
        The model to run MCMC on.
        We expect that all parameters in this model have priors defined.
    dataset : TODVec | SolutionSet
        The data to compute the likelihood of the model with.
    num_steps : int, default: 5000
        The number of steps to run MCMC for.
    num_leaps: int, default: 10
        The number of leapfrog steps to take at each sample.
    step_size, default: .02
        The step size to use in the leapfrog algorithm.
        This should be tuned to get an acceptance fraction of ~.65.
    sample_which : int, default: -1,
        Sets which parameters to sample.
        If this is >= 0 then we will sample which ever parameters were
        fit in that round of fitting.
        If this is -1 then we will sample which ever parameters were fit
        in the last round of fitting.
        If this is -2 then any parameters that were ever fit will be sampled.
        If this is <= -3 or >= `model.n_rounds` then all parameters are sampled.
    mode : str, default: "tod"
        The type of data to run this function on.
        Should be either "tod" or "map".

    Returns
    -------
    model : Model
        The model with MCMC estimated parameters and errors.
        The parameters are estimated as the mean of the samples.
        The errors are estimated as the standard deviation.
        This also has the chi-squared of the estimated parameters.
    flat_samples : jax.Array
        Array of samples from running MCMC.

    """
    token = mpi4jax.barrier(comm=dataset.comm)
    rank = dataset.comm.Get_rank()

    if sample_which >= 0 and sample_which < model.n_rounds:
        model.cur_round = sample_which
        to_fit = model.to_fit
    elif sample_which == -1:
        to_fit = model.to_fit
    elif sample_which == -2:
        to_fit = model.to_fit_ever
    else:
        to_fit = jnp.ones_like(model.to_fit_ever, dtype=bool)
    to_fit = jnp.array(to_fit)
    model = model.add_round(jnp.array(to_fit))

    init_pars = jnp.array(model.pars)
    init_errs = jnp.zeros_like(model.pars)
    final_pars = init_pars.copy()
    final_errs = init_errs.copy()

    prior_l, prior_u = model.priors
    scale = (jnp.abs(prior_l) + jnp.abs(prior_u)) / 2.0
    scale = jnp.where(scale == 0, 1, scale)
    init_pars = init_pars.at[:].multiply(1.0 / scale)
    npar = jnp.sum(to_fit)

    def _is_inf(pars, model):
        _ = (pars, model)
        return -1 * jnp.inf

    def _not_inf(pars, model):
        pars, _ = mpi4jax.bcast(pars, 0, comm=dataset.comm)
        temp_model = model.update(pars, init_errs, model.chisq)
        log_like = -0.5 * get_chisq(temp_model, dataset, mode)
        return log_like

    @jax.jit
    def _log_prob(pars, model=model, init_pars=init_pars):
        full_pars = init_pars.at[to_fit].set(pars)
        full_pars = full_pars.at[:].multiply(scale)
        _, in_bounds = _prior_pars_fit(model.priors, full_pars, jnp.array(model.to_fit))
        log_prior = jnp.sum(
            jnp.where(in_bounds.at[model.to_fit].get(), 0, -1 * jnp.inf)
        )
        return jax.lax.cond(
            jnp.isfinite(log_prior),
            _not_inf,
            _is_inf,
            full_pars,
            model,
        )

    def _is_inf_grad(pars, model, scale):
        _ = (pars, model, scale)
        return jnp.inf * jnp.ones(npar)

    def _not_inf_grad(pars, model, scale):
        pars, _ = mpi4jax.bcast(pars, 0, comm=dataset.comm)
        temp_model = model.update(pars, init_errs, model.chisq)
        grad = get_grad(temp_model, dataset, mode)
        grad = grad.at[:].multiply(scale)
        return grad.at[to_fit].get().ravel()

    @jax.jit
    def _log_prob_grad(pars, model=model, init_pars=init_pars):
        full_pars = init_pars.at[to_fit].set(pars)
        full_pars = full_pars.at[:].multiply(scale)
        _, in_bounds = _prior_pars_fit(model.priors, full_pars, jnp.array(model.to_fit))
        log_prior = jnp.sum(jnp.where(in_bounds.at[to_fit].get(), 0, -1 * jnp.inf))
        return jax.lax.cond(
            jnp.isfinite(log_prior),
            _not_inf_grad,
            _is_inf_grad,
            full_pars,
            model,
            scale,
        )

    key = jax.random.PRNGKey(0)
    key, token = mpi4jax.bcast(key, 0, comm=dataset.comm, token=token)
    chain = hmc(
        init_pars.at[to_fit].get().ravel(),
        _log_prob,
        _log_prob_grad,
        num_steps=num_steps,
        num_leaps=num_leaps,
        step_size=step_size,
        comm=dataset.comm,
        key=key,
    )
    flat_samples = chain.at[:].multiply(scale.at[to_fit].get())
    if rank == 0:
        final_pars = final_pars.at[to_fit].set(jnp.median(flat_samples, axis=0).ravel())
        final_errs = final_errs.at[to_fit].set(jnp.std(flat_samples, axis=0).ravel())
    final_pars, token = mpi4jax.bcast(final_pars, 0, comm=dataset.comm, token=token)
    final_errs, _ = mpi4jax.bcast(final_errs, 0, comm=dataset.comm, token=token)
    model = model.update(
        final_pars.block_until_ready(), final_errs.block_until_ready(), model.chisq
    )
    chisq = get_chisq(model, dataset, mode)
    model = model.update(final_pars, final_errs, chisq)

    return model, flat_samples