Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions data_models/polarisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def polmatrixmultiply(cm, vec, polaxis=1):

For an image vec has axes [nchan, npol, ny, nx] and polaxis=1
For visibility vec has axes [row, nchan, npol] and polaxis=2
For blockvisibility vec has axes [row, ant, ant, nchan, npol] and polaxis=4

:param cm: matrix to apply
:param vec: array to be multiplied [...,:]
Expand All @@ -127,10 +128,19 @@ def polmatrixmultiply(cm, vec, polaxis=1):
return numpy.dot(cm, vec)
else:
# This tensor swaps the first two axes so we need to tranpose back
# e.g. if polaxis=2 1000, 3, 4 becomes 4, 1000, 3
result = numpy.tensordot(cm, vec, axes=(1, polaxis))
permut = list(range(len(result.shape)))
permut[0], permut[polaxis] = permut[polaxis], permut[0]
return numpy.transpose(result, axes=permut)
permut = list(range(len(vec.shape)))
assert polaxis < 4 and polaxis > 0, "Error in polarisation conversion logic"
if polaxis == 1:
permut[0], permut[1] = permut[1], permut[0]
elif polaxis == 2:
permut[0], permut[1], permut[2] = permut[1], permut[2], permut[0]
elif polaxis == 3:
permut[0], permut[1], permut[2] = permut[1], permut[2], permut[0]
transposed = numpy.transpose(result, axes=permut)
assert transposed.shape == vec.shape
return transposed


def convert_stokes_to_linear(stokes, polaxis=1):
Expand Down
6 changes: 2 additions & 4 deletions processing_components/imaging/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
log = logging.getLogger(__name__)


def shift_vis_to_image(vis: Visibility, im: Image, tangent: bool = True, inverse: bool = False) \
-> Visibility:
def shift_vis_to_image(vis: Union[Visibility, BlockVisibility], im: Image, tangent: bool = True, inverse: bool = False) \
-> Union[Visibility, BlockVisibility]:
"""Shift visibility to the FFT phase centre of the image in place

:param vis: Visibility data
Expand Down Expand Up @@ -78,8 +78,6 @@ def shift_vis_to_image(vis: Visibility, im: Image, tangent: bool = True, inverse
vis = phaserotate_visibility(vis, image_phasecentre, tangent=tangent, inverse=inverse)
vis.phasecentre = im.phasecentre

assert isinstance(vis, Visibility), "after phase_rotation, vis is not a Visibility"

return vis


Expand Down
200 changes: 80 additions & 120 deletions processing_components/imaging/ng.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,10 @@
"""
Functions that aid fourier transform processing. These are built on top of the core
functions in processing_library.fourier_transforms.
Functions that implement prediction of and imaging from visibilities using the nifty gridder.

The measurement equation for a sufficently narrow field of view interferometer is:
https://gitlab.mpcdf.mpg.de/ift/nifty_gridder

.. math::
This performs all necessary w term corrections, to high precision.

V(u,v,w) =\\int I(l,m) e^{-2 \\pi j (ul+vm)} dl dm


The measurement equation for a wide field of view interferometer is:

.. math::

V(u,v,w) =\\int \\frac{I(l,m)}{\\sqrt{1-l^2-m^2}} e^{-2 \\pi j (ul+vm + w(\\sqrt{1-l^2-m^2}-1))} dl dm

This and related modules contain various approachs for dealing with the wide-field problem where the
extra phase term in the Fourier transform cannot be ignored.
"""

import logging
Expand All @@ -26,63 +14,54 @@

from data_models.memory_data_models import Visibility, BlockVisibility, Image
from data_models.parameters import get_parameter
from ..visibility.base import copy_visibility
from data_models.polarisation import convert_pol_frame
from processing_components.image.operations import copy_image
from processing_components.imaging.base import shift_vis_to_image, normalize_sumwt
from ..visibility.base import copy_visibility

log = logging.getLogger(__name__)

try:
import nifty_gridder as ng

def predict_ng(bvis: Union[BlockVisibility, Visibility], model: Image, gcfcf=None, **kwargs) -> \
def predict_ng(bvis: Union[BlockVisibility, Visibility], model: Image, **kwargs) -> \
Union[BlockVisibility, Visibility]:
""" Predict using convolutional degridding.
Nifty-gridder version.

