Skip to content

model

Data classes for describing models in a structured way.

Model dataclass

Dataclass to describe a model. This includes some caching features that improve performance when fitting. Note that because of the caching dynamically modifying what structures compose the model may not work as intended so beware.

Attributes:

Name Type Description
name str

The name of the model. This is used for display purposes only.

structures list[Structure]

The structures that compose the model. Will be sorted to match core.ORDER once initialized.

xyz tuple[Array, Array, Array, float, float]

Defines the grid used by model computation. The first three elements are a sparse 3D grid in arcseconds, with the first two elements being RA and Dec respectively and the third element being the LOS. The last two elements are the model center in Ra and Dec (also in arcseconds). The structure functions use this as the coordinate to reference dx and dy to.

dz float

The LOS integration factor. Should minimally be the pixel size in arcseconds along the LOS, but can also include additional factors for performing unit conversions.

n_rounds int

How many rounds of fitting to perform.

cur_round int, default: 0

Which round of fitting we are currently in, rounds are 0 indexed.

to_run tuple[bool, bool, bool, bool], default: (True, True, True, True)

The model stages to run. See core.make_to_run for details.

chisq float, default: np.inf

The chi-squared of this model relative to some data. Used when fitting.

Source code in witch/containers/model.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
@register_pytree_node_class
@dataclass
class Model:
    """
    Dataclass to describe a model.
    This includes some caching features that improve performance when fitting.
    Note that because of the caching dynamically modifying what structures compose
    the model may not work as intended so beware.

    Attributes
    ----------
    name : str
        The name of the model.
        This is used for display purposes only.
    structures : list[Structure]
        The structures that compose the model.
        Will be sorted to match `core.ORDER` once initialized.
    xyz : tuple[jax.Array, jax.Array, jax.Array, float, float]
        Defines the grid used by model computation.
        The first three elements are a sparse 3D grid in arcseconds,
        with the first two elements being RA and Dec respectively and the third
        element being the LOS. The last two elements are the model center in Ra and Dec
        (also in arcseconds). The structure functions use this as the coordinate to reference
        `dx` and `dy` to.
    dz : float
        The LOS integration factor.
        Should minimally be the pixel size in arcseconds along the LOS,
        but can also include additional factors for performing unit conversions.
    n_rounds : int
        How many rounds of fitting to perform.
    cur_round : int, default: 0
        Which round of fitting we are currently in,
        rounds are 0 indexed.
    to_run : tuple[bool, bool, bool, bool], default: (True, True, True, True)
        The model stages to run.
        See `core.make_to_run` for details.
    chisq : float, default: np.inf
        The chi-squared of this model relative to some data.
        Used when fitting.
    """

    name: str
    structures: tuple[Structure, ...]
    xyz: tuple[jax.Array, jax.Array, jax.Array, float, float]  # arcseconds
    dz: float  # arcseconds
    n_rounds: int
    cur_round: int = 0
    to_run: tuple[bool, bool, bool, bool] = field(default_factory=core.make_to_run)
    chisq: jax.Array = field(
        default_factory=jnp.array(jnp.inf).copy
    )  # scalar float array

    def check_compatibility(self, other: Self) -> bool:
        """
        Check whether 'other' (a model loaded from a checkpoint) is compatible with the current model defined in config.
        Checks that structures, xyz, and dz are equal.

        Arguments
        ---------
        other : Model
            Model to compare against

        Returns
        -------
        compatible : bool
            True if compatible,
            False if not.
        """
        # Structure name match
        struct_names = tuple(s.name for s in self.structures)
        ostruct_names = tuple(s.name for s in other.structures)
        if struct_names != ostruct_names:
            print(
                f"Model structure mismatch. "
                f"Config structures = {struct_names}, "
                f"Checkpoint structures = {ostruct_names}"
            )
            return False

        # Parameter counts per structure
        for self_struct, other_struct in zip(self.structures, other.structures):
            self_params = self_struct.parameters
            other_params = other_struct.parameters
            if len(self_params) != len(other_params):
                print(
                    f"Parameter count mismatch in structure '{self_struct.name}': "
                    f"{len(self_params)} (config) vs {len(other_params)} (ckpt)"
                )
                return False

        # Parameter order match
        if self.par_names != other.par_names:
            print("Parameter ordering mismatch between config and checkpoint models")
            return False

        # xyz compatibility
        if len(self.xyz) != len(other.xyz):
            print("parameter ordering mismatch between config and checkpoint models")
            return False

        for a, b in zip(self.xyz, other.xyz):
            if isinstance(a, (float, int)):  # check the two float elements
                if not isinstance(b, (float, int)):
                    print("xyz element type mismatch")
                    return False
            else:  # jax arrays
                if a.shape != b.shape:
                    print(f"xyz[i] array shape mismatch: {a.shape} vs {b.shape}")
                    return False

        # dz compatibility
        if self.dz != other.dz:  # Checks same numerical value
            print(f"dz value mismatch: {self.dz} vs {other.dz}")
            return False

        return True

    def __setattr__(self, name, value):
        if name == "cur_round" or name == "xyz":
            self.__dict__.pop("model_grad", None)
            self.__dict__.pop("model", None)
        if name == "n_rounds":
            self.__dict__.pop("to_fit_ever", None)
        return super().__setattr__(name, value)

    def __repr__(self) -> str:
        rep = self.name + ":\n"
        rep += f"Round {self.cur_round + 1} out of {self.n_rounds}\n"
        for struct in self.structures:
            rep += "\t" + struct.name + ":\n"
            for par in struct.parameters:
                rep += (
                    "\t\t"
                    + par.name
                    + "*" * par.fit[self.cur_round]
                    + " ["
                    + str(par.prior[0])
                    + ", "
                    + str(par.prior[1])
                    + "]"
                    + " = "
                    + str(par.val)
                    + " ± "
                    + str(par.err)
                    + " ("
                    + str(jnp.abs(par.val / par.err))
                    + " σ)"
                    + "\n"
                )
        rep += f"chisq is {self.chisq}"
        return rep

    @cached_property
    def n_struct(self) -> list[int]:
        """
        Number of each type of structures in the model.
        Note that this is cached.

        Returns
        -------
        n_struct : list[int]
            `n_struct[i]` is the number of `core.ORDER[i]`
            structures in this model.
        """
        n_struct = [0] * len(core.ORDER)
        for structure in self.structures:
            idx = core.ORDER.index(structure.structure)
            n_struct[idx] += 1
        return n_struct

    @cached_property
    def n_rbins(self) -> list[int]:
        """
        Number of r bins for nonparametric structures.
        Note that this is cached.

        Returns
        -------
        n_rbins : list[int]
            `n_rbins[i]` is the number of rbins in this structure.
        """
        n_rbins = [structure.n_rbins for structure in self.structures]

        return n_rbins

    @property
    def pars(self) -> jax.Array:
        """
        Get the current parameter values.

        Returns
        -------
        pars :  jax.Array
            The parameter values in the order expected by `core.model`.
        """
        pars = jnp.array([])
        for structure in self.structures:
            for parameter in structure.parameters:
                pars = jnp.append(pars, parameter.val.ravel())
        return jnp.array(pars)

    @cached_property
    def par_names(self) -> list[str]:
        """
        Get the names of all parameters.
        Note that this is cached.

        Returns
        -------
        par_names : list[str]
            Parameter names in the same order as `pars`.
        """
        par_names = []
        for structure in self.structures:
            for parameter in structure.parameters:
                if len(parameter.val) > 1:
                    for i in range(len(parameter.val)):
                        par_names += [parameter.name + "_{}".format(i)]
                else:
                    par_names += [parameter.name]

        return par_names

    @cached_property
    def par_struct_names(self) -> list[str]:
        """
        Get the struct names of all parameters.
        Note that this is cached.

        Returns
        -------
        par_struct_names : list[str]
            Parameter struct names in the same order as `pars`.
        """
        par_struct_names = []
        for structure in self.structures:
            for parameter in structure.parameters:
                par_struct_names += [structure.name] * len(parameter.val)

        return par_struct_names

    @cached_property
    def struct_names(self) -> list[str]:
        """
        Get the names of all structures in this model
        """
        return [struct.name for struct in self.structures]

    @property
    def errs(self) -> jax.Array:
        """
        Get the current parameter errors.

        Returns
        -------
        errs : jax.Array
            The errors in the same order as vals.
        """
        errs = jnp.array([])
        for structure in self.structures:
            for parameter in structure.parameters:
                errs = jnp.append(errs, parameter.err.ravel())
        return jnp.array(errs)

    @cached_property
    def priors(self) -> tuple[jax.Array, jax.Array]:
        """
        Get the priors for all parameters.
        Note that this is cached.

        Returns
        -------
        priors : tuple[jax.Array, jax.Array]
            Parameter priors in the same order as `pars`.
            This is a tuple with the first element being an array
            of lower bounds and the second being upper.
        """
        lower = []
        upper = []
        for structure in self.structures:
            for parameter in structure.parameters:
                lower += [parameter.prior[0]] * len(parameter.val)
                upper += [parameter.prior[1]] * len(parameter.val)
        priors = (jnp.array(lower), jnp.array(upper))
        return priors

    @property
    def to_fit(self) -> tuple[bool]:  # jax.Array:
        """
        Get which parameters we want to fit for the current round.

        Returns
        -------
        to_fit : jax.Array
            `to_fit[i]` is True if we want to fit the `i`'th parameter
            in the current round.
            This is in the same order as `pars`.
        """

        to_fit = []
        for structure in self.structures:
            for parameter in structure.parameters:
                to_fit += [parameter.fit[self.cur_round]] * len(parameter.val)
                # to_fit = jnp.append(to_fit, jnp.array([parameter.fit[self.cur_round]] * len(parameter.val)).ravel())

        return tuple(to_fit)  # jnp.ravel(jnp.array(to_fit))

    @cached_property
    def to_fit_ever(self) -> jax.Array:
        """
        Check which parameters we ever fit.
        Note that this is cached.

        Returns
        -------
        to_fit_ever : jax.Array
            `to_fit[i]` is True if we ever want to fit the `i`'th parameter.
            This is in the same order as `pars`.
        """
        to_fit = jnp.array([], dtype=bool)
        for structure in self.structures:
            for parameter in structure.parameters:
                to_fit = jnp.append(
                    to_fit,
                    jnp.array(
                        [parameter.fit_ever] * len(parameter.val), dtype=bool
                    ).ravel(),
                )

        return jnp.ravel(jnp.array(to_fit))

    @cached_property
    def model(self) -> jax.Array:
        """
        The evaluated model, see `core.model` for details.
        Note that this is cached, but is automatically reset whenever
        `update` is called or `cur_round` or `xyz` changes.

        Returns
        -------
        model : jax.Array
            The model evaluted on `xyz` with the current values of `pars`.
        """
        return core.model(
            self.xyz,
            tuple(self.n_struct),
            tuple(self.n_rbins),
            self.dz,
            self.to_run,
            *self.pars,
        )

    @cached_property
    def model_grad(self) -> tuple[jax.Array, jax.Array]:
        """
        The evaluated model and its gradient, see `core.model_grad` for details.
        Note that this is cached, but is automatically reset whenever
        `update` is called or `cur_round` changes.

        Returns
        -------
        model : jax.Array
            The model evaluted on `xyz` with the current values of `pars`.
        grad : jax.Array
            The gradient evaluted on `xyz` with the current values of `pars`.
            Has shape `(len(pars),) + model.shape`.
        """
        argnums = tuple(np.where(self.to_fit)[0] + core.ARGNUM_SHIFT)
        return core.model_grad(
            self.xyz,
            tuple(self.n_struct),
            tuple(self.n_rbins),
            self.dz,
            self.to_run,
            argnums,
            *self.pars,
        )

    def update(self, vals: jax.Array, errs: jax.Array, chisq: jax.Array) -> Self:
        """
        Update the parameter values and errors as well as the model chi-squared.
        This also resets the cache on `model` and `model_grad`
        if `vals` is different than `self.pars`.

        Parameters
        ----------
        vals : jax.Array
            The new parameter values.
            Should be in the same order as `pars`.
        errs : jax.Array
            The new parameter errors.
            Should be in the same order as `pars`.
        chisq : jax.Array
            The new chi-squared.
            Should be a scalar float array.

        Returns
        -------
        updated : Model
            The updated model.
            While nominally the model will update in place, returning it
            alows us to use this function in JITed functions.
        """
        if not np.array_equal(self.pars, vals):
            self.__dict__.pop("model", None)
            self.__dict__.pop("model_grad", None)
        n = 0
        for struct in self.structures:
            for par in struct.parameters:
                for i in range(len(par.val)):
                    par.val = par.val.at[i].set(vals[n])
                    par.err = par.err.at[i].set(errs[n])
                    n += 1
        self.chisq = chisq

        return self

    def add_round(self, to_fit) -> Self:
        """
        Add an additional round to the model.

        Parameters
        ----------
        to_fit : jax.Array
            Boolean array denoting which parameters to fit this round.
            Should be in the same order as `pars`.

        Returns
        -------
        updated : Model
            The updated model with the new round.
            While nominally the model will update in place, returning it
            alows us to use this function in JITed functions.
        """
        self.n_rounds = self.n_rounds + 1
        self.cur_round = self.cur_round + 1
        n = 0
        for struct in self.structures:
            for par in struct.parameters:
                par.fit = par.fit + tuple((to_fit[n].item(),))
                n += 1
        return self

    def remove_struct(self, struct_name: str):
        """
        Remove structure by name.

        Parameters
        ----------
        struct_name : str
            Name of struct to be removed.
        """
        self.structures = tuple(
            structure for structure in self.structures if structure.name != struct_name
        )

        to_pop = ["to_fit_ever", "n_struct", "priors", "par_names"]
        for key in to_pop:
            if key in self.__dict__:  # Pop keys if they are in dict
                self.__dict__.pop(key)

        self.__dict__.pop("model", None)
        self.__dict__.pop("model_grad", None)

    def save(self, path: str):
        """
        Serialize the model to a file with dill.

        Parameters
        ----------
        path : str
            The file to save to.
            Does not check to see if the path is valid.
        """
        with open(path, "wb") as f:
            dill.dump(self, f)

    @classmethod
    def load(cls, path: str) -> Self:
        """
        Load the model from a file with dill.

        Parameters
        ----------
        path : str
            The path to the saved model.
            Does not check to see if the path is valid.

        Returns
        -------
        model : Model
            The loaded model.
        """
        with open(path, "rb") as f:
            return dill.load(f)

    @classmethod
    def from_cfg(
        cls,
        cfg: dict,
        model_field: str = "model",
        generate_name: bool = True,
        remove_structs: bool = False,
    ) -> Self:
        """
        Create an instance of model from a witcher config.

        Parameters
        ----------
        cfg : dict
            The config loaded into a dict.
        model_field : str, default: "model"
            The name of the model in the config.
        generate_name : bool, default: True
            If True generate a name for the model.
            If False use `model_field`.
        remove_structs : bool, default: False
            If True then don't include structures marked for removal.

        Returns
        -------
        model : Model
            The model described by the config.
        """
        # Do imports
        for module, name in cfg.get("imports", {}).items():
            mod = import_module(module)
            if isinstance(name, str):
                globals()[name] = mod
            elif isinstance(name, list):
                for n in name:
                    globals()[n] = getattr(mod, n)
            else:
                raise TypeError("Expect import name to be a string or a list")

        # Load constants
        constants = {
            name: eval(str(const)) for name, const in cfg.get("constants", {}).items()
        }
        constants = constants

        # Get jax device
        dev_id = cfg.get("jax_device", 0)
        device = jax.devices()[dev_id]

        # Setup coordindate stuff
        r_map = eval(str(cfg["coords"]["r_map"]))
        dr = eval(str(cfg["coords"]["dr"]))
        dz = eval(str(cfg["coords"].get("dz", dr)))
        x0 = eval(str(cfg["coords"]["x0"]))
        y0 = eval(str(cfg["coords"]["y0"]))

        xyz_host = wg.make_grid(
            r_map, dr, dr, dz, x0 * wg.rad_to_arcsec, y0 * wg.rad_to_arcsec
        )
        xyz = jax.device_put(xyz_host, device)
        xyz[0].block_until_ready()
        xyz[1].block_until_ready()
        xyz[2].block_until_ready()

        n_rounds = cfg.get("n_rounds", 1)

        structures = []
        for name, structure in cfg[model_field]["structures"].items():
            if structure.get("to_remove", False) and remove_structs:
                continue
            n_rbins = structure.get("n_rbins", 0)
            parameters = []
            for par_name, param in structure["parameters"].items():
                val = eval(str(param["value"]))
                fit = param.get("to_fit", [False] * max(1, n_rounds))
                if isinstance(fit, bool):
                    fit = [fit] * max(1, n_rounds)
                if len(fit) != max(1, n_rounds):
                    raise ValueError(
                        f"to_fit has {len(fit)} entries but we only have {n_rounds} rounds"
                    )
                priors = param.get("priors", None)
                if priors is not None:
                    priors = eval(str(priors))
                else:
                    priors = (-1 * np.inf, np.inf)
                parameters.append(
                    Parameter(
                        par_name,
                        tuple(fit),
                        jnp.atleast_1d(jnp.array(val, dtype=float)),
                        jnp.zeros_like(jnp.atleast_1d(jnp.array(val)), dtype=float),
                        jnp.array(priors, dtype=float),
                    )
                )
            structures.append(
                Structure(
                    name,
                    structure["structure"],
                    parameters,
                    n_rbins=n_rbins,
                )
            )

        name = model_field
        if generate_name:
            name = "-".join([structure.name for structure in structures])

        # Make sure the structure is in the order that core expects
        structure_idx = jnp.argsort(
            jnp.array(
                [core.ORDER_DICT[structure.structure] for structure in structures]
            )
        ).ravel()
        structures = tuple([structures[i] for i in structure_idx])

        return cls(name, structures, xyz, dz, n_rounds)

    # Functions for making this a pytree
    # Don't call this on your own
    def tree_flatten(self) -> tuple[tuple, tuple]:
        children = (self.structures, self.xyz, self.dz, self.chisq)
        aux_data = (
            self.name,
            self.n_rounds,
            self.cur_round,
            self.to_run,
        )

        return (children, aux_data)

    @classmethod
    def tree_unflatten(cls, aux_data, children) -> Self:
        name, n_rounds, cur_round, to_run = aux_data
        structures, xyz, dz, chisq = children

        return cls(
            name,
            structures,
            xyz,
            dz,
            n_rounds,
            cur_round,
            to_run,
            chisq,
        )

