diff --git a/02-sgd/.gitignore b/02-sgd/.gitignore new file mode 100644 index 0000000..1b2512e --- /dev/null +++ b/02-sgd/.gitignore @@ -0,0 +1,4 @@ +data +runs +checkpoints +.snakemake \ No newline at end of file diff --git a/02-sgd/README.md b/02-sgd/README.md index d7bc9cf..1bade68 100644 --- a/02-sgd/README.md +++ b/02-sgd/README.md @@ -2,16 +2,64 @@ This is a complete example demonstrating how MPoL works using simulated data. -Before starting, you should have already run the scripts in the `generate-mock-baselines` folder to produce a mock sky image and interferometer baselines in a file called `mock_data.npz`. Then, you should copy that file to this repository under `data/mock_data.npz`. +# Prerequisites -This repository assumes that you will run all scripts from this `sgd` directory (the one containing `sgd/README.md`). Some aspects of the workflow are automated with Snakemake ([`Snakefile`](Snakefile)). +Before starting, you should have already run the scripts `00` and `01` folders to produce mock baselines. Then, you will need to copy the `mock_data.npz` into this directory in a new `data` folder. For example, from within this 02 folder, run + +```shell +$ mkdir data +$ cp ../01-generate-mock-baselines/data/mock_data.npz data/ +``` + +# Installation + +You can install necessary Python packages into your environment by +```shell +$ pip install -r requirements.txt +``` + +and then you can run the code by + +```shell +snakemake -c1 all +``` + +# Description of Contents + +This repository assumes that you will run all scripts from this `02` directory (the one containing `02-sgd/README.md`). Some aspects of the workflow are automated with Snakemake ([`Snakefile`](Snakefile)). First, we recommend looking at [`src/load_data.py`](src/load_data.py) to see how mock visibilities $\mathcal{V}(u,v)$ are generated from the mock image and baselines. Then, we recommend looking at [`src/plot_baselines.py`](src/plot_baselines.py) and [`src/dirty_image.py`](src/dirty_image.py) to make diagnostic plots of the baseline and a dirty image of the data, to check that everything appears as you might expect. +You can run these simple scripts using + +``` +$ snakemake -c1 all +``` + +![baselines](analysis/baselines.png) + +![Dirty Beam and Image](analysis/dirty_image.png) + +# RML imaging workflow + The RML imaging workflow is demonstrated in [`src/sgd.py`](src/sgd.py). We recommend looking through that file before reading the rest of this document. If you are new to PyTorch idioms, we recommend familiarizing yourself with the [PyTorch basics](https://mpol-dev.github.io/MPoL/background.html#pytorch) first. +The RML imaging workflow is not part of the Snakemake workflow, instead, one runs the script like + +``` +$ python src/sgd.py --epochs=5 +``` + +Note this will just result in a short test. Run `python src/sgd.py --help` to see all available command line arguments, and see below for configurations that will result in better images. + +One can visualize the results using Tensorboard via + +```shell +$ tensorboard --logdir runs +``` + # Validation Since this example uses mock data, we have the advantage of knowing the true sky image. This allows us to calculate a 'validation loss' between the synthesized image and the true sky. @@ -25,25 +73,42 @@ This approach cannot be used with real datasets, obviously, but in this case aff If the dataset lacks many long baselines, it is unrealistic for RML to recover the native resolution of the image. In this case, we can calculate the validation score at resolutions coarser than the source image. We do this by convolving both $I_\mathrm{true}$ and $I_\mathrm{syn}$ with a 2D Gaussian described by FWHMs of $\theta_a, \theta_b$ before computing $L_\mathrm{validation}$. -# (lack of) Regularization -To demonstrate why regularization is needed for imaging workflows, try running without any: +# Example Result +Here is an example image produced with + +```shell +$ mkdir checkpoints +$ python src/sgd.py --tensorboard-log-dir=runs/ent0 --save-checkpoint=checkpoints/ent0.pt --lr 1e-1 --FWHM 0.05 --epochs=5 --lam-ent=1e-5 +``` + +and plotted with +```shell +$ python src/plot_image.py checkpoints/ent0.pt analysis/butterfly.png ``` -python src/sgd.py --tensorboard-log-dir=runs/nolam0 --epochs=40 --log-interval=2 --save-checkpoint=checkpoints/nolam0.pt --lr 1e-2 + +![RML Butterfly](analysis/butterfly.png) + + +# More examples of training loops +To demonstrate why regularization is needed for imaging workflows, try running without any: + +```shell +python src/sgd.py --tensorboard-log-dir=runs/nolam0 --epochs=10 --log-interval=2 --save-checkpoint=checkpoints/nolam0.pt --lr 1e-2 ``` If run to convergence, you'll find a classic case of overfitting to the lower S/N visibilities at longer baselines / higher spatial frequencies. This manifests in the image as small splotches and/or individual pixels with very high flux concentrations. If we didn't enforce non-negative pixels by construction, this would probably manifest as high frequency "noise" similar to uniformly-weighted images. You can spot this behavior by monitoring the training loss and the validation loss with iteration. You will see the [classic textbook signature of overfitting](https://d2l.ai/chapter_linear-regression/generalization.html#underfitting-or-overfitting): the validation loss decreases for a while but eventually turns around and increases, while the training loss monotonically decreases as it fits the signal and then eventually tries to fit all the noise. One could attempt to regularize this behavior away using early stopping. However, in practice with real data we would not have access to a validation, so we look to alternative regularization techniques. -# Maximum Entropy Regularization +## Maximum Entropy Regularization One can obtain a decent image using Maximum Entropy Regularization. Here are a few examples that you can run, saving checkpoints and resuming from finished models. We recommend that you examine the output using Tensorboard after each run, and make adjustments accordingly. Initial run with no entropy: ```shell -python src/sgd.py --tensorboard-log-dir=runs/exp0 --save-checkpoint=checkpoints/0.pt --lr 1e-2 --FWHM 0.05 --epochs=50 +python src/sgd.py --tensorboard-log-dir=runs/exp0 --save-checkpoint=checkpoints/0.pt --lr 1e-2 --FWHM 0.05 --epochs=10 ``` Resuming from previous model, and speeding up learning rate @@ -61,5 +126,4 @@ Adding entropy regularization, and reducing learning rate slightly. python src/sgd.py --tensorboard-log-dir=runs/ent0 --load-checkpoint=checkpoints/2.pt --save-checkpoint=checkpoints/ent0.pt --lr 1e-1 --FWHM 0.05 --epochs=50 --lam-ent=1e-5 ``` - -Note that we could have started directly with the entropy regularization if we wished. The previous just demonstrates an exploratory workflow. \ No newline at end of file +Note that we could have started directly with the entropy regularization if we wished. This collection just demonstrates an exploratory workflow. \ No newline at end of file diff --git a/02-sgd/Snakefile b/02-sgd/Snakefile index ecb413d..74a1d04 100644 --- a/02-sgd/Snakefile +++ b/02-sgd/Snakefile @@ -20,7 +20,9 @@ rule dirty_image: # python src/sgd.py --tensorboard-log-dir=runs/exp2 --load-checkpoint=checkpoints/1.pt --save-checkpoint=checkpoints/2.pt --lr 4e-1 --FWHM 0.05 --epochs=50 -# python src/sgd.py --tensorboard-log-dir=runs/ent0 --load-checkpoint=checkpoints/2.pt --save-checkpoint=checkpoints/ent0.pt --lr 1e-1 --FWHM 0.05 --epochs=50 --lam-ent=1e-5 +# python src/sgd.py --tensorboard-log-dir=runs/ent0 --save-checkpoint=checkpoints/ent0.pt --lr 1e-1 --FWHM 0.05 --epochs=50 --lam-ent=1e-5 -# vary fixed FWHM and entropy regularization to find best validation score. \ No newline at end of file +# vary fixed FWHM and entropy regularization to find best validation score. + +# \ No newline at end of file diff --git a/02-sgd/analysis/baselines.png b/02-sgd/analysis/baselines.png new file mode 100644 index 0000000..213c3a2 Binary files /dev/null and b/02-sgd/analysis/baselines.png differ diff --git a/02-sgd/analysis/butterfly.png b/02-sgd/analysis/butterfly.png new file mode 100644 index 0000000..9f25e7f Binary files /dev/null and b/02-sgd/analysis/butterfly.png differ diff --git a/02-sgd/analysis/dirty_image.png b/02-sgd/analysis/dirty_image.png new file mode 100644 index 0000000..4af4e93 Binary files /dev/null and b/02-sgd/analysis/dirty_image.png differ diff --git a/02-sgd/requirements.txt b/02-sgd/requirements.txt index e9c6109..d8b4f4e 100644 --- a/02-sgd/requirements.txt +++ b/02-sgd/requirements.txt @@ -1,3 +1,5 @@ mpol +tensorboard visread -matplotlib \ No newline at end of file +matplotlib +snakemake \ No newline at end of file diff --git a/02-sgd/src/plot_image.py b/02-sgd/src/plot_image.py new file mode 100644 index 0000000..431fa07 --- /dev/null +++ b/02-sgd/src/plot_image.py @@ -0,0 +1,55 @@ +import torch +import argparse +import matplotlib.pyplot as plt +from mpol import coordinates, images +from mpol.constants import arcsec +from astropy.visualization.mpl_normalize import simple_norm + +def main(): + parser = argparse.ArgumentParser(description="Compare image to DSHARP image") + parser.add_argument("load_checkpoint", metavar="load-checkpoint", help="Path to checkpoint from which to resume.") + parser.add_argument("plotfile") + args = parser.parse_args() + + # get the MPoL image from the checkpoint + coords = coordinates.GridCoords(cell_size=0.005, npix=1028) + checkpoint = torch.load(args.load_checkpoint, map_location=torch.device('cpu')) + + # get the image cube in packed format and run through an ImageCube to unpack + icube = images.ImageCube(coords=coords) + icube(checkpoint["model_state_dict"]["icube.packed_cube"]) + + # remove channel dimension + mpol_img = torch.squeeze(icube.sky_cube) + + lmargin = 1.0 + rmargin = lmargin + XX = 5. #in + ax_width = (XX - lmargin - rmargin) + ax_height = ax_width + + cax_sep = 0.05 + cax_width = 0.1 + tmargin = 0.05 + bmargin = 1.0 + YY = bmargin + ax_height + tmargin + + fig = plt.figure(figsize=(XX,YY)) + + ax = fig.add_axes((lmargin/XX, bmargin/YY, ax_width/XX, ax_height/YY)) + cax = fig.add_axes(((lmargin + ax_width + cax_sep)/XX, bmargin/YY, cax_width/XX, ax_height/YY)) + + im = ax.imshow(mpol_img, extent=coords.img_ext, origin="lower", cmap="inferno") + cbar = plt.colorbar(im, cax=cax) + cbar.ax.tick_params(labelsize=9) + cbar.set_label(r"Jy/arcsec$^2$") + + ax.set_xlabel(r"$\Delta \alpha \cos \delta$ [${}^{\prime\prime}$]") + ax.set_ylabel(r"$\Delta \delta$ [${}^{\prime\prime}$]") + + fig.subplots_adjust(wspace=0.25) + fig.savefig(args.plotfile, dpi=300) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/02-sgd/src/sgd.py b/02-sgd/src/sgd.py index 4c34e89..52a9935 100644 --- a/02-sgd/src/sgd.py +++ b/02-sgd/src/sgd.py @@ -195,7 +195,7 @@ def main(): parser.add_argument( "--batch-size", type=int, - default=2000, + default=1000, help="input batch size for training", ) parser.add_argument( @@ -207,7 +207,7 @@ def main(): parser.add_argument( "--lr", type=float, - default=1e-3, + default=1e-2, help="learning rate", ) parser.add_argument("--FWHM", type=float, default=0.05, help="FWHM of Gaussian Base layer in arcseconds.") @@ -257,6 +257,7 @@ def main(): vis_data.uu, vis_data.vv, vis_data.weight, vis_data.data ) + print("running on ", device) print("total vis", len(train_dataset)) # set the batch sizes for the loaders