Skip to content

containers

Data classes for describing models in a structured way.

Model dataclass

Dataclass to describe a model. This includes some caching features that improve performance when fitting. Note that because of the caching dynamically modifying what structures compose the model may not work as intended so beware.

Attributes:

Name Type Description
name str

The name of the model. This is used for display purposes only.

structures list[Structure]

The structures that compose the model. Will be sorted to match core.ORDER once initialized.

xyz tuple[Array, Array, Array, float, float]

Defines the grid used by model computation. The first three elements are a sparse 3D grid in arcseconds, with the first two elements being RA and Dec respectively and the third element being the LOS. The last two elements are the model center in Ra and Dec (also in arcseconds). The structure functions use this as the coordinate to reference dx and dy to.

dz float

The LOS integration factor. Should minimally be the pixel size in arcseconds along the LOS, but can also include additional factors for performing unit conversions.

beam Array

The beam to convolve the model with.

n_rounds int

How many rounds of fitting to perform.

cur_round int, default: 0

Which round of fitting we are currently in, rounds are 0 indexed.

chisq float, default: np.inf

The chi-squared of this model relative to some data. Used when fitting.

original_order list[int]

The original order than the structures in structures were inputted.

Source code in witch/containers.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
@register_pytree_node_class
@dataclass
class Model:
    """
    Dataclass to describe a model.
    This includes some caching features that improve performance when fitting.
    Note that because of the caching dynamically modifying what structures compose
    the model may not work as intended so beware.

    Attributes
    ----------
    name : str
        The name of the model.
        This is used for display purposes only.
    structures : list[Structure]
        The structures that compose the model.
        Will be sorted to match `core.ORDER` once initialized.
    xyz : tuple[jax.Array, jax.Array, jax.Array, float, float]
        Defines the grid used by model computation.
        The first three elements are a sparse 3D grid in arcseconds,
        with the first two elements being RA and Dec respectively and the third
        element being the LOS. The last two elements are the model center in Ra and Dec
        (also in arcseconds). The structure functions use this as the coordinate to reference
        `dx` and `dy` to.
    dz : float
        The LOS integration factor.
        Should minimally be the pixel size in arcseconds along the LOS,
        but can also include additional factors for performing unit conversions.
    beam : jax.Array
        The beam to convolve the model with.
    n_rounds : int
        How many rounds of fitting to perform.
    cur_round : int, default: 0
        Which round of fitting we are currently in,
        rounds are 0 indexed.
    chisq : float, default: np.inf
        The chi-squared of this model relative to some data.
        Used when fitting.
    original_order : list[int]
        The original order than the structures in `structures` were inputted.
    """

    name: str
    structures: list[Structure]
    xyz: tuple[jax.Array, jax.Array, jax.Array, float, float]  # arcseconds
    dz: float  # arcseconds * unknown
    beam: jax.Array
    n_rounds: int
    cur_round: int = 0
    chisq: jax.Array = field(
        default_factory=jnp.array(jnp.inf).copy
    )  # scalar float array
    original_order: list[int] = field(init=False)

    def __post_init__(self):
        # Make sure the structure is in the order that core expects
        structure_idx = np.argsort(
            np.array(
                [core.ORDER.index(structure.structure) for structure in self.structures]
            )
        )
        self.structures = [self.structures[i] for i in structure_idx]
        self.original_order = list(jnp.sort(structure_idx))

    def __setattr__(self, name, value):
        if name == "cur_round" or name == "xyz":
            self.__dict__.pop("model_grad", None)
            self.__dict__.pop("model", None)
        return super().__setattr__(name, value)

    def __repr__(self) -> str:
        rep = self.name + ":\n"
        rep += f"Round {self.cur_round + 1} out of {self.n_rounds}\n"
        for i in self.original_order:
            struct = self.structures[i]
            rep += "\t" + struct.name + ":\n"
            for par in struct.parameters:
                rep += (
                    "\t\t"
                    + par.name
                    + "*" * par.fit[self.cur_round]
                    + " ["
                    + str(par.prior[0])
                    + ", "
                    + str(par.prior[1])
                    + "]"
                    + " = "
                    + str(par.val)
                    + " ± "
                    + str(par.err)
                    + " ("
                    + str(jnp.abs(par.val / par.err))
                    + " σ)"
                    + "\n"
                )
        rep += f"chisq is {self.chisq}"
        return rep

    @cached_property
    def n_struct(self) -> list[int]:
        """
        Number of each type of structures in the model.
        Note that this is cached.

        Returns
        -------
        n_struct : list[int]
            `n_struct[i]` is the number of `core.ORDER[i]`
            structures in this model.
        """
        n_struct = [0] * len(core.ORDER)
        for structure in self.structures:
            idx = core.ORDER.index(structure.structure)
            n_struct[idx] += 1
        return n_struct

    @property
    def pars(self) -> jax.Array:
        """
        Get the current parameter values.

        Returns
        -------
        pars :  jax.Array
            The parameter values in the order expected by `core.model`.
        """
        pars = []
        for structure in self.structures:
            pars += [parameter.val for parameter in structure.parameters]
        return jnp.array(pars)

    @cached_property
    def par_names(self) -> list[str]:
        """
        Get the names of all parameters.
        Note that this is cached.

        Returns
        -------
        par_names : list[str]
            Parameter names in the same order as `pars`.
        """
        par_names = []
        for structure in self.structures:
            par_names += [parameter.name for parameter in structure.parameters]
        return par_names

    @property
    def errs(self) -> jax.Array:
        """
        Get the current parameter errors.

        Returns
        -------
        errs : jax.Array
            The errors in the same order as vals.
        """
        errs = []
        for structure in self.structures:
            errs += [parameter.err for parameter in structure.parameters]
        return jnp.array(errs)

    @cached_property
    def priors(self) -> tuple[jax.Array, jax.Array]:
        """
        Get the priors for all parameters.
        Note that this is cached.

        Returns
        -------
        priors : tuple[jax.Array, jax.Array]
            Parameter priors in the same order as `pars`.
            This is a tuple with the first element being an array
            of lower bounds and the second being upper.
        """
        lower = []
        upper = []
        for structure in self.structures:
            lower += [parameter.prior[0] for parameter in structure.parameters]
            upper += [parameter.prior[1] for parameter in structure.parameters]
        priors = (jnp.array(lower), jnp.array(upper))
        return priors

    @property
    def to_fit(self) -> tuple[bool]:  # jax.Array:
        """
        Get which parameters we want to fit for the current round.

        Returns
        -------
        to_fit : jax.Array
            `to_fit[i]` is True if we want to fit the `i`'th parameter
            in the current round.
            This is in the same order as `pars`.
        """
        to_fit = []
        for structure in self.structures:
            to_fit += [
                parameter.fit[self.cur_round] for parameter in structure.parameters
            ]
        return tuple(to_fit)  # jnp.ravel(jnp.array(to_fit))

    @cached_property
    def to_fit_ever(self) -> jax.Array:
        """
        Check which parameters we ever fit.
        Note that this is cached.

        Returns
        -------
        to_fit_ever : jax.Array
            `to_fit[i]` is True if we ever want to fit the `i`'th parameter.
            This is in the same order as `pars`.
        """
        to_fit = []
        for structure in self.structures:
            to_fit += [parameter.fit_ever for parameter in structure.parameters]
        return jnp.ravel(jnp.array(to_fit))

    @cached_property
    def model(self) -> jax.Array:
        """
        The evaluated model, see `core.model` for details.
        Note that this is cached, but is automatically reset whenever
        `update` is called or `cur_round` or `xyz` changes.

        Returns
        -------
        model : jax.Array
            The model evaluted on `xyz` with the current values of `pars`.
        """
        return core.model(
            self.xyz,
            tuple(self.n_struct),
            self.dz,
            self.beam,
            *self.pars,
        )

    @cached_property
    def model_grad(self) -> tuple[jax.Array, jax.Array]:
        """
        The evaluated model and its gradient, see `core.model_grad` for details.
        Note that this is cached, but is automatically reset whenever
        `update` is called or `cur_round` changes.

        Returns
        -------
        model : jax.Array
            The model evaluted on `xyz` with the current values of `pars`.
        grad : jax.Array
            The gradient evaluted on `xyz` with the current values of `pars`.
            Has shape `(len(pars),) + model.shape`.
        """
        argnums = tuple(np.where(self.to_fit)[0] + core.ARGNUM_SHIFT)
        return core.model_grad(
            self.xyz,
            tuple(self.n_struct),
            self.dz,
            self.beam,
            argnums,
            *self.pars,
        )

    def to_tod(self, dx: ArrayLike, dy: ArrayLike) -> jax.Array:
        """
        Project the model into a TOD.

        Parameters
        ----------
        dx : ArrayLike
            The RA TOD in arcseconds.
        dy : ArrayLike
            The Dec TOD in arcseconds.

        Returns
        -------
        tod : jax.Array
            The model as a TOD.
            Same shape as dx.
        """
        return wu.bilinear_interp(
            dx, dy, self.xyz[0].ravel(), self.xyz[1].ravel(), self.model
        )

    def to_tod_grad(self, dx: ArrayLike, dy: ArrayLike) -> tuple[jax.Array, jax.Array]:
        """
        Project the model and gradient into a TOD.

        Parameters
        ----------
        dx : ArrayLike
            The RA TOD in arcseconds.
        dy : ArrayLike
            The Dec TOD in arcseconds.

        Returns
        -------
        tod : jax.Array
            The model as a TOD.
            Same shape as dx.
        grad_tod : jax.Array
            The gradient as a TOD.
            Has shape `(len(pars),) + dx.shape`.
        """
        model, grad = self.model_grad
        tod = wu.bilinear_interp(
            dx, dy, self.xyz[0].ravel(), self.xyz[1].ravel(), model
        )
        grad_tod = jnp.array(
            [
                (
                    wu.bilinear_interp(
                        dx, dy, self.xyz[0].ravel(), self.xyz[1].ravel(), _grad
                    )
                    if _fit
                    else jnp.zeros_like(tod)
                )
                for _grad, _fit in zip(grad, self.to_fit)
            ]
        )

        return tod, grad_tod

    def update(self, vals: jax.Array, errs: jax.Array, chisq: jax.Array) -> Self:
        """
        Update the parameter values and errors as well as the model chi-squared.
        This also resets the cache on `model` and `model_grad`
        if `vals` is different than `self.pars`.

        Parameters
        ----------
        vals : jax.Array
            The new parameter values.
            Should be in the same order as `pars`.
        errs : jax.Array
            The new parameter errors.
            Should be in the same order as `pars`.
        chisq : jax.Array
            The new chi-squared.
            Should be a scalar float array.

        Returns
        -------
        updated : Model
            The updated model.
            While nominally the model will update in place, returning it
            alows us to use this function in JITed functions.
        """
        if not np.array_equal(self.pars, vals):
            self.__dict__.pop("model", None)
            self.__dict__.pop("model_grad", None)
        n = 0
        for struct in self.structures:
            for par in struct.parameters:
                par.val = vals[n]
                par.err = errs[n]
                n += 1
        self.chisq = chisq

        return self

    def remove_struct(self, struct_name: str):
        """
        Remove structure by name.

        Parameters
        ----------
        struct_name : str
            Name of struct to be removed.
        """
        n = None
        for i, structure in enumerate(self.structures):
            if str(structure.name) == str(struct_name):
                n = i
        if type(n) == int:
            self.structures.pop(n)
        else:
            raise ValueError("Error: {} not in structure names".format(struct_name))

        self.__dict__.pop("to_fit_ever")
        self.__dict__.pop("n_struct")
        self.__dict__.pop("priors")
        self.__dict__.pop("par_names")
        self.__dict__.pop("model")
        self.__post_init__()

    def save(self, path: str):
        """
        Serialize the model to a file with dill.

        Parameters
        ----------
        path : str
            The file to save to.
            Does not check to see if the path is valid.
        """
        with open(path, "wb") as f:
            dill.dump(self, f)

    @classmethod
    def load(cls, path: str) -> Self:
        """
        Load the model from a file with dill.

        Parameters
        ----------
        path : str
            The path to the saved model.
            Does not check to see if the path is valid.

        Returns
        -------
        model : Model
            The loaded model.
        """
        with open(path, "rb") as f:
            return dill.load(f)

    @classmethod
    def from_cfg(cls, cfg: dict, beam: Optional[jax.Array] = None) -> Self:
        """
        Create an instance of model from a witcher config.

        Parameters
        ----------
        cfg : dict
            The config loaded into a dict.

        beam : Optional[Array], default: None

        Returns
        -------
        model : Model
            The model described by the config.
        """
        # Do imports
        for module, name in cfg.get("imports", {}).items():
            mod = import_module(module)
            if isinstance(name, str):
                locals()[name] = mod
            elif isinstance(name, list):
                for n in name:
                    locals()[n] = getattr(mod, n)
            else:
                raise TypeError("Expect import name to be a string or a list")

        # Load constants
        constants = {
            name: eval(str(const)) for name, const in cfg.get("constants", {}).items()
        }  # pyright: ignore [reportUnusedVariable]

        # Get jax device
        dev_id = cfg.get("jax_device", 0)
        device = jax.devices()[dev_id]

        # Setup coordindate stuff
        r_map = eval(str(cfg["coords"]["r_map"]))
        dr = eval(str(cfg["coords"]["dr"]))
        dz = eval(str(cfg["coords"].get("dz", dr)))
        x0 = eval(str(cfg["coords"]["x0"]))
        y0 = eval(str(cfg["coords"]["y0"]))

        xyz_host = wg.make_grid(
            r_map, dr, dr, dz, x0 * wg.rad_to_arcsec, y0 * wg.rad_to_arcsec
        )
        xyz = jax.device_put(xyz_host, device)
        xyz[0].block_until_ready()
        xyz[1].block_until_ready()
        xyz[2].block_until_ready()

        # Make beam
        if beam is None:
            beam = jnp.ones((1, 1))
        beam = jax.device_put(beam, device)
        if beam is None:
            raise ValueError("Beam somehow still None!")

        n_rounds = cfg.get("n_rounds", 1)
        dz = dz * eval(str(cfg["model"]["unit_conversion"]))

        structures = []
        for name, structure in cfg["model"]["structures"].items():
            parameters = []
            for par_name, param in structure["parameters"].items():
                val = eval(str(param["value"]))
                fit = param.get("to_fit", [False] * n_rounds)
                if isinstance(fit, bool):
                    fit = [fit] * n_rounds
                if len(fit) != n_rounds:
                    raise ValueError(
                        f"to_fit has {len(fit)} entries but we only have {n_rounds} rounds"
                    )
                priors = param.get("priors", None)
                if priors is not None:
                    priors = eval(str(priors))
                else:
                    priors = (-1 * np.inf, np.inf)
                parameters.append(
                    Parameter(
                        par_name,
                        tuple(fit),
                        jnp.array(val, dtype=float),
                        jnp.array(0.0, dtype=float),
                        jnp.array(priors, dtype=float),
                    )
                )
            structures.append(Structure(name, structure["structure"], parameters))
        name = cfg["model"].get(
            "name", "-".join([structure.name for structure in structures])
        )

        return cls(name, structures, xyz, dz, beam, n_rounds)

    # Functions for making this a pytree
    # Don't call this on your own
    def tree_flatten(self) -> tuple[tuple, tuple]:
        children = (tuple(self.structures), self.xyz, self.dz, self.beam, self.chisq)
        aux_data = (
            self.name,
            self.n_rounds,
            self.cur_round,
        )

        return (children, aux_data)

    @classmethod
    def tree_unflatten(cls, aux_data, children) -> Self:
        name, n_rounds, cur_round = aux_data
        structures, xyz, dz, beam, chisq = children

        return cls(name, list(structures), xyz, dz, beam, n_rounds, cur_round, chisq)

