diff --git a/rastervision_core/rastervision/core/data/label_source/chip_classification_label_source.py b/rastervision_core/rastervision/core/data/label_source/chip_classification_label_source.py index 12955cb1b..7ab73a132 100644 --- a/rastervision_core/rastervision/core/data/label_source/chip_classification_label_source.py +++ b/rastervision_core/rastervision/core/data/label_source/chip_classification_label_source.py @@ -113,9 +113,13 @@ def read_labels(labels_df: gpd.GeoDataFrame, if bbox is not None: boxes = [b for b in boxes if b.intersects(bbox)] class_ids = labels_df['class_id'].astype(int) + if 'scores' in labels_df.columns: + scores = labels_df['scores'] + else: + scores = [None] * len(class_ids) cells_to_class_id = { - cell: (class_id, None) - for cell, class_id in zip(boxes, class_ids) + cell: (class_id, class_scores) + for cell, class_id, class_scores in zip(boxes, class_ids, scores) } labels = ChipClassificationLabels(cells_to_class_id) return labels