errs property

Get the current parameter errors.

Returns:

Name Type Description
errs Array

The errors in the same order as vals.

model cached property

The evaluated model, see core.model for details. Note that this is cached, but is automatically reset whenever update is called or cur_round or xyz changes.

Returns:

Name Type Description
model Array

The model evaluted on xyz with the current values of pars.

model_grad cached property

The evaluated model and its gradient, see core.model_grad for details. Note that this is cached, but is automatically reset whenever update is called or cur_round changes.

Returns:

Name Type Description
model Array

The model evaluted on xyz with the current values of pars.

grad Array

The gradient evaluted on xyz with the current values of pars. Has shape (len(pars),) + model.shape.

n_rbins cached property

Number of r bins for nonparametric structures. Note that this is cached.

Returns:

Name Type Description
n_rbins list[int]

n_rbins[i] is the number of rbins in this structure.

n_struct cached property

Number of each type of structures in the model. Note that this is cached.

Returns:

Name Type Description
n_struct list[int]

n_struct[i] is the number of core.ORDER[i] structures in this model.

par_names cached property

Get the names of all parameters. Note that this is cached.

Returns:

Name Type Description
par_names list[str]

Parameter names in the same order as pars.

par_struct_names cached property

Get the struct names of all parameters. Note that this is cached.