errs property

Get the current parameter errors.

Returns:

Name Type Description
errs Array

The errors in the same order as vals.

model cached property

The evaluated model, see core.model for details. Note that this is cached, but is automatically reset whenever update is called or cur_round or xyz changes.

Returns:

Name Type Description
model Array

The model evaluted on xyz with the current values of pars.

model_grad cached property

The evaluated model and its gradient, see core.model_grad for details. Note that this is cached, but is automatically reset whenever update is called or cur_round changes.

Returns:

Name Type Description
model Array

The model evaluted on xyz with the current values of pars.

grad Array

The gradient evaluted on xyz with the current values of pars. Has shape (len(pars),) + model.shape.

n_struct cached property

Number of each type of structures in the model. Note that this is cached.

Returns:

Name Type Description
n_struct list[int]

n_struct[i] is the number of core.ORDER[i] structures in this model.

par_names cached property

Get the names of all parameters. Note that this is cached.

Returns:

Name Type Description
par_names list[str]

Parameter names in the same order as pars.

pars property

Get the current parameter values.

Returns:

Name Type Description
pars Array

The parameter values in the order expected by core.model.

priors cached property

Get the priors for all parameters. Note that this is cached.

Returns:

Name Type Description
priors tuple[Array, Array]

Parameter priors in the same order as pars. This is a tuple with the first element being an array of lower bounds and the second being upper.

