Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for mlflow #77

Open
wants to merge 307 commits into
base: main
Choose a base branch
from
Open

Add support for mlflow #77

wants to merge 307 commits into from

Conversation

khintz
Copy link
Contributor

@khintz khintz commented Oct 3, 2024

Describe your changes

Add support for mlflow logger by utilising pytorch_lightning.loggers
The native wandb module is replaced with pytorch_lightning wandb logger and introducing pytorch_lightning mlflow logger.
https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/loggers/logger.py

This will allow people to choose between wandb and mlflow.

Builds upon #66 although this is not strictly necessary for this change, but I am working with this feature to work with our dataset.

Issue Link

Closes #76

Type of change

  • 🐛 Bug fix (non-breaking change that fixes an issue)
  • ✨ New feature (non-breaking change that adds functionality)
  • 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • 📖 Documentation (Addition or improvements to documentation)

Checklist before requesting a review

  • My branch is up-to-date with the target branch - if not update your fork with the changes from the target branch (use pull with --rebase option if possible).
  • I have performed a self-review of my code
  • For any new/modified functions/classes I have added docstrings that clearly describe its purpose, expected inputs and returned values
  • I have placed in-line comments to clarify the intent of any hard-to-understand passages of my code
  • I have updated the README to cover introduced code changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have given the PR a name that clearly describes the change, written in imperative form (context).
  • I have requested a reviewer and an assignee (assignee is responsible for merging). This applies only if you have write access to the repo, otherwise feel free to tag a maintainer to add a reviewer and assignee.

Checklist for reviewers

Each PR comes with its own improvements and flaws. The reviewer should check the following:

  • the code is readable
  • the code is well tested
  • the code is documented (including return types and parameters)
  • the code is easy to maintain

Author checklist after completed review

  • I have added a line to the CHANGELOG describing this change, in a section
    reflecting type of change (add section where missing):
    • added: when you have added new functionality
    • changed: when default behaviour of the code has been changed
    • fixes: when your contribution fixes a bug

Checklist for assignee

  • PR is up to date with the base branch
  • the tests pass
  • author has added an entry to the changelog (and designated the change as added, changed or fixed)
  • Once the PR is ready to be merged, squash commits and merge the PR.

@khintz
Copy link
Contributor Author

khintz commented Dec 4, 2024

A status from me. I got a bit further with the signature (thanks to @sadamov), but I am seeing an error that I am currently trying to understand. @elbdmi will also have a look at this.

