Skip to content

Commit

Permalink
Pass optional args
Browse files Browse the repository at this point in the history
  • Loading branch information
alanocallaghan committed Jan 8, 2025
1 parent 2a8a003 commit 21b5af9
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 9 deletions.
42 changes: 34 additions & 8 deletions src/main/java/qupath/ext/instanseg/core/InstanSeg.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import ai.djl.ndarray.BaseNDManager;
import ai.djl.repository.zoo.Criteria;
import ai.djl.training.util.ProgressBar;
import java.util.HashMap;
import org.bytedeco.opencv.global.opencv_core;
import org.bytedeco.opencv.opencv_core.Mat;
import org.slf4j.Logger;
Expand Down Expand Up @@ -37,6 +38,7 @@
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ArrayBlockingQueue;
Expand All @@ -57,11 +59,19 @@ public class InstanSeg {
private final Device device;
private final TaskRunner taskRunner;
private final Class<? extends PathObject> preferredOutputClass;
private final Map<String, Object> optionalArgs;

// This was previously an adjustable parameter, but it's now fixed at 1 because we handle overlaps differently.
// However we might want to reinstate it, possibly as a proportion of the padding amount.
private final int boundaryThreshold = 1;

/**
* Run inference for the currently selected PathObjects in the current image.
*/
public InstanSegResults detectObjects() {
return detectObjects(QP.getCurrentImageData());
}

private InstanSeg(Builder builder) {
this.tileDims = builder.tileDims;
this.downsample = builder.downsample; // Optional... and not advised (use the model spec instead); set <= 0 to ignore
Expand All @@ -74,13 +84,7 @@ private InstanSeg(Builder builder) {
this.preferredOutputClass = builder.preferredOutputClass;
this.randomColors = builder.randomColors;
this.makeMeasurements = builder.makeMeasurements;
}

/**
* Run inference for the currently selected PathObjects in the current image.
*/
public InstanSegResults detectObjects() {
return detectObjects(QP.getCurrentImageData());
this.optionalArgs = builder.optionalArgs;
}

/**
Expand Down Expand Up @@ -215,7 +219,7 @@ private InstanSegResults runInstanSeg(ImageData<BufferedImage> imageData, Collec
.optModelUrls(String.valueOf(modelPath.toUri()))
.optProgress(new ProgressBar())
.optDevice(device) // Remove this line if devices are problematic!
.optTranslator(new MatTranslator(layout, layoutOutput, outputChannelArray))
.optTranslator(new MatTranslator(layout, layoutOutput, outputChannelArray, optionalArgs))
.build()
.loadModel()) {

Expand Down Expand Up @@ -392,6 +396,7 @@ public static final class Builder {
private Collection<? extends ColorTransforms.ColorTransform> channels;
private InstanSegModel model;
private Class<? extends PathObject> preferredOutputClass;
private final Map<String, Object> optionalArgs = new HashMap<>();

Builder() {}

Expand Down Expand Up @@ -653,6 +658,27 @@ public Builder outputAnnotations() {
return this;
}

/**
* Set a number of optional arguments
* @param optionalArgs The argument names and values.
* @return A modified builder.
*/
public Builder args(Map<String, Object> optionalArgs) {
this.optionalArgs.putAll(optionalArgs);
return this;
}

/**
* Set a number of optional arguments
* @param name The argument name.
* @param value The argument value.
* @return A modified builder.
*/
public Builder arg(String name, Object value) {
optionalArgs.put(name, value);
return this;
}

/**
* Request to make measurements from the objects created by InstanSeg.
* @return this builder
Expand Down
25 changes: 24 additions & 1 deletion src/main/java/qupath/ext/instanseg/core/MatTranslator.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
package qupath.ext.instanseg.core;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.bytedeco.opencv.opencv_core.Mat;
import qupath.ext.djl.DjlTools;

Expand All @@ -15,6 +21,7 @@ class MatTranslator implements Translator<Mat, Mat> {
private final String inputLayoutNd;
private final String outputLayoutNd;
private final int[] outputChannels;
private final Map<String, Object> optionalArgs;

/**
* Create a translator from InstanSeg input to output.
Expand All @@ -23,10 +30,11 @@ class MatTranslator implements Translator<Mat, Mat> {
* @param outputChannels Array of channels to output; if null or empty, output all channels.
* Values should be true for channels to output, false for channels to ignore.
*/
MatTranslator(String inputLayoutNd, String outputLayoutNd, boolean[] outputChannels) {
MatTranslator(String inputLayoutNd, String outputLayoutNd, boolean[] outputChannels, Map<String, Object> optionalArgs) {
this.inputLayoutNd = inputLayoutNd;
this.outputLayoutNd = outputLayoutNd;
this.outputChannels = convertBooleanArray(outputChannels);
this.optionalArgs = optionalArgs;
}

private static int[] convertBooleanArray(boolean[] array) {
Expand Down Expand Up @@ -55,9 +63,24 @@ public NDList processInput(TranslatorContext ctx, Mat input) {
var arrayCPU = array.toDevice(Device.cpu(), false);
out.add(arrayCPU);
}
List<NDArray> args = sanitizeOptionalArgs(optionalArgs, manager);
out.addAll(args);
return out;
}

private static List<NDArray> sanitizeOptionalArgs(Map<String, Object> optionalArgs, NDManager manager) {
List<NDArray> arrays = new ArrayList<>();
for (var es : optionalArgs.entrySet()) {
var val = es.getValue();
if (val instanceof Double || val instanceof BigDecimal) {
NDArray array = manager.create(((Number) val).floatValue());
array.setName("args." + es.getKey());
arrays.add(array);
}
}
return arrays;
}

@Override
public Mat processOutput(TranslatorContext ctx, NDList list) {
var array = list.getFirst();
Expand Down

0 comments on commit 21b5af9

Please sign in to comment.