diff --git a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java index d8cb8f2..5ad9ce3 100644 --- a/src/main/java/qupath/ext/instanseg/core/InstanSeg.java +++ b/src/main/java/qupath/ext/instanseg/core/InstanSeg.java @@ -15,6 +15,7 @@ import qupath.lib.objects.PathCellObject; import qupath.lib.objects.PathDetectionObject; import qupath.lib.objects.PathObject; +import qupath.lib.objects.PathObjectTools; import qupath.lib.objects.utils.ObjectMerger; import qupath.lib.objects.utils.ObjectProcessor; import qupath.lib.objects.utils.OverlapFixer; @@ -30,11 +31,12 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.List; +import java.util.Objects; import java.util.Optional; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; -import java.util.stream.IntStream; public class InstanSeg { @@ -45,8 +47,8 @@ public class InstanSeg { private final int padding; private final int[] outputChannels; private final boolean randomColors; - private final ImageData imageData; - private final Collection channels; + private final boolean makeMeasurements; + private final List inputChannels; private final InstanSegModel model; private final Device device; private final TaskRunner taskRunner; @@ -55,60 +57,72 @@ public class InstanSeg { // This was previously an adjustable parameter, but it's now fixed at 1 because we handle overlaps differently private final int boundaryThreshold = 1; - private InstanSeg(int tileDims, double downsample, int padding, int[] outputChannels, ImageData imageData, - Collection channels, InstanSegModel model, Device device, TaskRunner taskRunner, - Class preferredOutputClass, boolean randomColors) { - this.tileDims = tileDims; - this.downsample = downsample; // Optional... and not advised (use the model spec instead); set <= 0 to ignore - this.padding = padding; - this.outputChannels = outputChannels == null ? null : outputChannels.clone(); - this.imageData = imageData; - this.channels = channels; - this.model = model; - this.device = device; - this.taskRunner = taskRunner; - this.preferredOutputClass = preferredOutputClass; - this.randomColors = randomColors; + private InstanSeg(Builder builder) { + this.tileDims = builder.tileDims; + this.downsample = builder.downsample; // Optional... and not advised (use the model spec instead); set <= 0 to ignore + this.padding = builder.padding; + this.outputChannels = builder.outputChannels == null ? null : builder.outputChannels.clone(); + this.inputChannels = builder.channels == null ? Collections.emptyList() : List.copyOf(builder.channels); + this.model = builder.model; + this.device = builder.device; + this.taskRunner = builder.taskRunner; + this.preferredOutputClass = builder.preferredOutputClass; + this.randomColors = builder.randomColors; + this.makeMeasurements = builder.makeMeasurements; } /** - * Run inference for the currently selected PathObjects. + * Run inference for the currently selected PathObjects in the current image. */ public InstanSegResults detectObjects() { - return detectObjects(imageData.getHierarchy().getSelectionModel().getSelectedObjects()); + return detectObjects(QP.getCurrentImageData()); } /** - * Run inference for the currently selected PathObjects, then measure the new objects that were created. + * Run inference for the currently selected PathObjects in the specified image. */ - public InstanSegResults detectObjectsAndMeasure() { - return detectObjectsAndMeasure(imageData.getHierarchy().getSelectionModel().getSelectedObjects()); + public InstanSegResults detectObjects(ImageData imageData) { + Objects.requireNonNull(imageData, "No imageData available"); + return detectObjects(imageData, imageData.getHierarchy().getSelectionModel().getSelectedObjects()); } /** - * Run inference for the specified selected PathObjects, then measure the new objects that were created. + * Run inference for a collection of PathObjects from the current image. */ - public InstanSegResults detectObjectsAndMeasure(Collection pathObjects) { - var results = detectObjects(pathObjects); - for (var pathObject: pathObjects) { - makeMeasurements(imageData, pathObject.getChildObjects()); + public InstanSegResults detectObjects(Collection pathObjects) { + var imageData = QP.getCurrentImageData(); + var results = runInstanSeg(imageData, pathObjects); + if (makeMeasurements) { + for (var pathObject : pathObjects) { + makeMeasurements(imageData, pathObject.getChildObjects()); + } } return results; } /** - * Get the imageData from an InstanSeg object. - * @return The imageData used for the model. + * Run inference for a collection of PathObjects associated with the specified image. + * @throws IllegalArgumentException if the image or objects are null, or if the objects are not found within the image's hierarchy */ - public ImageData getImageData() { - return imageData; + public InstanSegResults detectObjects(ImageData imageData, Collection pathObjects) + throws IllegalArgumentException { + validateImageAndObjectsOrThrow(imageData, pathObjects); + var results = runInstanSeg(imageData, pathObjects); + if (makeMeasurements) { + for (var pathObject : pathObjects) { + makeMeasurements(imageData, pathObject.getChildObjects()); + } + } + return results; } - /** - * Run inference for a collection of PathObjects. - */ - public InstanSegResults detectObjects(Collection pathObjects) { - return runInstanSeg(pathObjects); + private void validateImageAndObjectsOrThrow(ImageData imageData, Collection pathObjects) { + Objects.requireNonNull(imageData, "No imageData available"); + Objects.requireNonNull(pathObjects, "No objects available"); + var hierarchy = imageData.getHierarchy(); + if (pathObjects.stream().anyMatch(p -> !PathObjectTools.hierarchyContainsObject(hierarchy, p))) { + throw new IllegalArgumentException("Objects must be contained in the image hierarchy!"); + } } /** @@ -125,7 +139,7 @@ public static Builder builder() { * @param imageData The ImageData for making measurements. * @param detections The objects to measure. */ - public void makeMeasurements(ImageData imageData, Collection detections) { + private void makeMeasurements(ImageData imageData, Collection detections) { double downsample = model.getPreferredDownsample(imageData.getServer().getPixelCalibration()); DetectionMeasurer.builder() .downsample(downsample) @@ -133,7 +147,7 @@ public void makeMeasurements(ImageData imageData, Collection pathObjects) { + private InstanSegResults runInstanSeg(ImageData imageData, Collection pathObjects) { long startTime = System.currentTimeMillis(); @@ -175,6 +189,9 @@ private InstanSegResults runInstanSeg(Collection pathObjec } } + // If no input channels are specified, use all channels + var inputChannels = getInputChannels(imageData); + try (var model = Criteria.builder() .setTypes(Mat.class, Mat.class) .optModelUrls(String.valueOf(modelPath.toUri())) @@ -199,9 +216,9 @@ private InstanSegResults runInstanSeg(Collection pathObjec (BaseNDManager)baseManager.getParentManager()); int sizeWithoutPadding = (int) Math.ceil(downsample * (tileDims - (double) padding*2)); - var predictionProcessor = new TilePredictionProcessor(predictors, channels, tileDims, tileDims, padToInputSize); + var predictionProcessor = new TilePredictionProcessor(predictors, inputChannels, tileDims, tileDims, padToInputSize); var processor = OpenCVProcessor.builder(predictionProcessor) - .imageSupplier((parameters) -> ImageOps.buildImageDataOp(channels) + .imageSupplier((parameters) -> ImageOps.buildImageDataOp(inputChannels) .apply(parameters.getImageData(), parameters.getRegionRequest())) .tiler(Tiler.builder(sizeWithoutPadding) .alignCenter() @@ -237,6 +254,23 @@ private InstanSegResults runInstanSeg(Collection pathObjec } } + /** + * Get the input channels to use; if we don't have any specified, use all of them + * @param imageData + * @return + */ + private List getInputChannels(ImageData imageData) { + if (inputChannels == null || inputChannels.isEmpty()) { + List channels = new ArrayList<>(); + for (int i = 0; i < imageData.getServer().nChannels(); i++) { + channels.add(ColorTransforms.createChannelExtractor(i)); + } + return channels; + } else { + return inputChannels; + } + } + private static ObjectProcessor createPostProcessor() { var merger = ObjectMerger.createIoMinMerger(0.5); var fixer = OverlapFixer.builder() @@ -269,17 +303,17 @@ public static final class Builder { private static final Logger logger = LoggerFactory.getLogger(Builder.class); private static final int MIN_TILE_DIMS = 256; - private static final int MAX_TILE_DIMS = 2048; + private static final int MAX_TILE_DIMS = 4096; private int tileDims = 512; private double downsample = -1; // Optional - we usually get this from the model private int padding = 80; // Previous default of 40 could miss large objects private int[] outputChannels = null; private boolean randomColors = true; + private boolean makeMeasurements = false; private Device device = Device.fromName("cpu"); private TaskRunner taskRunner = TaskRunnerUtils.getDefaultInstance().createTaskRunner(); - private ImageData imageData; - private Collection channels; + private Collection channels; private InstanSegModel model; private Class preferredOutputClass; @@ -288,7 +322,7 @@ public static final class Builder { /** * Set the width and height of tiles * @param tileDims The tile width and height - * @return A modified builder + * @return this builder */ public Builder tileDims(int tileDims) { if (tileDims < MIN_TILE_DIMS) { @@ -306,7 +340,7 @@ public Builder tileDims(int tileDims) { /** * Set the downsample to be used in region requests * @param downsample The downsample to be used - * @return A modified builder + * @return this builder */ public Builder downsample(double downsample) { this.downsample = downsample; @@ -316,7 +350,7 @@ public Builder downsample(double downsample) { /** * Set the padding (overlap) between tiles * @param padding The extra size added to tiles to allow overlap - * @return A modified builder + * @return this builder */ public Builder interTilePadding(int padding) { if (padding < 0) { @@ -337,38 +371,19 @@ public Builder interTilePadding(int padding) { * so it can be much cheaper to eliminate channels now, rather than discard unwanted detections later. * * @param outputChannels 0-based indices of the output channels, or leave empty to use all channels - * @return A modified builder + * @return this builder */ public Builder outputChannels(int... outputChannels) { this.outputChannels = outputChannels.clone(); return this; } - /** - * Set the imageData to be used - * @param imageData An imageData instance - * @return A modified builder - */ - public Builder imageData(ImageData imageData) { - this.imageData = imageData; - return this; - } - - /** - * Set the imageData to be used as the current image data. - * @return A modified builder - */ - public Builder currentImageData() { - this.imageData = QP.getCurrentImageData(); - return this; - } - /** * Set the channels to be used in inference * @param channels A collection of channels to be used in inference - * @return A modified builder + * @return this builder */ - public Builder channels(Collection channels) { + public Builder inputChannels(Collection channels) { this.channels = channels; return this; } @@ -376,34 +391,31 @@ public Builder channels(Collection channels) { /** * Set the channels to be used in inference * @param channels Channels to be used in inference - * @return A modified builder + * @return this builder */ - public Builder channels(ColorTransforms.ColorTransform channel, ColorTransforms.ColorTransform... channels) { - var l = Arrays.asList(channels); + public Builder inputChannels(ColorTransforms.ColorTransform channel, ColorTransforms.ColorTransform... channels) { + var l = new ArrayList(); l.add(channel); + l.addAll(Arrays.asList(channels)); this.channels = l; return this; } /** - * Set the model to use all channels for inference - * @return A modified builder + * Request that all input channels be used in inference + * @return this builder */ - public Builder allChannels() { - // assignment is just to suppress IDE suggestion for void return val - var tmp = channelIndices( - IntStream.range(0, imageData.getServer().nChannels()) - .boxed() - .toList()); + public Builder allInputChannels() { + this.channels = Collections.emptyList(); return this; } /** * Set the channels using indices * @param channels Integers used to specify the channels used - * @return A modified builder + * @return this builder */ - public Builder channelIndices(Collection channels) { + public Builder inputChannelIndices(Collection channels) { this.channels = channels.stream() .map(ColorTransforms::createChannelExtractor) .toList(); @@ -413,9 +425,9 @@ public Builder channelIndices(Collection channels) { /** * Set the channels using indices * @param channels Integers used to specify the channels used - * @return A modified builder + * @return this builder */ - public Builder channelIndices(int channel, int... channels) { + public Builder inputChannels(int channel, int... channels) { List l = new ArrayList<>(); l.add(ColorTransforms.createChannelExtractor(channel)); for (int i: channels) { @@ -428,9 +440,9 @@ public Builder channelIndices(int channel, int... channels) { /** * Set the channel names to be used * @param channels A set of channel names - * @return A modified builder + * @return this builder */ - public Builder channelNames(Collection channels) { + public Builder inputChannelNames(Collection channels) { this.channels = channels.stream() .map(ColorTransforms::createChannelExtractor) .toList(); @@ -440,9 +452,9 @@ public Builder channelNames(Collection channels) { /** * Set the channel names to be used * @param channels A set of channel names - * @return A modified builder + * @return this builder */ - public Builder channelNames(String channel, String... channels) { + public Builder inputChannelNames(String channel, String... channels) { List l = new ArrayList<>(); l.add(ColorTransforms.createChannelExtractor(channel)); for (String s: channels) { @@ -454,7 +466,7 @@ public Builder channelNames(String channel, String... channels) { /** * Request that random colors be used for the output objects. - * @return + * @return this builder */ public Builder randomColors() { return randomColors(true); @@ -463,7 +475,7 @@ public Builder randomColors() { /** * Optionally request that random colors be used for the output objects. * @param doRandomColors - * @return + * @return this builder */ public Builder randomColors(boolean doRandomColors) { this.randomColors = doRandomColors; @@ -473,7 +485,7 @@ public Builder randomColors(boolean doRandomColors) { /** * Set the number of threads used * @param nThreads The number of threads to be used - * @return A modified builder + * @return this builder */ public Builder nThreads(int nThreads) { this.taskRunner = TaskRunnerUtils.getDefaultInstance().createTaskRunner(nThreads); @@ -483,7 +495,7 @@ public Builder nThreads(int nThreads) { /** * Set the TaskRunner * @param taskRunner An object that will run tasks and show progress - * @return A modified builder + * @return this builder */ public Builder taskRunner(TaskRunner taskRunner) { this.taskRunner = taskRunner; @@ -493,7 +505,7 @@ public Builder taskRunner(TaskRunner taskRunner) { /** * Set the specific model to be used * @param model An already instantiated InstanSeg model. - * @return A modified builder + * @return this builder */ public Builder model(InstanSegModel model) { this.model = model; @@ -503,7 +515,7 @@ public Builder model(InstanSegModel model) { /** * Set the specific model by path * @param path A path on disk to create an InstanSeg model from. - * @return A modified builder + * @return this builder */ public Builder modelPath(Path path) throws IOException { return model(InstanSegModel.fromPath(path)); @@ -512,7 +524,7 @@ public Builder modelPath(Path path) throws IOException { /** * Set the specific model by path * @param path A path on disk to create an InstanSeg model from. - * @return A modified builder + * @return this builder */ public Builder modelPath(String path) throws IOException { return modelPath(Path.of(path)); @@ -521,7 +533,7 @@ public Builder modelPath(String path) throws IOException { /** * Set the device to be used * @param deviceName The name of the device to be used (eg, "gpu", "mps"). - * @return A modified builder + * @return this builder */ public Builder device(String deviceName) { this.device = Device.fromName(deviceName); @@ -531,7 +543,7 @@ public Builder device(String deviceName) { /** * Set the device to be used * @param device The {@link Device} to be used - * @return A modified builder + * @return this builder */ public Builder device(Device device) { this.device = device; @@ -540,7 +552,7 @@ public Builder device(Device device) { /** * Specify cells as the output class, possibly without nuclei - * @return A modified builder + * @return this builder */ public Builder outputCells() { this.preferredOutputClass = PathCellObject.class; @@ -549,7 +561,7 @@ public Builder outputCells() { /** * Specify (possibly nested) detections as the output class - * @return A modified builder + * @return this builder */ public Builder outputDetections() { this.preferredOutputClass = PathDetectionObject.class; @@ -558,37 +570,28 @@ public Builder outputDetections() { /** * Specify (possibly nested) annotations as the output class - * @return A modified builder + * @return this builder */ public Builder outputAnnotations() { this.preferredOutputClass = PathAnnotationObject.class; return this; } + /** + * Request to make measurements from the objects created by InstanSeg. + * @return this builder + */ + public Builder makeMeasurements(boolean doMeasure) { + this.makeMeasurements = doMeasure; + return this; + } + /** * Build the InstanSeg instance. * @return An InstanSeg instance ready for object detection. */ public InstanSeg build() { - if (imageData == null) { - // assignment is just to suppress IDE suggestion for void return - var tmp = currentImageData(); - } - if (channels == null) { - var tmp = allChannels(); - } - return new InstanSeg( - this.tileDims, - this.downsample, - this.padding, - this.outputChannels, - this.imageData, - this.channels, - this.model, - this.device, - this.taskRunner, - this.preferredOutputClass, - this.randomColors); + return new InstanSeg(this); } } diff --git a/src/main/java/qupath/ext/instanseg/core/TilePredictionProcessor.java b/src/main/java/qupath/ext/instanseg/core/TilePredictionProcessor.java index 451a2bb..cc97cfb 100644 --- a/src/main/java/qupath/ext/instanseg/core/TilePredictionProcessor.java +++ b/src/main/java/qupath/ext/instanseg/core/TilePredictionProcessor.java @@ -22,6 +22,7 @@ import java.awt.image.BufferedImage; import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; @@ -58,7 +59,7 @@ class TilePredictionProcessor implements Processor { private final Map normalization = Collections.synchronizedMap(new WeakHashMap<>()); TilePredictionProcessor(BlockingQueue> predictors, - Collection channels, + Collection channels, int inputWidth, int inputHeight, boolean doPadding) { this.predictors = predictors; this.channels = List.copyOf(channels); diff --git a/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java b/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java index eaf7bf5..676fa2b 100644 --- a/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java +++ b/src/main/java/qupath/ext/instanseg/ui/ChannelSelectItem.java @@ -51,6 +51,9 @@ String getConstructor() { } static String toConstructorString(Collection items) { - return "[" + items.stream().map(ChannelSelectItem::getConstructor).collect(Collectors.joining(", ")) + "]"; + if (items == null || items.isEmpty()) + return "allInputChannels()"; + else + return "inputChannels([" + items.stream().map(ChannelSelectItem::getConstructor).collect(Collectors.joining(", ")) + "])"; } } diff --git a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java index d710d34..d3589a6 100644 --- a/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java +++ b/src/main/java/qupath/ext/instanseg/ui/InstanSegController.java @@ -739,12 +739,12 @@ protected Void call() { var instanSeg = InstanSeg.builder() .model(model) - .imageData(imageData) .device(deviceChoices.getSelectionModel().getSelectedItem()) + .inputChannels(channels.stream().map(ChannelSelectItem::getTransform).toList()) .outputChannels(outputChannels) - .channels(channels.stream().map(ChannelSelectItem::getTransform).toList()) .tileDims(InstanSegPreferences.tileSizeProperty().get()) .taskRunner(taskRunner) + .makeMeasurements(makeMeasurementsCheckBox.isSelected()) .build(); boolean makeMeasurements = makeMeasurementsCheckBox.isSelected(); @@ -752,30 +752,25 @@ protected Void call() { qupath.ext.instanseg.core.InstanSeg.builder() .modelPath("%s") .device("%s") + .%s .outputChannels(%s) - .channels(%s) .tileDims(%d) .nThreads(%d) + .makeMeasurements(%s) .build() - .%s + .detectObjects() """, path.get(), deviceChoices.getSelectionModel().getSelectedItem(), + ChannelSelectItem.toConstructorString(channels), outputChannels.length == 0 ? "" : Arrays.stream(outputChannels) .mapToObj(Integer::toString) .collect(Collectors.joining(", ")), - ChannelSelectItem.toConstructorString(channels), InstanSegPreferences.tileSizeProperty().get(), InstanSegPreferences.numThreadsProperty().getValue(), - makeMeasurements ? "detectObjectsAndMeasure()" : "detectObjects()" - ); - InstanSegResults results; - if (makeMeasurements) { - results = instanSeg.detectObjectsAndMeasure(selectedObjects); - } else { - results = instanSeg.detectObjects(selectedObjects); - } - + makeMeasurements + ).strip(); + InstanSegResults results = instanSeg.detectObjects(imageData, selectedObjects); imageData.getHierarchy().fireHierarchyChangedEvent(this); imageData.getHistoryWorkflow() .addStep(