Skip to content

core

Core module for generating models aed their gradients.

model = jax.jit(model, static_argnums=model_static) module-attribute

Generically create models with substructure.

Arguments:

xyz: Coordinate grid to compute profile on.

n_struct: Number of each structure to use.
          Should be in the same order as `order`.

dz: Factor to scale by while integrating.
    Since it is a global factor it can contain unit conversions.
    Historically equal to y2K_RJ * dr * da * XMpc / me.

beam: Beam to convolve by, should be a 2d array.

params: 1D array of model parameters.

Returns:

model: The model with the specified substructure evaluated on the grid.

model_grad = jax.jit(model_grad, static_argnums=model_grad_static) module-attribute

A wrapper around model that also returns the gradients of the model. Only the additional arguments are described here, see model for the others. Note that the additional arguments are passed before the *params argument.

Arguments:

argnums: The arguments to evaluate the gradient at

Returns:

model: The model with the specified substructure.

grad: The gradient of the model with respect to the model parameters.