Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial proof-of-concept Expectation Propagation #64

Open
wants to merge 113 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 103 commits
Commits
Show all changes
113 commits
Select commit Hold shift + click to select a range
3b8a075
initial commit of Laplace approximation & EP
st-- Sep 17, 2021
c73e9ce
Laplace demo
st-- Sep 17, 2021
93ec795
bugfix
st-- Sep 17, 2021
d62b2dc
bugfix demo
st-- Sep 17, 2021
76afc55
WIP: cleanup; pass around dist_y_given_f explicitly
st-- Sep 17, 2021
46009e4
use ForwardDiff elementwise (much faster)
st-- Sep 17, 2021
0c1a329
format
st-- Sep 17, 2021
88410cc
WIP: gradient
st-- Sep 17, 2021
4853b5f
*Very* WIP - but works
st-- Sep 20, 2021
3c91a22
frule bugfix
st-- Sep 21, 2021
513d34d
clean up comments
st-- Sep 21, 2021
33af436
intermediate cleanup (callbacks not tested)
st-- Sep 21, 2021
30b45ac
fix callback support
st-- Sep 21, 2021
09e5e82
Laplace initial tests
st-- Sep 21, 2021
2591555
make argument order consistent
st-- Sep 21, 2021
5d58f5d
more cleanup
st-- Sep 21, 2021
6c4e56d
cleanup
st-- Sep 21, 2021
ead526f
cleanup
st-- Sep 21, 2021
fe76626
@info -> @debug and ChainRulesTestUtil workaround
st-- Sep 21, 2021
f7439d6
chainrule tests
st-- Sep 21, 2021
d71ca20
add @info for res_cold/res_warm
st-- Sep 22, 2021
ecb29d9
cleanup
st-- Sep 22, 2021
b275332
explicit FiniteDifferences gradient test on laplace_lml
st-- Sep 22, 2021
8135c40
Merge branch 'master' of github.com:JuliaGaussianProcesses/Approximat…
st-- Sep 22, 2021
0d8eb02
format
st-- Sep 22, 2021
af07036
format
st-- Sep 22, 2021
6f08e73
remove Zygote dependency - part 1
st-- Sep 23, 2021
f84ed86
remove Zygote dependency - part 2
st-- Sep 23, 2021
e4aab7e
pkg bugfix
st-- Sep 23, 2021
d19003f
update example manifests
st-- Sep 23, 2021
2d43113
fix chainrule test by evaluating frule/rrule on newton_inner_loop bas…
st-- Sep 23, 2021
bbd24ec
add compat
st-- Sep 23, 2021
ac6486d
clean up test
st-- Sep 23, 2021
96957e3
add missing CRC dependency
st-- Sep 23, 2021
5af23fa
remove workaround
st-- Sep 23, 2021
54acbe3
use more of AbstractGPs API
st-- Sep 23, 2021
df88117
clean up laplace_steps
st-- Sep 23, 2021
8d20a65
add laplace example
st-- Sep 23, 2021
30ba663
cleanup
st-- Sep 23, 2021
8209442
cleanup2
st-- Sep 23, 2021
def61e7
format
st-- Sep 23, 2021
4566c17
remove demo script
st-- Sep 23, 2021
4095df1
bugfix
st-- Sep 23, 2021
91d0b34
update notebook
st-- Sep 23, 2021
bdc3618
bugfix 2
st-- Sep 23, 2021
12655f8
bugfiiiix
st-- Sep 23, 2021
b1c0f80
also plot mean
st-- Sep 23, 2021
3c97c97
improve plotting
st-- Sep 23, 2021
81ceb93
Apply suggestions from code review
st-- Sep 23, 2021
82d4695
remove `@ref` that does not work
st-- Sep 23, 2021
aca90b1
cleanup
st-- Sep 23, 2021
02b528b
make use of closure fields
st-- Sep 23, 2021
cbfbc95
Merge branch 'st/laplace_and_ep' of github.com:JuliaGaussianProcesses…
st-- Sep 23, 2021
e788f4c
yaf
st-- Sep 23, 2021
ca68222
improved type stability
st-- Sep 24, 2021
9144ed3
replace QuadGK with Gauss-Hermite
st-- Sep 24, 2021
771ef3e
Apply suggestions from code review
st-- Sep 24, 2021
b52ef55
more type stability cleanup
st-- Sep 24, 2021
4f45ff6
Merge branch 'st/laplace_and_ep' of github.com:JuliaGaussianProcesses…
st-- Sep 24, 2021
4a694d9
fix test
st-- Sep 24, 2021
62840ad
add missing test file
st-- Sep 24, 2021
6f7a5ba
more explanation on the example script
st-- Sep 24, 2021
5445a95
fix test seed
st-- Sep 24, 2021
02f414b
Merge remote-tracking branch 'origin/master' into st/laplace_and_ep
st-- Sep 29, 2021
bf3cf02
initial version of full EP inference
st-- Oct 1, 2021
a37714f
add posterior
st-- Oct 1, 2021
3ab1f14
add export
st-- Oct 1, 2021
d0cc8fd
add EP to example script
st-- Oct 1, 2021
9c4f7c4
Apply suggestions from code review
st-- Oct 1, 2021
c690dd0
bugfix
st-- Oct 1, 2021
88de29b
Merge branch 'master' of github.com:JuliaGaussianProcesses/Approximat…
st-- Oct 1, 2021
97c6257
add prediction test for EP
st-- Oct 1, 2021
bf98f16
run all tests
st-- Oct 1, 2021
4aa068b
Apply suggestions from code review
st-- Oct 1, 2021
67b69a7
pass n_gh around
st-- Oct 1, 2021
41add63
Merge branch 'st/ExpectationPropagation' of github.com:JuliaGaussianP…
st-- Oct 1, 2021
6564670
whitespace
st-- Oct 1, 2021
40eab41
reorder ep.jl
st-- Oct 1, 2021
e068563
bugfix
st-- Oct 1, 2021
1dec761
update Manifests
st-- Oct 4, 2021
d3adaf9
add approx_lml for EP (DOES NOT WORK CORRECTLY YET)
st-- Oct 4, 2021
f5f61d4
Apply suggestions from code review
st-- Oct 4, 2021
29e5cac
more explicit using statements
st-- Dec 16, 2021
0cfc13a
update Project
st-- Dec 16, 2021
7127c9a
Merge branch 'st/ExpectationPropagation' of github.com:JuliaGaussianP…
st-- Dec 16, 2021
89a1281
Merge branch 'master' of github.com:JuliaGaussianProcesses/Approximat…
st-- Dec 16, 2021
8e30287
WIP
st-- Dec 16, 2021
80a99cc
Apply suggestions from code review
st-- Dec 16, 2021
51323ef
Apply suggestions from code review
st-- Dec 16, 2021
b7a38da
update manifests
st-- Dec 16, 2021
514e06a
move testset into test/ep.jl
st-- Dec 16, 2021
1da0cd4
remove approx_lml and tests
st-- Dec 16, 2021
5d88e5a
do not export ExpectationPropagation, add warnings
st-- Dec 16, 2021
5742ca6
fix merge
st-- Dec 17, 2021
83986b1
Merge branch 'master' into st/ExpectationPropagation
st-- Jan 11, 2022
7b34db1
some fixes
st-- Jan 11, 2022
e4a53d9
comment
st-- Jan 11, 2022
82c340a
error with message instead of assert
st-- Jan 11, 2022
191ed5a
Apply suggestions from code review
st-- Jan 11, 2022
6de0a2b
add missing import
st-- Jan 11, 2022
1d9b7b1
Merge branch 'st/ExpectationPropagation' of github.com:JuliaGaussianP…
st-- Jan 11, 2022
8778ea0
update comparison notebook
st-- Jan 11, 2022
194ed6b
fix
st-- Jan 11, 2022
75264e0
Update src/ep.jl
st-- Jan 11, 2022
48c7570
Merge branch 'master' into st/ExpectationPropagation
st-- Jan 13, 2022
409f8c3
Update src/ep.jl
st-- Jan 14, 2022
5900e73
add imports
st-- Jan 14, 2022
c774892
Merge branch 'master' of github.com:JuliaGaussianProcesses/Approximat…
st-- Mar 15, 2022
e71e697
Merge branch 'master' into st/ExpectationPropagation
st-- Mar 17, 2022
2a0ff19
make use of TestUtils
st-- Mar 17, 2022
5e02212
convert to ExpectationPropagationModule
st-- Mar 17, 2022
7234963
Apply suggestions from code review
st-- Mar 21, 2022
db9b459
missing import
st-- Mar 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ApproximateGPs"
uuid = "298c2ebc-0411-48ad-af38-99e88101b606"
authors = ["JuliaGaussianProcesses Team"]
version = "0.2.5"
version = "0.2.6"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Expand All @@ -11,8 +11,10 @@ FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GPLikelihoods = "6031954c-0455-49d7-b3b9-3e1c99afaf40"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
KLDivergences = "3c9cd921-3d3f-41e2-830c-e020174918cc"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Expand All @@ -27,6 +29,7 @@ FastGaussQuadrature = "0.4"
FillArrays = "0.12"
ForwardDiff = "0.10"
GPLikelihoods = "0.1, 0.2"
IrrationalConstants = "0.1"
KLDivergences = "0.2.1"
PDMats = "0.11"
Reexport = "1"
Expand Down
82 changes: 41 additions & 41 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ version = "0.0.1"