to_fit property

Get which parameters we want to fit for the current round.

Returns:

Name Type Description
to_fit Array

to_fit[i] is True if we want to fit the i'th parameter in the current round. This is in the same order as pars.

to_fit_ever cached property

Check which parameters we ever fit. Note that this is cached.

Returns:

Name Type Description
to_fit_ever Array

to_fit[i] is True if we ever want to fit the i'th parameter. This is in the same order as pars.

from_cfg(cfg, beam=None) classmethod

Create an instance of model from a witcher config.

Parameters:

Name Type Description Default
cfg dict

The config loaded into a dict.

required
beam Optional[Array]
None

Returns:

Name Type Description
model Model

The model described by the config.

Source code in witch/containers.py
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
@classmethod
def from_cfg(cls, cfg: dict, beam: Optional[jax.Array] = None) -> Self:
    """
    Create an instance of model from a witcher config.

    Parameters
    ----------
    cfg : dict
        The config loaded into a dict.

    beam : Optional[Array], default: None

    Returns
    -------
    model : Model
        The model described by the config.
    """
    # Do imports
    for module, name in cfg.get("imports", {}).items():
        mod = import_module(module)
        if isinstance(name, str):
            locals()[name] = mod
        elif isinstance(name, list):
            for n in name:
                locals()[n] = getattr(mod, n)
        else:
            raise TypeError("Expect import name to be a string or a list")

    # Load constants
    constants = {
        name: eval(str(const)) for name, const in cfg.get("constants", {}).items()
    }  # pyright: ignore [reportUnusedVariable]

    # Get jax device
    dev_id = cfg.get("jax_device", 0)
    device = jax.devices()[dev_id]

    # Setup coordindate stuff
    r_map = eval(str(cfg["coords"]["r_map"]))
    dr = eval(str(cfg["coords"]["dr"]))
    dz = eval(str(cfg["coords"].get("dz", dr)))
    x0 = eval(str(cfg["coords"]["x0"]))
    y0 = eval(str(cfg["coords"]["y0"]))

    xyz_host = wg.make_grid(
        r_map, dr, dr, dz, x0 * wg.rad_to_arcsec, y0 * wg.rad_to_arcsec
    )
    xyz = jax.device_put(xyz_host, device)
    xyz[0].block_until_ready()
    xyz[1].block_until_ready()
    xyz[2].block_until_ready()

    # Make beam
    if beam is None:
        beam = jnp.ones((1, 1))
    beam = jax.device_put(beam, device)
    if beam is None:
        raise ValueError("Beam somehow still None!")

    n_rounds = cfg.get("n_rounds", 1)
    dz = dz * eval(str(cfg["model"]["unit_conversion"]))

    structures = []
    for name, structure in cfg["model"]["structures"].items():
        parameters = []
        for par_name, param in structure["parameters"].items():
            val = eval(str(param["value"]))
            fit = param.get("to_fit", [False] * n_rounds)
            if isinstance(fit, bool):
                fit = [fit] * n_rounds
            if len(fit) != n_rounds:
                raise ValueError(
                    f"to_fit has {len(fit)} entries but we only have {n_rounds} rounds"
                )
            priors = param.get("priors", None)
            if priors is not None:
                priors = eval(str(priors))
            else:
                priors = (-1 * np.inf, np.inf)
            parameters.append(
                Parameter(
                    par_name,
                    tuple(fit),
                    jnp.array(val, dtype=float),
                    jnp.array(0.0, dtype=float),
                    jnp.array(priors, dtype=float),
                )
            )
        structures.append(Structure(name, structure["structure"], parameters))
    name = cfg["model"].get(
        "name", "-".join([structure.name for structure in structures])
    )

    return cls(name, structures, xyz, dz, beam, n_rounds)