Returns:

Name Type Description
par_struct_names list[str]

Parameter struct names in the same order as pars.

pars property

Get the current parameter values.

Returns:

Name Type Description
pars Array

The parameter values in the order expected by core.model.

priors cached property

Get the priors for all parameters. Note that this is cached.

Returns:

Name Type Description
priors tuple[Array, Array]

Parameter priors in the same order as pars. This is a tuple with the first element being an array of lower bounds and the second being upper.

struct_names cached property

Get the names of all structures in this model

to_fit property

Get which parameters we want to fit for the current round.

Returns:

Name Type Description
to_fit Array

to_fit[i] is True if we want to fit the i'th parameter in the current round. This is in the same order as pars.

to_fit_ever cached property

Check which parameters we ever fit. Note that this is cached.

Returns:

Name Type Description
to_fit_ever Array

to_fit[i] is True if we ever want to fit the i'th parameter. This is in the same order as pars.

add_round(to_fit)

Add an additional round to the model.

Parameters:

Name Type Description Default
to_fit Array

Boolean array denoting which parameters to fit this round. Should be in the same order as pars.

required

Returns:

Name Type Description
updated Model

The updated model with the new round. While nominally the model will update in place, returning it alows us to use this function in JITed functions.

Source code in witch/containers/model.py
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
def add_round(self, to_fit) -> Self:
    """
    Add an additional round to the model.

    Parameters
    ----------
    to_fit : jax.Array
        Boolean array denoting which parameters to fit this round.
        Should be in the same order as `pars`.

    Returns
    -------
    updated : Model
        The updated model with the new round.
        While nominally the model will update in place, returning it
        alows us to use this function in JITed functions.
    """
    self.n_rounds = self.n_rounds + 1
    self.cur_round = self.cur_round + 1
    n = 0
    for struct in self.structures:
        for par in struct.parameters:
            par.fit = par.fit + tuple((to_fit[n].item(),))
            n += 1
    return self

