Skip to content

Commit

Permalink
Merge pull request #17 from una-auxme/dev
Browse files Browse the repository at this point in the history
Added Aqua.jl to tests
  • Loading branch information
JulianTrommer authored Jul 18, 2024
2 parents a484830 + 8c0dd85 commit 12b7e3d
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 32 deletions.
11 changes: 6 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ version = "0.4.0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GraphNetCore = "7809f980-de1b-4f9a-8451-85f041491431"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
Expand All @@ -16,7 +15,7 @@ JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -29,10 +28,10 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[compat]
Aqua = "0.8"
CUDA = "5"
ChainRulesCore = "1.16.0 - 1"
DataFrames = "1.6"
DifferentialEquations = "7.11 - 7"
Distributions = "0.25"
GraphNetCore = "0.3"
HDF5 = "0.17"
Expand All @@ -41,21 +40,23 @@ JSON = "0.21"
Lux = "0.5"
LuxCUDA = "0.3"
Optimisers = "0.3"
Optimization = "3.18"
OrdinaryDiffEq = "6.85 - 6"
Printf = "1"
ProgressMeter = "1.7.0 - 1"
Random = "1"
SciMLBase = "2.7.0 - 2"
SciMLSensitivity = "7.45 - 7"
Statistics = "1"
TFRecord = "0.4.2"
Test = "1"
Wandb = "0.5"
Zygote = "0.6"
cuDNN = "1.3"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Aqua", "Test"]
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# MeshGraphNets.jl

[![Docs](https://img.shields.io/badge/docs-dev-blue.svg)](https://una-auxme.github.io/MeshGraphNets.jl/dev)
[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)

[*MeshGraphNets.jl*](https://github.com/una-auxme/MeshGraphNets.jl) is a software package for the Julia programming language that provides an implementation of the [MeshGraphNets](https://arxiv.org/abs/2010.03409) framework by [Google DeepMind](https://deepmind.google/) for simulating mesh-based physical systems via graph neural networks:

Expand Down
2 changes: 1 addition & 1 deletion examples/cylinder_flow/cylinder_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

using MeshGraphNets

import DifferentialEquations: Euler, Tsit5
import OrdinaryDiffEq: Euler, Tsit5
import Optimisers: Adam

######################
Expand Down
4 changes: 2 additions & 2 deletions src/MeshGraphNets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ using Optimisers
using Wandb
using Zygote

import DifferentialEquations.OrdinaryDiffEq: OrdinaryDiffEqAlgorithm, Tsit5
import OrdinaryDiffEq: OrdinaryDiffEqAlgorithm, Tsit5
import ProgressMeter: Progress
import SciMLBase: ODEProblem

import Base: @kwdef
import DifferentialEquations: solve, remake
import HDF5: h5open, create_group, open_group
import ProgressMeter: next!, update!, finish!
import SciMLBase: solve, remake
import Statistics: mean

include("utils.jl")
Expand Down
3 changes: 2 additions & 1 deletion src/dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
#

import Distributions: Normal
import HDF5: Group
import Random: MersenneTwister
import TFRecord: Example

import HDF5: h5open, Group, read_dataset
import HDF5: read_dataset
import JLD2: jldopen
import JSON: parse
import Random: seed!, make_seed, shuffle
Expand Down
17 changes: 13 additions & 4 deletions src/graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,20 @@ Constructs a [FeatureGraph](@ref) based on the given arguments.
- Resulting [FeatureGraph](@ref).
"""
function build_graph(mgn::GraphNetwork, data, fields, datapoint::Integer, node_type, edge_features::AbstractArray{Float32, 2}, senders::AbstractArray{T, 1}, receivers::AbstractArray{T, 1}) where {T <: Integer}
# Removed generator in favor of removing Zygote.jl piracies (minimal increase of time and allocations)
# Can be reverted once Enzyme.jl is compatible
nt = mgn.n_norm["node_type"](node_type)
nf = similar(nt, 0, size(nt, 2))
for field in fields
nf = vcat(nf, mgn.n_norm[field](data[field][:, :, min(size(data[field], 3), datapoint)]))
end
nf = vcat(nf, nt)
return FeatureGraph(
vcat(
[mgn.n_norm[field](data[field][:, :, min(size(data[field], 3), datapoint)]) for field in fields]...,
mgn.n_norm["node_type"](node_type)
),
nf,
# vcat(
# [mgn.n_norm[field](data[field][:, :, min(size(data[field], 3), datapoint)]) for field in fields]...,
# mgn.n_norm["node_type"](node_type)
# ),
mgn.e_norm(edge_features),
senders,
receivers
Expand Down
3 changes: 1 addition & 2 deletions src/strategies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
# Licensed under the MIT license. See LICENSE file in the project root for details.
#

import DifferentialEquations: ODEFunction
import SciMLBase: AbstractSensitivityAlgorithm
import SciMLBase: AbstractSensitivityAlgorithm, ODEFunction
import SciMLSensitivity: InterpolatingAdjoint, ZygoteVJP

#######################################################
Expand Down
18 changes: 1 addition & 17 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#

import Printf: @sprintf
import Statistics: mean, stdm
import Statistics: stdm

"""
der_minmax(path)
Expand Down Expand Up @@ -286,19 +286,3 @@ function clear_log(lines::Integer, move_up = true)
clear_line()
end
end

####################################################################
# Overwritten for differentiation over GraphNetwork and normaliser #
####################################################################

Zygote.accum(x::Base.RefValue{Any}, y::NamedTuple{(:model, :ps, :st, :e_norm, :n_norm, :o_norm)}) = Zygote.accum(x[], y)

Zygote.accum(x::NamedTuple{(:model, :ps, :st, :e_norm, :n_norm, :o_norm)}, y::Base.RefValue{Any}) = Zygote.accum(x, y[])

Zygote.accum(x::Base.RefValue{Any}, y::NamedTuple{(:data_min, :data_max, :target_min, :target_max)}) = Zygote.accum(x[], y)

Base.:+(x::Base.RefValue{Any}, y::NamedTuple{(:model, :ps, :st, :e_norm, :n_norm, :o_norm)}) = Zygote.accum(x[], y)

Base.:+(x::NamedTuple{(:model, :ps, :st, :e_norm, :n_norm, :o_norm)}, y::Base.RefValue{Any}) = Zygote.accum(x, y[])

Base.:+(x::Base.RefValue{Any}, y::NamedTuple{(:data_min, :data_max, :target_min, :target_max)}) = Zygote.accum(x[], y)
8 changes: 8 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@

using MeshGraphNets
using Test
using Aqua

@testset "MeshGraphNets.jl" begin
# TODO
@testset "Aqua.jl" begin
# Ambiguities in external packages
@testset "Method ambiguity" begin
Aqua.test_ambiguities([MeshGraphNets])
end
Aqua.test_all(MeshGraphNets; ambiguities = false)
end
end

0 comments on commit 12b7e3d

Please sign in to comment.