load(path) classmethod

Load the model from a file with dill.

Parameters:

Name Type Description Default
path str

The path to the saved model. Does not check to see if the path is valid.

required

Returns:

Name Type Description
model Model

The loaded model.

Source code in witch/containers.py
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
@classmethod
def load(cls, path: str) -> Self:
    """
    Load the model from a file with dill.

    Parameters
    ----------
    path : str
        The path to the saved model.
        Does not check to see if the path is valid.

    Returns
    -------
    model : Model
        The loaded model.
    """
    with open(path, "rb") as f:
        return dill.load(f)

remove_struct(struct_name)

Remove structure by name.

Parameters:

Name Type Description Default
struct_name str

Name of struct to be removed.

required
Source code in witch/containers.py
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
def remove_struct(self, struct_name: str):
    """
    Remove structure by name.

    Parameters
    ----------
    struct_name : str
        Name of struct to be removed.
    """
    n = None
    for i, structure in enumerate(self.structures):
        if str(structure.name) == str(struct_name):
            n = i
    if type(n) == int:
        self.structures.pop(n)
    else:
        raise ValueError("Error: {} not in structure names".format(struct_name))

    self.__dict__.pop("to_fit_ever")
    self.__dict__.pop("n_struct")
    self.__dict__.pop("priors")
    self.__dict__.pop("par_names")
    self.__dict__.pop("model")
    self.__post_init__()