check_compatibility(other)

Check whether 'other' (a model loaded from a checkpoint) is compatible with the current model defined in config. Checks that structures, xyz, and dz are equal.

Arguments

other : Model Model to compare against

Returns:

Name Type Description
compatible bool

True if compatible, False if not.

Source code in witch/containers/model.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def check_compatibility(self, other: Self) -> bool:
    """
    Check whether 'other' (a model loaded from a checkpoint) is compatible with the current model defined in config.
    Checks that structures, xyz, and dz are equal.

    Arguments
    ---------
    other : Model
        Model to compare against

    Returns
    -------
    compatible : bool
        True if compatible,
        False if not.
    """
    # Structure name match
    struct_names = tuple(s.name for s in self.structures)
    ostruct_names = tuple(s.name for s in other.structures)
    if struct_names != ostruct_names:
        print(
            f"Model structure mismatch. "
            f"Config structures = {struct_names}, "
            f"Checkpoint structures = {ostruct_names}"
        )
        return False

    # Parameter counts per structure
    for self_struct, other_struct in zip(self.structures, other.structures):
        self_params = self_struct.parameters
        other_params = other_struct.parameters
        if len(self_params) != len(other_params):
            print(
                f"Parameter count mismatch in structure '{self_struct.name}': "
                f"{len(self_params)} (config) vs {len(other_params)} (ckpt)"
            )
            return False

    # Parameter order match
    if self.par_names != other.par_names:
        print("Parameter ordering mismatch between config and checkpoint models")
        return False

    # xyz compatibility
    if len(self.xyz) != len(other.xyz):
        print("parameter ordering mismatch between config and checkpoint models")
        return False

    for a, b in zip(self.xyz, other.xyz):
        if isinstance(a, (float, int)):  # check the two float elements
            if not isinstance(b, (float, int)):
                print("xyz element type mismatch")
                return False
        else:  # jax arrays
            if a.shape != b.shape:
                print(f"xyz[i] array shape mismatch: {a.shape} vs {b.shape}")
                return False

    # dz compatibility
    if self.dz != other.dz:  # Checks same numerical value
        print(f"dz value mismatch: {self.dz} vs {other.dz}")
        return False

    return True

