Skip to content

utils

A set of utility functions and constants used for unit conversions and adding generic structure common to multiple models.

K_CMB2K_RJ(freq)

Convert from K_CMB to K_RJ.

Arguments:

freq: The observing frequency in Hz.

Returns:

K_CMB2K_RJ: Conversion factor from K_CMB to K_RJ.
Source code in witch/utils.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
@partial(jax.jit, static_argnums=(0,))
def K_CMB2K_RJ(freq):
    """
    Convert from K_CMB to K_RJ.

    Arguments:

        freq: The observing frequency in Hz.

    Returns:

        K_CMB2K_RJ: Conversion factor from K_CMB to K_RJ.
    """
    x = freq * h / kb / Tcmb
    return jnp.exp(x) * x * x / jnp.expm1(x) ** 2

beam_double_gauss(dr, fwhm1=9.735, amp1=0.9808, fwhm2=32.627, amp2=0.0192)

Helper function to generate a double gaussian beam.

Arguments:

dr: Pixel size.

fwhm1: Full width half max of the primary gaussian.

amp1: Amplitude of the primary gaussian.

fwhm2: Full width half max of the secondairy gaussian.

amp2: Amplitude of the secondairy gaussian.

Returns:

beam: Double gaussian beam.
Source code in witch/utils.py
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
def beam_double_gauss(dr, fwhm1=9.735, amp1=0.9808, fwhm2=32.627, amp2=0.0192):
    """
    Helper function to generate a double gaussian beam.

    Arguments:

        dr: Pixel size.

        fwhm1: Full width half max of the primary gaussian.

        amp1: Amplitude of the primary gaussian.

        fwhm2: Full width half max of the secondairy gaussian.

        amp2: Amplitude of the secondairy gaussian.

    Returns:

        beam: Double gaussian beam.
    """
    x = jnp.arange(-1.5 * fwhm1 // (dr), 1.5 * fwhm1 // (dr)) * (dr)
    beam_xx, beam_yy = jnp.meshgrid(x, x)
    beam_rr = jnp.sqrt(beam_xx**2 + beam_yy**2)
    beam = amp1 * jnp.exp(-4 * jnp.log(2) * beam_rr**2 / fwhm1**2) + amp2 * jnp.exp(
        -4 * jnp.log(2) * beam_rr**2 / fwhm2**2
    )
    return beam / jnp.sum(beam)

bilinear_interp(x, y, xp, yp, fp)

JAX implementation of bilinear interpolation. Out of bounds values are set to 0. Using the repeated linear interpolation method here, see https://en.wikipedia.org/wiki/Bilinear_interpolation#Repeated_linear_interpolation.

Arguments:

x: X values to return interpolated values at.

y: Y values to return interpolated values at.

xp: X values to interpolate with, should be 1D.
    Assumed to be sorted.

yp: Y values to interpolate with, should be 1D.
    Assumed to be sorted.

fp: Functon values at (xp, yp), should have shape (len(xp), len(yp)).
    Note that if you are using meshgrid, we assume 'ij' indexing.

Return:

f: The interpolated values
Source code in witch/utils.py
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
@jax.jit
def bilinear_interp(x, y, xp, yp, fp):
    """
    JAX implementation of bilinear interpolation.
    Out of bounds values are set to 0.
    Using the repeated linear interpolation method here,
    see https://en.wikipedia.org/wiki/Bilinear_interpolation#Repeated_linear_interpolation.

    Arguments:

        x: X values to return interpolated values at.

        y: Y values to return interpolated values at.

        xp: X values to interpolate with, should be 1D.
            Assumed to be sorted.

        yp: Y values to interpolate with, should be 1D.
            Assumed to be sorted.

        fp: Functon values at (xp, yp), should have shape (len(xp), len(yp)).
            Note that if you are using meshgrid, we assume 'ij' indexing.

    Return:

        f: The interpolated values
    """
    if len(xp.shape) != 1:
        raise ValueError("xp must be 1D")
    if len(yp.shape) != 1:
        raise ValueError("yp must be 1D")
    if fp.shape != xp.shape + yp.shape:
        raise ValueError(
            "Incompatible shapes for fp, xp, yp: %s, %s, %s",
            fp.shape,
            xp.shape,
            yp.shape,
        )

    # Figure out bounds and mapping
    # This breaks if xp, yp is not sorted
    ix = jnp.clip(jnp.searchsorted(xp, x, side="right"), 1, len(xp) - 1)
    iy = jnp.clip(jnp.searchsorted(yp, y, side="right"), 1, len(yp) - 1)
    q_11 = fp[ix - 1, iy - 1]
    q_21 = fp[ix, iy - 1]
    q_12 = fp[ix - 1, iy]
    q_22 = fp[ix, iy]

    # Interpolate in x to start
    denom_x = xp[ix] - xp[ix - 1]
    dx_1 = x - xp[ix - 1]
    dx_2 = xp[ix] - x
    f_xy1 = (dx_2 * q_11 + dx_1 * q_21) / denom_x
    f_xy2 = (dx_2 * q_12 + dx_1 * q_22) / denom_x

    # Now do y as well
    denom_y = yp[iy] - yp[iy - 1]
    dy_1 = y - yp[iy - 1]
    dy_2 = yp[iy] - y
    f = (dy_2 * f_xy1 + dy_1 * f_xy2) / denom_y

    # Zero out the out of bounds values
    f = jnp.where((x < xp[0]) + (x > xp[-1]) + (y < yp[0]) + (y > yp[-1]), 0.0, f)

    return f

fft_conv(image, kernel)

Perform a convolution using FFTs for speed.

Arguments:

image: Data to be convolved

kernel: Convolution kernel

Returns:

convolved_map: Image convolved with kernel.
Source code in witch/utils.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
@jax.jit
def fft_conv(image, kernel):
    """
    Perform a convolution using FFTs for speed.

    Arguments:

        image: Data to be convolved

        kernel: Convolution kernel

    Returns:

        convolved_map: Image convolved with kernel.
    """
    Fmap = jnp.fft.fft2(jnp.fft.fftshift(image))
    Fkernel = jnp.fft.fft2(jnp.fft.fftshift(kernel))
    convolved_map = jnp.fft.fftshift(jnp.real(jnp.fft.ifft2(Fmap * Fkernel)))

    return convolved_map

get_da(z)

Get factor to convert from arcseconds to MPc.

Arguments:

z: The redshift at which to compute the factor.

Returns:

da: Conversion factor from arcseconds to MPc
Source code in witch/utils.py
142
143
144
145
146
147
148
149
150
151
152
153
154
def get_da(z):
    """
    Get factor to convert from arcseconds to MPc.

    Arguments:

        z: The redshift at which to compute the factor.

    Returns:

        da: Conversion factor from arcseconds to MPc
    """
    return jnp.interp(z, dzline, daline)

get_hz(z)

Get h(z).

Arguments:

z: The redshift at which to compute the factor.

Returns:

hz: h at the given z.
Source code in witch/utils.py
172
173
174
175
176
177
178
179
180
181
182
183
184
def get_hz(z):
    """
    Get h(z).

    Arguments:

        z: The redshift at which to compute the factor.

    Returns:

        hz: h at the given z.
    """
    return jnp.interp(z, dzline, hzline)

get_nz(z)

Get n(z).

Arguments:

z: The redshift at which to compute the factor.

Returns:

nz: n at the given z.
Source code in witch/utils.py
157
158
159
160
161
162
163
164
165
166
167
168
169
def get_nz(z):
    """
    Get n(z).

    Arguments:

        z: The redshift at which to compute the factor.

    Returns:

        nz: n at the given z.
    """
    return jnp.interp(z, dzline, nzline)

make_grid(r_map, dx, dy=None, dz=None, x0=0, y0=0)

Make coordinate grids to build models in. All grids are sparse and are int(2*r_map / dr) in each dimension.

Arguments:

r_map: Size of grid radially.

dx: Grid resolution in x, should be in same units as r_map.

dy: Grid resolution in y, should be in same units as r_map.
    If None then dy is set to dx.

dz: Grid resolution in z, should be in same units as r_map.
    If None then dz is set to dx.

x0: Origin of grid in RA, assumed to be in same units as r_map.

y0: Origin of grid in Dec, assumed to be in same units as r_map.

Returns:

x: Grid of x coordinates in same units as r_map.

y: Grid of y coordinates in same units as r_map

z: Grid of z coordinates in same units as r_map
Source code in witch/utils.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
def make_grid(r_map, dx, dy=None, dz=None, x0=0, y0=0):
    """
    Make coordinate grids to build models in.
    All grids are sparse and are int(2*r_map / dr) in each dimension.

    Arguments:

        r_map: Size of grid radially.

        dx: Grid resolution in x, should be in same units as r_map.

        dy: Grid resolution in y, should be in same units as r_map.
            If None then dy is set to dx.

        dz: Grid resolution in z, should be in same units as r_map.
            If None then dz is set to dx.

        x0: Origin of grid in RA, assumed to be in same units as r_map.

        y0: Origin of grid in Dec, assumed to be in same units as r_map.

    Returns:

        x: Grid of x coordinates in same units as r_map.

        y: Grid of y coordinates in same units as r_map

        z: Grid of z coordinates in same units as r_map
    """
    if dy is None:
        dy = dx
    if dz is None:
        dz = dx

    # Make grid with resolution dr and size r_map
    x = (
        jnp.linspace(-1 * r_map, r_map, 2 * int(r_map / dx))
        / jnp.cos(y0 / rad_to_arcsec)
        + x0
    )
    y = jnp.linspace(-1 * r_map, r_map, 2 * int(r_map / dy)) + y0
    z = jnp.linspace(-1 * r_map, r_map, 2 * int(r_map / dz))

    return tuple(jnp.meshgrid(x, y, z, sparse=True, indexing="ij") + [x0, y0])

make_grid_from_skymap(skymap, z_map, dz, x0=None, y0=None)

Make coordinate grids to build models in. All grids are sparse and are int(2*r_map / dr) in each dimension.

Arguments:

z_map: Size of grid along LOS, in radians.

dz: Grid resolution along LOS, in radians.

x0: Map x center in radians. If None, grid center is used.

y0: Map y center in radians. If None, grid center is used.

Returns:

x: Grid of x coordinates in radians.

y: Grid of y coordinates in radians.

z: Grid of z coordinates in radians.
Source code in witch/utils.py
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
def make_grid_from_skymap(skymap, z_map, dz, x0=None, y0=None):
    """
    Make coordinate grids to build models in.
    All grids are sparse and are int(2*r_map / dr) in each dimension.

    Arguments:

        z_map: Size of grid along LOS, in radians.

        dz: Grid resolution along LOS, in radians.

        x0: Map x center in radians. If None, grid center is used.

        y0: Map y center in radians. If None, grid center is used.

    Returns:

        x: Grid of x coordinates in radians.

        y: Grid of y coordinates in radians.

        z: Grid of z coordinates in radians.
    """
    # make grid
    _x = jnp.arange(skymap.nx, dtype=float)
    _y = jnp.arange(skymap.ny, dtype=float)
    _z = jnp.linspace(-1 * z_map, z_map, 2 * int(z_map / dz), dtype=float)
    x, y, z = jnp.meshgrid(_x, _y, _z, sparse=True, indexing="ij")

    # Pad so we don't need to broadcast
    x_flat = x.ravel()
    y_flat = y.ravel()
    len_diff = len(x_flat) - len(y_flat)
    if len_diff > 0:
        y_flat = jnp.pad(y_flat, (0, len_diff), "edge")
    elif len_diff < 0:
        x_flat = jnp.pad(x_flat, (0, abs(len_diff)), "edge")

    # Convert x and y to ra/dec
    ra_dec = skymap.wcs.wcs_pix2world(
        jnp.column_stack((x_flat, y_flat)), 0, ra_dec_order=True
    )
    ra_dec = np.deg2rad(ra_dec)
    ra = ra_dec[:, 0]
    dec = ra_dec[:, 1]

    # Remove padding
    if len_diff > 0:
        dec = dec[: (-1 * len_diff)]
    elif len_diff < 0:
        ra = ra[:len_diff]

    if not x0:
        x0 = (skymap.lims[1] + skymap.lims[0]) / 2
    if not y0:
        y0 = (skymap.lims[3] + skymap.lims[2]) / 2

    ra -= x0
    dec -= y0

    # Sparse indexing to save mem
    x = x.at[:, 0, 0].set(ra)
    y = y.at[0, :, 0].set(dec)

    return x, y, z

tod_hi_pass(tod, N_filt)

High pass a tod with a tophat

Arguments:

tod: TOD to high pass

N_filt: N_filt of tophat

Returns:

tod_filtered: Filtered TOD
Source code in witch/utils.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
@partial(jax.jit, static_argnums=(1,))
def tod_hi_pass(tod, N_filt):
    """
    High pass a tod with a tophat

    Arguments:

        tod: TOD to high pass

        N_filt: N_filt of tophat


    Returns:

        tod_filtered: Filtered TOD
    """
    mask = jnp.ones(tod.shape)
    mask = jax.ops.index_update(mask, jax.ops.index[..., :N_filt], 0.0)

    ## apply the filter in fourier space
    Ftod = jnp.fft.fft(tod)
    Ftod_filtered = Ftod * mask
    tod_filtered = jnp.fft.ifft(Ftod_filtered).real
    return tod_filtered

tod_to_index(xi, yi, x0, y0, grid, conv_factor)

Convert RA/Dec TODs to index space.

Arguments:

xi: RA TOD

yi: Dec TOD

x0: RA at center of model. Nominally the cluster center.

y0: Dec at center of model. Nominally the cluster center.

grid: The grid to index on.

conv_factor: Conversion factor to put RA and Dec in same units as r_map.
             Nominally (da * 180 * 3600) / pi

Returns:

idx: The RA TOD in index space

idy: The Dec TOD in index space.
Source code in witch/utils.py
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
def tod_to_index(xi, yi, x0, y0, grid, conv_factor):
    """
    Convert RA/Dec TODs to index space.

    Arguments:

        xi: RA TOD

        yi: Dec TOD

        x0: RA at center of model. Nominally the cluster center.

        y0: Dec at center of model. Nominally the cluster center.

        grid: The grid to index on.

        conv_factor: Conversion factor to put RA and Dec in same units as r_map.
                     Nominally (da * 180 * 3600) / pi

    Returns:

        idx: The RA TOD in index space

        idy: The Dec TOD in index space.
    """
    dx = (xi - x0) * jnp.cos(yi)
    dy = yi - y0

    dx *= conv_factor
    dy *= conv_factor

    # Assuming sparse indexing here
    idx = np.digitize(dx, grid[0].ravel())
    idy = np.digitize(dy, grid[1].ravel())

    idx = np.rint(idx).astype(int)
    idy = np.rint(idy).astype(int)

    # Ensure out of bounds for stuff not in grid
    idx = jnp.where((idx < 0) + (idx >= grid[0].shape[0]), 2 * grid[0].shape[0], idx)
    idy = jnp.where((idy < 0) + (idy >= grid[1].shape[1]), 2 * grid[1].shape[1], idy)

    return idx, idy

transform_grid(dx, dy, dz, r_1, r_2, r_3, theta, xyz)

Shift, rotate, and apply ellipticity to coordinate grid.

Arguments:

dx: RA of cluster center relative to grid origin

dy: Dec of cluster center relative to grid origin

dz: Line of sight offset of cluster center relative to grid origin

r_1: Amount to scale along x-axis

r_2: Amount to scale along y-axis

r_3: Amount to scale along z-axis

theta: Angle to rotate in xy-plane

xyz: Coordinte grid to transform

Returns:

xyz: Transformed coordinate grid
Source code in witch/utils.py
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
@jax.jit
def transform_grid(dx, dy, dz, r_1, r_2, r_3, theta, xyz):
    """
    Shift, rotate, and apply ellipticity to coordinate grid.

    Arguments:

        dx: RA of cluster center relative to grid origin

        dy: Dec of cluster center relative to grid origin

        dz: Line of sight offset of cluster center relative to grid origin

        r_1: Amount to scale along x-axis

        r_2: Amount to scale along y-axis

        r_3: Amount to scale along z-axis

        theta: Angle to rotate in xy-plane

        xyz: Coordinte grid to transform

    Returns:

        xyz: Transformed coordinate grid
    """
    # Get origin
    x0, y0 = xyz[3], xyz[4]
    # Shift origin
    x = (xyz[0] - (x0 + dx / jnp.cos(y0 / rad_to_arcsec))) * jnp.cos(
        (y0 + dy) / rad_to_arcsec
    )
    y = xyz[1] - (y0 + dy)
    z = xyz[2] - dz

    # Rotate
    xx = x * jnp.cos(theta) + y * jnp.sin(theta)
    yy = y * jnp.cos(theta) - x * jnp.sin(theta)

    # Apply ellipticity
    x = xx / r_1
    y = yy / r_2
    z = z / r_3

    return x, y, z

y2K_CMB(freq, Te)

Convert from compton y to K_CMB.

Arguments:

freq: The observing frequency in Hz.

Te: Electron temperature

Returns:

y2K_CMB: Conversion factor from compton y to K_CMB.
Source code in witch/utils.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
@partial(jax.jit, static_argnums=(0, 1))
def y2K_CMB(freq, Te):
    """
    Convert from compton y to K_CMB.

    Arguments:

        freq: The observing frequency in Hz.

        Te: Electron temperature

    Returns:

        y2K_CMB: Conversion factor from compton y to K_CMB.
    """
    x = freq * h / kb / Tcmb
    xt = x / jnp.tanh(0.5 * x)
    st = x / jnp.sinh(0.5 * x)
    # fmt:off
    Y0 = -4.0 + xt
    Y1 = (-10.0
        + ((47.0 / 2.0) + (-(42.0 / 5.0) + (7.0 / 10.0) * xt) * xt) * xt
        + st * st * (-(21.0 / 5.0) + (7.0 / 5.0) * xt)
    )
    Y2 = ((-15.0 / 2.0)
        + ((1023.0 / 8.0) + ((-868.0 / 5.0) + ((329.0 / 5.0) + ((-44.0 / 5.0) + (11.0 / 30.0) * xt) * xt) * xt) * xt) * xt
        + ((-434.0 / 5.0) + ((658.0 / 5.0) + ((-242.0 / 5.0) + (143.0 / 30.0) * xt) * xt) * xt
        + (-(44.0 / 5.0) + (187.0 / 60.0) * xt) * (st * st)) * st * st
    )
    Y3 = ((15.0 / 2.0)
        + ((2505.0 / 8.0) + ((-7098.0 / 5.0) + ((14253.0 / 10.0) + ((-18594.0 / 35.0) 
         + ((12059.0 / 140.0) + ((-128.0 / 21.0) + (16.0 / 105.0) * xt) * xt) * xt) * xt) * xt) * xt) * xt
        + (((-7098.0 / 10.0) + ((14253.0 / 5.0) + ((-102267.0 / 35.0) + ((156767.0 / 140.0)
         + ((-1216.0 / 7.0) + (64.0 / 7.0) * xt) * xt) * xt) * xt) * xt)
         + (((-18594.0 / 35.0) + ((205003.0 / 280.0) + ((-1920.0 / 7.0) + (1024.0 / 35.0) * xt) * xt) * xt)
          + ((-544.0 / 21.0) + (992.0 / 105.0) * xt) * st * st) * st * st) * st * st
    )
    Y4 = ((-135.0 / 32.0)
        + ((30375.0 / 128.0) + ((-62391.0 / 10.0) + ((614727.0 / 40.0) + ((-124389.0 / 10.0) + ((355703.0 / 80.0) + ((-16568.0 / 21.0)
         + ((7516.0 / 105.0) + ((-22.0 / 7.0) + (11.0 / 210.0) * xt) * xt) * xt) * xt) * xt) * xt) * xt) * xt) * xt
        + ((-62391.0 / 20.0) + ((614727.0 / 20.0) + ((-1368279.0 / 20.0) + ((4624139.0 / 80.0) + ((-157396.0 / 7.0) + ((30064.0 / 7.0)
         + ((-2717.0 / 7.0) + (2761.0 / 210.0) * xt) * xt) * xt) * xt) * xt) * xt) * xt
         + ((-124389.0 / 10.0)
          + ((6046951.0 / 160.0) + ((-248520.0 / 7.0) + ((481024.0 / 35.0) + ((-15972.0 / 7.0) + (18689.0 / 140.0) * xt) * xt) * xt) * xt) * xt
          + ((-70414.0 / 21.0) + ((465992.0 / 105.0) + ((-11792.0 / 7.0) + (19778.0 / 105.0) * xt) * xt) * xt
           + ((-682.0 / 7.0) + (7601.0 / 210.0) * xt) * st * st) * st * st) * st * st) * st * st
    )
    # fmt:on
    factor = Y0 + (Te / me) * (
        Y1 + (Te / me) * (Y2 + (Te / me) * (Y3 + (Te / me) * Y4))
    )
    return factor * Tcmb

y2K_RJ(freq, Te)

Convert from compton y to K_RJ.

Arguments:

freq: The observing frequency in Hz.

Te: Electron temperature

Returns:

y2K_RJ: Conversion factor from compton y to K_RJ.
Source code in witch/utils.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
@partial(jax.jit, static_argnums=(0, 1))
def y2K_RJ(freq, Te):
    """
    Convert from compton y to K_RJ.

    Arguments:

        freq: The observing frequency in Hz.

        Te: Electron temperature

    Returns:

        y2K_RJ: Conversion factor from compton y to K_RJ.
    """
    factor = y2K_CMB(freq, Te)
    return factor * K_CMB2K_RJ(freq)