diff --git a/CHANGELOG.md b/CHANGELOG.md index 2dc9d3f..94afa3c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 for the entire input (but still possible). ### Changed +- Explicitly disallow `mlm:name`, `mlm:input`, `mlm:output` and `mlm:hyperparameters` at the Asset level. + These fields describe the model as a whole and should therefore be defined in Item properties. - Moved `norm_type` to `value_scaling` object to better reflect the expected operation, which could be another operation than what is typically known as "normalization" or "standardization" techniques in machine learning. - Moved `statistics` to `value_scaling` object to better reflect their mutual `type` and additional @@ -34,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Otherwise, the amount of `value_scaling` objects should match the number of bands or channels involved in the input. ### Fixed +- Fix missing `mlm:artifact_type` property check for a Model Asset definition + (fixes ). + The `mlm:artifact_type` is now mutually and exclusively required by the corresponding Asset with `mlm:model` role. - Fix check of disallowed unknown/undefined `mlm:`-prefixed fields (fixes [#41](https://github.com/stac-extensions/mlm/issues/41)). diff --git a/README.md b/README.md index 265ac4d..9739397 100644 --- a/README.md +++ b/README.md @@ -116,34 +116,48 @@ The fields in the table below can be used in these parts of STAC documents: [item-assets]: https://github.com/stac-extensions/item-assets -| Field Name | Type | Description | -|-----------------------------|---------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| mlm:name | string | **REQUIRED** A name for the model. This can include, but must be distinct, from simply naming the model architecture. If there is a publication or other published work related to the model, use the official name of the model. | -| mlm:architecture | [Model Architecture](#model-architecture) string | **REQUIRED** A generic and well established architecture name of the model. | -| mlm:tasks | \[[Task Enum](#task-enum)] | **REQUIRED** Specifies the Machine Learning tasks for which the model can be used for. If multi-tasks outputs are provided by distinct model heads, specify all available tasks under the main properties and specify respective tasks in each [Model Output Object](#model-output-object). | -| mlm:framework | string | Framework used to train the model (ex: PyTorch, TensorFlow). | -| mlm:framework_version | string | The `framework` library version. Some models require a specific version of the machine learning `framework` to run. | -| mlm:memory_size | integer | The in-memory size of the model on the accelerator during inference (bytes). | -| mlm:total_parameters | integer | Total number of model parameters, including trainable and non-trainable parameters. | -| mlm:pretrained | boolean | Indicates if the model was pretrained. If the model was pretrained, consider providing `pretrained_source` if it is known. | -| mlm:pretrained_source | string \| null | The source of the pretraining. Can refer to popular pretraining datasets by name (i.e. Imagenet) or less known datasets by URL and description. If trained from scratch (i.e.: `pretrained = false`), the `null` value should be set explicitly. | -| mlm:batch_size_suggestion | integer | A suggested batch size for the accelerator and summarized hardware. | -| mlm:accelerator | [Accelerator Type Enum](#accelerator-type-enum) \| null | The intended computational hardware that runs inference. If undefined or set to `null` explicitly, the model does not require any specific accelerator. | -| mlm:accelerator_constrained | boolean | Indicates if the intended `accelerator` is the only `accelerator` that can run inference. If undefined, it should be assumed `false`. | -| mlm:accelerator_summary | string | A high level description of the `accelerator`, such as its specific generation, or other relevant inference details. | -| mlm:accelerator_count | integer | A minimum amount of `accelerator` instances required to run the model. | -| mlm:input | \[[Model Input Object](#model-input-object)] | **REQUIRED** Describes the transformation between the EO data and the model input. | -| mlm:output | \[[Model Output Object](#model-output-object)] | **REQUIRED** Describes each model output and how to interpret it. | -| mlm:hyperparameters | [Model Hyperparameters Object](#model-hyperparameters-object) | Additional hyperparameters relevant for the model. | - -To decide whether above fields should be applied under Item `properties` or under respective Assets, the context of -each field must be considered. For example, the `mlm:name` should always be provided in the Item `properties`, since -it relates to the model as a whole. In contrast, some models could support multiple `mlm:accelerator`, which could be -handled by distinct source code represented by different Assets. In such case, `mlm:accelerator` definitions should be -nested under their relevant Asset. If a field is defined both at the Item and Asset level, the value at the Asset level -would be considered for that specific Asset, and the value at the Item level would be used for other Assets that did -not override it for their respective reference. For some of the fields, further details are provided in following -sections to provide more precisions regarding some potentially ambiguous use cases. +| Field Name | Type | Description | +|-------------------------------------------|---------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| mlm:name [\[1\]][1] | string | **REQUIRED** A name for the model. This can include, but must be distinct, from simply naming the model architecture. If there is a publication or other published work related to the model, use the official name of the model. | +| mlm:architecture | [Model Architecture](#model-architecture) string | **REQUIRED** A generic and well established architecture name of the model. | +| mlm:tasks | \[[Task Enum](#task-enum)] | **REQUIRED** Specifies the Machine Learning tasks for which the model can be used for. If multi-tasks outputs are provided by distinct model heads, specify all available tasks under the main properties and specify respective tasks in each [Model Output Object](#model-output-object). | +| mlm:framework | string | Framework used to train the model (ex: PyTorch, TensorFlow). | +| mlm:framework_version | string | The `framework` library version. Some models require a specific version of the machine learning `framework` to run. | +| mlm:memory_size | integer | The in-memory size of the model on the accelerator during inference (bytes). | +| mlm:total_parameters | integer | Total number of model parameters, including trainable and non-trainable parameters. | +| mlm:pretrained | boolean | Indicates if the model was pretrained. If the model was pretrained, consider providing `pretrained_source` if it is known. | +| mlm:pretrained_source | string \| null | The source of the pretraining. Can refer to popular pretraining datasets by name (i.e. Imagenet) or less known datasets by URL and description. If trained from scratch (i.e.: `pretrained = false`), the `null` value should be set explicitly. | +| mlm:batch_size_suggestion | integer | A suggested batch size for the accelerator and summarized hardware. | +| mlm:accelerator | [Accelerator Type Enum](#accelerator-type-enum) \| null | The intended computational hardware that runs inference. If undefined or set to `null` explicitly, the model does not require any specific accelerator. | +| mlm:accelerator_constrained | boolean | Indicates if the intended `accelerator` is the only `accelerator` that can run inference. If undefined, it should be assumed `false`. | +| mlm:accelerator_summary | string | A high level description of the `accelerator`, such as its specific generation, or other relevant inference details. | +| mlm:accelerator_count | integer | A minimum amount of `accelerator` instances required to run the model. | +| mlm:input [\[1\]][1] | \[[Model Input Object](#model-input-object)] | **REQUIRED** Describes the transformation between the EO data and the model input. | +| mlm:output [\[1\]][1] | \[[Model Output Object](#model-output-object)] | **REQUIRED** Describes each model output and how to interpret it. | +| mlm:hyperparameters [\[1\]][1] | [Model Hyperparameters Object](#model-hyperparameters-object) | Additional hyperparameters relevant for the model. | + + + +[1]: #notes + +### Notes +[1][1] Fields allowed only in Item `properties` + + + +> [!NOTE] +> Unless stated otherwise by [\[1\]][1] in the table, fields can be used at either the Item or Asset level. +>

