Skip to content

objective

Module for the objective functions used by LM fitting and MCMC. All objective functions should return a log-likelihood (modulo a DC offset) as well as the gradient and curvature of the log-likelihood with respect to the model parameters.

Note that everything in done in analogy to chi-squared so there is a factor of -2 applied as needed to the non chi-squared distributions.

chisq_objective(model, datavec, mode='tod', do_loglike=True, do_grad=True, do_curve=True)

Objective function to minimize when fitting a dataset where a Gaussian distribution is reasonible. This is an MPI aware function.

Parameters:

Name Type Description Default
model Model

The model object we are using to fit.

required
datavec TODVec | SolutionSet

The data to fit against. This is what we use to compute our fit residuals.

required
mode str

The type of data we compile this function for. Should be either "tod" or "map".

"tod"
do_loglike bool

If True then we will compute the chi-squared between the model and the data.

True
do_grad bool

If True then compute the gradient of chi-squared with respect to the model parameters.

True
do_curve bool

If True than compute the curvature of chi-squared with respect to the model parameters.

True

Returns:

Name Type Description
chisq Array

The chi-squared between the model and data. If do_loglike is False then this is jnp.array(0).

grad Array

The gradient of the parameters at there current values. If do_grad is False then this is an array of zeros. This is a (npar,) array.

curve Array

The curvature of the parameter space at the current values. If do_curve is False then this is an array of zeros. This is a (npar, npar) array.

Source code in witch/objective.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
@partial(jax.jit, static_argnames=("mode", "do_loglike", "do_grad", "do_curve"))
def chisq_objective(
    model: Model,
    datavec: TODVec | SolutionSet,
    mode: str = "tod",
    do_loglike: bool = True,
    do_grad: bool = True,
    do_curve: bool = True,
) -> tuple[jax.Array, jax.Array, jax.Array]:
    """
    Objective function to minimize when fitting a dataset where a Gaussian distribution is reasonible.
    This is an MPI aware function.

    Parameters
    ----------
    model : Model
        The model object we are using to fit.
    datavec: TODVec | SolutionSet
        The data to fit against.
        This is what we use to compute our fit residuals.
    mode : str, default: "tod"
        The type of data we compile this function for.
        Should be either "tod" or "map".
    do_loglike : bool, default: True
        If True then we will compute the chi-squared between
        the model and the data.
    do_grad : bool, default: True
        If True then compute the gradient of chi-squared with
        respect to the model parameters.
    do_curve : bool, default: True
        If True than compute the curvature of chi-squared with
        respect to the model parameters.

    Returns
    -------
    chisq : jax.Array
        The chi-squared between the model and data.
        If `do_loglike` is `False` then this is `jnp.array(0)`.
    grad : jax.Array
        The gradient of the parameters at there current values.
        If `do_grad` is `False` then this is an array of zeros.
        This is a `(npar,)` array.
    curve : jax.Array
        The curvature of the parameter space at the current values.
        If `do_curve` is `False` then this is an array of zeros.
        This is a `(npar, npar)` array.
    """
    if mode not in ["tod", "map"]:
        raise ValueError("Invalid mode")
    npar = len(model.pars)
    chisq = jnp.array(0)
    grad = jnp.zeros(npar)
    curve = jnp.zeros((npar, npar))

    zero = jnp.zeros((1, 1))
    only_chisq = not (do_grad or do_curve)

    for data in datavec:
        if mode == "tod":
            x = data.x * wu.rad_to_arcsec
            y = data.y * wu.rad_to_arcsec
            if only_chisq:
                pred_dat = model.to_tod(x, y)
                grad_dat = zero
            else:
                pred_dat, grad_dat = model.to_tod_grad(x, y)
        else:
            x, y = data.xy
            if only_chisq:
                pred_dat = model.to_map(x * wu.rad_to_arcsec, y * wu.rad_to_arcsec)
                grad_dat = zero
            else:
                pred_dat, grad_dat = model.to_map_grad(
                    x * wu.rad_to_arcsec, y * wu.rad_to_arcsec
                )

        resid = data.data - pred_dat
        resid_filt = data.noise.apply_noise(resid)
        if do_loglike:
            chisq += jnp.sum(resid * resid_filt)

        if only_chisq:
            continue

        grad_filt = jnp.zeros_like(grad_dat)
        for i in range(npar):
            grad_filt = grad_filt.at[i].set(
                data.noise.apply_noise(grad_dat.at[i].get())
            )
        grad_filt = jnp.reshape(grad_filt, (npar, -1))
        grad_dat = jnp.reshape(grad_dat, (npar, -1))
        resid = resid.ravel()

        if do_grad:
            grad = grad.at[:].add(jnp.dot(grad_filt, jnp.transpose(resid)))
        if do_curve:
            curve = curve.at[:].add(jnp.dot(grad_filt, jnp.transpose(grad_dat)))

    token = mpi4jax.barrier(comm=datavec.comm)
    if do_loglike:
        chisq, token = mpi4jax.allreduce(chisq, MPI.SUM, comm=datavec.comm, token=token)
    if do_grad:
        grad, token = mpi4jax.allreduce(grad, MPI.SUM, comm=datavec.comm, token=token)
    if do_curve:
        curve, token = mpi4jax.allreduce(curve, MPI.SUM, comm=datavec.comm, token=token)
    _ = token

    return chisq, grad, curve

poisson_objective(model, datavec, mode='tod', do_loglike=True, do_grad=True, do_curve=True)

Objective function to minimize when fitting a dataset where a Poisson distribution is reasonible. This is an MPI aware function.

