Skip to content

Commit

Permalink
add plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed May 31, 2024
1 parent 8ee1609 commit 77cbd11
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
24 changes: 15 additions & 9 deletions src/anemoi/inference/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,34 @@
import tqdm
from ai_models.model import Model

from anemoi.inference.runner import AUTOCAST
from anemoi.inference.runner import DefaultRunner

LOG = logging.getLogger(__name__)


class AIModelPlugin(Model):

_expver = "0000"

@property
def expver(self):
if self._expver == "0000":
LOG.warning(f"'expver' is not available in this model, using '{self._expver}'.")
return self._expver

@expver.setter
def expver(self, value):
self._expver = value

def parse_model_args(self, args):
parser = argparse.ArgumentParser()

parser.add_argument("--checkpoint")
parser.add_argument("--checkpoint", required=not hasattr(self, "download_files"))
parser.add_argument(
"--autocast",
type=str,
choices=(
"16",
"32",
"float16",
"float32",
"bfloat16",
"b16",
),
choices=sorted(AUTOCAST.keys()),
)
args = parser.parse_args(args)

Expand Down
5 changes: 4 additions & 1 deletion src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
None: torch.float16,
}


Expand Down Expand Up @@ -106,6 +105,10 @@ def run(
if autocast is None:
autocast = self.checkpoint.precision

if autocast is None:
LOGGER.warning("No autocast given, using float16")
autocast = "16"

autocast = AUTOCAST[autocast]

input_fields = input_fields.sel(**self.checkpoint.select)
Expand Down

0 comments on commit 77cbd11

Please sign in to comment.