Skip to content

utils

A set of utility functions and constants used for unit conversions and adding generic structure common to multiple models.

K_CMB2K_RJ(freq)

Convert from K_CMB to K_RJ.

Parameters:

Name Type Description Default
freq float

The observing frequency in Hz.

required

Returns:

Name Type Description
K_CMB2K_RJ float

Conversion factor from K_CMB to K_RJ.

Source code in witch/utils.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
@partial(jax.jit, static_argnums=(0,))
def K_CMB2K_RJ(freq: float) -> float:
    """
    Convert from K_CMB to K_RJ.

    Parameters
    ----------
    freq : float
        The observing frequency in Hz.

    Returns
    -------
    K_CMB2K_RJ : float
        Conversion factor from K_CMB to K_RJ.
    """
    x = freq * h / kb / Tcmb
    return jnp.exp(x) * x * x / jnp.expm1(x) ** 2

beam_double_gauss(dr, fwhm1, amp1, fwhm2, amp2)

Helper function to generate a double gaussian beam.

Parameters:

Name Type Description Default
dr float

Pixel size.

required
fwhm1 float

Full width half max of the primary gaussian in the same units as dr.

required
amp1 float

Amplitude of the primary gaussian.

required
fwhm2 float

Full width half max of the secondairy gaussian in the same units as dr.

required
amp2 float

Amplitude of the secondairy gaussian.

required

Returns:

Type Description
beam: Double gaussian beam.
Source code in witch/utils.py
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
def beam_double_gauss(
    dr: float, fwhm1: float, amp1: float, fwhm2: float, amp2: float
) -> jax.Array:
    """
    Helper function to generate a double gaussian beam.

    Parameters
    ----------
    dr : float
        Pixel size.
    fwhm1 : float
        Full width half max of the primary gaussian in the same units as `dr`.
    amp1 : float
        Amplitude of the primary gaussian.
    fwhm2 : float
        Full width half max of the secondairy gaussian in the same units as `dr`.
    amp2 : float
        Amplitude of the secondairy gaussian.

    Returns
    -------
        beam: Double gaussian beam.
    """
    x = jnp.arange(-1.5 * fwhm1 // (dr), 1.5 * fwhm1 // (dr)) * (dr)
    beam_xx, beam_yy = jnp.meshgrid(x, x)
    beam_rr = jnp.sqrt(beam_xx**2 + beam_yy**2)
    beam = amp1 * jnp.exp(-4 * jnp.log(2) * beam_rr**2 / fwhm1**2) + amp2 * jnp.exp(
        -4 * jnp.log(2) * beam_rr**2 / fwhm2**2
    )
    return beam / jnp.sum(beam)

bilinear_interp(x, y, xp, yp, fp)

JAX implementation of bilinear interpolation. Out of bounds values are set to 0. Using the repeated linear interpolation method here, see https://en.wikipedia.org/wiki/Bilinear_interpolation#Repeated_linear_interpolation.

Parameters:

Name Type Description Default
x Array

X values to return interpolated values at.

required
y Array

Y values to return interpolated values at.

required
xp Array

X values to interpolate with, should be 1D. Assumed to be sorted.

required
yp Array

Y values to interpolate with, should be 1D. Assumed to be sorted.

required
fp Array

Functon values at (xp, yp), should have shape (len(xp), len(yp)). Note that if you are using meshgrid, we assume 'ij' indexing.

required

Returns:

Name Type Description
f Array

The interpolated values.

Source code in witch/utils.py
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
@jax.jit
def bilinear_interp(
    x: jax.Array, y: jax.Array, xp: jax.Array, yp: jax.Array, fp: jax.Array
) -> jax.Array:
    """
    JAX implementation of bilinear interpolation.
    Out of bounds values are set to 0.
    Using the repeated linear interpolation method here,
    see https://en.wikipedia.org/wiki/Bilinear_interpolation#Repeated_linear_interpolation.

    Parameters
    ----------
    x : jax.Array
        X values to return interpolated values at.
    y : jax.Array
        Y values to return interpolated values at.
    xp : jax.Array
        X values to interpolate with, should be 1D.
        Assumed to be sorted.
    yp : jax.Array
        Y values to interpolate with, should be 1D.
        Assumed to be sorted.
    fp : jax.Array
        Functon values at `(xp, yp)`, should have shape `(len(xp), len(yp))`.
        Note that if you are using meshgrid, we assume `'ij'` indexing.

    Returns
    -------
    f : jax.Array
        The interpolated values.
    """
    if len(xp.shape) != 1:
        raise ValueError("xp must be 1D")
    if len(yp.shape) != 1:
        raise ValueError("yp must be 1D")
    if fp.shape != xp.shape + yp.shape:
        raise ValueError(
            "Incompatible shapes for fp, xp, yp: %s, %s, %s",
            fp.shape,
            xp.shape,
            yp.shape,
        )

    # Figure out bounds and mapping
    # This breaks if xp, yp is not sorted
    ix = jnp.clip(jnp.searchsorted(xp, x, side="right"), 1, len(xp) - 1)
    iy = jnp.clip(jnp.searchsorted(yp, y, side="right"), 1, len(yp) - 1)
    q_11 = fp[ix - 1, iy - 1]
    q_21 = fp[ix, iy - 1]
    q_12 = fp[ix - 1, iy]
    q_22 = fp[ix, iy]

    # Interpolate in x to start
    denom_x = xp[ix] - xp[ix - 1]
    dx_1 = x - xp[ix - 1]
    dx_2 = xp[ix] - x
    f_xy1 = (dx_2 * q_11 + dx_1 * q_21) / denom_x
    f_xy2 = (dx_2 * q_12 + dx_1 * q_22) / denom_x

    # Now do y as well
    denom_y = yp[iy] - yp[iy - 1]
    dy_1 = y - yp[iy - 1]
    dy_2 = yp[iy] - y
    f = (dy_2 * f_xy1 + dy_1 * f_xy2) / denom_y

    # Zero out the out of bounds values
    f = jnp.where((x < xp[0]) + (x > xp[-1]) + (y < yp[0]) + (y > yp[-1]), 0.0, f)

    return f

fft_conv(image, kernel)

Perform a convolution using FFTs for speed with jax.

Parameters:

Name Type Description Default
image ArrayLike

Data to be convolved.

required
kernel ArrayLike

Convolution kernel.

required

Returns:

Name Type Description
convolved_map Array

Image convolved with kernel.

Source code in witch/utils.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
@jax.jit
def fft_conv(image: ArrayLike, kernel: ArrayLike) -> jax.Array:
    """
    Perform a convolution using FFTs for speed with jax.

    Parameters
    ----------
    image : ArrayLike
        Data to be convolved.
    kernel : ArrayLike
        Convolution kernel.

    Returns
    -------
    convolved_map : jax.Array
        Image convolved with kernel.
    """
    Fmap = jnp.fft.fft2(jnp.fft.fftshift(image))
    Fkernel = jnp.fft.fft2(jnp.fft.fftshift(kernel))
    convolved_map = jnp.fft.fftshift(jnp.real(jnp.fft.ifft2(Fmap * Fkernel)))

    return convolved_map

get_da(z)

Get factor to convert from arcseconds to MPc.

Parameters:

Name Type Description Default
z float

The redshift at which to compute the factor.

required

Returns:

Name Type Description
da float

Conversion factor from arcseconds to MPc.

Source code in witch/utils.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def get_da(z: float) -> float:
    """
    Get factor to convert from arcseconds to MPc.

    Parameters
    ----------
    z : float
        The redshift at which to compute the factor.

    Returns
    -------
    da : float
        Conversion factor from arcseconds to MPc.
    """
    return float(jnp.interp(z, dzline, daline))

get_hz(z)

Get the dimensionless hubble constant, h, at a given redshift.

Parameters:

Name Type Description Default
z float

The redshift at which to compute h.

required

Returns:

Name Type Description
hz float

h at the given z.

Source code in witch/utils.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def get_hz(z: float) -> float:
    """
    Get the dimensionless hubble constant, h, at a given redshift.

    Parameters
    ----------
    z : float
        The redshift at which to compute h.

    Returns
    -------
    hz : float
        h at the given z.
    """
    return float(jnp.interp(z, dzline, hzline))

get_nz(z)

Get the critical density at a given redshift.

Parameters:

Name Type Description Default
z float

The redshift at which to compute the critical density.

required

Returns:

Name Type Description
nz float

Critical density at the given z.

Source code in witch/utils.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def get_nz(z: float) -> float:
    """
    Get the critical density at a given redshift.

    Parameters
    ----------
    z : float
        The redshift at which to compute the critical density.

    Returns
    -------
    nz : float
        Critical density at the given z.
    """
    return float(jnp.interp(z, dzline, nzline))

make_grid(r_map, dx, dy=None, dz=None, x0=0, y0=0)

Make coordinate grids to build models in. All grids are sparse and are int(2*r_map / dr) in each the non-sparse dimension.

Parameters:

Name Type Description Default
r_map float

Size of grid radially.

required
dx float

Grid resolution in x, should be in same units as r_map.

required
dy Optional[float]

Grid resolution in y, should be in same units as r_map. If None then dy is set to dx.

None
dz Optional[float]

Grid resolution in z, should be in same units as r_map. If None then dz is set to dx.

None
x0 float

Origin of grid in RA, assumed to be in same units as r_map.

0
y0 float

Origin of grid in Dec, assumed to be in same units as r_map.

0

Returns:

Name Type Description
x Array

Grid of x coordinates in same units as r_map. Has shape (`int(2*r_map / dr), 1, 1).

y Array

Grid of y coordinates in same units as r_map. Has shape (1, `int(2*r_map / dr), 1).

z Array

Grid of z coordinates in same units as r_map. Has shape (1, 1, int(2*r_map / dr)).

x0 float

Origin of grid in RA, in same units as r_map.

y0 float

Origin of grid in Dec, in same units as r_map.

Source code in witch/utils.py
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
def make_grid(
    r_map: float,
    dx: float,
    dy: Optional[float] = None,
    dz: Optional[float] = None,
    x0: float = 0,
    y0: float = 0,
) -> Grid:
    """
    Make coordinate grids to build models in.
    All grids are sparse and are `int(2*r_map / dr)` in each the non-sparse dimension.

    Parameters
    ----------
    r_map : float
        Size of grid radially.
    dx : float
        Grid resolution in x, should be in same units as r_map.
    dy : Optional[float], default: None
        Grid resolution in y, should be in same units as r_map.
        If None then dy is set to dx.
    dz : Optional[float], default: None
        Grid resolution in z, should be in same units as r_map.
        If None then dz is set to dx.
    x0 : float, default: 0
        Origin of grid in RA, assumed to be in same units as r_map.
    y0 : float, default: 0
        Origin of grid in Dec, assumed to be in same units as r_map.

    Returns
    -------
    x : jax.Array
        Grid of x coordinates in same units as r_map.
        Has shape (`int(2*r_map / dr), 1, 1).
    y : jax.Array
        Grid of y coordinates in same units as r_map.
        Has shape (1, `int(2*r_map / dr), 1).
    z : jax.Array
        Grid of z coordinates in same units as r_map.
        Has shape (1, 1, `int(2*r_map / dr)`).
    x0 : float
        Origin of grid in RA, in same units as r_map.
    y0 : float
        Origin of grid in Dec, in same units as r_map.
    """
    if dy is None:
        dy = dx
    if dz is None:
        dz = dx

    # Make grid with resolution dr and size r_map
    x = (
        jnp.linspace(-1 * r_map, r_map, 2 * int(r_map / dx))
        / jnp.cos(y0 / rad_to_arcsec)
        + x0
    )
    y = jnp.linspace(-1 * r_map, r_map, 2 * int(r_map / dy)) + y0
    z = jnp.linspace(-1 * r_map, r_map, 2 * int(r_map / dz))
    x, y, z = jnp.meshgrid(x, y, z, sparse=True, indexing="ij")

    return (x, y, z, x0, y0)

make_grid_from_skymap(skymap, z_map, dz, x0=None, y0=None)

Make coordinate grids to build models in from a minkasi skymap. All grids are sparse and match the input map and xy and have size int(2*z_map/dz) in z. Unlike make_grid here we assume things are radians.

Parameters:

Name Type Description Default
skymap Skymap

The map to base the grid off of.

required
z_map float

Size of grid along LOS, in radians.

required
dz float

Grid resolution along LOS, in radians.

required
x0 Optional[float]

Map x center in radians. If None, grid center is used.

None
y0 Optional[float]

Map y center in radians. If None, grid center is used.

None

Returns:

Name Type Description
x Array

Grid of x coordinates in radians. Has shape (skymap.nx, 1, 1).

y Array

Grid of y coordinates in radians. Has shape (1, skymap.ny, 1).

z Array

Grid of z coordinates in same units as radians. Has shape (1, 1, int(2*z_map / dz)).

x0 float

Origin of grid in RA, in radians.

y0 float

Origin of grid in Dec, in radians.

Source code in witch/utils.py
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
def make_grid_from_skymap(
    skymap,
    z_map: float,
    dz: float,
    x0: Optional[float] = None,
    y0: Optional[float] = None,
) -> Grid:
    """
    Make coordinate grids to build models in from a minkasi skymap.
    All grids are sparse and match the input map and xy and have size `int(2*z_map/dz)` in z.
    Unlike `make_grid` here we assume things are radians.

    Parameters
    ----------
    skymap : minkasi.maps.Skymap
        The map to base the grid off of.
    z_map : float
        Size of grid along LOS, in radians.
    dz : float
        Grid resolution along LOS, in radians.
    x0 : Optional[float], default: None
        Map x center in radians.
        If None, grid center is used.
    y0 : Optional[float], default: None
        Map y center in radians. If None, grid center is used.

    Returns
    -------
    x : jax.Array
        Grid of x coordinates in radians.
        Has shape (`skymap.nx`, 1, 1).
    y : jax.Array
        Grid of y coordinates in radians.
        Has shape (1, `skymap.ny`, 1).
    z : jax.Array
        Grid of z coordinates in same units as radians.
        Has shape (1, 1, `int(2*z_map / dz)`).
    x0 : float
        Origin of grid in RA, in radians.
    y0 : float
        Origin of grid in Dec, in radians.
    """
    # make grid
    _x = jnp.arange(skymap.nx, dtype=float)
    _y = jnp.arange(skymap.ny, dtype=float)
    _z = jnp.linspace(-1 * z_map, z_map, 2 * int(z_map / dz), dtype=float)
    x, y, z = jnp.meshgrid(_x, _y, _z, sparse=True, indexing="ij")

    # Pad so we don't need to broadcast
    x_flat = x.ravel()
    y_flat = y.ravel()
    len_diff = len(x_flat) - len(y_flat)
    if len_diff > 0:
        y_flat = jnp.pad(y_flat, (0, len_diff), "edge")
    elif len_diff < 0:
        x_flat = jnp.pad(x_flat, (0, abs(len_diff)), "edge")

    # Convert x and y to ra/dec
    ra_dec = skymap.wcs.wcs_pix2world(
        jnp.column_stack((x_flat, y_flat)), 0, ra_dec_order=True
    )
    ra_dec = np.deg2rad(ra_dec)
    ra = ra_dec[:, 0]
    dec = ra_dec[:, 1]

    # Remove padding
    if len_diff > 0:
        dec = dec[: (-1 * len_diff)]
    elif len_diff < 0:
        ra = ra[:len_diff]

    if not x0:
        x0 = (skymap.lims[1] + skymap.lims[0]) / 2
    if not y0:
        y0 = (skymap.lims[3] + skymap.lims[2]) / 2

    if x0 is None or y0 is None:
        raise TypeError("Origin still None")

    ra -= x0
    dec -= y0

    # Sparse indexing to save mem
    x = x.at[:, 0, 0].set(ra)
    y = y.at[0, :, 0].set(dec)

    return x, y, z, float(x0), float(y0)

tod_hi_pass(tod, N_filt)

High pass a tod with a tophat

Parameters:

Name Type Description Default
tod Array

TOD to high pass.

required
N_filt int

N_filt of tophat.

required

Returns:

Name Type Description
tod_filtered Array

High pass filtered TOD

Source code in witch/utils.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
@partial(jax.jit, static_argnums=(1,))
def tod_hi_pass(tod: jax.Array, N_filt: int) -> jax.Array:
    """
    High pass a tod with a tophat

    Parameters
    ----------
    tod : jax.Array
        TOD to high pass.
    N_filt : int
        N_filt of tophat.

    Returns
    -------
    tod_filtered : jax.Array
        High pass filtered TOD
    """
    mask = jnp.ones(tod.shape)
    mask = mask.at[..., :N_filt].set(0.0)

    ## apply the filter in fourier space
    Ftod = jnp.fft.fft(tod)
    Ftod_filtered = Ftod * mask
    tod_filtered = jnp.fft.ifft(Ftod_filtered).real
    return tod_filtered

tod_to_index(xi, yi, x0, y0, grid, conv_factor=1.0)

Convert RA/Dec TODs to index space.

Parameters:

Name Type Description Default
xi NDArray[floating]

RA TOD, usually in radians

required
yi NDArray[floating]

Dec TOD, usually in radians

required
grid Grid

The grid to index on.

required
conv_factor float

Conversion factor to put RA and Dec in same units as the grid.

1.

Returns:

Name Type Description
idx Array

The RA TOD in index space

idy Array

The Dec TOD in index space.

Source code in witch/utils.py
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
def tod_to_index(
    xi: NDArray[np.floating],
    yi: NDArray[np.floating],
    x0: float,
    y0: float,
    grid: Grid,
    conv_factor: float = 1.0,
) -> tuple[jax.Array, jax.Array]:
    """
    Convert RA/Dec TODs to index space.

    Parameters
    ----------
    xi : NDArray[np.floating]
        RA TOD, usually in radians
    yi : NDArray[np.floating]
        Dec TOD, usually in radians
    grid : Grid
        The grid to index on.
    conv_factor : float, default: 1.
        Conversion factor to put RA and Dec in same units as the grid.

    Returns
    -------
    idx : jax.Array
        The RA TOD in index space
    idy : jax.Array
        The Dec TOD in index space.
    """
    x0, y0 = grid[-2:]
    dx = (xi - x0) * jnp.cos(yi)
    dy = yi - y0

    dx *= conv_factor
    dy *= conv_factor

    # Assuming sparse indexing here
    idx = np.digitize(dx, grid[0].ravel())
    idy = np.digitize(dy, grid[1].ravel())

    idx = np.rint(idx).astype(int)
    idy = np.rint(idy).astype(int)

    # Ensure out of bounds for stuff not in grid
    idx = jnp.where((idx < 0) + (idx >= grid[0].shape[0]), 2 * grid[0].shape[0], idx)
    idy = jnp.where((idy < 0) + (idy >= grid[1].shape[1]), 2 * grid[1].shape[1], idy)

    return idx, idy

transform_grid(dx, dy, dz, r_1, r_2, r_3, theta, xyz)

Shift, rotate, and apply ellipticity to coordinate grid. Note that the Grid type is an alias for tuple[jax.Array, jax.Array, jax.Array, float, float].

Parameters:

Name Type Description Default
dx float

Amount to move grid origin in x

required
dy float

Amount to move grid origin in y

required
dz float

Amount to move grid origin in z

required
r_1 float

Amount to scale along x-axis

required
r_2 float

Amount to scale along y-axis

required
r_3 float

Amount to scale along z-axis

required
theta float

Angle to rotate in xy-plane in radians

required
xyz Grid

Coordinte grid to transform

required

Returns:

Name Type Description
trasnformed Grid

Transformed coordinate grid.

Source code in witch/utils.py
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
@jax.jit
def transform_grid(
    dx: float,
    dy: float,
    dz: float,
    r_1: float,
    r_2: float,
    r_3: float,
    theta: float,
    xyz: Grid,
):
    """
    Shift, rotate, and apply ellipticity to coordinate grid.
    Note that the `Grid` type is an alias for `tuple[jax.Array, jax.Array, jax.Array, float, float]`.

    Parameters
    ----------
    dx : float
        Amount to move grid origin in x
    dy : float
        Amount to move grid origin in y
    dz : float
        Amount to move grid origin in z
    r_1 : float
        Amount to scale along x-axis
    r_2 : float
        Amount to scale along y-axis
    r_3 : float
        Amount to scale along z-axis
    theta : float
        Angle to rotate in xy-plane in radians
    xyz : Grid
        Coordinte grid to transform

    Returns
    -------
    trasnformed : Grid
        Transformed coordinate grid.
    """
    # Get origin
    x0, y0 = xyz[3], xyz[4]
    # Shift origin
    x = (xyz[0] - (x0 + dx / jnp.cos(y0 / rad_to_arcsec))) * jnp.cos(
        (y0 + dy) / rad_to_arcsec
    )
    y = xyz[1] - (y0 + dy)
    z = xyz[2] - dz

    # Rotate
    xx = x * jnp.cos(theta) + y * jnp.sin(theta)
    yy = y * jnp.cos(theta) - x * jnp.sin(theta)

    # Apply ellipticity
    x = xx / r_1
    y = yy / r_2
    z = z / r_3

    return x, y, z, x0 - dx, y0 - dy

y2K_CMB(freq, Te)

Convert from compton y to K_CMB.

Parameters:

Name Type Description Default
freq float

The observing frequency in Hz.

required
Te float

Electron temperature

required

Returns:

Name Type Description
y2K_CMB float

Conversion factor from compton y to K_CMB.

Source code in witch/utils.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
@partial(jax.jit, static_argnums=(0, 1))
def y2K_CMB(freq: float, Te: float) -> float:
    """
    Convert from compton y to K_CMB.

    Parameters
    ----------
    freq : float
        The observing frequency in Hz.
    Te : float
        Electron temperature

    Returns
    -------
    y2K_CMB : float
        Conversion factor from compton y to K_CMB.
    """
    x = freq * h / kb / Tcmb
    xt = x / jnp.tanh(0.5 * x)
    st = x / jnp.sinh(0.5 * x)
    # fmt:off
    Y0 = -4.0 + xt
    Y1 = (-10.0
        + ((47.0 / 2.0) + (-(42.0 / 5.0) + (7.0 / 10.0) * xt) * xt) * xt
        + st * st * (-(21.0 / 5.0) + (7.0 / 5.0) * xt)
    )
    Y2 = ((-15.0 / 2.0)
        + ((1023.0 / 8.0) + ((-868.0 / 5.0) + ((329.0 / 5.0) + ((-44.0 / 5.0) + (11.0 / 30.0) * xt) * xt) * xt) * xt) * xt
        + ((-434.0 / 5.0) + ((658.0 / 5.0) + ((-242.0 / 5.0) + (143.0 / 30.0) * xt) * xt) * xt
        + (-(44.0 / 5.0) + (187.0 / 60.0) * xt) * (st * st)) * st * st
    )
    Y3 = ((15.0 / 2.0)
        + ((2505.0 / 8.0) + ((-7098.0 / 5.0) + ((14253.0 / 10.0) + ((-18594.0 / 35.0) 
         + ((12059.0 / 140.0) + ((-128.0 / 21.0) + (16.0 / 105.0) * xt) * xt) * xt) * xt) * xt) * xt) * xt
        + (((-7098.0 / 10.0) + ((14253.0 / 5.0) + ((-102267.0 / 35.0) + ((156767.0 / 140.0)
         + ((-1216.0 / 7.0) + (64.0 / 7.0) * xt) * xt) * xt) * xt) * xt)
         + (((-18594.0 / 35.0) + ((205003.0 / 280.0) + ((-1920.0 / 7.0) + (1024.0 / 35.0) * xt) * xt) * xt)
          + ((-544.0 / 21.0) + (992.0 / 105.0) * xt) * st * st) * st * st) * st * st
    )
    Y4 = ((-135.0 / 32.0)
        + ((30375.0 / 128.0) + ((-62391.0 / 10.0) + ((614727.0 / 40.0) + ((-124389.0 / 10.0) + ((355703.0 / 80.0) + ((-16568.0 / 21.0)
         + ((7516.0 / 105.0) + ((-22.0 / 7.0) + (11.0 / 210.0) * xt) * xt) * xt) * xt) * xt) * xt) * xt) * xt) * xt
        + ((-62391.0 / 20.0) + ((614727.0 / 20.0) + ((-1368279.0 / 20.0) + ((4624139.0 / 80.0) + ((-157396.0 / 7.0) + ((30064.0 / 7.0)
         + ((-2717.0 / 7.0) + (2761.0 / 210.0) * xt) * xt) * xt) * xt) * xt) * xt) * xt
         + ((-124389.0 / 10.0)
          + ((6046951.0 / 160.0) + ((-248520.0 / 7.0) + ((481024.0 / 35.0) + ((-15972.0 / 7.0) + (18689.0 / 140.0) * xt) * xt) * xt) * xt) * xt
          + ((-70414.0 / 21.0) + ((465992.0 / 105.0) + ((-11792.0 / 7.0) + (19778.0 / 105.0) * xt) * xt) * xt
           + ((-682.0 / 7.0) + (7601.0 / 210.0) * xt) * st * st) * st * st) * st * st) * st * st
    )
    # fmt:on
    factor = Y0 + (Te / me) * (
        Y1 + (Te / me) * (Y2 + (Te / me) * (Y3 + (Te / me) * Y4))
    )
    return factor * Tcmb

y2K_RJ(freq, Te)

Convert from compton y to K_RJ.

Parameters:

Name Type Description Default
freq float

The observing frequency in Hz.

required
Te float

Electron temperature

required

Returns:

Name Type Description
y2K_RJ float

Conversion factor from compton y to K_RJ.

Source code in witch/utils.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
@partial(jax.jit, static_argnums=(0, 1))
def y2K_RJ(freq: float, Te: float) -> float:
    """
    Convert from compton y to K_RJ.

    Parameters
    ----------
    freq : float
        The observing frequency in Hz.
    Te : float
        Electron temperature

    Returns
    -------
    y2K_RJ : float
        Conversion factor from compton y to K_RJ.
    """
    factor = y2K_CMB(freq, Te)
    return factor * K_CMB2K_RJ(freq)