save(path)

Serialize the model to a file with dill.

Parameters:

Name Type Description Default
path str

The file to save to. Does not check to see if the path is valid.

required
Source code in witch/containers.py
522
523
524
525
526
527
528
529
530
531
532
533
def save(self, path: str):
    """
    Serialize the model to a file with dill.

    Parameters
    ----------
    path : str
        The file to save to.
        Does not check to see if the path is valid.
    """
    with open(path, "wb") as f:
        dill.dump(self, f)

to_tod(dx, dy)

Project the model into a TOD.

Parameters:

Name Type Description Default
dx ArrayLike

The RA TOD in arcseconds.

required
dy ArrayLike

The Dec TOD in arcseconds.

required

Returns:

Name Type Description
tod Array

The model as a TOD. Same shape as dx.

Source code in witch/containers.py
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
def to_tod(self, dx: ArrayLike, dy: ArrayLike) -> jax.Array:
    """
    Project the model into a TOD.

    Parameters
    ----------
    dx : ArrayLike
        The RA TOD in arcseconds.
    dy : ArrayLike
        The Dec TOD in arcseconds.

    Returns
    -------
    tod : jax.Array
        The model as a TOD.
        Same shape as dx.
    """
    return wu.bilinear_interp(
        dx, dy, self.xyz[0].ravel(), self.xyz[1].ravel(), self.model
    )

