"""main module to instantiate deepCR models and use them
"""
from os import path, mkdir
import math
import shutil
import numpy as np
import torch
import torch.nn as nn
from torch import from_numpy
from joblib import Parallel, delayed
from joblib import dump, load
from joblib import wrap_non_picklable_objects
from deepCR.unet import WrappedModel, UNet2Sigmoid
from deepCR.util import medmask
from learned_models import mask_dict, inpaint_dict, default_model_path
__all__ = ['deepCR']
[docs]class deepCR():
def __init__(self, mask='ACS-WFC-F606W-2-32', inpaint=None, device='CPU', hidden=32):
"""
Instantiation of deepCR with specified model configurations
Parameters
----------
mask : str
Either name of existing deepCR-mask model, or file path of your own model (incl. '.pth')
inpaint : (optional) str
Name of existing inpainting model to use. If left as None then by default use a simple 5x5 median mask
sampling for inpainting
device : str
One of 'CPU' or 'GPU'
hidden : int
Number of hidden channel for first deepCR-mask layer. Specify only if using custom deepCR-mask model.
Returns
-------
None
"""
if device == 'GPU':
self.dtype = torch.cuda.FloatTensor
self.dint = torch.cuda.ByteTensor
wrapper = nn.DataParallel
else:
self.dtype = torch.FloatTensor
self.dint = torch.ByteTensor
wrapper = WrappedModel
if mask in mask_dict.keys():
self.scale = mask_dict[mask][2]
mask_path = default_model_path + '/mask/' + mask + '.pth'
self.maskNet = wrapper(mask_dict[mask][0](*mask_dict[mask][1]))
else:
self.scale = 1
mask_path = mask
self.maskNet = wrapper(UNet2Sigmoid(1, 1, hidden))
self.maskNet.type(self.dtype)
if device != 'GPU':
self.maskNet.load_state_dict(torch.load(mask_path, map_location='cpu'))
else:
self.maskNet.load_state_dict(torch.load(mask_path))
self.maskNet.eval()
for p in self.maskNet.parameters():
p.required_grad = False
if inpaint is not None:
inpaint_path = default_model_path + '/inpaint/' + inpaint + '.pth'
self.inpaintNet = wrapper(inpaint_dict[inpaint][0](*inpaint_dict[inpaint][1])).type(self.dtype)
if device != 'GPU':
self.inpaintNet.load_state_dict(torch.load(inpaint_path, map_location='cpu'))
else:
self.inpaintNet.load_state_dict(torch.load(inpaint_path))
self.inpaintNet.eval()
for p in self.inpaintNet.parameters():
p.required_grad = False
else:
self.inpaintNet = None
[docs] def clean(self, img0, threshold=0.5, inpaint=True, binary=True, segment=False,
patch=256, parallel=False, n_jobs=-1):
"""
Identify cosmic rays in an input image, and (optionally) inpaint with the predicted cosmic ray mask
:param img0: (np.ndarray) 2D input image conforming to model requirements. For HST ACS/WFC, must be from
_flc.fits and in units of electrons in native resolution.
:param threshold: (float; [0, 1]) applied to probabilistic mask to generate binary mask
:param inpaint: (bool) return clean, inpainted image only if True
:param binary: return binary CR mask if True. probabilistic mask if False
:param segment: (bool) if True, segment input image into chunks of patch * patch before performing CR rejection.
Used for memory control.
:param patch: (int) Use 256 unless otherwise required. if segment==True, segment image into chunks of
patch * patch.
:param parallel: (bool) run in parallel if True and segment==True
:param n_jobs: (int) number of jobs to run in parallel, passed to `joblib.` Beware of memory overflow for
larger n_jobs.
:return: CR mask and (optionally) clean inpainted image
"""
# data pre-processing
img0 = img0.astype(np.float32) / self.scale
if not segment and not parallel:
return self.clean_(img0, threshold=threshold,
inpaint=inpaint, binary=binary)
else:
if not parallel:
return self.clean_large(img0, threshold=threshold,
inpaint=inpaint, binary=binary, patch=patch)
else:
return self.clean_large_parallel(img0, threshold=threshold,
inpaint=inpaint, binary=binary, patch=patch,
n_jobs=n_jobs)
[docs] def clean_(self, img0, threshold=0.5, inpaint=True, binary=True):
"""
given input image
return cosmic ray mask and (optionally) clean image
mask could be binary or probabilistic
:param img0: (np.ndarray) 2D input image
:param threshold: for creating binary mask from probabilistic mask
:param inpaint: return clean image only if True
:param binary: return binary mask if True. probabilistic mask otherwise.
:return: CR mask and (optionally) clean inpainted image
"""
shape = img0.shape
pad_x = 4 - shape[0] % 4
pad_y = 4 - shape[1] % 4
if pad_x == 4:
pad_x = 0
if pad_y == 4:
pad_y = 0
img0 = np.pad(img0, ((pad_x, 0), (pad_y, 0)), mode='constant')
shape = img0.shape[-2:]
img0 = from_numpy(img0).type(self.dtype).view(1, -1, shape[0], shape[1])
mask = self.maskNet(img0)
if not binary:
return mask.detach().cpu().view(shape[0], shape[1]).numpy()[pad_x:, pad_y:]
binary_mask = (mask > threshold).type(self.dtype)
if inpaint:
if self.inpaintNet is not None:
cat = torch.cat((img0 * (1 - binary_mask), binary_mask), dim=1)
img1 = self.inpaintNet(cat)
img1 = img1.detach()
inpainted = img1 * binary_mask + img0 * (1 - binary_mask)
binary_mask = binary_mask.detach().cpu().view(shape[0], shape[1]).numpy()
inpainted = inpainted.detach().cpu().view(shape[0], shape[1]).numpy()
else:
binary_mask = binary_mask.detach().cpu().view(shape[0], shape[1]).numpy()
img0 = img0.detach().cpu().view(shape[0], shape[1]).numpy()
img1 = medmask(img0, binary_mask)
inpainted = img1 * binary_mask + img0 * (1 - binary_mask)
if binary:
return binary_mask[pad_x:, pad_y:], inpainted[pad_x:, pad_y:] * 100
else:
mask = mask.detach().cpu().view(shape[0], shape[1]).numpy()
return mask[pad_x:, pad_y:], inpainted[pad_x:, pad_y:] * 100
else:
if binary:
binary_mask = binary_mask.detach().cpu().view(shape[0], shape[1]).numpy()
return binary_mask[pad_x:, pad_y:]
else:
mask = mask.detach().cpu().view(shape[0], shape[1]).numpy()
return mask[pad_x:, pad_y:]
[docs] def clean_large_parallel(self, img0, threshold=0.5, inpaint=True, binary=True,
patch=256, n_jobs=-1):
"""
given input image
return cosmic ray mask and (optionally) clean image
mask could be binary or probabilistic
:param img0: (np.ndarray) 2D input image
:param threshold: for creating binary mask from probabilistic mask
:param inpaint: return clean image only if True
:param binary: return binary mask if True. probabilistic mask otherwise.
:param patch: (int) Use 256 unless otherwise required. patch size to run deepCR on.
:param n_jobs: (int) number of jobs to run in parallel, passed to `joblib.` Beware of memory overflow for
larger n_jobs.
:return: CR mask and (optionally) clean inpainted image
"""
folder = './joblib_memmap'
try:
mkdir(folder)
except FileExistsError:
pass
im_shape = img0.shape
img0_dtype = img0.dtype
hh = int(math.ceil(im_shape[0]/patch))
ww = int(math.ceil(im_shape[1]/patch))
img0_filename_memmap = path.join(folder, 'img0_memmap')
dump(img0, img0_filename_memmap)
img0 = load(img0_filename_memmap, mmap_mode='r')
if inpaint:
img1_filename_memmap = path.join(folder, 'img1_memmap')
img1 = np.memmap(img1_filename_memmap, dtype=img0.dtype,
shape=im_shape, mode='w+')
else:
img1 = None
mask_filename_memmap = path.join(folder, 'mask_memmap')
mask = np.memmap(mask_filename_memmap, dtype=np.int8 if binary else img0_dtype,
shape=im_shape, mode='w+')
@wrap_non_picklable_objects
def fill_values(i, j, img0, img1, mask, patch, inpaint, threshold, binary):
img = img0[i * patch: min((i + 1) * patch, im_shape[0]), j * patch: min((j + 1) * patch, im_shape[1])]
if inpaint:
mask_, clean_ = self.clean_(img, threshold=threshold, inpaint=True, binary=binary)
mask[i * patch: min((i + 1) * patch, im_shape[0]), j * patch: min((j + 1) * patch, im_shape[1])] = mask_
img1[i * patch: min((i + 1) * patch, im_shape[0]), j * patch: min((j + 1) * patch, im_shape[1])] = clean_
else:
mask_ = self.clean_(img, threshold=threshold, inpaint=False, binary=binary)
mask[i * patch: min((i + 1) * patch, im_shape[0]), j * patch: min((j + 1) * patch, im_shape[1])] = mask_
results = Parallel(n_jobs=n_jobs, verbose=0)\
(delayed(fill_values)(i, j, img0, img1, mask, patch, inpaint, threshold, binary)
for i in range(hh) for j in range(ww))
mask = np.array(mask)
if inpaint:
img1 = np.array(img1)
try:
shutil.rmtree(folder)
except:
print('Could not clean-up automatically.')
if inpaint:
return mask, img1
else:
return mask
[docs] def clean_large(self, img0, threshold=0.5, inpaint=True, binary=True,
patch=256):
"""
given input image
return cosmic ray mask and (optionally) clean image
mask could be binary or probabilistic
:param img0: (np.ndarray) 2D input image
:param threshold: for creating binary mask from probabilistic mask
:param inpaint: return clean image only if True
:param binary: return binary mask if True. probabilistic mask otherwise.
:return: mask or binary mask; or None if internal call
"""
im_shape = img0.shape
hh = int(math.ceil(im_shape[0]/patch))
ww = int(math.ceil(im_shape[1]/patch))
img1 = np.zeros((im_shape[0], im_shape[1]))
mask = np.zeros((im_shape[0], im_shape[1]))
if inpaint:
for i in range(hh):
for j in range(ww):
img = img0[i * patch: min((i + 1) * patch, im_shape[0]), j * patch: min((j + 1) * patch, im_shape[1])]
mask_, clean_ = self.clean_(img, threshold=threshold, inpaint=True, binary=binary)
mask[i * patch: min((i + 1) * patch, im_shape[0]), j * patch: min((j + 1) * patch, im_shape[1])] = mask_
img1[i * patch: min((i + 1) * patch, im_shape[0]), j * patch: min((j + 1) * patch, im_shape[1])] = clean_
return mask, img1
else:
for i in range(hh):
for j in range(ww):
img = img0[i * patch: min((i + 1) * patch, im_shape[0]), j * patch: min((j + 1) * patch, im_shape[1])]
mask_ = self.clean_(img, threshold=threshold, inpaint=False, binary=binary)
mask[i * patch: min((i + 1) * patch, im_shape[0]), j * patch: min((j + 1) * patch, im_shape[1])] = mask_
return mask
[docs] def inpaint(self, img0, mask):
"""
inpaint img0 under mask
:param img0: (np.ndarray) input image
:param mask: (np.ndarray) inpainting mask
:return: inpainted clean image
"""
img0 = img0.astype(np.float32) / 100
mask = mask.astype(np.float32)
shape = img0.shape[-2:]
if self.inpaintNet is not None:
img0 = from_numpy(img0).type(self.dtype). \
view(1, -1, shape[0], shape[1])
mask = from_numpy(mask).type(self.dtype). \
view(1, -1, shape[0], shape[1])
cat = torch.cat((img0 * (1 - mask), mask), dim=1)
img1 = self.inpaintNet(cat)
img1 = img1.detach()
inpainted = img1 * mask + img0 * (1 - mask)
inpainted = inpainted.detach().cpu(). \
view(shape[0], shape[1]).numpy()
else:
img1 = medmask(img0, mask)
inpainted = img1 * mask + img0 * (1 - mask)
return inpainted * 100