+> To decide whether above fields should be applied under Item `properties` or under respective Assets, the context of +> each field must be considered. For example, the `mlm:name` should always be provided in the Item `properties`, since +> it relates to the model as a whole. In contrast, some models could support multiple `mlm:accelerator`, which could be +> handled by distinct source code represented by different Assets. In such case, `mlm:accelerator` definitions should be +> nested under their relevant Asset. If a field is defined both at the Item and Asset level, the value at the Asset +> level would be considered for that specific Asset, and the value at the Item level would be used for other Assets that +> did not override it for their respective reference. For some of the fields, further details are provided in following +> sections to provide more precisions regarding some potentially ambiguous use cases. + + In addition, fields from the multiple relevant extensions should be defined as applicable. See [Best Practices - Recommended Extensions to Compose with the ML Model Extension](best-practices.md#recommended-extensions-to-compose-with-the-ml-model-extension) @@ -632,13 +646,13 @@ In order to provide more context, the following roles are also recommended were ### Model Asset -| Field Name | Type | Description | -|-------------------|-------------------------------------------|--------------------------------------------------------------------------------------------------| -| title | string | Description of the model asset. | -| href | string | URI to the model artifact. | -| type | string | The media type of the artifact (see [Model Artifact Media-Type](#model-artifact-media-type). | -| roles | \[string] | **REQUIRED** Specify `mlm:model`. Can include `["mlm:weights", "mlm:checkpoint"]` as applicable. | -| mlm:artifact_type | [Artifact Type Enum](#artifact-type-enum) | Specifies the kind of model artifact. Typically related to a particular ML framework. | +| Field Name | Type | Description | +|-------------------|---------------------------------|--------------------------------------------------------------------------------------------------| +| title | string | Description of the model asset. | +| href | string | URI to the model artifact. | +| type | string | The media type of the artifact (see [Model Artifact Media-Type](#model-artifact-media-type). | +| roles | \[string] | **REQUIRED** Specify `mlm:model`. Can include `["mlm:weights", "mlm:checkpoint"]` as applicable. | +| mlm:artifact_type | [Artifact Type](#artifact-type) | Specifies the kind of model artifact. Typically related to a particular ML framework. | Recommended Asset `roles` include `mlm:weights` or `mlm:checkpoint` for model weights that need to be loaded by a model definition and `mlm:compiled` for models that can be loaded directly without an intermediate model definition. @@ -674,7 +688,7 @@ official. In order to validate the specific framework and artifact type employed [iana-media-type]: https://www.iana.org/assignments/media-types/media-types.xhtml -#### Artifact Type Enum +#### Artifact Type This value can be used to provide additional details about the specific model artifact being described. For example, PyTorch offers [various strategies][pytorch-frameworks] for providing model definitions, diff --git a/examples/item_bands_expression.json b/examples/item_bands_expression.json index edd4a41..38ba9f0 100644 --- a/examples/item_bands_expression.json +++ b/examples/item_bands_expression.json @@ -150,6 +150,7 @@ "mlm:model", "mlm:weights" ], + "mlm:artifact_type": "torch.save", "$comment": "Following 'eo:bands' is required to fulfil schema validation of 'eo' extension.", "eo:bands": [ { diff --git a/examples/item_basic.json b/examples/item_basic.json index 4ce3c24..3806446 100644 --- a/examples/item_basic.json +++ b/examples/item_basic.json @@ -120,7 +120,8 @@ "type": "text/html", "roles": [ "mlm:model" - ] + ], + "mlm:artifact_type": "torch.save" } }, "links": [ diff --git a/examples/item_eo_and_raster_bands.json b/examples/item_eo_and_raster_bands.json index 3c9da1d..1f1d9ab 100644 --- a/examples/item_eo_and_raster_bands.json +++ b/examples/item_eo_and_raster_bands.json @@ -508,6 +508,7 @@ "mlm:model", "mlm:weights" ], + "mlm:artifact_type": "torch.save", "$comment": "Following 'eo:bands' is required to fulfil schema validation of 'eo' extension.", "eo:bands": [ { @@ -557,7 +558,6 @@ "description": "Source code to run the model.", "type": "text/x-python", "roles": [ - "mlm:model", "code", "metadata" ] diff --git a/examples/item_eo_bands.json b/examples/item_eo_bands.json index 569c54b..605f3e6 100644 --- a/examples/item_eo_bands.json +++ b/examples/item_eo_bands.json @@ -285,6 +285,7 @@ "mlm:model", "mlm:weights" ], + "mlm:artifact_type": "torch.save", "$comment": "Following 'eo:bands' is required to fulfil schema validation of 'eo' extension.", "eo:bands": [ { diff --git a/examples/item_eo_bands_summarized.json b/examples/item_eo_bands_summarized.json index 53c66ab..d3c8709 100644 --- a/examples/item_eo_bands_summarized.json +++ b/examples/item_eo_bands_summarized.json @@ -377,6 +377,7 @@ "mlm:model", "mlm:weights" ], + "mlm:artifact_type": "torch.save", "$comment": "Following 'eo:bands' is required to fulfil schema validation of 'eo' extension.", "eo:bands": [ { @@ -426,7 +427,6 @@ "description": "Source code to run the model.", "type": "text/x-python", "roles": [ - "mlm:model", "code", "metadata" ] diff --git a/examples/item_multi_io.json b/examples/item_multi_io.json index dce2b87..e8f6b11 100644 --- a/examples/item_multi_io.json +++ b/examples/item_multi_io.json @@ -227,6 +227,7 @@ "mlm:model", "mlm:weights" ], + "mlm:artifact_type": "torch.save", "raster:bands": [ { "name": "B02 - blue", diff --git a/examples/item_raster_bands.json b/examples/item_raster_bands.json index c52fd42..33a29b9 100644 --- a/examples/item_raster_bands.json +++ b/examples/item_raster_bands.json @@ -216,6 +216,7 @@ "mlm:model", "mlm:weights" ], + "mlm:artifact_type": "torch.save", "raster:bands": [ { "name": "B01", diff --git a/json-schema/schema.json b/json-schema/schema.json index a7e14e8..b87ad45 100644 --- a/json-schema/schema.json +++ b/json-schema/schema.json @@ -6,7 +6,7 @@ "$comment": "Use 'allOf+if/then' for each 'type' to allow implementations to report more specific messages about the exact case in error (if any). Using only a 'oneOf/allOf' with the 'type' caused any incompatible 'type' to be reported first with a minimal and poorly described error by 'pystac'.", "allOf": [ { - "$comment": "This is the schema for STAC extension MLM in Items.", + "description": "This is the schema for STAC extension MLM in Items.", "if": { "required": [ "type" @@ -20,7 +20,7 @@ "then": { "allOf": [ { - "$comment": "Schema to validate the MLM fields under Item properties or Assets properties.", + "description": "Schema to validate the MLM fields permitted under Item properties or Assets properties.", "type": "object", "required": [ "properties", @@ -28,29 +28,13 @@ ], "properties": { "properties": { - "allOf": [ - { - "required": [ - "mlm:name", - "mlm:architecture", - "mlm:tasks", - "mlm:input", - "mlm:output" - ] - }, - { - "$ref": "#/$defs/fields" - } - ] + "$comment": "Schema to validate the MLM fields permitted under Item properties.", + "$ref": "#/$defs/mlmItemFields" }, "assets": { - "type": "object", "additionalProperties": { - "allOf": [ - { - "$ref": "#/$defs/fields" - } - ] + "$comment": "Schema to validate the MLM fields permitted under Asset properties.", + "$ref": "#/$defs/mlmAssetFields" } } } @@ -63,14 +47,18 @@ "$ref": "#/$defs/AnyBandsRef" }, { - "$comment": "Schema to validate model role requirement.", + "$comment": "Schema to validate that at least one Asset defines a model role.", "$ref": "#/$defs/AssetModelRoleMinimumOneDefinition" + }, + { + "$comment": "Schema to validate that the Asset model properties are mutually exclusive to the model role.", + "$ref": "#/$defs/AssetModelRequiredProperties" } ] } }, { - "$comment": "This is the schema for STAC extension MLM in Collections.", + "description": "This is the schema for STAC extension MLM in Collections.", "if": { "required": [ "type" @@ -89,19 +77,19 @@ "summaries": { "type": "object", "additionalProperties": { - "$ref": "#/$defs/fields" + "$ref": "#/$defs/mlmCollectionFields" } }, "assets": { "type": "object", "additionalProperties": { - "$ref": "#/$defs/fields" + "$ref": "#/$defs/mlmAssetFields" } }, "item_assets": { "type": "object", "additionalProperties": { - "$ref": "#/$defs/fields" + "$ref": "#/$defs/mlmAssetFields" } } } @@ -258,6 +246,7 @@ } }, "fields": { + "description": "All possible MLM fields regardless of the level they apply (Collection, Item, Asset, Link).", "type": "object", "properties": { "mlm:name": { @@ -310,14 +299,110 @@ }, "mlm:hyperparameters": { "$ref": "#/$defs/mlm:hyperparameters" + }, + "mlm:artifact_type": { + "$ref": "#/$defs/mlm:artifact_type" } }, - "$comment": "Allow properties not defined by MLM prefix to allow combination with other extensions.", + "$comment": "Allow properties not defined by MLM prefix to work with other extensions and attributes, but disallow undefined MLM fields.", "patternProperties": { "^(?!mlm:)": {} }, "additionalProperties": false }, + "mlmCollectionFields": { + "description": "Schema to validate the MLM fields permitted under Collection summaries.", + "allOf": [ + { + "description": "Fields that are mandatory under the Collection summaries.", + "type": "object", + "required": [] + }, + { + "description": "Fields that are disallowed under the Collection summaries.", + "not": { + "required": [ + "mlm:input", + "mlm:output", + "mlm:artifact_type" + ] + } + }, + { + "description": "Field with known definitions that must be validated.", + "$ref": "#/$defs/fields" + } + ] + }, + "mlmItemFields": { + "description": "Schema to validate the MLM fields permitted under Item properties.", + "allOf": [ + { + "description": "Fields that are mandatory under the Item properties.", + "required": [ + "mlm:name", + "mlm:architecture", + "mlm:tasks", + "mlm:input", + "mlm:output" + ] + }, + { + "description": "Fields that are disallowed under the Item properties.", + "$comment": "Particularity of the 'not/required' approach: they must be tested one by one. Otherwise, it validates that they are all (simultaneously) not present.", + "not": { + "anyOf": [ + { + "required": [ + "mlm:artifact_type" + ] + } + ] + } + }, + { + "description": "Field with known definitions that must be validated.", + "$ref": "#/$defs/fields" + } + ] + }, + "mlmAssetFields": { + "description": "Schema to validate the MLM fields permitted under Assets properties.", + "allOf": [ + { + "description": "Fields that are disallowed under the Asset properties.", + "$comment": "Particularity of the 'not/required' approach: they must be tested one by one. Otherwise, it validates that they are all (simultaneously) not present.", + "not": { + "anyOf": [ + { + "required": [ + "mlm:name" + ] + }, + { + "required": [ + "mlm:input" + ] + }, + { + "required": [ + "mlm:output" + ] + }, + { + "required": [ + "mlm:hyperparameters" + ] + } + ] + } + }, + { + "description": "Field with known definitions that must be validated.", + "$ref": "#/$defs/fields" + } + ] + }, "mlm:name": { "type": "string", "pattern": "^[a-zA-Z][a-zA-Z0-9_.\\-\\s]+[a-zA-Z0-9]$" @@ -369,6 +454,15 @@ "type": "string", "pattern": "^(0|[1-9]\\d*)\\.(0|[1-9]\\d*)\\.(0|[1-9]\\d*)(?:-((?:0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\\.(?:0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\\+([0-9a-zA-Z-]+(?:\\.[0-9a-zA-Z-]+)*))?$" }, + "mlm:artifact_type": { + "type": "string", + "minLength": 1, + "examples": [ + "torch.save", + "torch.jit.save", + "torch.export.save" + ] + }, "mlm:tasks": { "type": "array", "uniqueItems": true, @@ -845,6 +939,57 @@ "DataType": { "$ref": "https://stac-extensions.github.io/raster/v1.1.0/schema.json#/definitions/bands/items/properties/data_type" }, + "HasArtifactType": { + "$comment": "Used to check the artifact type property that is required by a Model Asset annotated by 'mlm:model' role.", + "type": "object", + "required": [ + "mlm:artifact_type" + ], + "properties": { + "mlm:artifact_type": { + "$ref": "#/$defs/mlm:artifact_type" + } + } + }, + "AssetModelRole": { + "$comment": "Used to check the presence of 'mlm:model' role required by a Model Asset.", + "type": "object", + "required": [ + "roles" + ], + "properties": { + "roles": { + "type": "array", + "contains": { + "const": "mlm:model" + }, + "minItems": 1 + } + } + }, + "AssetModelRequiredProperties": { + "$comment": "Asset containing the model definition must indicate both the 'mlm:model' role and an artifact type.", + "required": [ + "assets" + ], + "properties": { + "assets": { + "additionalProperties": { + "if": { + "$ref": "#/$defs/AssetModelRole" + }, + "then": { + "$ref": "#/$defs/HasArtifactType" + }, + "else": { + "not": { + "$ref": "#/$defs/HasArtifactType" + } + } + } + } + } + }, "AssetModelRoleMinimumOneDefinition": { "$comment": "At least one Asset must provide the model definition indicated by the 'mlm:model' role.", "required": [ @@ -855,15 +1000,7 @@ "properties": { "assets": { "additionalProperties": { - "properties": { - "roles": { - "type": "array", - "items": { - "const": "mlm:model" - }, - "minItems": 1 - } - } + "$ref": "#/$defs/AssetModelRole" } } } @@ -891,19 +1028,6 @@ } ] }, - "AssetModelRole": { - "required": [ - "roles" - ], - "properties": { - "roles": { - "contains": { - "type": "string", - "const": "mlm:model" - } - } - } - }, "ModelBands": { "description": "List of bands (if any) that compose the input. Band order represents the index position of the bands.", "$comment": "No 'minItems' here to support model inputs not using any band (other data source).", diff --git a/stac_model/examples.py b/stac_model/examples.py index e5047cc..05c1c0b 100644 --- a/stac_model/examples.py +++ b/stac_model/examples.py @@ -130,6 +130,7 @@ def eurosat_resnet() -> ItemMLModelExtension: "mlm:weights", "data", ], + extra_fields={"mlm:artifact_type": "torch.save"} ), "source_code": pystac.Asset( title="Model implementation.", diff --git a/stac_model/output.py b/stac_model/output.py index 5b1d2ea..6654240 100644 --- a/stac_model/output.py +++ b/stac_model/output.py @@ -7,8 +7,8 @@ class ModelResult(MLMBaseModel): - shape: list[int | float] = Field(..., min_items=1) - dim_order: list[str] = Field(..., min_items=1) + shape: list[int | float] = Field(..., min_length=1) + dim_order: list[str] = Field(..., min_length=1) data_type: DataType diff --git a/stac_model/runtime.py b/stac_model/runtime.py index 197f7b8..15abc7b 100644 --- a/stac_model/runtime.py +++ b/stac_model/runtime.py @@ -46,4 +46,4 @@ class Runtime(MLMBaseModel): accelerator: AcceleratorType | None = Field(default=None) accelerator_constrained: bool = Field(default=False) accelerator_summary: Annotated[str | None, OmitIfNone] = Field(default=None) - accelerator_count: Annotated[int | None, OmitIfNone] = Field(default=None, minimum=1) + accelerator_count: Annotated[int | None, OmitIfNone] = Field(default=None, ge=1) diff --git a/tests/test_schema.py b/tests/test_schema.py index 53addef..48e120a 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -4,6 +4,7 @@ import pystac import pytest +from jsonschema.exceptions import ValidationError from pystac.validation.stac_validator import STACValidator from stac_model.base import JSON @@ -11,6 +12,9 @@ from conftest import get_all_stac_item_examples +# ignore typing errors introduced by generic JSON manipulation errors +# mypy: disable_error_code="arg-type,call-overload,index,union-attr" + @pytest.mark.parametrize( "mlm_example", # value passed to 'mlm_example' fixture @@ -32,7 +36,7 @@ def test_mlm_schema( ["item_raster_bands.json"], indirect=True, ) -def test_mlm_no_undefined_prefixed_field( +def test_mlm_no_undefined_prefixed_field_item_properties( mlm_validator: STACValidator, mlm_example: dict[str, JSON], ) -> None: @@ -40,11 +44,82 @@ def test_mlm_no_undefined_prefixed_field( mlm_item = pystac.Item.from_dict(mlm_data) pystac.validation.validate(mlm_item, validator=mlm_validator) # ensure original is valid - mlm_data["properties"]["mlm:unknown"] = "random" # type: ignore + # undefined property anywhere in the schema + mlm_data = copy.deepcopy(mlm_example) + mlm_data["properties"]["mlm:unknown"] = "random" + with pytest.raises(pystac.errors.STACValidationError) as exc: + mlm_item = pystac.Item.from_dict(mlm_data) + pystac.validation.validate(mlm_item, validator=mlm_validator) + assert all( + info in str(exc.value.source) + for info in ["mlm:unknown", "^(?!mlm:)"] + ) + + # defined property only allowed at the Asset level + mlm_data = copy.deepcopy(mlm_example) + mlm_data["properties"]["mlm:artifact_type"] = "torch.save" with pytest.raises(pystac.errors.STACValidationError) as exc: mlm_item = pystac.Item.from_dict(mlm_data) pystac.validation.validate(mlm_item, validator=mlm_validator) - assert all(field in str(exc.value.source) for field in ["mlm:unknown", "^(?!mlm:)"]) + errors = cast(list[ValidationError], exc.value.source) + assert "mlm:artifact_type" in str(errors[0].validator_value) + assert errors[0].schema["description"] == "Fields that are disallowed under the Item properties." + + +@pytest.mark.parametrize( + "mlm_example", + ["item_raster_bands.json"], + indirect=True, +) +@pytest.mark.parametrize( + ["test_field", "test_value"], + [ + ("mlm:unknown", "random"), + ("mlm:name", "test-model"), + ("mlm:input", []), + ("mlm:output", []), + ("mlm:hyperparameters", {"test": {}}), + ] +) +def test_mlm_no_undefined_prefixed_field_asset_properties( + mlm_validator: STACValidator, + mlm_example: dict[str, JSON], + test_field: str, + test_value: Any, +) -> None: + mlm_data = copy.deepcopy(mlm_example) + mlm_item = pystac.Item.from_dict(mlm_data) + pystac.validation.validate(mlm_item, validator=mlm_validator) # ensure original is valid + assert mlm_data["assets"]["weights"] + + mlm_data = copy.deepcopy(mlm_example) + mlm_data["assets"]["weights"][test_field] = test_value + with pytest.raises(pystac.errors.STACValidationError) as exc: + mlm_item = pystac.Item.from_dict(mlm_data) + pystac.validation.validate(mlm_item, validator=mlm_validator) + assert len(exc.value.source) == 1 + errors = cast(list[ValidationError], exc.value.source) + assert test_field in errors[0].instance + assert errors[0].schema["description"] in [ + "All possible MLM fields regardless of the level they apply (Collection, Item, Asset, Link).", + "Fields that are disallowed under the Asset properties." + ] + + +@pytest.mark.parametrize( + "mlm_example", + ["item_raster_bands.json"], + indirect=True, +) +def test_mlm_allowed_field_asset_properties_override( + mlm_validator: STACValidator, + mlm_example: dict[str, JSON], +) -> None: + # defined property allowed both at the Item at the Asset level + mlm_data = copy.deepcopy(mlm_example) + mlm_data["assets"]["weights"]["mlm:accelerator"] = "cuda" + mlm_item = pystac.Item.from_dict(mlm_data) + pystac.validation.validate(mlm_item, validator=mlm_validator) @pytest.mark.parametrize( @@ -60,7 +135,7 @@ def test_mlm_missing_bands_invalid_if_mlm_input_lists_bands( pystac.validation.validate(mlm_item, validator=mlm_validator) # ensure original is valid mlm_bands_bad_data = copy.deepcopy(mlm_example) - mlm_bands_bad_data["assets"]["weights"].pop("raster:bands") # type: ignore # no 'None' to raise in case modified + mlm_bands_bad_data["assets"]["weights"].pop("raster:bands") # no 'None' to raise in case missing with pytest.raises(pystac.errors.STACValidationError): mlm_bands_bad_item = pystac.Item.from_dict(mlm_bands_bad_data) pystac.validation.validate(mlm_bands_bad_item, validator=mlm_validator) @@ -79,7 +154,7 @@ def test_mlm_eo_bands_invalid_only_in_item_properties( pystac.validation.validate(mlm_item, validator=mlm_validator) # ensure original is valid mlm_eo_bands_bad_data = copy.deepcopy(mlm_example) - mlm_eo_bands_bad_data["assets"]["weights"].pop("eo:bands") # type: ignore # no 'None' to raise in case modified + mlm_eo_bands_bad_data["assets"]["weights"].pop("eo:bands") # no 'None' to raise in case missing with pytest.raises(pystac.errors.STACValidationError): mlm_eo_bands_bad_item = pystac.Item.from_dict(mlm_eo_bands_bad_data) pystac.validation.validate(mlm_eo_bands_bad_item, validator=mlm_validator) @@ -95,12 +170,12 @@ def test_mlm_no_input_allowed_but_explicit_empty_array_required( mlm_example: dict[str, JSON], ) -> None: mlm_data = copy.deepcopy(mlm_example) - mlm_data["properties"]["mlm:input"] = [] # type: ignore + mlm_data["properties"]["mlm:input"] = [] mlm_item = pystac.Item.from_dict(mlm_data) pystac.validation.validate(mlm_item, validator=mlm_validator) with pytest.raises(pystac.errors.STACValidationError): - mlm_data["properties"].pop("mlm:input") # type: ignore # no 'None' to raise in case modified + mlm_data["properties"].pop("mlm:input") # no 'None' to raise in case missing mlm_item = pystac.Item.from_dict(mlm_data) pystac.validation.validate(mlm_item, validator=mlm_validator) @@ -163,13 +238,13 @@ def test_mlm_other_non_mlm_assets_allowed( mlm_item = pystac.Item.from_dict(mlm_data) pystac.validation.validate(mlm_item, validator=mlm_validator) # self-check valid beforehand - mlm_data["assets"]["sample"] = { # type: ignore + mlm_data["assets"]["sample"] = { "type": "image/jpeg", "href": "https://example.com/sample/output.jpg", "roles": ["preview"], "title": "Model Output Predictions Sample", } - mlm_data["assets"]["model-cart"] = { # type: ignore + mlm_data["assets"]["model-cart"] = { "type": "text/markdown", "href": "https://example.com/sample/model.md", "roles": ["metadata"], @@ -184,25 +259,71 @@ def test_mlm_other_non_mlm_assets_allowed( ["item_basic.json"], indirect=True, ) +@pytest.mark.parametrize( + ["model_asset_extras", "is_valid"], + [ + ({"roles": ["checkpoint"]}, False), + ({"roles": ["checkpoint", "mlm:model"]}, False), + ({"roles": ["checkpoint"], "mlm:artifact_type": "test"}, False), + ({"roles": ["checkpoint", "mlm:model"], "mlm:artifact_type": "test"}, True), + ] +) def test_mlm_at_least_one_asset_model( mlm_validator: STACValidator, mlm_example: dict[str, JSON], + model_asset_extras: dict[str, Any], + is_valid: bool, ) -> None: mlm_data = copy.deepcopy(mlm_example) mlm_item = pystac.Item.from_dict(mlm_data) pystac.validation.validate(mlm_item, validator=mlm_validator) # self-check valid beforehand - mlm_data["assets"] = { # needs at least 1 asset with role 'mlm:model' - "model": { - "type": "application/octet-stream; application=pytorch", - "href": "https://example.com/sample/checkpoint.pt", - "roles": ["checkpoint"], - "title": "Model Weights Checkpoint", - } + mlm_model = { + "type": "application/octet-stream; application=pytorch", + "href": "https://example.com/sample/checkpoint.pt", + "title": "Model Weights Checkpoint", } - with pytest.raises(pystac.errors.STACValidationError): - mlm_item = pystac.Item.from_dict(mlm_data) + mlm_model.update(model_asset_extras) + mlm_data["assets"] = { + "model": mlm_model # type: ignore + } + mlm_item = pystac.Item.from_dict(mlm_data) + if is_valid: + pystac.validation.validate(mlm_item, validator=mlm_validator) + else: + with pytest.raises(pystac.errors.STACValidationError) as exc: + pystac.validation.validate(mlm_item, validator=mlm_validator) + errors = cast(list[ValidationError], exc.value.source) + assert errors[0].schema["$comment"] in [ + "At least one Asset must provide the model definition indicated by the 'mlm:model' role.", + "Used to check the artifact type property that is required by a Model Asset annotated by 'mlm:model' role." + ] + + +@pytest.mark.parametrize( + "mlm_example", + ["item_basic.json"], + indirect=True, +) +def test_mlm_asset_artifact_type_checked( + mlm_validator: STACValidator, + mlm_example: dict[str, JSON], +) -> None: + mlm_data = copy.deepcopy(mlm_example) + mlm_item = pystac.Item.from_dict(mlm_data) + pystac.validation.validate(mlm_item, validator=mlm_validator) # self-check valid beforehand + + mlm_data["assets"]["model"]["mlm:artifact_type"] = 1234 # type: ignore + mlm_item = pystac.Item.from_dict(mlm_data) + with pytest.raises(pystac.errors.STACValidationError) as exc: + pystac.validation.validate(mlm_item, validator=mlm_validator) + assert "1234 is not of type 'string'" in str(exc.value.source) + + mlm_data["assets"]["model"]["mlm:artifact_type"] = "" # type: ignore + mlm_item = pystac.Item.from_dict(mlm_data) + with pytest.raises(pystac.errors.STACValidationError) as exc: pystac.validation.validate(mlm_item, validator=mlm_validator) + assert "should be non-empty" in str(exc.value.source) def test_model_metadata_to_dict(eurosat_resnet):