Skip to content

Commit

Permalink
Fix get/setRNGState for gaussian state.
Browse files Browse the repository at this point in the history
Fixes torch#8
  • Loading branch information
timharley committed Apr 11, 2014
1 parent 14b20df commit 9abb33d
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 29 deletions.
12 changes: 12 additions & 0 deletions lib/TH/THRandom.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ THGenerator* THGenerator_new()
return self;
}

THGenerator* THGenerator_copy(THGenerator *self, THGenerator *from)
{
memcpy(self, from, sizeof(THGenerator));
return self;
}

void THGenerator_free(THGenerator *self)
{
THFree(self);
Expand Down Expand Up @@ -89,6 +95,12 @@ unsigned long THRandom_seed(THGenerator *_generator)
void THRandom_manualSeed(THGenerator *_generator, unsigned long the_seed_)
{
int j;

/* This ensures reseeding resets all of the state (i.e. state for Gaussian numbers) */
THGenerator *blank = THGenerator_new();
THGenerator_copy(_generator, blank);
THGenerator_free(blank);

_generator->the_initial_seed = the_seed_;
_generator->state[0] = _generator->the_initial_seed & 0xffffffffUL;
for(j = 1; j < n; j++)
Expand Down
6 changes: 3 additions & 3 deletions lib/TH/THRandom.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#define _MERSENNE_STATE_N 624
#define _MERSENNE_STATE_M 397
/* A THGenerator contains all the state required for a single random number stream */
typedef struct THGenerator {
/* The initial seed. */
unsigned long the_initial_seed;
Expand All @@ -23,10 +24,9 @@ typedef struct THGenerator {

#define torch_Generator "torch.Generator"

/* Create a new random number generator stream */
/* Manipulate THGenerator objects */
TH_API THGenerator * THGenerator_new();

/* Free a random number generator stream */
TH_API THGenerator * THGenerator_copy(THGenerator *self, THGenerator *from);
TH_API void THGenerator_free(THGenerator *gen);

/* Initializes the random number generator with the current time (granularity: seconds) and returns the seed. */
Expand Down
33 changes: 12 additions & 21 deletions lib/TH/generic/THTensorRandom.c
Original file line number Diff line number Diff line change
Expand Up @@ -210,33 +210,24 @@ TH_API void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator,

#endif

#if defined(TH_REAL_IS_LONG)
#if defined(TH_REAL_IS_BYTE)
TH_API void THTensor_(getRNGState)(THGenerator *_generator, THTensor *self)
{
unsigned long *data;
long *offset;
long *left;

THTensor_(resize1d)(self,626);
data = (unsigned long *)THTensor_(data)(self);
offset = (long *)data+624;
left = (long *)data+625;

THRandom_getState(_generator, data, offset, left);
static const size_t size = sizeof(THGenerator);
THGenerator *state;
THTensor_(resize1d)(self, size);
state = (THGenerator *)THTensor_(data)(self);
THGenerator_copy(state, _generator);
}

TH_API void THTensor_(setRNGState)(THGenerator *_generator, THTensor *self)
{
unsigned long *data;
long *offset;
long *left;

THArgCheck(THTensor_(nElement)(self) == 626, 1, "state should have 626 elements");
data = (unsigned long *)THTensor_(data)(self);
offset = (long *)(data+624);
left = (long *)(data+625);

THRandom_setState(_generator, data, *offset, *left);
static const size_t size = sizeof(THGenerator);
THGenerator *state;
THArgCheck(THTensor_(nElement)(self) == size, 1, "RNG state is wrong size");
THArgCheck(THTensor_(isContiguous)(self), 1, "RNG state needs to be contiguous");
state = (THGenerator *)THTensor_(data)(self);
THGenerator_copy(_generator, state);
}
#endif

Expand Down
2 changes: 1 addition & 1 deletion lib/TH/generic/THTensorRandom.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ TH_API void THTensor_(logNormal)(THTensor *self, THGenerator *_generator, double
TH_API void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTensor *prob_dist, int n_sample, int with_replacement);
#endif

#if defined(TH_REAL_IS_LONG)
#if defined(TH_REAL_IS_BYTE)
TH_API void THTensor_(getRNGState)(THGenerator *_generator, THTensor *self);
TH_API void THTensor_(setRNGState)(THGenerator *_generator, THTensor *self);
#endif
Expand Down
8 changes: 4 additions & 4 deletions random.lua
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ interface:wrap('manualSeed',
{name="long"}})

interface:wrap('getRNGState',
'THLongTensor_getRNGState',
'THByteTensor_getRNGState',
{{name='Generator', default=true},
{name='LongTensor',default=true,returned=true,method={default='nil'}}
{name='ByteTensor',default=true,returned=true,method={default='nil'}}
})

interface:wrap('setRNGState',
'THLongTensor_setRNGState',
'THByteTensor_setRNGState',
{{name='Generator', default=true},
{name='LongTensor',default=true,returned=true,method={default='nil'}}
{name='ByteTensor',default=true,returned=true,method={default='nil'}}
})

interface:register("random__")
Expand Down
14 changes: 14 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,20 @@ function torchtest.RNGState()
mytester:assertTensorEq(before, after, 1e-16, 'getRNGState/setRNGState not generating same sequence')
end

function torchtest.testBoxMullerState()
torch.manualSeed(123)
local odd_number = 101
local seeded = torch.randn(odd_number)
local state = torch.getRNGState()
local midstream = torch.randn(odd_number)
torch.setRNGState(state)
local repeat_midstream = torch.randn(odd_number)
torch.manualSeed(123)
local reseeded = torch.randn(odd_number)
mytester:assertTensorEq(midstream, repeat_midstream, 1e-16, 'getRNGState/setRNGState not generating same sequence of normally distributed numbers')
mytester:assertTensorEq(seeded, reseeded, 1e-16, 'repeated calls to manualSeed not generating same sequence of normally distributed numbers')
end

function torchtest.testCholesky()
local x = torch.rand(10,10)
local A = torch.mm(x, x:t())
Expand Down

0 comments on commit 9abb33d

Please sign in to comment.