import sys
import shutil
from pathlib import Path
from multiprocessing import cpu_count
import threading
import logging
import numpy as np
import tomopy
import dxchange
import h5py
import meta
from tomopy_cli import file_io
from tomopy_cli import config
from tomopy_cli import prep
from tomopy_cli import find_center
__all__ = ['rec', 'double_fov', 'double_fov_try', 'padded_rec', 'padding',
'unpadding', 'reconstruct', 'mask', 'reconstruction_folder']
log = logging.getLogger(__name__)
[docs]def rec(params):
data_shape = file_io.get_dx_dims(params)
# Read parameters from YAML file
try:
params = config.yaml_args(params, params.parameter_file, str(params.file_name))
except KeyError:
pass
# Read parameters from DXchange file if requested
params = file_io.auto_read_dxchange(params)
if params.rotation_axis <= 0:
params.rotation_axis = data_shape[2]/2
log.warning(' *** *** No rotation center given: assuming the middle of the projections at %f' % float(params.rotation_axis))
# Select sinogram range to reconstruct
if (params.reconstruction_type == "full"):
if params.start_row:
sino_start = params.start_row
else:
sino_start = 0
if params.end_row < 0:
sino_end = data_shape[1]
else:
sino_end = params.end_row
# If params.nsino_per_chunk < 1, use # of processor cores
if params.nsino_per_chunk < 1:
params.nsino_per_chunk = cpu_count()
nSino_per_chunk = params.nsino_per_chunk * pow(2, int(params.binning))
chunks = int(np.ceil((sino_end - sino_start)/nSino_per_chunk))
elif (params.reconstruction_type == 'try'):
_try_rec(params)
return
else: # "slice"
nSino_per_chunk = pow(2, int(params.binning))
chunks = 1
ssino = int(data_shape[1] * params.nsino)
sino_start = ssino
sino_end = sino_start + pow(2, int(params.binning))
if(params.start_proj):
sproj = params.start_proj
else:
sproj = 0
if(params.end_proj or params.end_proj<0):
eproj = params.end_proj
else:
eproj = data_shape[0]
pproj = (sproj, eproj)
log.info(" *** reconstructing [%d] slices from slice [%d] to [%d] in [%d] chunks of [%d] slices each" % (
(sino_end - sino_start) / pow(2, int(params.binning)),
sino_start/pow(2, int(params.binning)),
sino_end/pow(2, int(params.binning)),
chunks, nSino_per_chunk/pow(2, int(params.binning))))
strt = sino_start
write_threads = []
if chunks == 0:
log.warning(" *** 0 chunks selected for reconstruction, "
"check your *start_row*, "
"*end_row*, and *nsino_per_chunk*.")
for iChunk in range(0, chunks):
log.info('chunk # %i/%i' % (iChunk + 1, chunks))
sino = _compute_sino(iChunk, sino_start, sino_end, nSino_per_chunk, chunks, params)
# Read APS 32-BM raw data.
proj, flat, dark, theta, rotation_axis = file_io.read_tomo(sino, pproj, params)
# What if sino overruns the size of data?
if sino[1] - sino[0] > proj.shape[1]:
log.warning(" *** Chunk size > remaining data size.")
sino = [sino[0], sino[0] + proj.shape[1]]
# Apply all preprocessing functions
data = prep.all(proj, flat, dark, params, sino)
del(proj, flat, dark)
# unpad after phase retrieval
if params.retrieve_phase_method == "paganin":
params.phase_pad //= pow(2, int(params.binning))
sino -= params.phase_pad
data = data[:,-params.phase_pad[0]:data.shape[1]-params.phase_pad[1]]
log.info(' *** unpadding after phase retrieval gives slices [%i,%i] ' % (sino[0],sino[1]))
# Reconstruct: this is for "slice" and "full" methods
rotation_axis_rec = rotation_axis
if (params.file_type == 'double_fov'):
if(rotation_axis<data.shape[-1]//2):
#if rotation center is on the left side of the ROI
data = data[:,:,::-1]
rotation_axis_rec = data.shape[-1]-rotation_axis
# double FOV by adding zeros
data = double_fov(data,rotation_axis_rec)
#Perform actual reconstruction
rec = padded_rec(data, theta, rotation_axis_rec, params)
# Save images
recon_base_dir = reconstruction_folder(params)
fpath = Path(params.file_name).resolve()
if params.reconstruction_type == "full":
recon_dir = recon_base_dir / "{}_rec".format(fpath.stem)
if params.save_format == 'tiff':
fname = recon_dir / 'recon'
log.debug("Full tiff dir: %s", fname)
write_thread = threading.Thread(target=dxchange.write_tiff_stack,
args = (rec,),
kwargs = {'fname': str(fname),
'start': strt,
'overwrite': True})
elif params.save_format == "h5":
# HDF5 output
fname = "{}.hdf".format(recon_dir)
# file_io.write_hdf5(rec, fname=str(fname), dest_idx=slice(strt, strt+rec.shape[0]),
# maxsize=(sino_end, *rec.shape[1:]), overwrite=(iChunk==0))
ds_end = int(np.ceil(sino_end / pow(2, int(params.binning))))
write_thread = threading.Thread(target=file_io.write_hdf5,
args = (rec,),
kwargs = {'fname': str(fname),
'dname': '/exchange/data',
'dest_idx': slice(strt, strt+rec.shape[0]),
'maxsize': (ds_end, *rec.shape[1:]),
'overwrite': iChunk==0})
else:
log.error(" *** Unknown save_format '%s'", params.save_format)
fname = "<Not saved (bad output-format)>"
write_thread = None
# Save the data to disk
if write_thread is not None:
write_thread.start()
write_threads.append(write_thread)
# Increment counter for which chunks to save
strt += (sino[1] - sino[0])
elif params.reconstruction_type == "slice":
# Construct the path for where to save the tiffs
fname = recon_base_dir / 'slice_rec' / 'recon_{}'.format(fpath.stem)
dxchange.write_tiff(rec, fname=str(fname), overwrite=False)
else:
raise ValueError("Unknown value for *reconstruction type*: {}. "
"Valid options are {}"
"".format(params.reconstruction_type,
config.SECTIONS['reconstruction']['reconstruction-type']['choices']))
log.info(" *** reconstructions: %s" % fname)
# Wait until the all threads are done writing data
for thread in write_threads:
thread.join()
if params.save_format == "h5":
log.info('adding meta data from the raw to the recon hdf file')
log.info(" *** raw hdf: %s" % params.file_name)
log.info(" *** rec hdf: %s" % fname)
tree, meta_dict = meta.read_hdf(params.file_name)
with h5py.File(fname, 'a') as hf:
for key, value in meta_dict.items():
# print(key, value)
dset = hf.create_dataset(key, data=value[0])
if value[1] is not None:
dset.attrs['units'] = value[1]
def _compute_sino(iChunk, sino_start, sino_end, nSino_per_chunk, chunks, params):
'''Computes a 2-element array to give starting and ending slices
for this chunk.
'''
sino_chunk_start = int(sino_start + nSino_per_chunk*iChunk)
sino_chunk_end = int(sino_start + nSino_per_chunk*(iChunk+1))
if sino_chunk_end > sino_end:
log.warning(' *** asking to go to row {0:d}, but our end row is {1:d}'.format(sino_chunk_end, sino_end))
sino_chunk_end = sino_end
log.info(' *** [%i, %i]' % (sino_chunk_start/pow(2, int(params.binning)), sino_chunk_end/pow(2, int(params.binning))))
sino = np.array((int(sino_chunk_start), int(sino_chunk_end)))
# extra data for padded phase retrieval
if params.retrieve_phase_method == "paganin":
phase_pad = np.zeros(2,dtype=int)
if(iChunk>0):
phase_pad[0] = -params.retrieve_phase_pad
if (iChunk<chunks-1):
phase_pad[1] = params.retrieve_phase_pad
sino += phase_pad
log.info(' *** extra padding for phase retrieval gives slices [%i,%i] to be read from memory ' % (sino[0],sino[1]))
params.phase_pad = phase_pad
return sino
def _try_rec(params):
log.info(" *** *** starting 'try' reconstruction")
data_shape = file_io.get_dx_dims(params)
# Select sinogram range to reconstruct
nSino_per_chunk = pow(2, int(params.binning))
sino_start = int(data_shape[1] * params.nsino)
sino_end = sino_start + pow(2, int(params.binning))
if sino_end > data_shape[1]:
log.warning(' *** *** *** binning would request row past end of data. Truncating.')
sino_start = data_shape[1] - pow(2, int(params.binning))
sino_end = data_shape[1]
log.info("reconstructing a slice binned from raw data rows [%d] to [%d]" % \
(sino_start, sino_end))
log.info(' *** binned rows [%i, %i]' % (sino_start/pow(2, int(params.binning)), sino_end/pow(2, int(params.binning))))
sino = (int(sino_start), int(sino_end))
if(params.start_proj):
sproj = params.start_proj
else:
sproj = 0
if(params.end_proj):
eproj = params.end_proj
if not params.end_proj or params.end_proj == -1:
eproj = data_shape[0] + 1
pproj = (sproj, eproj)
# Set up the centers of rotation we will use
# Read APS 32-BM raw data.
proj, flat, dark, theta, rotation_axis = file_io.read_tomo(sino, pproj, params, True)
# Apply all preprocessing functions
data = prep.all(proj, flat, dark, params, sino)
rec = []
center_range = []
# try passes an array of rotation centers and this is only supported by gridrec
# reconstruction_algorithm_org = params.reconstruction_algorithm
# params.reconstruction_algorithm = 'gridrec'
if (params.file_type == 'standard' or params.file_type == 'double_fov'):
center_search_width = params.center_search_width/np.power(2, float(params.binning))
center_range = np.arange(rotation_axis-center_search_width, rotation_axis+center_search_width, 0.5)
# stack = np.empty((len(center_range), data_shape[0], int(data_shape[2])))
if (params.blocked_views):
# blocked_views = params.blocked_views_end - params.blocked_views_start
# stack = np.empty((len(center_range), data_shape[0]-blocked_views, int(data_shape[2])))
st = params.blocked_views_start
end = params.blocked_views_end
#log.warning('%f %f',st,end)
ids = np.where(((theta-st)%np.pi<0) + ((theta-st)%np.pi>end-st))[0]
stack = np.empty((len(center_range), len(ids), int(data_shape[2])))
else:
stack = np.empty((len(center_range), data.shape[0], int(data.shape[2])))
for i, axis in enumerate(center_range):
stack[i] = data[:, 0, :]
log.warning(' reconstruct slice [%d] with rotation axis range [%.2f - %.2f] in [%.2f] pixel steps'
% (sino_start, center_range[0], center_range[-1], center_range[1] - center_range[0]))
center_range_rec = center_range
if (params.file_type == 'double_fov'):
if(rotation_axis<stack.shape[-1]//2):
#if rotation center is on the left side of the ROI
stack = stack[:,:,::-1]
center_range_rec = stack.shape[-1]-center_range
# double FOV by adding zeros
stack = double_fov_try(stack,center_range_rec)
if params.reconstruction_algorithm == 'gridrec':
rec = padded_rec(stack, theta, center_range_rec, params)
else:
log.warning(" *** Doing try_center with '%s' instead of 'gridrec' is slow.", params.reconstruction_algorithm)
rec = []
for center in center_range_rec:
rec.append(padded_rec(data[:, 0:1, :], theta, center, params))
rec = np.asarray(rec)
else:
rotation_axis = params.rotation_axis_flip // pow(2,int(params.binning))
center_search_width = params.center_search_width/np.power(2, float(params.binning))
center_range = np.arange(rotation_axis-center_search_width, rotation_axis+center_search_width, 0.5)
stitched_data = []
rot_centers = np.zeros_like(center_range)
#Loop through the assumed rotation centers
for i, rot_center in enumerate(center_range):
params.rotation_axis_flip = rot_center
temp = file_io.flip_and_stitch(params, data, np.ones_like(data[0,...]),
np.zeros_like(data[0,...]), theta)
stitched_data.append(temp[0])
theta180 = temp[3]
rot_centers[i] = params.rotation_axis
total_cols = np.min([i.shape[2] for i in stitched_data])
stack = np.empty((len(center_range), theta180.shape[0], total_cols))
for i in range(center_range.shape[0]):
stack[i] = stitched_data[i][:theta180.shape[0],0,:total_cols]
del(stitched_data)
rec = padded_rec(stack, theta180, rot_centers, params)
# Save images to a temporary folder.
fpath = Path(params.file_name).resolve()
rec_dir = reconstruction_folder(params) / 'try_center' / fpath.stem
for i,axis in enumerate(center_range):
this_center = axis * np.power(2, float(params.binning))
rfname = rec_dir / "recon_{:.2f}.tiff".format(this_center)
dxchange.write_tiff(rec[i], fname=str(rfname), overwrite=True)
[docs]def double_fov(data,rotation_axis):
# smooth the sinogram border with a smooth weigting function from 0 to 1
w = max(1,int(2*(data.shape[-1]-rotation_axis)))
v = np.linspace(1,0,w,endpoint=False)
v = v**5*(126-420*v+540*v**2-315*v**3+70*v**4)
data[:,:,-w:] *= v
# double sinogram size with adding 0
data = np.pad(data,((0,0),(0,0),(0,data.shape[-1])),'constant')
return data
[docs]def double_fov_try(data,rotation_axis):
# smooth the sinogram border with a smooth weigting function from 0 to 1
for r_axis in range(len(rotation_axis)):
w = max(1,int(2*(data.shape[-1]-rotation_axis[r_axis])))
v = np.linspace(1,0,w,endpoint=False)
v = v**5*(126-420*v+540*v**2-315*v**3+70*v**4)
data[r_axis,:,-w:] *= v
# double sinogram size with adding 0
data = np.pad(data,((0,0),(0,0),(0,data.shape[-1])),'constant')
return data
[docs]def padded_rec(data, theta, rotation_axis, params):
# original shape
N = data.shape[2]
# padding
data, padded_rotation_axis = padding(data, rotation_axis, params)
# reconstruct object
rec = reconstruct(data, theta, padded_rotation_axis, params)
# un-padding - restore shape
rec = unpadding(rec, N, params)
# mask each reconstructed slice with a circle
rec = mask(rec, params)
return rec
[docs]def padding(data, rotation_axis, params):
log.info(" *** padding")
do_gridrec_padding = params.reconstruction_algorithm=='gridrec' and params.gridrec_padding
do_lprec_padding = params.reconstruction_algorithm=='lprec' and params.lprec_padding
if do_gridrec_padding or do_lprec_padding:
log.info(' *** *** ON')
N = data.shape[2]
data_pad = np.zeros([data.shape[0],data.shape[1],3*N//2],dtype = "float32")
data_pad[:,:,N//4:5*N//4] = data
data_pad[:,:,0:N//4] = np.reshape(data[:,:,0],[data.shape[0],data.shape[1],1])
data_pad[:,:,5*N//4:] = np.reshape(data[:,:,-1],[data.shape[0],data.shape[1],1])
data = data_pad
rot_center = rotation_axis + N//4
else:
log.warning(' *** *** OFF')
data = data
rot_center = rotation_axis
return data, rot_center
[docs]def unpadding(rec, N, params):
log.info(" *** un-padding")
do_gridrec_padding = params.reconstruction_algorithm=='gridrec' and params.gridrec_padding
do_lprec_padding = params.reconstruction_algorithm=='lprec' and params.lprec_padding
if do_gridrec_padding or do_lprec_padding:
log.info(' *** *** ON')
rec = rec[:,N//4:5*N//4,N//4:5*N//4]
else:
log.warning(' *** *** OFF')
rec = rec
return rec
[docs]def reconstruct(data, theta, rot_center, params):
if(params.reconstruction_type == "try"):
sinogram_order = True
else:
sinogram_order = False
# Check for sane input values
if not np.all(np.isfinite(data)):
log.warning(" *** nan/inf found in input data. "
"Consider using ``--fix-nan-and-inf True``.")
log.info(" *** algorithm: %s" % params.reconstruction_algorithm)
# Apply the various reconstruction algorithms
if params.reconstruction_algorithm == 'astrasirt':
extra_options ={}
try:
extra_options['MinConstraint'] = float(params.astrasirt_min_constraint)
except ValueError:
log.warning(" *** *** invalid astrasirt_min_constraint value..."
"ignoring.")
try:
extra_options['MaxConstraint'] = float(params.astrasirt_max_constraint)
except ValueError:
log.warning(" *** *** invalid astrasirt_max_constraint value..."
"ignoring.")
options = {'proj_type':params.astrasirt_proj_type,
'method': params.astrasirt_method,
'num_iter': params.astrasirt_num_iter,
'extra_options': extra_options,}
if params.astrasirt_bootstrap:
log.info(' *** *** bootstrapping with gridrec')
rec = tomopy.recon(data, theta,
center=rot_center,
sinogram_order=sinogram_order,
algorithm='gridrec',
filter_name=params.gridrec_filter)
rec = tomopy.misc.corr.gaussian_filter(rec, axis=1)
rec = tomopy.misc.corr.gaussian_filter(rec, axis=2)
# shift = (int((data.shape[2]/2 - rot_center)+.5))
# data = np.roll(data, shift, axis=2)
recon_kw = dict(center=rot_center, algorithm=tomopy.astra,
options=options)
if params.astrasirt_bootstrap:
log.info(' *** *** using gridrec to start astrasirt recon')
recon_kw['init_recon'] = rec
rec = tomopy.recon(data, theta, **recon_kw)
elif params.reconstruction_algorithm == 'astrasart':
extra_options ={}
try:
extra_options['MinConstraint'] = float(params.astrasart_min_constraint)
except ValueError:
log.warning(" *** *** invalid astrasart_min_constraint value..."
"ignoring.")
try:
extra_options['MaxConstraint'] = float(params.astrasart_max_constraint)
except ValueError:
log.warning(" *** *** invalid astrasart_max_constraint value..."
"ignoring.")
options = {'proj_type':params.astrasart_proj_type,
'method': params.astrasart_method,
'num_iter': params.astrasart_num_iter * data.shape[0],
'extra_options': extra_options,}
if params.astrasart_bootstrap:
log.info(' *** *** bootstrapping with gridrec')
rec = tomopy.recon(data, theta,
center=rot_center,
sinogram_order=sinogram_order,
algorithm='gridrec',
filter_name=params.gridrec_filter)
shift = (int((data.shape[2]/2 - rot_center)+.5))
data = np.roll(data, shift, axis=2)
if params.astrasart_bootstrap:
log.info(' *** *** using gridrec to start astrasart recon')
rec = tomopy.recon(data, theta, init_recon=rec, algorithm=tomopy.astra, options=options)
else:
rec = tomopy.recon(data, theta, algorithm=tomopy.astra, options=options)
rec = tomopy.recon(data, theta, algorithm=tomopy.astra, options=options)
elif params.reconstruction_algorithm == 'astracgls':
extra_options ={}
options = {'proj_type':params.astracgls_proj_type,
'method': params.astracgls_method,
'num_iter': params.astracgls_num_iter,
'extra_options': extra_options,}
if params.astracgls_bootstrap:
log.info(' *** *** bootstrapping with gridrec')
rec = tomopy.recon(data, theta,
center=rot_center,
sinogram_order=sinogram_order,
algorithm='gridrec',
filter_name=params.gridrec_filter)
shift = (int((data.shape[2]/2 - rot_center)+.5))
data = np.roll(data, shift, axis=2)
if params.astracgls_bootstrap:
log.info(' *** *** using gridrec to start astracgls recon')
rec = tomopy.recon(data, theta, init_recon=rec, algorithm=tomopy.astra, options=options)
else:
rec = tomopy.recon(data, theta, algorithm=tomopy.astra, options=options)
# gridrec
elif params.reconstruction_algorithm == 'gridrec':
log.warning(" *** *** sinogram_order: %s" % sinogram_order)
if(params.reconstruction_type == "try"):
# each chunk works with 1 rotation center
nchunk = 1
else:
nchunk = None
rec = tomopy.recon(data, theta,
center=rot_center,
sinogram_order=sinogram_order,
algorithm='gridrec',
filter_name=params.gridrec_filter,
nchunk = nchunk)
# log-polar based method
elif params.reconstruction_algorithm == 'lprec':
log.warning(" *** *** sinogram_order: %s" % sinogram_order)
lpmethod = params.lprec_method
if (lpmethod=='fbp'):
filter_name = params.lprec_fbp_filter
else:
filter_name = 'none'
rec = tomopy.recon(data, theta,
center=rot_center,
sinogram_order=sinogram_order,
algorithm=tomopy.lprec,
lpmethod=lpmethod,
filter_name=filter_name,
ncore=1,
num_iter=params.lprec_num_iter,
reg_par=params.lprec_reg,
gpu_list=range(params.lprec_num_gpu))
else:
log.warning(" *** *** algorithm: %s is not supported yet" % params.reconstruction_algorithm)
params.reconstruction_algorithm = 'gridrec'
log.warning(" *** *** using: %s instead" % params.reconstruction_algorithm)
log.warning(" *** *** sinogram_order: %s" % sinogram_order)
rec = tomopy.recon(data, theta, center=rot_center, sinogram_order=sinogram_order, algorithm=params.reconstruction_algorithm, filter_name=params.gridrec_filter)
# Check for sane values
if np.all(np.isnan(rec)):
log.error(" *** *** reconstruction produced all NaN")
log.info(" *** reconstruction finished")
return rec
[docs]def mask(data, params):
log.info(" *** mask")
if(params.reconstruction_mask):
log.info(' *** *** ON')
if 0 < params.reconstruction_mask_ratio <= 1:
log.warning(" *** mask ratio: %f " % params.reconstruction_mask_ratio)
data = tomopy.circ_mask(data, axis=0, ratio=params.reconstruction_mask_ratio)
log.info(' *** masking finished')
else:
log.error(" *** mask ratio must be between 0-1: %f is ignored" % params.reconstruction_mask_ratio)
else:
log.warning(' *** *** OFF')
return data
[docs]def reconstruction_folder(params):
"""Build the path to the folder that will receive the reconstruction.
"""
file_path = Path(params.file_name).resolve()
folder_fmt = params.save_folder
# Format the folder name with the config parameters
if file_path.is_dir():
file_name_parent = file_path
else:
file_name_parent = file_path.parent
folder_fmt = folder_fmt.format(file_name_parent=file_name_parent, **params.__dict__)
return Path(folder_fmt)