from_cfg(cfg, model_field='model', generate_name=True, remove_structs=False) classmethod

Create an instance of model from a witcher config.

Parameters:

Name Type Description Default
cfg dict

The config loaded into a dict.

required
model_field str

The name of the model in the config.

"model"
generate_name bool

If True generate a name for the model. If False use model_field.

True
remove_structs bool

If True then don't include structures marked for removal.

False

Returns:

Name Type Description
model Model

The model described by the config.

Source code in witch/containers/model.py
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
@classmethod
def from_cfg(
    cls,
    cfg: dict,
    model_field: str = "model",
    generate_name: bool = True,
    remove_structs: bool = False,
) -> Self:
    """
    Create an instance of model from a witcher config.

    Parameters
    ----------
    cfg : dict
        The config loaded into a dict.
    model_field : str, default: "model"
        The name of the model in the config.
    generate_name : bool, default: True
        If True generate a name for the model.
        If False use `model_field`.
    remove_structs : bool, default: False
        If True then don't include structures marked for removal.

    Returns
    -------
    model : Model
        The model described by the config.
    """
    # Do imports
    for module, name in cfg.get("imports", {}).items():
        mod = import_module(module)
        if isinstance(name, str):
            globals()[name] = mod
        elif isinstance(name, list):
            for n in name:
                globals()[n] = getattr(mod, n)
        else:
            raise TypeError("Expect import name to be a string or a list")

    # Load constants
    constants = {
        name: eval(str(const)) for name, const in cfg.get("constants", {}).items()
    }
    constants = constants

    # Get jax device
    dev_id = cfg.get("jax_device", 0)
    device = jax.devices()[dev_id]

    # Setup coordindate stuff
    r_map = eval(str(cfg["coords"]["r_map"]))
    dr = eval(str(cfg["coords"]["dr"]))
    dz = eval(str(cfg["coords"].get("dz", dr)))
    x0 = eval(str(cfg["coords"]["x0"]))
    y0 = eval(str(cfg["coords"]["y0"]))

    xyz_host = wg.make_grid(
        r_map, dr, dr, dz, x0 * wg.rad_to_arcsec, y0 * wg.rad_to_arcsec
    )
    xyz = jax.device_put(xyz_host, device)
    xyz[0].block_until_ready()
    xyz[1].block_until_ready()
    xyz[2].block_until_ready()

    n_rounds = cfg.get("n_rounds", 1)

    structures = []
    for name, structure in cfg[model_field]["structures"].items():
        if structure.get("to_remove", False) and remove_structs:
            continue
        n_rbins = structure.get("n_rbins", 0)
        parameters = []
        for par_name, param in structure["parameters"].items():
            val = eval(str(param["value"]))
            fit = param.get("to_fit", [False] * max(1, n_rounds))
            if isinstance(fit, bool):
                fit = [fit] * max(1, n_rounds)
            if len(fit) != max(1, n_rounds):
                raise ValueError(
                    f"to_fit has {len(fit)} entries but we only have {n_rounds} rounds"
                )
            priors = param.get("priors", None)
            if priors is not None:
                priors = eval(str(priors))
            else:
                priors = (-1 * np.inf, np.inf)
            parameters.append(
                Parameter(
                    par_name,
                    tuple(fit),
                    jnp.atleast_1d(jnp.array(val, dtype=float)),
                    jnp.zeros_like(jnp.atleast_1d(jnp.array(val)), dtype=float),
                    jnp.array(priors, dtype=float),
                )
            )
        structures.append(
            Structure(
                name,
                structure["structure"],
                parameters,
                n_rbins=n_rbins,
            )
        )

    name = model_field
    if generate_name:
        name = "-".join([structure.name for structure in structures])

    # Make sure the structure is in the order that core expects
    structure_idx = jnp.argsort(
        jnp.array(
            [core.ORDER_DICT[structure.structure] for structure in structures]
        )
    ).ravel()
    structures = tuple([structures[i] for i in structure_idx])

    return cls(name, structures, xyz, dz, n_rounds)