The log_model() and create_input_example() functions now looks like (in train_model.py:

def log_model(self, data_module, model):
        input_example = self.create_input_example(data_module)

        with torch.no_grad():
            model_output = model.common_step(input_example)[0] # expects batch, returns tuple (ar_model)

        #TODO: Are we sure we can hardcode the input names?
        signature = infer_signature(
            {name: tensor.cpu().numpy() for name, tensor in zip(['init_states', 'target_states', 'forcing', 'target_times'], input_example)},
            model_output.cpu().numpy()
        )

        mlflow.pytorch.log_model(
            model,
            "model",
            input_example=input_example[0].cpu().numpy(),
            signature=signature
        )

def create_input_example(self, data_module):

        if data_module.val_dataset is None:
            data_module.setup(stage="fit")

        data_loader = data_module.train_dataloader()
        batch_sample = next(iter(data_loader))
        return batch_sample

For example model(*input) had to be changed to model.common_step(input) because pytorch expected a function called forward(), which is default as far as I understand. @joeloskarsson would it make sense to rename common_step() to forward()?

When I log the model I do

training_logger.log_model(data_module, model)

but it fails to validate "serving the input example" and print the full tensor:

2024/12/03 10:56:04 WARNING mlflow.models.model: Failed to validate serving input example {
  "inputs": [
    [
      [
        [
          -1.1313700675964355,
          1.9031524658203125,
          -0.8801671266555786,
          0.11871696263551712,
          0.9589263200759888
        ]
...
...
...
   [ 0.08666334  0.1328821  -3.2410812  -0.21018785 -0.6424305 ]]]]' with schema '['init_states': Tensor('float32', (-1, 2, 464721, 5)), 'target_states': Tensor('float32', (-1, 1, 464721, 5)), 'forcing': Tensor('float32', (-1, 1, 464721, 3)), 'target_times': Tensor('int64', (-1, 1))]'. 
Error: Model is missing inputs ['init_states', 'target_states', 'forcing', 'target_times'].
2024/12/03 10:56:19 INFO mlflow.tracking._tracking_service.client: 🏃 View run dazzling-bird-134 at: https://mlflow.dmidev.org/#/experiments/2/runs/8d183886feb648ed87e0e981b6e2c898.
2024/12/03 10:56:19 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: https://mlflow.dmidev.org/#/experiments/2.
2024/12/03 10:56:19 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2024/12/03 10:56:19 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!

Input and output data examples are uploaded to mlflow, see this example.

@khintz
Copy link
Contributor Author

khintz commented Dec 9, 2024

I think it's best to skip the log model feature of mlflow for now.
It is not straightforward to implement in an useful way I think, but I would like some opinions.

Discarding the log model feature, means mlflow only will be supported for training (like wandb).

When I call mlflow.pytorch.log_model() with an input example the full tensor is printed, because ( https://mlflow.org/docs/latest/models.html - thanks @elbdmi )

MLflow will propagate any errors raised by the model if
the model does not accept the provided input type

Instead, I can infer the signature and give that, which kind of works. However, there are two issues with that.

1: mlflow.pytorch.log_model() tries to define a conda env to upload with the model. It can't find my environment (I'm using pdm), so I get some useless requirement files uploaded which doesn't reflect the dependencies needed.

2: I get some warnings from MLFlow about modules not being found. It doesn't seem to affect the results, but I have not figured out what the reason is. However, it is related to when it tries to load the model, where it seems to refer to a name similar to the processor:

stdout:
stderr: Traceback (most recent call last):
  File "/vf/kah/neural-lam/.venv/lib/python3.10/site-packages/mlflow/utils/_capture_modules.py", line 256, in <module>
    main()
  File "/vf/kah/neural-lam/.venv/lib/python3.10/site-packages/mlflow/utils/_capture_modules.py", line 233, in main
    store_imported_modules(
  File "/vf/kah/neural-lam/.venv/lib/python3.10/site-packages/mlflow/utils/_capture_modules.py", line 210, in store_imported_modules
    importlib.import_module(f"mlflow.{flavor}")._load_pyfunc(model_path)
  File "/vf/kah/neural-lam/.venv/lib/python3.10/site-packages/mlflow/pytorch/__init__.py", line 720, in _load_pyfunc
    return _PyTorchWrapper(_load_model(path, device=device), device=device)
  File "/vf/kah/neural-lam/.venv/lib/python3.10/site-packages/mlflow/pytorch/__init__.py", line 614, in _load_model
    pytorch_model = torch.load(model_path, **kwargs)
  File "/vf/kah/neural-lam/.venv/lib/python3.10/site-packages/torch/serialization.py", line 1097, in load
    return _load(
  File "/vf/kah/neural-lam/.venv/lib/python3.10/site-packages/torch/serialization.py", line 1525, in _load
    result = unpickler.load()
  File "/vf/kah/neural-lam/.venv/lib/python3.10/site-packages/torch/serialization.py", line 1515, in find_class
    return super().find_class(mod_name, name)
  File "/vf/kah/neural-lam/.venv/lib/python3.10/site-packages/mlflow/utils/_capture_modules.py", line 57, in wrapper
    result = original(name, globals, locals, fromlist, level)
ModuleNotFoundError: No module named 'Sequential_e291b6'

I have also tried to use mlflows validate_serving_input() function to validate the input when not given to the log_model() function (https://mlflow.org/docs/latest/_modules/mlflow/models/utils.html#validate_serving_input) but I could not get that to work.

The whole CustomMLFlowLogger class looks like this now:

class CustomMLFlowLogger(pl.loggers.MLFlowLogger):
    def __init__(self, experiment_name, tracking_uri):
        super().__init__(
            experiment_name=experiment_name, tracking_uri=tracking_uri
        )
        mlflow.start_run(run_id=self.run_id, log_system_metrics=True)
        mlflow.log_param("run_id", self.run_id)

    @property
    def save_dir(self):
        return "mlruns"

    def log_image(self, key, images, step=None):
        # Third-party
        from PIL import Image

        if step is not None:
            key = f"{key}_{step}"

        # Need to save the image to a temporary file, then log that file
        # mlflow.log_image, should do this automatically, but is buggy
        temporary_image = f"{key}.png"
        images[0].savefig(temporary_image)

        img = Image.open(temporary_image)
        mlflow.log_image(img, f"{key}.png")


    def log_model(self, data_module, model):
        input_example = self.create_input_example(data_module)

        with torch.no_grad():
            model_output = model.common_step(input_example)[
                0
            ]  # expects batch, returns tuple (prediction, target, pred_std, _)

        log_model_input_example = {
            name: tensor.cpu().numpy()
            for name, tensor in zip(
                ["init_states", "target_states", "forcing", "target_times"],
                input_example,
            )
        }

        signature = infer_signature(
            log_model_input_example, model_output.cpu().numpy()
        )

        mlflow.pytorch.log_model(
            model,
            "model",
            signature=signature,
        )

        # validate_serving_input(model_uri, validate_example)

    def create_input_example(self, data_module):

        if data_module.val_dataset is None:
            data_module.setup(stage="fit")

        data_loader = data_module.train_dataloader()
        batch_sample = next(iter(data_loader))
        return batch_sample

@khintz khintz marked this pull request as ready for review December 9, 2024 21:15
Comment on lines 75 to 108
def log_model(self, data_module, model):
input_example = self.create_input_example(data_module)

with torch.no_grad():
model_output = model.common_step(input_example)[
0
] # common_step returns tuple (prediction, target, pred_std, _)

log_model_input_example = {
name: tensor.cpu().numpy()
for name, tensor in zip(
["init_states", "target_states", "forcing", "target_times"],
input_example,
)
}

signature = infer_signature(
log_model_input_example, model_output.cpu().numpy()
)

mlflow.pytorch.log_model(
model,
"model",
signature=signature,
)

def create_input_example(self, data_module):

if data_module.val_dataset is None:
data_module.setup(stage="fit")

data_loader = data_module.train_dataloader()
batch_sample = next(iter(data_loader))
return batch_sample
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log_model and thereby also create_input_example is not used so they can be removed. However it can be used to log a model if one wishes to do so, but with an input example that is not validated.
I will vote for removing this and revisit in another PR if needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed with 821443a

Comment on lines 250 to 252
warnings.warn(
"Only WandbLogger & MLFlowLogger is supported for tracking metrics"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess what happens in this case is that experiment results will only go to stdout? It would be good to clarify that in this warning.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed with d503048

Comment on lines +89 to +91
logger: str = "wandb"
logger_url: str = ""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do we actually want the config to contain, vs what should be cmd-line arguments? I would have thought that the choice of logger would be an argparse flag, in a similar way as the plotting choices. My thought process is that logging/plotting does not affect the end product (trained model) whereas all the current options in the config does. But we are not really consistent with this divide either, as there are plenty of argparse options currently that change the model training.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it sounds reasonable to have the logging choices as cmd-line arguments given that the plot arguments are already in that category. On the other hand, don't risk to get too many cmd-line arguments? I find it sometimes quite hard to remember the correct names. Either way, I agree that either both plot and logger should be cmd-line or both should be in a config.

@leifdenby leifdenby requested a review from elbdmi December 10, 2024 10:47
@leifdenby leifdenby added this to the v0.4.0 milestone Dec 16, 2024
@khintz khintz requested review from SimonKamuk and removed request for elbdmi January 14, 2025 12:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support mlflow logger (and other loggers from pytorch-lightning)
4 participants