Skip to content

Commit

Permalink
Merge pull request #21 from AsafManela/master
Browse files Browse the repository at this point in the history
Fix cross-validation and add tests for it
  • Loading branch information
andreasnoack authored Nov 4, 2018
2 parents 0748b2e + beb13b4 commit 5571896
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 46 deletions.
2 changes: 1 addition & 1 deletion src/cross_validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using MLBase
# which chooses λt at lowest mean OOS deviance.

function CVmin(oosdevs)
cvmeans = mean(oosdevs,dims=2)
cvmeans = vec(mean(oosdevs,dims=2))
segCVmin = argmin(cvmeans)
end

Expand Down
70 changes: 25 additions & 45 deletions test/cross_validation.jl
Original file line number Diff line number Diff line change
@@ -1,56 +1,36 @@
using MLBase
using MLBase, Random

datapath = joinpath(dirname(@__FILE__), "..","test","data")
datapath = joinpath(dirname(@__FILE__), "data")

(family, dist, link) = (("gaussian", Normal(), IdentityLink()), ("binomial", Binomial(), LogitLink()), ("poisson", Poisson(), LogLink()))[1]
data = readcsvmat(joinpath(datapath,"gamlr.$family.data.csv"))
data = readcsvmat(joinpath(datapath,"gamlr.gaussian.data.csv"))
y = data[:,1]
X = data[:,2:end]
(n,p) = size(X)
γ = [0 2 10][1]
fitname = "gamma"
# get gamlr params and estimates
params = readtable(joinpath(datapath,"gamlr.$family.$fitname.params.csv"))
fittable = readtable(joinpath(datapath,"gamlr.$family.$fitname.fit.csv"))
gcoefs = convert(Matrix{Float64},readcsv(joinpath(datapath,"gamlr.$family.$fitname.coefs.csv")))
family = params[1,:fit_family]
λ = nothing #convert(Vector{Float64},fittable[:fit_lambda]) # should be set to nothing evenatually
# fit julia version
offset = fill(0.001,size(y))

@time glp = fit(GammaLassoPath, X, y, dist, link, λ=λ,γ=γ,standardize=true, λminratio=0.001, offset=offset)
path = glp
path = fit(LassoPath, X, y; offset=offset)
β = coef(path; select=:all)

plot(path)
coefsAICc = coef(path;select=:AICc)
segminAICc = minAICc(path)
@test segminAICc == 71
@test coefsAICc == β[:,segminAICc]

@time coefsAICc = coef(path;select=:AICc)
Random.seed!(13)
@time coefsCVmin = coef(path;select=:CVmin)
Random.seed!(13)
@time coefsCV1se = coef(path;select=:CV1se,nCVfolds=100)
# fieldnames(path.m.pp)
# y == path.m.rr.y
# offset == path.m.rr.offset
# path.m.pp.X
#
# size(path.m.pp.X)
# size(convert(path.Xnorm)

# gen = LOOCV(nobs(path))
# T = eltype(λ)
# offset=Array(T,0)
ix=1:length(y)

# plot(path)

Random.seed!(13); gen = Kfold(length(y[ix]),10)
@time segCVmin = cross_validate_path(path;gen=gen)
gen = Kfold(length(y),10)
segCVmin = cross_validate_path(path;gen=gen)
coefsCVmin = coef(path;select=:CVmin)
@test segCVmin == 71
@test coefsCVmin == β[:,segCVmin]

Random.seed!(13); gen = Kfold(length(y[ix]),10)
@time segCVmin = cross_validate_path(path,X[ix,:],y[ix];offset=offset[ix],gen=gen)

Random.seed!(13); gen = Kfold(length(y[ix]),10)
@time segCV1se = cross_validate_path(path,X[ix,:],y[ix];select=:CV1se,gen=gen,offset=offset[ix])
Random.seed!(13)
gen = Kfold(length(y),10)
segCVmin = cross_validate_path(path,X,y; gen=gen, offset=offset)
coefsCVmin = coef(path;select=:CVmin)
@test segCVmin == 71
@test coefsCVmin == β[:,segCVmin]

λCVmin = path.λ[segCVmin]
λCV1se = path.λ[segCV1se]
Random.seed!(13)
coefsCV1se = coef(path;select=:CV1se,nCVfolds=20)
segCV1se = cross_validate_path(path,X,y;select=:CV1se,gen=gen,offset=offset)
@test segCV1se == 42
@test coefsCV1se == β[:,segCV1se]
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ include("lasso.jl")
include("gammalasso.jl")
include("fusedlasso.jl")
include("trendfiltering.jl")
include("cross_validation.jl")

end

0 comments on commit 5571896

Please sign in to comment.