Nifty-gridder version. https://gitlab.mpcdf.mpg.de/ift/nifty_gridder

:param bvis: BlockVisibility to be predicted
:param model: model image

:return: resulting BlockVisibility (in place works)
"""

assert isinstance(bvis, BlockVisibility), bvis


if model is None:
return bvis

nthreads = get_parameter(kwargs, "threads", 4)
epsilon = get_parameter(kwargs, "epsilon", 6.0e-6)
epsilon = get_parameter(kwargs, "epsilon", 1e-12)
do_wstacking = get_parameter(kwargs, "do_wstacking", True)
verbosity = get_parameter(kwargs, "verbosity", 0)

assert isinstance(bvis, BlockVisibility), bvis
verbosity = get_parameter(kwargs, "verbosity", 2)

newbvis = copy_visibility(bvis, zero=True)

# Extracting data from BlockVisibility
freq = bvis.frequency # frequency, Hz
nants = bvis.uvw.shape[1]
ntimes = bvis.uvw.shape[0]
nbaselines = nants * (nants - 1) // 2
v_nchan = bvis.vis.shape[-2]
v_npol = bvis.vis.shape[-1]

uvw = numpy.zeros([ntimes * nbaselines, 3])
ms = numpy.zeros([ntimes * nbaselines, v_nchan, v_npol], dtype='complex')
nrows, nants, _, vnchan, vnpol = bvis.vis.shape

iflat = 0
for it in range(ntimes):
for iant1 in range(nants):
for iant2 in range(iant1 + 1, nants):
uvw[iflat, :] = newbvis.data['uvw'][it, iant2, iant1, :]
iflat += 1
uvw = newbvis.data['uvw'].reshape([nrows * nants * nants, 3])
vis = newbvis.data['vis'].reshape([nrows * nants * nants, vnchan, vnpol])

ms[:, :, :] = 0.0 + 0.0j # Make all vis data equal to 0 +0j
wgt = numpy.ones((ms.shape[0], ms.shape[2])) # All weights equal to 1.0
vis[...] = 0.0 + 0.0j # Make all vis data equal to 0 +0j

# Get the image properties
m_nchan, m_npol, ny, nx = model.data.shape
# Check if the number of frequency channels matches in bvis and a model
# assert (m_nchan == v_nchan)
assert (m_npol == v_npol)
# assert (m_nchan == v_nchan)
assert (m_npol == vnpol)

fuvw = uvw.copy()
# We need to flip the u and w axes. The flip in w is equivalent to the conjugation of the
Expand All @@ -94,37 +73,32 @@ def predict_ng(bvis: Union[BlockVisibility, Visibility], model: Image, gcfcf=Non
pixsize = numpy.abs(numpy.radians(model.wcs.wcs.cdelt[0]))

# Make de-gridding over a frequency range and pol fields
imchan = numpy.round(model.wcs.sub([4]).wcs_world2pix(freq, 0)[0]).astype('int')
for i in range(v_nchan):
for j in range(v_npol):
ngvis = ng.dirty2ms(fuvw.astype(numpy.float64),
freq[i:i + 1].astype(numpy.float64),
model.data[imchan[i], j, :, :].T.astype(numpy.float64),
wgt=wgt,
vis_to_im = numpy.round(model.wcs.sub([4]).wcs_world2pix(freq, 0)[0]).astype('int')
for vchan in range(vnchan):
imchan = vis_to_im[vchan]
for vpol in range(vnpol):
vis[..., vchan, vpol] = ng.dirty2ms(fuvw.astype(numpy.float64),
freq[vchan:vchan + 1].astype(numpy.float64),
model.data[imchan, vpol, :, :].T.astype(numpy.float64),
pixsize_x=pixsize,
pixsize_y=pixsize,
epsilon=epsilon,
do_wstacking=do_wstacking,
nthreads=nthreads,
verbosity=verbosity)
iflat = 0
for it in range(ntimes):
for iant1 in range(nants):
for iant2 in range(iant1 + 1, nants):
newbvis.data['vis'][it, iant2, iant1, i, j] = ngvis[iflat]
newbvis.data['vis'][it, iant1, iant2, i, j] = numpy.conjugate(ngvis[iflat])
iflat += 1
verbosity=verbosity)[:,0]

vis = convert_pol_frame(vis, model.polarisation_frame, bvis.polarisation_frame, polaxis=2)
newbvis.data['vis'] = vis.reshape([nrows, nants, nants, vnchan, vnpol])

# Now we can shift the visibility from the image frame to the original visibility frame
# sbvis = shift_vis_to_image(bvis, model, tangent=True, inverse=True)

return newbvis
return shift_vis_to_image(newbvis, model, tangent=True, inverse=True)


def invert_ng(bvis: BlockVisibility, model: Image, dopsf: bool = False, normalize: bool = True, gcfcf=None,
**kwargs) -> (
Image, numpy.ndarray):
def invert_ng(bvis: BlockVisibility, model: Image, dopsf: bool = False, normalize: bool = True,
**kwargs) -> (Image, numpy.ndarray):
""" Invert using nifty-gridder module