to_tod_grad(dx, dy)

Project the model and gradient into a TOD.

Parameters:

Name Type Description Default
dx ArrayLike

The RA TOD in arcseconds.

required
dy ArrayLike

The Dec TOD in arcseconds.

required

Returns:

Name Type Description
tod Array

The model as a TOD. Same shape as dx.

grad_tod Array

The gradient as a TOD. Has shape (len(pars),) + dx.shape.

Source code in witch/containers.py
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
def to_tod_grad(self, dx: ArrayLike, dy: ArrayLike) -> tuple[jax.Array, jax.Array]:
    """
    Project the model and gradient into a TOD.

    Parameters
    ----------
    dx : ArrayLike
        The RA TOD in arcseconds.
    dy : ArrayLike
        The Dec TOD in arcseconds.

    Returns
    -------
    tod : jax.Array
        The model as a TOD.
        Same shape as dx.
    grad_tod : jax.Array
        The gradient as a TOD.
        Has shape `(len(pars),) + dx.shape`.
    """
    model, grad = self.model_grad
    tod = wu.bilinear_interp(
        dx, dy, self.xyz[0].ravel(), self.xyz[1].ravel(), model
    )
    grad_tod = jnp.array(
        [
            (
                wu.bilinear_interp(
                    dx, dy, self.xyz[0].ravel(), self.xyz[1].ravel(), _grad
                )
                if _fit
                else jnp.zeros_like(tod)
            )
            for _grad, _fit in zip(grad, self.to_fit)
        ]
    )

    return tod, grad_tod

