Skip to content

Commit

Permalink
Merge pull request #416 from frheault/concatenate_dwi
Browse files Browse the repository at this point in the history
Concatenate DWI
  • Loading branch information
arnaudbore authored Feb 25, 2021
2 parents adc0bb5 + 691a3cd commit c620b82
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 1 deletion.
94 changes: 94 additions & 0 deletions scripts/scil_concatenate_dwi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Concatenate DWI, bval and bvecs together. File must be specified in matching
order. Default data type will be the same as the first input DWI.
"""

import argparse

from dipy.io.gradients import read_bvals_bvecs
from dipy.io.utils import is_header_compatible
import nibabel as nib
import numpy as np

from scilpy.io.utils import (add_overwrite_arg,
assert_inputs_exist,
assert_outputs_exist)


def _build_arg_parser():
p = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawTextHelpFormatter)

p.add_argument('out_dwi',
help='The name of the output DWI file.')
p.add_argument('out_bval',
help='The name of the output b-values.')
p.add_argument('out_bvec',
help='The name of the output b-vectors.')

p.add_argument('--in_dwis', nargs='+',
help='The DWI file (.nii) to concatenate.')
p.add_argument('--in_bvals', nargs='+',
help='The b-values in FSL format.')
p.add_argument('--in_bvecs', nargs='+',
help='The b-vectors in FSL format.')

p.add_argument('--data_type',
help='Data type of the output image. Use the format: '
'uint8, int16, int/float32, int/float64.')

add_overwrite_arg(p)

return p


def main():
parser = _build_arg_parser()
args = parser.parse_args()

if len(args.in_dwis) != len(args.in_bvals) \
or len(args.in_dwis) != len(args.in_bvecs):
parser.error('DWI, bvals and bvecs must have the same length')

assert_inputs_exist(parser, args.in_dwis + args.in_bvals + args.in_bvecs)
assert_outputs_exist(parser, args, [args.out_dwi, args.out_bval,
args.out_bvec])

all_bvals = []
all_bvecs = []
total_size = 0
for i in range(len(args.in_dwis)):
bvals, bvecs = read_bvals_bvecs(args.in_bvals[i], args.in_bvecs[i])
if len(bvals) != len(bvecs):
raise ValueError('Paired bvals and bvecs must have the same size.')
total_size += len(bvals)
all_bvals.append(bvals)
all_bvecs.append(bvecs)
all_bvals = np.concatenate(all_bvals)
all_bvecs = np.concatenate(all_bvecs)

ref_dwi = nib.load(args.in_dwis[0])
all_dwi = np.zeros(ref_dwi.shape[0:3] + (total_size,),
dtype=args.data_type)
last_count = ref_dwi.shape[-1]
all_dwi[..., 0:last_count] = ref_dwi.get_fdata()
for i in range(1, len(args.in_dwis)):
curr_dwi = nib.load(args.in_dwis[i])
if not is_header_compatible(curr_dwi, ref_dwi):
raise ValueError('All DWI must have the compatible header.')
curr_size = curr_dwi.shape[-1]
all_dwi[..., last_count:last_count+curr_size] = \
curr_dwi.get_fdata()

np.savetxt(args.out_bval, all_bvals, '%d')
np.savetxt(args.out_bvec, all_bvecs.T, '%0.15f')
nib.save(nib.Nifti1Image(all_dwi, ref_dwi.affine, header=ref_dwi.header),
args.out_dwi)


if __name__ == "__main__":
main()
5 changes: 4 additions & 1 deletion scripts/scil_extract_dwi_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
assert_inputs_exist, assert_outputs_exist)
from scilpy.utils.bvec_bval_tools import extract_dwi_shell


# TODO switch from parser to p
# TODO switch to in_*
# TODO switch to out_*
def _build_arg_parser():
parser = argparse.ArgumentParser(
description=__doc__,
Expand Down Expand Up @@ -100,6 +102,7 @@ def main():

np.savetxt(args.output_bvals, new_bvals, '%d')
np.savetxt(args.output_bvecs, new_bvecs.T, '%0.15f')
# use named argument header=
nib.save(nib.Nifti1Image(shell_data, img.affine, img.header),
args.output_dwi)

Expand Down
32 changes: 32 additions & 0 deletions scripts/tests/test_concatenate_dwi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import tempfile

from scilpy.io.fetcher import fetch_data, get_home, get_testing_files_dict

# If they already exist, this only takes 5 seconds (check md5sum)
fetch_data(get_testing_files_dict(), keys=['processing.zip'])
tmp_dir = tempfile.TemporaryDirectory()


def test_help_option(script_runner):
ret = script_runner.run('scil_concatenate_dwi.py', '--help')
assert ret.success


def test_execution_processing_concatenate(script_runner):
os.chdir(os.path.expanduser(tmp_dir.name))
in_dwi = os.path.join(get_home(), 'processing',
'dwi_crop.nii.gz')
in_bval = os.path.join(get_home(), 'processing',
'dwi.bval')
in_bvec = os.path.join(get_home(), 'processing',
'dwi.bvec')
ret = script_runner.run('scil_concatenate_dwi.py', 'dwi_concat.nii.gz',
'concat.bval', 'concat.bvec',
'--in_dwi', in_dwi, in_dwi,
'--in_bvals', in_bval, in_bval,
'--in_bvecs', in_bvec, in_bvec)
assert ret.success

0 comments on commit c620b82

Please sign in to comment.