From 7dac1b53cd0f1b9744a07108e4e6ab1c0774757a Mon Sep 17 00:00:00 2001 From: Julian Trommer Date: Wed, 17 Jul 2024 16:17:27 +0200 Subject: [PATCH 1/3] Added Aqua.jl to tests --- Project.toml | 7 ++++--- README.md | 1 + test/runtests.jl | 13 +++++++++++++ 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index cf4c214..21f9446 100644 --- a/Project.toml +++ b/Project.toml @@ -16,7 +16,6 @@ JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" -Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -29,6 +28,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [compat] +Aqua = "0.8" CUDA = "5" ChainRulesCore = "1.16.0 - 1" DataFrames = "1.6" @@ -41,7 +41,6 @@ JSON = "0.21" Lux = "0.5" LuxCUDA = "0.3" Optimisers = "0.3" -Optimization = "3.18" Printf = "1" ProgressMeter = "1.7.0 - 1" Random = "1" @@ -49,13 +48,15 @@ SciMLBase = "2.7.0 - 2" SciMLSensitivity = "7.45 - 7" Statistics = "1" TFRecord = "0.4.2" +Test = "1" Wandb = "0.5" Zygote = "0.6" cuDNN = "1.3" julia = "1.10" [extras] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["Aqua", "Test"] diff --git a/README.md b/README.md index 72de6e2..ee3caff 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ # MeshGraphNets.jl [![Docs](https://img.shields.io/badge/docs-dev-blue.svg)](https://una-auxme.github.io/MeshGraphNets.jl/dev) +[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) [*MeshGraphNets.jl*](https://github.com/una-auxme/MeshGraphNets.jl) is a software package for the Julia programming language that provides an implementation of the [MeshGraphNets](https://arxiv.org/abs/2010.03409) framework by [Google DeepMind](https://deepmind.google/) for simulating mesh-based physical systems via graph neural networks: diff --git a/test/runtests.jl b/test/runtests.jl index 5080d49..3005a8b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,7 +5,20 @@ using MeshGraphNets using Test +using Aqua @testset "MeshGraphNets.jl" begin # TODO + @testset "Aqua.jl" begin + # Ambiguities in external packages + @testset "Method ambiguity" begin + Aqua.test_ambiguities([MeshGraphNets]) + end + Aqua.test_all(MeshGraphNets; ambiguities = false, piracies = false) + + # Piracy due to Zygote pullback + # @testset "Piracy" begin + # Aqua.test_piracies(MeshGraphNets) + # end + end end From 5cbcc472485dec1d13d16c15df2faf14a8da7da4 Mon Sep 17 00:00:00 2001 From: Julian Trommer Date: Thu, 18 Jul 2024 14:05:37 +0200 Subject: [PATCH 2/3] Fixed piracies of differentation functions --- src/graph.jl | 17 +++++++++++++---- src/utils.jl | 16 ---------------- test/runtests.jl | 7 +------ 3 files changed, 14 insertions(+), 26 deletions(-) diff --git a/src/graph.jl b/src/graph.jl index 0cce58b..ddd3b12 100644 --- a/src/graph.jl +++ b/src/graph.jl @@ -69,11 +69,20 @@ Constructs a [FeatureGraph](@ref) based on the given arguments. - Resulting [FeatureGraph](@ref). """ function build_graph(mgn::GraphNetwork, data, fields, datapoint::Integer, node_type, edge_features::AbstractArray{Float32, 2}, senders::AbstractArray{T, 1}, receivers::AbstractArray{T, 1}) where {T <: Integer} + # Removed generator in favor of removing Zygote.jl piracies (minimal increase of time and allocations) + # Can be reverted once Enzyme.jl is compatible + nt = mgn.n_norm["node_type"](node_type) + nf = similar(nt, 0, size(nt, 2)) + for field in fields + nf = vcat(nf, mgn.n_norm[field](data[field][:, :, min(size(data[field], 3), datapoint)])) + end + nf = vcat(nf, nt) return FeatureGraph( - vcat( - [mgn.n_norm[field](data[field][:, :, min(size(data[field], 3), datapoint)]) for field in fields]..., - mgn.n_norm["node_type"](node_type) - ), + nf, + # vcat( + # [mgn.n_norm[field](data[field][:, :, min(size(data[field], 3), datapoint)]) for field in fields]..., + # mgn.n_norm["node_type"](node_type) + # ), mgn.e_norm(edge_features), senders, receivers diff --git a/src/utils.jl b/src/utils.jl index d1c5603..de56767 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -286,19 +286,3 @@ function clear_log(lines::Integer, move_up = true) clear_line() end end - -#################################################################### -# Overwritten for differentiation over GraphNetwork and normaliser # -#################################################################### - -Zygote.accum(x::Base.RefValue{Any}, y::NamedTuple{(:model, :ps, :st, :e_norm, :n_norm, :o_norm)}) = Zygote.accum(x[], y) - -Zygote.accum(x::NamedTuple{(:model, :ps, :st, :e_norm, :n_norm, :o_norm)}, y::Base.RefValue{Any}) = Zygote.accum(x, y[]) - -Zygote.accum(x::Base.RefValue{Any}, y::NamedTuple{(:data_min, :data_max, :target_min, :target_max)}) = Zygote.accum(x[], y) - -Base.:+(x::Base.RefValue{Any}, y::NamedTuple{(:model, :ps, :st, :e_norm, :n_norm, :o_norm)}) = Zygote.accum(x[], y) - -Base.:+(x::NamedTuple{(:model, :ps, :st, :e_norm, :n_norm, :o_norm)}, y::Base.RefValue{Any}) = Zygote.accum(x, y[]) - -Base.:+(x::Base.RefValue{Any}, y::NamedTuple{(:data_min, :data_max, :target_min, :target_max)}) = Zygote.accum(x[], y) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 3005a8b..cc928ad 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,11 +14,6 @@ using Aqua @testset "Method ambiguity" begin Aqua.test_ambiguities([MeshGraphNets]) end - Aqua.test_all(MeshGraphNets; ambiguities = false, piracies = false) - - # Piracy due to Zygote pullback - # @testset "Piracy" begin - # Aqua.test_piracies(MeshGraphNets) - # end + Aqua.test_all(MeshGraphNets; ambiguities = false) end end From 8c0dd850823db22df77f8cf84ee6d1c98784eff9 Mon Sep 17 00:00:00 2001 From: Julian Trommer Date: Thu, 18 Jul 2024 14:08:19 +0200 Subject: [PATCH 3/3] Removed general dependencies in favor of specialized packages --- Project.toml | 4 ++-- examples/cylinder_flow/cylinder_flow.jl | 2 +- src/MeshGraphNets.jl | 4 ++-- src/dataset.jl | 3 ++- src/strategies.jl | 3 +-- src/utils.jl | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 21f9446..2d179d0 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,6 @@ version = "0.4.0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" GraphNetCore = "7809f980-de1b-4f9a-8451-85f041491431" HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" @@ -16,6 +15,7 @@ JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -32,7 +32,6 @@ Aqua = "0.8" CUDA = "5" ChainRulesCore = "1.16.0 - 1" DataFrames = "1.6" -DifferentialEquations = "7.11 - 7" Distributions = "0.25" GraphNetCore = "0.3" HDF5 = "0.17" @@ -41,6 +40,7 @@ JSON = "0.21" Lux = "0.5" LuxCUDA = "0.3" Optimisers = "0.3" +OrdinaryDiffEq = "6.85 - 6" Printf = "1" ProgressMeter = "1.7.0 - 1" Random = "1" diff --git a/examples/cylinder_flow/cylinder_flow.jl b/examples/cylinder_flow/cylinder_flow.jl index dbcae7b..151f113 100644 --- a/examples/cylinder_flow/cylinder_flow.jl +++ b/examples/cylinder_flow/cylinder_flow.jl @@ -5,7 +5,7 @@ using MeshGraphNets -import DifferentialEquations: Euler, Tsit5 +import OrdinaryDiffEq: Euler, Tsit5 import Optimisers: Adam ###################### diff --git a/src/MeshGraphNets.jl b/src/MeshGraphNets.jl index 56cab82..6a524cc 100644 --- a/src/MeshGraphNets.jl +++ b/src/MeshGraphNets.jl @@ -13,14 +13,14 @@ using Optimisers using Wandb using Zygote -import DifferentialEquations.OrdinaryDiffEq: OrdinaryDiffEqAlgorithm, Tsit5 +import OrdinaryDiffEq: OrdinaryDiffEqAlgorithm, Tsit5 import ProgressMeter: Progress import SciMLBase: ODEProblem import Base: @kwdef -import DifferentialEquations: solve, remake import HDF5: h5open, create_group, open_group import ProgressMeter: next!, update!, finish! +import SciMLBase: solve, remake import Statistics: mean include("utils.jl") diff --git a/src/dataset.jl b/src/dataset.jl index e73d92e..ff23c76 100644 --- a/src/dataset.jl +++ b/src/dataset.jl @@ -4,10 +4,11 @@ # import Distributions: Normal +import HDF5: Group import Random: MersenneTwister import TFRecord: Example -import HDF5: h5open, Group, read_dataset +import HDF5: read_dataset import JLD2: jldopen import JSON: parse import Random: seed!, make_seed, shuffle diff --git a/src/strategies.jl b/src/strategies.jl index b58084d..4ec96fe 100644 --- a/src/strategies.jl +++ b/src/strategies.jl @@ -3,8 +3,7 @@ # Licensed under the MIT license. See LICENSE file in the project root for details. # -import DifferentialEquations: ODEFunction -import SciMLBase: AbstractSensitivityAlgorithm +import SciMLBase: AbstractSensitivityAlgorithm, ODEFunction import SciMLSensitivity: InterpolatingAdjoint, ZygoteVJP ####################################################### diff --git a/src/utils.jl b/src/utils.jl index de56767..5cd144d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -4,7 +4,7 @@ # import Printf: @sprintf -import Statistics: mean, stdm +import Statistics: stdm """ der_minmax(path)