[[AbstractGPs]]
deps = ["ChainRulesCore", "Distributions", "FillArrays", "IrrationalConstants", "KernelFunctions", "LinearAlgebra", "Random", "RecipesBase", "Reexport", "Statistics", "StatsBase", "Test"]
git-tree-sha1 = "e2af18922f65ea8cee7b8cd97a1668691f76f4b2"
git-tree-sha1 = "821d37c2f571ed5f2dfa028f03e324abd12f9910"
uuid = "99985d1d-32ba-4be9-9821-2ec096f28918"
version = "0.5.3"
version = "0.5.4"

[[ApproximateGPs]]
deps = ["AbstractGPs", "ChainRulesCore", "Distributions", "FastGaussQuadrature", "FillArrays", "ForwardDiff", "GPLikelihoods", "KLDivergences", "LinearAlgebra", "PDMats", "Reexport", "SpecialFunctions", "Statistics", "StatsBase"]
deps = ["AbstractGPs", "ChainRulesCore", "Distributions", "FastGaussQuadrature", "FillArrays", "ForwardDiff", "GPLikelihoods", "IrrationalConstants", "KLDivergences", "LinearAlgebra", "PDMats", "Random", "Reexport", "SpecialFunctions", "Statistics", "StatsBase"]
path = ".."
uuid = "298c2ebc-0411-48ad-af38-99e88101b606"
version = "0.2.3"
Expand All @@ -28,15 +28,15 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "f885e7e7c124f8c92650d61b9477b9ac2ee607dd"
git-tree-sha1 = "4c26b4e9e91ca528ea212927326ece5918a04b47"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.11.1"
version = "1.11.2"