load(path) classmethod

Load the model from a file with dill.

Parameters:

Name Type Description Default
path str

The path to the saved model. Does not check to see if the path is valid.

required

Returns:

Name Type Description
model Model

The loaded model.

Source code in witch/containers/model.py
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
@classmethod
def load(cls, path: str) -> Self:
    """
    Load the model from a file with dill.

    Parameters
    ----------
    path : str
        The path to the saved model.
        Does not check to see if the path is valid.

    Returns
    -------
    model : Model
        The loaded model.
    """
    with open(path, "rb") as f:
        return dill.load(f)

remove_struct(struct_name)

Remove structure by name.

Parameters:

Name Type Description Default
struct_name str

Name of struct to be removed.

required
Source code in witch/containers/model.py
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
def remove_struct(self, struct_name: str):
    """
    Remove structure by name.

    Parameters
    ----------
    struct_name : str
        Name of struct to be removed.
    """
    self.structures = tuple(
        structure for structure in self.structures if structure.name != struct_name
    )

    to_pop = ["to_fit_ever", "n_struct", "priors", "par_names"]
    for key in to_pop:
        if key in self.__dict__:  # Pop keys if they are in dict
            self.__dict__.pop(key)

    self.__dict__.pop("model", None)
    self.__dict__.pop("model_grad", None)