Parameters:

Name Type Description Default
model Model

The model object we are using to fit.

required
datavec TODVec | SolutionSet

The data to fit against. This is what we use to compute our fit residuals.

required
mode str

The type of data we compile this function for. Should be either "tod" or "map".

"tod"
do_loglike bool

If True then we will compute the log-likelihood between the model and the data.

True
do_grad bool

If True then compute the gradient of chi-squared with respect to the model parameters.

True
do_curve bool

If True than compute the curvature of chi-squared with respect to the model parameters.

True

Returns:

Name Type Description
loglike Array

The log-likelihood between the model and data. If do_loglike is False then this is jnp.array(0). Note that there is a factor of -2 here to make it add with chi-squared.

grad Array

The gradient of the parameters at there current values. If do_grad is False then this is an array of zeros. This is a (npar,) array. Note that there is a factor of -2 here to make it add with the gradient of chi-squared.

curve Array

The curvature of the parameter space at the current values. If do_curve is False then this is an array of zeros. This is a (npar, npar) array. Note that there is a factor of -2 here to make it add with the cruvature of chi-squared.

Source code in witch/objective.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
@partial(jax.jit, static_argnames=("mode", "do_loglike", "do_grad", "do_curve"))
def poisson_objective(
    model: Model,
    datavec: TODVec | SolutionSet,
    mode: str = "tod",
    do_loglike: bool = True,
    do_grad: bool = True,
    do_curve: bool = True,
) -> tuple[jax.Array, jax.Array, jax.Array]:
    """
    Objective function to minimize when fitting a dataset where a Poisson distribution is reasonible.
    This is an MPI aware function.

    Parameters
    ----------
    model : Model
        The model object we are using to fit.
    datavec: TODVec | SolutionSet
        The data to fit against.
        This is what we use to compute our fit residuals.
    mode : str, default: "tod"
        The type of data we compile this function for.
        Should be either "tod" or "map".
    do_loglike : bool, default: True
        If True then we will compute the log-likelihood between
        the model and the data.
    do_grad : bool, default: True
        If True then compute the gradient of chi-squared with
        respect to the model parameters.
    do_curve : bool, default: True
        If True than compute the curvature of chi-squared with
        respect to the model parameters.

    Returns
    -------
    loglike : jax.Array
        The log-likelihood between the model and data.
        If `do_loglike` is `False` then this is `jnp.array(0)`.
        Note that there is a factor of -2 here to make it add with chi-squared.
    grad : jax.Array
        The gradient of the parameters at there current values.
        If `do_grad` is `False` then this is an array of zeros.
        This is a `(npar,)` array.
        Note that there is a factor of -2 here to make it add with the gradient of chi-squared.
    curve : jax.Array
        The curvature of the parameter space at the current values.
        If `do_curve` is `False` then this is an array of zeros.
        This is a `(npar, npar)` array.
        Note that there is a factor of -2 here to make it add with the cruvature of chi-squared.
    """
    if mode not in ["tod", "map"]:
        raise ValueError("Invalid mode")
    npar = len(model.pars)
    loglike = jnp.array(0)
    grad = jnp.zeros(npar)
    curve = jnp.zeros((npar, npar))

    zero = jnp.zeros((1, 1))
    only_loglike = not (do_grad or do_curve)

    for data in datavec:
        if mode == "tod":
            x = data.x * wu.rad_to_arcsec
            y = data.y * wu.rad_to_arcsec
            if only_loglike:
                pred_dat = model.to_tod(x, y)
                grad_dat = zero
            else:
                pred_dat, grad_dat = model.to_tod_grad(x, y)
        else:
            x, y = data.xy
            if only_loglike:
                pred_dat = model.to_map(x * wu.rad_to_arcsec, y * wu.rad_to_arcsec)
                grad_dat = zero
            else:
                pred_dat, grad_dat = model.to_map_grad(
                    x * wu.rad_to_arcsec, y * wu.rad_to_arcsec
                )

        resid = (data.data / pred_dat) - 1
        if do_loglike:
            loglike += jnp.sum(
                data.data * jnp.log(pred_dat) - pred_dat - jnp.log(factorial(data.data))
            )

        if only_loglike:
            continue

        grad_filt = jnp.zeros_like(grad_dat)
        for i in range(npar):
            grad_filt = grad_filt.at[i].set(
                data.noise.apply_noise(grad_dat.at[i].get())
            )
        grad_filt = jnp.reshape(grad_filt, (npar, -1))
        grad_dat = jnp.reshape(grad_dat, (npar, -1))
        resid = resid.ravel()

        if do_grad:
            grad = grad.at[:].add(jnp.dot(grad_dat, jnp.transpose(resid)))
        if do_curve:
            # Dropping the second term here, so Jon note for justification
            curve = curve.at[:].add(
                jnp.dot(
                    -1 * grad_dat * (data.data / (pred_dat**2)).ravel(),
                    jnp.transpose(grad_dat),
                )
            )

    token = mpi4jax.barrier(comm=datavec.comm)
    if do_loglike:
        loglike, token = mpi4jax.allreduce(
            loglike, MPI.SUM, comm=datavec.comm, token=token
        )
    if do_grad:
        grad, token = mpi4jax.allreduce(grad, MPI.SUM, comm=datavec.comm, token=token)
    if do_curve:
        curve, token = mpi4jax.allreduce(curve, MPI.SUM, comm=datavec.comm, token=token)
    _ = token

    return -2 * loglike, -2 * grad, -2 * curve