Skip to content

Commit

Permalink
fix issues with saving results
Browse files Browse the repository at this point in the history
  • Loading branch information
sophiamaedler committed Nov 30, 2024
1 parent ab1fdd9 commit 4382e82
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/scportrait/pipeline/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,25 +397,29 @@ def inference(self,

features_path = tempmmap.create_empty_mmap(shape_features, dtype = np.float32)
cell_ids_path = tempmmap.create_empty_mmap(shape_labels, dtype = np.int64)
labels_path = tempmmap.create_empty_mmap(shape_labels, dtype = np.int64)

features = tempmmap.mmap_array_from_path(features_path)
cell_ids = tempmmap.mmap_array_from_path(cell_ids_path)
labels = tempmmap.mmap_array_from_path(labels_path)

#save the results for each batch into the memory mapped array at the specified indices
features[ix:(ix+batch_size)] = result.numpy()
cell_ids[ix:(ix+batch_size)] = class_id.unsqueeze(1)
labels[ix:(ix+batch_size)] = label.unsqueeze(1)
ix += batch_size

for i in range(len(dataloader) - 1):
if i % 10 == 0:
self.log(f"processing batch {i}")
x, label, id = next(data_iter)
x, label, class_id = next(data_iter)

r = model_fun(x.to(self.config["inference_device"]))

#save the results for each batch into the memory mapped array at the specified indices
features[ix:(ix+r.shape[0])] = r.cpu().detach().numpy()
cell_ids[ix:(ix+r.shape[0])] = label.unsqueeze(1)
cell_ids[ix:(ix+r.shape[0])] = class_id.unsqueeze(1)
labels[ix:(ix+r.shape[0])] = label.unsqueeze(1)

ix += r.shape[0]

Expand All @@ -424,14 +428,11 @@ def inference(self,
sigma = 1e-9
features = np.log(features + sigma)

label = label.numpy()
class_id = class_id.numpy()

# save inferred activations / predictions
result_labels = [f"result_{i}" for i in range(features.shape[1])]

dataframe = pd.DataFrame(data=features, columns=result_labels)
dataframe["label"] = label
dataframe["label"] = labels
dataframe["cell_id"] = cell_ids.astype("int")

self.log("finished processing")
Expand Down

0 comments on commit 4382e82

Please sign in to comment.