Skip to content

Commit

Permalink
Merge pull request #1252 from OldLipe/feat/dev-sits
Browse files Browse the repository at this point in the history
Fix vector cube bugs
  • Loading branch information
gilbertocamara authored Dec 16, 2024
2 parents ced2f92 + ef975f0 commit 3e80043
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 4 deletions.
9 changes: 7 additions & 2 deletions R/api_classify.R
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,8 @@
prediction <- .classify_ts_gpu(
pred = pred,
ml_model = ml_model,
gpu_memory = gpu_memory)
gpu_memory = gpu_memory
)
else
prediction <- .classify_ts_cpu(
pred = pred,
Expand Down Expand Up @@ -651,6 +652,8 @@
values <- .pred_features(pred_part)
# Classify
values <- ml_model(values)
# normalize and calibrate values
values <- .ml_normalize(ml_model, values)
# Return classification
values <- tibble::as_tibble(values)
values
Expand Down Expand Up @@ -691,8 +694,10 @@
values <- .pred_features(pred_part)
# Classify
values <- ml_model(values)
# normalize and calibrate values
values <- .ml_normalize(ml_model, values)
# Return classification
values <- tibble::tibble(data.frame(values))
values <- tibble::as_tibble(values)
# Clean GPU memory
.ml_gpu_clean(ml_model)
return(values)
Expand Down
2 changes: 2 additions & 0 deletions R/api_ml_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,10 @@
#' @export
#'
.ml_normalize.torch_model <- function(ml_model, values){
column_names <- colnames(values)
values[is.na(values)] <- 0
values <- softmax(values)
colnames(values) <- column_names
return(values)
}
#' @export
Expand Down
3 changes: 2 additions & 1 deletion R/api_tile.R
Original file line number Diff line number Diff line change
Expand Up @@ -1465,7 +1465,8 @@ NULL
x = r_obj,
y = segments,
fun = NULL,
include_cols = "pol_id"
include_cols = "pol_id",
progress = FALSE
)
values <- dplyr::bind_rows(values)
values <- dplyr::select(values, -"coverage_fraction")
Expand Down
2 changes: 1 addition & 1 deletion R/sits_classify.R
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ sits_classify.segs_cube <- function(data,
proc_bloat <- .conf("processing_bloat_gpu")
}
# avoid memory race in Apple MPS
if(.torch_mps_enabled(ml_model)){
if (.torch_mps_enabled(ml_model)) {
memsize <- 1
gpu_memory <- 1
}
Expand Down
2 changes: 2 additions & 0 deletions R/sits_plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,8 @@ plot.vector_cube <- function(x, ...,
sf_seg = sf_seg,
seg_color = seg_color,
line_width = line_width,
first_quantile = first_quantile,
last_quantile = last_quantile,
scale = scale,
max_cog_size = max_cog_size,
tmap_params = tmap_params
Expand Down
7 changes: 7 additions & 0 deletions inst/extdata/config_internals.yml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,13 @@ default_values :
maximum_value: 1.7014118346015974e+37
offset_value : 0
scale_factor : 1
FLT8S :
data_type : "FLT8S"
missing_value: -3.402823466385288e+37
minimum_value: -3.402823466385288e+37
maximum_value: 1.7014118346015974e+37
offset_value : 0
scale_factor : 1

# Derived cube definitions
derived_cube :
Expand Down

0 comments on commit 3e80043

Please sign in to comment.