Skip to content

grid

Functions for building and working with the model grid.

make_grid(r_map, dx, dy=None, dz=None, x0=0, y0=0)

Make coordinate grids to build models in. All grids are sparse and are int(2*r_map / dr) in each the non-sparse dimension.

Parameters:

Name Type Description Default
r_map float

Size of grid radially.

required
dx float

Grid resolution in x, should be in same units as r_map.

required
dy Optional[float]

Grid resolution in y, should be in same units as r_map. If None then dy is set to dx.

None
dz Optional[float]

Grid resolution in z, should be in same units as r_map. If None then dz is set to dx.

None
x0 float

Origin of grid in RA, assumed to be in same units as r_map.

0
y0 float

Origin of grid in Dec, assumed to be in same units as r_map.

0

Returns:

Name Type Description
x Array

Grid of x coordinates in same units as r_map. Has shape (`int(2*r_map / dr), 1, 1).

y Array

Grid of y coordinates in same units as r_map. Has shape (1, `int(2*r_map / dr), 1).

z Array

Grid of z coordinates in same units as r_map. Has shape (1, 1, int(2*r_map / dr)).

x0 float

Origin of grid in RA, in same units as r_map.

y0 float

Origin of grid in Dec, in same units as r_map.

Source code in witch/grid.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def make_grid(
    r_map: float,
    dx: float,
    dy: Optional[float] = None,
    dz: Optional[float] = None,
    x0: float = 0,
    y0: float = 0,
) -> Grid:
    """
    Make coordinate grids to build models in.
    All grids are sparse and are `int(2*r_map / dr)` in each the non-sparse dimension.

    Parameters
    ----------
    r_map : float
        Size of grid radially.
    dx : float
        Grid resolution in x, should be in same units as r_map.
    dy : Optional[float], default: None
        Grid resolution in y, should be in same units as r_map.
        If None then dy is set to dx.
    dz : Optional[float], default: None
        Grid resolution in z, should be in same units as r_map.
        If None then dz is set to dx.
    x0 : float, default: 0
        Origin of grid in RA, assumed to be in same units as r_map.
    y0 : float, default: 0
        Origin of grid in Dec, assumed to be in same units as r_map.

    Returns
    -------
    x : jax.Array
        Grid of x coordinates in same units as r_map.
        Has shape (`int(2*r_map / dr), 1, 1).
    y : jax.Array
        Grid of y coordinates in same units as r_map.
        Has shape (1, `int(2*r_map / dr), 1).
    z : jax.Array
        Grid of z coordinates in same units as r_map.
        Has shape (1, 1, `int(2*r_map / dr)`).
    x0 : float
        Origin of grid in RA, in same units as r_map.
    y0 : float
        Origin of grid in Dec, in same units as r_map.
    """
    if dy is None:
        dy = dx
    if dz is None:
        dz = dx

    # Make grid with resolution dr and size r_map
    x = (
        jnp.linspace(-1 * r_map, r_map, 2 * int(r_map / dx))
        / jnp.cos(y0 / rad_to_arcsec)
        + x0
    )
    y = jnp.linspace(-1 * r_map, r_map, 2 * int(r_map / dy)) + y0
    z = jnp.linspace(-1 * r_map, r_map, 2 * int(r_map / dz))
    x, y, z = jnp.meshgrid(x, y, z, sparse=True, indexing="ij")

    return (x, y, z, x0, y0)

make_grid_from_wcs(wcs, nx, ny, z_map, dz, x0=None, y0=None)

Make coordinate grids to build models in from a minkasi skymap. All grids are sparse and match the input map and xy and have size int(2*z_map/dz) in z. Unlike make_grid here we assume things are radians.

Parameters:

Name Type Description Default
wcs WCS

The WCS to base the grid off of.

required
nx int

The number of pixels in x.

required
ny int

The number of pixels in y.

required
z_map float

Size of grid along LOS, in radians.

required
dz float

Grid resolution along LOS, in radians.

required
x0 Optional[float]

Map x center in radians. If None, grid center is used.

