diff --git a/src/modelframe.jl b/src/modelframe.jl index d3d98dad..2b5a6d73 100644 --- a/src/modelframe.jl +++ b/src/modelframe.jl @@ -53,22 +53,26 @@ end _missing_omit(x::AbstractVector{T}) where T = copyto!(similar(x, nonmissingtype(T)), x) _missing_omit(x::AbstractVector, rows) = _missing_omit(view(x, rows)) -function missing_omit(d::T) where T<:ColumnTable +function _maybe_missing_omit(d::T) where T<:ColumnTable nonmissings = trues(length(first(d))) - for col in d - _nonmissing!(nonmissings, col) - end - d_nonmissing = if all(nonmissings) - map(_missing_omit, d) + if any(eltype(col) >: Missing for col in d) + for col in d + _nonmissing!(nonmissings, col) + end + d_nonmissing = if all(nonmissings) + map(_missing_omit, d) + else + rows = findall(nonmissings) + map(Base.Fix2(_missing_omit, rows), d) + end + return d_nonmissing, nonmissings else - rows = findall(nonmissings) - map(Base.Fix2(_missing_omit, rows), d) + return d, nonmissings end - d_nonmissing, nonmissings end -missing_omit(data::T, formula::AbstractTerm) where T<:ColumnTable = - missing_omit(NamedTuple{tuple(termvars(formula)...)}(data)) +_maybe_missing_omit(data::T, formula::AbstractTerm) where T<:ColumnTable = + _maybe_missing_omit(NamedTuple{tuple(termvars(formula)...)}(data)) function ModelFrame(f::FormulaTerm, data::ColumnTable; model::Type{M}=StatisticalModel, contrasts=Dict{Symbol,Any}()) where M @@ -78,7 +82,7 @@ function ModelFrame(f::FormulaTerm, data::ColumnTable; throw(ArgumentError(msg)) end - data, _ = missing_omit(data, f) + data, _ = _maybe_missing_omit(data, f) sch = schema(f, data, contrasts) f = apply_schema(f, sch, M) diff --git a/src/statsmodel.jl b/src/statsmodel.jl index 0bb67c7c..933c37f9 100644 --- a/src/statsmodel.jl +++ b/src/statsmodel.jl @@ -174,7 +174,7 @@ function StatsBase.predict(mm::TableRegressionModel, data; kwargs...) throw(ArgumentError("expected data in a Table, got $(typeof(data))")) f = mm.mf.f - cols, nonmissings = missing_omit(columntable(data), f.rhs) + cols, nonmissings = _maybe_missing_omit(columntable(data), f.rhs) new_x = modelcols(f.rhs, cols) y_pred = predict(mm.model, reshape(new_x, size(new_x, 1), :); kwargs...)