[[ChangesOfVariables]]
deps = ["LinearAlgebra", "Test"]
git-tree-sha1 = "9a1d594397670492219635b35a3d830b04730d62"
deps = ["ChainRulesCore", "LinearAlgebra", "Test"]
git-tree-sha1 = "bf98fa45a0a4cee295de98d4c1462be26345b9a1"
uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
version = "0.1.1"
version = "0.1.2"

[[CommonSubexpressions]]
deps = ["MacroTools", "Test"]
Expand All @@ -46,9 +46,9 @@ version = "0.3.0"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "dce3e3fea680869eaa0b774b2e8343e9ff442313"
git-tree-sha1 = "44c37b4636bc54afac5c574d2d02b625349d6582"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.40.0"
version = "3.41.0"

[[CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
Expand All @@ -66,9 +66,9 @@ version = "1.9.0"

[[DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "7d9d316f04214f7efdbb6398d545446e246eff02"
git-tree-sha1 = "3daef5523dd2e769dad2365274f760ff5f282c7d"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.18.10"
version = "0.18.11"

[[Dates]]
deps = ["Printf"]
Expand All @@ -80,9 +80,9 @@ uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"

[[DensityInterface]]
deps = ["InverseFunctions", "Test"]
git-tree-sha1 = "794daf62dce7df839b8ed446fc59c68db4b5182f"
git-tree-sha1 = "80c3e8639e3353e5d2912fb3a1916b8455e2494b"
uuid = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
version = "0.3.3"
version = "0.4.0"

[[DiffResults]]
deps = ["StaticArrays"]
Expand All @@ -92,25 +92,25 @@ version = "1.0.3"

[[DiffRules]]
deps = ["LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"]
git-tree-sha1 = "3287dacf67c3652d3fed09f4c12c187ae4dbb89a"
git-tree-sha1 = "9bc5dac3c8b6706b58ad5ce24cffd9861f07c94f"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "1.4.0"
version = "1.9.0"

[[Distances]]
deps = ["LinearAlgebra", "Statistics", "StatsAPI"]
git-tree-sha1 = "837c83e5574582e07662bbbba733964ff7c26b9d"
deps = ["LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI"]
git-tree-sha1 = "3258d0659f812acde79e8a74b11f17ac06d0ca04"
uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
version = "0.10.6"
version = "0.10.7"

[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[Distributions]]
deps = ["ChainRulesCore", "DensityInterface", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test"]
git-tree-sha1 = "cce8159f0fee1281335a04bbf876572e46c921ba"
git-tree-sha1 = "c1724611e6ae29c6094c8d9850e3136297ba7fff"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.25.29"
version = "0.25.36"

[[DocStringExtensions]]
deps = ["LibGit2"]
Expand All @@ -130,9 +130,9 @@ uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"

[[FastGaussQuadrature]]
deps = ["LinearAlgebra", "SpecialFunctions", "StaticArrays"]
git-tree-sha1 = "5829b25887e53fb6730a9df2ff89ed24baa6abf6"
git-tree-sha1 = "58d83dd5a78a36205bdfddb82b1bb67682e64487"
uuid = "442a2c76-b920-505d-bb47-c5924d526838"
version = "0.4.7"
version = "0.4.9"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"]
Expand All @@ -142,9 +142,9 @@ version = "0.12.7"

[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"]
git-tree-sha1 = "6406b5112809c08b1baa5703ad274e1dded0652f"
git-tree-sha1 = "2b72a5624e289ee18256111657663721d59c143e"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.23"
version = "0.10.24"

[[Functors]]
git-tree-sha1 = "e4768c3b7f597d5a352afa09874d16e3c3f6ead2"
Expand Down Expand Up @@ -227,9 +227,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[LogExpFunctions]]
deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"]
git-tree-sha1 = "be9eef9f9d78cecb6f262f3c10da151a6c5ab827"
git-tree-sha1 = "e5718a00af0ab9756305a0392832c8952c7426c1"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.3.5"
version = "0.3.6"

[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
Expand Down Expand Up @@ -261,9 +261,9 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804"
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"

[[NaNMath]]
git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb"
git-tree-sha1 = "f755f36b19a5116bb580de457cda0c140153f283"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.5"
version = "0.3.6"

[[NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
Expand All @@ -285,9 +285,9 @@ version = "1.4.1"

[[PDMats]]
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
git-tree-sha1 = "c8b8775b2f242c80ea85c83714c64ecfa3c53355"
git-tree-sha1 = "ee26b350276c51697c9c2d88a072b339f9f03d73"
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
version = "0.11.3"
version = "0.11.5"

[[Parsers]]
deps = ["Dates"]
Expand Down Expand Up @@ -324,9 +324,9 @@ deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[RecipesBase]]
git-tree-sha1 = "44a75aa7a527910ee3d1751d1f0e4148698add9e"
git-tree-sha1 = "6bf3f380ff52ce0832ddd3a2a7b9538ed1bcca7d"
uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
version = "1.1.2"
version = "1.2.1"

[[Reexport]]
git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b"
Expand All @@ -335,9 +335,9 @@ version = "1.2.2"

[[Requires]]
deps = ["UUIDs"]
git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621"
git-tree-sha1 = "8f82019e525f4d5c669692772a6f4b0a58b06a6a"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.1.3"
version = "1.2.0"

[[Rmath]]
deps = ["Random", "Rmath_jll"]
Expand Down Expand Up @@ -391,21 +391,21 @@ deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[StatsAPI]]
git-tree-sha1 = "1958272568dc176a1d881acb797beb909c785510"
git-tree-sha1 = "0f2aa8e32d511f758a2ce49208181f7733a0936a"
uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
version = "1.0.0"
version = "1.1.0"

[[StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
git-tree-sha1 = "eb35dcc66558b2dda84079b9a1be17557d32091a"
git-tree-sha1 = "2bb0cb32026a66037360606510fca5984ccc6b75"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.33.12"
version = "0.33.13"

[[StatsFuns]]
deps = ["ChainRulesCore", "InverseFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"]
git-tree-sha1 = "385ab64e64e79f0cd7cfcf897169b91ebbb2d6c8"
git-tree-sha1 = "bedb3e17cc1d94ce0e6e66d3afa47157978ba404"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "0.9.13"
version = "0.9.14"

[[SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
Expand Down
Loading