save(path)

Serialize the model to a file with dill.

Parameters:

Name Type Description Default
path str

The file to save to. Does not check to see if the path is valid.

required
Source code in witch/containers/model.py
509
510
511
512
513
514
515
516
517
518
519
520
def save(self, path: str):
    """
    Serialize the model to a file with dill.

    Parameters
    ----------
    path : str
        The file to save to.
        Does not check to see if the path is valid.
    """
    with open(path, "wb") as f:
        dill.dump(self, f)

update(vals, errs, chisq)

Update the parameter values and errors as well as the model chi-squared. This also resets the cache on model and model_grad if vals is different than self.pars.

Parameters:

Name Type Description Default
vals Array

The new parameter values. Should be in the same order as pars.

required
errs Array

The new parameter errors. Should be in the same order as pars.

required
chisq Array

The new chi-squared. Should be a scalar float array.

required

Returns:

Name Type Description
updated Model

The updated model. While nominally the model will update in place, returning it alows us to use this function in JITed functions.

Source code in witch/containers/model.py
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
def update(self, vals: jax.Array, errs: jax.Array, chisq: jax.Array) -> Self:
    """
    Update the parameter values and errors as well as the model chi-squared.
    This also resets the cache on `model` and `model_grad`
    if `vals` is different than `self.pars`.

    Parameters
    ----------
    vals : jax.Array
        The new parameter values.
        Should be in the same order as `pars`.
    errs : jax.Array
        The new parameter errors.
        Should be in the same order as `pars`.
    chisq : jax.Array
        The new chi-squared.
        Should be a scalar float array.

    Returns
    -------
    updated : Model
        The updated model.
        While nominally the model will update in place, returning it
        alows us to use this function in JITed functions.
    """
    if not np.array_equal(self.pars, vals):
        self.__dict__.pop("model", None)
        self.__dict__.pop("model_grad", None)
    n = 0
    for struct in self.structures:
        for par in struct.parameters:
            for i in range(len(par.val)):
                par.val = par.val.at[i].set(vals[n])
                par.err = par.err.at[i].set(errs[n])
                n += 1
    self.chisq = chisq

    return self