https://gitlab.mpcdf.mpg.de/ift/nifty_gridder

Use the image im as a template. Do PSF in a separate call.

Expand All @@ -134,50 +108,34 @@ def invert_ng(bvis: BlockVisibility, model: Image, dopsf: bool = False, normaliz
:param bvis: BlockVisibility to be inverted
:param im: image template (not changed)
:param normalize: Normalize by the sum of weights (True)
:return: resulting image
:return: sum of the weights for each frequency and polarization
:return: (resulting image, sum of the weights for each frequency and polarization)

"""

im = copy_image(model)


normalize = True

assert isinstance(bvis, BlockVisibility), bvis


im = copy_image(model)

nthreads = get_parameter(kwargs, "threads", 4)
epsilon = get_parameter(kwargs, "epsilon", 6.0e-6)
datacube = get_parameter(kwargs, "datacube", True)
epsilon = get_parameter(kwargs, "epsilon", 1e-12)
do_wstacking = get_parameter(kwargs, "do_wstacking", True)
verbosity = get_parameter(kwargs, "verbosity", 0)

sbvis = copy_visibility(bvis)

# sbvis = shift_vis_to_image(sbvis, im, tangent=True, inverse=False)
sbvis = copy_visibility(bvis)
sbvis = shift_vis_to_image(sbvis, im, tangent=True, inverse=False)

vis = bvis.vis

# Extracting data from BlockVisibility
freq = sbvis.frequency # frequency, Hz
uvw_nonzero = numpy.nonzero(sbvis.uvw[:, :, :, 0])
uvw = sbvis.uvw[uvw_nonzero] # UVW, meters [:,3]
ms = sbvis.vis[uvw_nonzero] # Visibility data [:,nfreq,npol]
# wgt = numpy.ones((ms.shape[0], ms.shape[2])) # All weights equal to 1.0
wgt = sbvis.imaging_weight[uvw_nonzero]

# Add up XX and YY if polarized data
if ms.shape[2] == 1: # Scalar
idx = [0] # Only I
else: # Polar
idx = [0, 3] # XX and YY
ms = numpy.sum(ms[:, :, idx], axis=2)

nrows, nants, _, vnchan, vnpol = vis.shape
uvw = sbvis.uvw.reshape([nrows * nants * nants, 3])
ms = vis.reshape([nrows * nants * nants, vnchan, vnpol])
wgt = sbvis.imaging_weight.reshape([nrows * nants * nants, vnchan, vnpol])

if dopsf:
ms[...] = 1.0 + 0.0j

wgt = numpy.sum(wgt[:, :, idx], axis=2)
# wgt = 1 / numpy.sum(1 / wgt, axis=1)

# Assign the weights to all frequencies
# wgt = numpy.repeat(wgt[:, None], len(freq), axis=1)
if epsilon > 5.0e-6:
ms = ms.astype("c8")
wgt = wgt.astype("f4")
Expand All @@ -186,38 +144,40 @@ def invert_ng(bvis: BlockVisibility, model: Image, dopsf: bool = False, normaliz
npixdirty = im.nwidth
pixsize = numpy.abs(numpy.radians(im.wcs.wcs.cdelt[0]))

# If non-spectral image
if im.nchan == 1:
datacube = False
# Else check if the number of frequencies in the image and MS match
else:
assert (im.nchan == len(freq))

sumwt = numpy.ones((im.nchan, im.npol))
fuvw = uvw.copy()
# We need to flip the u and w axes.
fuvw[:, 0] *= -1.0
fuvw[:, 2] *= -1.0
if not datacube:
dirty = ng.ms2dirty(
fuvw, freq, ms, wgt, npixdirty, npixdirty, pixsize, pixsize, epsilon,
do_wstacking=do_wstacking, nthreads=nthreads, verbosity=verbosity)
sumwt[0, 0] = numpy.sum(wgt)
if normalize:
dirty = dirty / sumwt[0, 0]
im.data[0][0] = dirty.T
else:
for i in range(len(freq)):
print(i, freq[i], freq[i:i + 1].shape, ms[:, i:i + 1].shape, wgt[:, i:i + 1].shape)
dirty = ng.ms2dirty(
fuvw, freq[i:i + 1], ms[:, i:i + 1], wgt[:, i:i + 1], npixdirty, npixdirty, pixsize, pixsize,
epsilon,
do_wstacking=do_wstacking, nthreads=nthreads, verbosity=verbosity)
sumwt[i, 0] = numpy.sum(wgt[:, i:i + 1])
if normalize:
dirty = dirty / sumwt[i, 0]
im.data[i][0] = dirty.T

nchan, npol, ny, nx = im.shape
im.data[...] = 0.0
sumwt = numpy.zeros([nchan, npol])

ms = convert_pol_frame(ms, bvis.polarisation_frame, im.polarisation_frame, polaxis=2)
# There's a latent problem here with the weights.
# wgt = numpy.real(convert_pol_frame(wgt, bvis.polarisation_frame, im.polarisation_frame, polaxis=2))

# Set up the conversion from visibility channels to image channels
vis_to_im = numpy.round(model.wcs.sub([4]).wcs_world2pix(freq, 0)[0]).astype('int')
for vchan in range(vnchan):
ichan = vis_to_im[vchan]
for pol in range(npol):
# Nifty gridder likes to receive contiguous arrays
ms_1d = numpy.array([ms[row, vchan:vchan+1, pol] for row in range(nrows * nants * nants)], dtype='complex')
ms_1d.reshape([ms_1d.shape[0], 1])
wgt_1d = numpy.array([wgt[row, vchan:vchan+1, pol] for row in range(nrows * nants * nants)])
wgt_1d.reshape([wgt_1d.shape[0], 1])
dirty = ng.ms2dirty(
fuvw, freq[vchan:vchan+1], ms_1d, wgt_1d,
npixdirty, npixdirty, pixsize, pixsize, epsilon, do_wstacking=do_wstacking,
nthreads=nthreads, verbosity=verbosity)
sumwt[ichan, pol] += numpy.sum(wgt[:, vchan, pol])
im.data[ichan, pol] += dirty.T

if normalize:
im = normalize_sumwt(im, sumwt)


return im, sumwt

except ImportError:
Expand Down
Loading