Skip to content

Commit

Permalink
Added torch.[de]serialize{To/From}Storage to File.lua.
Browse files Browse the repository at this point in the history
These match the previous [de]serialize functions, but return
a torch.CharStorage instead of a Lua string. This can be used to
prevent hitting the LuaJIT 2GB limit for very large objects, because the
CharStorage is allocated on the C side (not counting towards this limit)
but the Lua string will count towards this limit.

This modification fixed a repeatable OOM crash I found for trying to
serialize certain large models.
  • Loading branch information
malcolmreynolds committed Feb 11, 2015
1 parent d652209 commit 3a96d15
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
22 changes: 16 additions & 6 deletions File.lua
Original file line number Diff line number Diff line change
Expand Up @@ -275,20 +275,25 @@ end

-- simple helpers to serialize/deserialize arbitrary objects/tables
function torch.serialize(object, mode)
local storage = torch.serializeToStorage(object, mode)
return storage:string()
end

-- Serialize to a CharStorage, not a lua string. This avoids
function torch.serializeToStorage(object, mode)
mode = mode or 'binary'
local f = torch.MemoryFile()
f = f[mode](f)
f:writeObject(object)
local s = f:storage():string()
local storage = f:storage()
f:close()
return s
return storage
end

function torch.deserialize(str, mode)
function torch.deserializeFromStorage(storage, mode)
mode = mode or 'binary'
local x = torch.CharStorage():string(str)
local tx = torch.CharTensor(x)
local xp = torch.CharStorage(x:size(1)+1)
local tx = torch.CharTensor(storage)
local xp = torch.CharStorage(tx:size(1)+1)
local txp = torch.CharTensor(xp)
txp:narrow(1,1,tx:size(1)):copy(tx)
txp[tx:size(1)+1] = 0
Expand All @@ -299,6 +304,11 @@ function torch.deserialize(str, mode)
return object
end

function torch.deserialize(str, mode)
local storage = torch.CharStorage():string(str)
return torch.deserializeFromStorage(storage, mode)
end

-- public API (saveobj/loadobj are safe for global import)
torch.saveobj = torch.save
torch.loadobj = torch.load
17 changes: 17 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1873,6 +1873,23 @@ function torchtest.permute()
mytester:assertTableEq(x:size():totable(), orig, 'Tensor:permute changes tensor')
end

function torchtest.serialize()
local tableObj = {6, a = 42}
local tensObj = torch.randn(3,4,5)

-- Test serializing a table
local serString = torch.serialize(tableObj)
local serStorage = torch.serializeToStorage(tableObj)
mytester:assertTableEq(tableObj, torch.deserialize(serString))
mytester:assertTableEq(tableObj, torch.deserializeFromStorage(serStorage))

-- Test serializing a Tensor
serString = torch.serialize(tensObj)
serStorage = torch.serializeToStorage(tensObj)
mytester:assertTensorEq(tensObj, torch.deserialize(serString), 1e-10)
mytester:assertTensorEq(tensObj, torch.deserializeFromStorage(serStorage), 1e-10)
end

function torch.test(tests)
math.randomseed(os.time())
if torch.getdefaulttensortype() == 'torch.FloatTensor' then
Expand Down

0 comments on commit 3a96d15

Please sign in to comment.