Model_xfer dataclass

Bases: Model

Source code in witch/containers/model.py
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
@dataclass
class Model_xfer(Model):

    ks: jax.Array = field(default_factory=jnp.array([0]).copy)  # scalar float array
    xfer_vals: jax.Array = field(
        default_factory=jnp.array([1]).copy
    )  # scalar float array

    def __post_init__(self):
        pass

    @classmethod
    def from_parent(cls, parent, xfer_str) -> Self:
        xfer = load_xfer(xfer_str)

        pixsize = np.abs(parent.xyz[1][0][1] - parent.xyz[1][0][0])
        ks = xfer[0, 0:] * pixsize  # This picks up an extra dim?
        xfer_vals = xfer[1, 0:]

        my_dict = {}
        for key in parent.__dataclass_fields__.keys():
            if parent.__dataclass_fields__[key].init:
                my_dict[key] = deepcopy(parent.__dict__[key])

        return cls(**my_dict, ks=ks.ravel(), xfer_vals=xfer_vals)

    @cached_property
    def model(self) -> jax.Array:
        cur_map = core.model(
            self.xyz,
            tuple(self.n_struct),
            tuple(self.n_rbins),
            self.dz,
            *self.pars,
        )
        # Code from JMP, whoever that is, by way of Charles
        farr = np.fft.fft2(cur_map)
        nx, ny = cur_map.shape
        kx = np.outer(np.fft.fftfreq(nx), np.zeros(ny).T + 1.0)
        ky = np.outer(np.zeros(nx).T + 1.0, np.fft.fftfreq(ny))
        k = np.sqrt(kx * kx + ky * ky)

        filt = self.table_filter_2d(k)
        farr *= filt

        return np.real(np.fft.ifft2(farr))

    def table_filter_2d(self, k) -> jax.Array:
        f = interpolate.interp1d(self.ks, self.xfer_vals)
        kbin_min = self.ks.min()
        kbin_max = self.ks.max()

        filt = k * 0.0
        filt[(k >= kbin_min) & (k <= kbin_max)] = f(
            k[(k >= kbin_min) & (k <= kbin_max)]
        )
        filt[(k < kbin_min)] = self.xfer_vals[self.ks == kbin_min]
        filt[(k > kbin_max)] = self.xfer_vals[self.ks == kbin_max]

        return filt

    @cached_property
    def model_grad(self) -> None:
        """
        The evaluated model and its gradient, see `core.model_grad` for details.
        Note that this is cached, but is automatically reset whenever
        `update` is called or `cur_round` changes. Currently computing
        grad for models with transfer function is not supported.

        Returns
        -------
        None
        """

        raise TypeError(
            "Error; Grad cannot currently be computed on Models with transfer function"
        )
        return None  # Shouldnt get here

model_grad cached property

The evaluated model and its gradient, see core.model_grad for details. Note that this is cached, but is automatically reset whenever update is called or cur_round changes. Currently computing grad for models with transfer function is not supported.

Returns:

Type Description
None