Skip to content

Commit

Permalink
reduce array allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
rafaqz committed Oct 24, 2022
1 parent 463eb0a commit 5c2fecf
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
28 changes: 16 additions & 12 deletions src/modelframe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,26 @@ end
_missing_omit(x::AbstractVector{T}) where T = copyto!(similar(x, nonmissingtype(T)), x)
_missing_omit(x::AbstractVector, rows) = _missing_omit(view(x, rows))

function missing_omit(d::T) where T<:ColumnTable
function _maybe_missing_omit(d::T) where T<:ColumnTable
nonmissings = trues(length(first(d)))
for col in d
_nonmissing!(nonmissings, col)
end
d_nonmissing = if all(nonmissings)
map(_missing_omit, d)
if any(eltype(col) >: Missing for col in d)
for col in d
_nonmissing!(nonmissings, col)
end
d_nonmissing = if all(nonmissings)
map(_missing_omit, d)
else
rows = findall(nonmissings)
map(Base.Fix2(_missing_omit, rows), d)
end
return d_nonmissing, nonmissings
else
rows = findall(nonmissings)
map(Base.Fix2(_missing_omit, rows), d)
return d, nonmissings
end
d_nonmissing, nonmissings
end

missing_omit(data::T, formula::AbstractTerm) where T<:ColumnTable =
missing_omit(NamedTuple{tuple(termvars(formula)...)}(data))
_maybe_missing_omit(data::T, formula::AbstractTerm) where T<:ColumnTable =
_maybe_missing_omit(NamedTuple{tuple(termvars(formula)...)}(data))

function ModelFrame(f::FormulaTerm, data::ColumnTable;
model::Type{M}=StatisticalModel, contrasts=Dict{Symbol,Any}()) where M
Expand All @@ -78,7 +82,7 @@ function ModelFrame(f::FormulaTerm, data::ColumnTable;
throw(ArgumentError(msg))
end

data, _ = missing_omit(data, f)
data, _ = _maybe_missing_omit(data, f)

sch = schema(f, data, contrasts)
f = apply_schema(f, sch, M)
Expand Down
2 changes: 1 addition & 1 deletion src/statsmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ function StatsBase.predict(mm::TableRegressionModel, data; kwargs...)
throw(ArgumentError("expected data in a Table, got $(typeof(data))"))

f = mm.mf.f
cols, nonmissings = missing_omit(columntable(data), f.rhs)
cols, nonmissings = _maybe_missing_omit(columntable(data), f.rhs)
new_x = modelcols(f.rhs, cols)
y_pred = predict(mm.model, reshape(new_x, size(new_x, 1), :);
kwargs...)
Expand Down

0 comments on commit 5c2fecf

Please sign in to comment.