update(vals, errs, chisq)

Update the parameter values and errors as well as the model chi-squared. This also resets the cache on model and model_grad if vals is different than self.pars.

Parameters:

Name Type Description Default
vals Array

The new parameter values. Should be in the same order as pars.

required
errs Array

The new parameter errors. Should be in the same order as pars.

required
chisq Array

The new chi-squared. Should be a scalar float array.

required

Returns:

Name Type Description
updated Model

The updated model. While nominally the model will update in place, returning it alows us to use this function in JITed functions.

Source code in witch/containers.py
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
def update(self, vals: jax.Array, errs: jax.Array, chisq: jax.Array) -> Self:
    """
    Update the parameter values and errors as well as the model chi-squared.
    This also resets the cache on `model` and `model_grad`
    if `vals` is different than `self.pars`.

    Parameters
    ----------
    vals : jax.Array
        The new parameter values.
        Should be in the same order as `pars`.
    errs : jax.Array
        The new parameter errors.
        Should be in the same order as `pars`.
    chisq : jax.Array
        The new chi-squared.
        Should be a scalar float array.

    Returns
    -------
    updated : Model
        The updated model.
        While nominally the model will update in place, returning it
        alows us to use this function in JITed functions.
    """
    if not np.array_equal(self.pars, vals):
        self.__dict__.pop("model", None)
        self.__dict__.pop("model_grad", None)
    n = 0
    for struct in self.structures:
        for par in struct.parameters:
            par.val = vals[n]
            par.err = errs[n]
            n += 1
    self.chisq = chisq

    return self

Parameter dataclass

Dataclass to represent a single parameter of a model.

Attributes:

Name Type Description
name str

The name of the parameter. This is used only for display purposes.

fit Array

Should be array with length Model.n_rounds. fit[i] is True if we want to fit this parameter in the i'th round.

val float

The value of the parameter.

err float

The error on the parameter value.

prior tuple[float, float]

The prior on this parameter. Should be the tuple (lower_bound, upper_bound).

Source code in witch/containers.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
@register_pytree_node_class
@dataclass
class Parameter:
    """
    Dataclass to represent a single parameter of a model.

    Attributes
    ----------
    name : str
        The name of the parameter.
        This is used only for display purposes.
    fit : jax.Array
        Should be array with length `Model.n_rounds`.
        `fit[i]` is True if we want to fit this parameter in the `i`'th round.
    val : float
        The value of the parameter.
    err : float
        The error on the parameter value.
    prior : tuple[float, float]
        The prior on this parameter.
        Should be the tuple `(lower_bound, upper_bound)`.
    """

    name: str
    fit: tuple[bool]  # 1d bool array
    val: jax.Array  # Scalar float array
    err: jax.Array  # Scalar float array
    prior: jax.Array  # 2 element float array

    @property
    def fit_ever(self) -> bool:  # jax.Array:
        """
        Check if this parameter is set to ever be fit.

        Returns
        -------
        fit_ever : jax.Array
            Single element jax boolean array.
            True if this parameter is ever fit.
        """
        return np.any(self.fit).item()

    # Functions for making this a pytree
    # Don't call this on your own
    def tree_flatten(self) -> tuple[tuple, tuple]:
        children = (self.val, self.err, self.prior)
        aux_data = (self.name, self.fit)

        return (children, aux_data)

    @classmethod
    def tree_unflatten(cls, aux_data, children) -> Self:
        name, fit = aux_data
        return cls(name, fit, *children)

fit_ever property

Check if this parameter is set to ever be fit.

Returns:

Name Type Description
fit_ever Array

Single element jax boolean array. True if this parameter is ever fit.

Structure dataclass

Dataclass to represent a structure within the model.

Attributes:

Name Type Description
name str

The name of the structure. This is used only for display purposes.

structure str

The type of structure that this is an instance of. Should be a string that appears in core.ORDER

parameters list[Parameter]

The model parameters for this structure.

Raises:

Type Description
ValueError

If structure is not a valid structure. If we have the wrong number of parameters.

Source code in witch/containers.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
@register_pytree_node_class
@dataclass
class Structure:
    """
    Dataclass to represent a structure within the model.

    Attributes
    ----------
    name : str
        The name of the structure.
        This is used only for display purposes.
    structure : str
        The type of structure that this is an instance of.
        Should be a string that appears in `core.ORDER`
    parameters : list[Parameter]
        The model parameters for this structure.

    Raises
    ------
    ValueError
        If `structure` is not a valid structure.
        If we have the wrong number of parameters.
    """

    name: str
    structure: str
    parameters: list[Parameter]

    def __post_init__(self):
        self.structure = self.structure.lower()
        # Check that this is a valid structure
        if self.structure not in STRUCT_N_PAR.keys():
            raise ValueError(f"{self.name} has invalid structure: {self.structure}")
        # Check that we have the correct number of params
        if len(self.parameters) != STRUCT_N_PAR[self.structure]:
            raise ValueError(
                f"{self.name} has incorrect number of parameters, expected {STRUCT_N_PAR[self.structure]} for {self.structure} but was given {len(self.parameters)}"
            )

    # Functions for making this a pytree
    # Don't call this on your own
    def tree_flatten(self) -> tuple[tuple, tuple]:
        children = tuple(self.parameters)
        aux_data = (self.name, self.structure)

        return (children, aux_data)

    @classmethod
    def tree_unflatten(cls, aux_data, children) -> Self:
        name, structure = aux_data
        parameters = children

        return cls(name, structure, list(parameters))