None
y0 Optional[float]

Map y center in radians. If None, grid center is used.

None

Returns:

Name Type Description
x Array

Grid of x coordinates in radians. Has shape (skymap.nx, 1, 1).

y Array

Grid of y coordinates in radians. Has shape (1, skymap.ny, 1).

z Array

Grid of z coordinates in same units as radians. Has shape (1, 1, int(2*z_map / dz)).

x0 float

Origin of grid in RA, in radians.

y0 float

Origin of grid in Dec, in radians.

Source code in witch/grid.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def make_grid_from_wcs(
    wcs: WCS,
    nx: int,
    ny: int,
    z_map: float,
    dz: float,
    x0: Optional[float] = None,
    y0: Optional[float] = None,
) -> Grid:
    """
    Make coordinate grids to build models in from a minkasi skymap.
    All grids are sparse and match the input map and xy and have size `int(2*z_map/dz)` in z.
    Unlike `make_grid` here we assume things are radians.

    Parameters
    ----------
    wcs : WCS
        The WCS to base the grid off of.
    nx : int
        The number of pixels in x.
    ny : int
        The number of pixels in y.
    z_map : float
        Size of grid along LOS, in radians.
    dz : float
        Grid resolution along LOS, in radians.
    x0 : Optional[float], default: None
        Map x center in radians.
        If None, grid center is used.
    y0 : Optional[float], default: None
        Map y center in radians. If None, grid center is used.

    Returns
    -------
    x : jax.Array
        Grid of x coordinates in radians.
        Has shape (`skymap.nx`, 1, 1).
    y : jax.Array
        Grid of y coordinates in radians.
        Has shape (1, `skymap.ny`, 1).
    z : jax.Array
        Grid of z coordinates in same units as radians.
        Has shape (1, 1, `int(2*z_map / dz)`).
    x0 : float
        Origin of grid in RA, in radians.
    y0 : float
        Origin of grid in Dec, in radians.
    """
    # make grid
    _x = jnp.arange(nx, dtype=float)
    _y = jnp.arange(ny, dtype=float)
    _z = jnp.linspace(-1 * z_map, z_map, 2 * int(z_map / dz), dtype=float)
    x, y, z = jnp.meshgrid(_x, _y, _z, sparse=True, indexing="ij")

    # Pad so we don't need to broadcast
    x_flat = x.ravel()
    y_flat = y.ravel()
    len_diff = len(x_flat) - len(y_flat)
    if len_diff > 0:
        y_flat = jnp.pad(y_flat, (0, len_diff), "edge")
    elif len_diff < 0:
        x_flat = jnp.pad(x_flat, (0, abs(len_diff)), "edge")

    # Convert x and y to ra/dec
    ra_dec = wcs.wcs_pix2world(jnp.column_stack((x_flat, y_flat)), 0, ra_dec_order=True)
    ra_dec = np.deg2rad(ra_dec)
    ra = ra_dec[:, 0]
    dec = ra_dec[:, 1]

    # Remove padding
    if len_diff > 0:
        dec = dec[: (-1 * len_diff)]
    elif len_diff < 0:
        ra = ra[:len_diff]

    if not x0:
        x0 = (np.max(ra) + np.min(ra)) / 2
    if not y0:
        y0 = (np.max(dec) + np.min(dec)) / 2

    if x0 is None or y0 is None:
        raise TypeError("Origin still None")

    ra -= x0
    dec -= y0

    # Sparse indexing to save mem
    x = x.at[:, 0, 0].set(ra)
    y = y.at[0, :, 0].set(dec)

    return x, y, z, float(x0), float(y0)

tod_to_index(xi, yi, x0, y0, grid, conv_factor=1.0)

Convert RA/Dec TODs to index space.

Parameters:

Name Type Description Default
xi NDArray[floating]

RA TOD, usually in radians

required
yi NDArray[floating]

Dec TOD, usually in radians

required
grid Grid

The grid to index on.

required
conv_factor float

Conversion factor to put RA and Dec in same units as the grid.

1.

Returns:

Name Type Description
idx Array

The RA TOD in index space

idy Array

The Dec TOD in index space.

Source code in witch/grid.py
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
def tod_to_index(
    xi: NDArray[np.floating],
    yi: NDArray[np.floating],
    x0: float,
    y0: float,
    grid: Grid,
    conv_factor: float = 1.0,
) -> tuple[jax.Array, jax.Array]:
    """
    Convert RA/Dec TODs to index space.

    Parameters
    ----------
    xi : NDArray[np.floating]
        RA TOD, usually in radians
    yi : NDArray[np.floating]
        Dec TOD, usually in radians
    grid : Grid
        The grid to index on.
    conv_factor : float, default: 1.
        Conversion factor to put RA and Dec in same units as the grid.

    Returns
    -------
    idx : jax.Array
        The RA TOD in index space
    idy : jax.Array
        The Dec TOD in index space.
    """
    x0, y0 = grid[-2:]
    dx = (xi - x0) * jnp.cos(yi)
    dy = yi - y0

    dx *= conv_factor
    dy *= conv_factor

    # Assuming sparse indexing here
    idx = np.digitize(dx, grid[0].ravel())
    idy = np.digitize(dy, grid[1].ravel())

    idx = np.rint(idx).astype(int)
    idy = np.rint(idy).astype(int)

    # Ensure out of bounds for stuff not in grid
    idx = jnp.where((idx < 0) + (idx >= grid[0].shape[0]), 2 * grid[0].shape[0], idx)
    idy = jnp.where((idy < 0) + (idy >= grid[1].shape[1]), 2 * grid[1].shape[1], idy)

    return idx, idy

transform_grid(dx, dy, dz, r_1, r_2, r_3, theta, xyz)

Shift, rotate, and apply ellipticity to coordinate grid. Note that the Grid type is an alias for tuple[jax.Array, jax.Array, jax.Array, float, float].

Parameters:

Name Type Description Default
dx float

Amount to move grid origin in x

required
dy float

Amount to move grid origin in y

required
dz float

Amount to move grid origin in z

required
r_1 float

Amount to scale along x-axis

required
r_2 float

Amount to scale along y-axis

required
r_3 float

Amount to scale along z-axis

required
theta float

Angle to rotate in xy-plane in radians

required
xyz Grid

Coordinte grid to transform

required

Returns:

Name Type Description
trasnformed Grid

Transformed coordinate grid.

Source code in witch/grid.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
@jax.jit
def transform_grid(
    dx: float,
    dy: float,
    dz: float,
    r_1: float,
    r_2: float,
    r_3: float,
    theta: float,
    xyz: Grid,
):
    """
    Shift, rotate, and apply ellipticity to coordinate grid.
    Note that the `Grid` type is an alias for `tuple[jax.Array, jax.Array, jax.Array, float, float]`.

    Parameters
    ----------
    dx : float
        Amount to move grid origin in x
    dy : float
        Amount to move grid origin in y
    dz : float
        Amount to move grid origin in z
    r_1 : float
        Amount to scale along x-axis
    r_2 : float
        Amount to scale along y-axis
    r_3 : float
        Amount to scale along z-axis
    theta : float
        Angle to rotate in xy-plane in radians
    xyz : Grid
        Coordinte grid to transform

    Returns
    -------
    trasnformed : Grid
        Transformed coordinate grid.
    """
    # Get origin
    x0, y0 = xyz[3], xyz[4]
    # Shift origin
    x = (xyz[0] - (x0 + dx / jnp.cos(y0 / rad_to_arcsec))) * jnp.cos(
        (y0 + dy) / rad_to_arcsec
    )
    y = xyz[1] - (y0 + dy)
    z = xyz[2] - dz

    # Rotate
    xx = x * jnp.cos(theta) + y * jnp.sin(theta)
    yy = y * jnp.cos(theta) - x * jnp.sin(theta)

    # Apply ellipticity
    x = xx / r_1
    y = yy / r_2
    z = z / r_3

    return x, y, z, x0 - dx, y0 - dy