diff --git a/matlab/+dagnn/Loss.m b/matlab/+dagnn/Loss.m index a9675b26..19b0ea1b 100644 --- a/matlab/+dagnn/Loss.m +++ b/matlab/+dagnn/Loss.m @@ -12,12 +12,23 @@ methods function outputs = forward(obj, inputs, params) - outputs{1} = vl_nnloss(inputs{1}, inputs{2}, [], 'loss', obj.loss, obj.opts{:}) ; + + % If there are 3 inputs, the third input should contain sample + % specific weights + weights = ones(size(inputs{2})); + i = find(strcmpi(obj.opts,'instanceWeights')); + if numel(inputs)==3 + weights = inputs{3}; + elseif ~isempty(i) + weights = obj.opts{i+1}; + end + + outputs{1} = vl_nnloss(inputs{1}, inputs{2}, [], 'loss', obj.loss, 'InstanceWeights', weights, obj.opts{:}) ; obj.accumulateAverage(inputs, outputs); end function accumulateAverage(obj, inputs, outputs) - if obj.ignoreAverage, return; end; + if obj.ignoreAverage, return; end n = obj.numAveraged ; m = n + size(inputs{1}, 1) * size(inputs{1}, 2) * size(inputs{1}, 4); obj.average = bsxfun(@plus, n * obj.average, gather(outputs{1})) / m ; @@ -25,7 +36,20 @@ function accumulateAverage(obj, inputs, outputs) end function [derInputs, derParams] = backward(obj, inputs, params, derOutputs) - derInputs{1} = vl_nnloss(inputs{1}, inputs{2}, derOutputs{1}, 'loss', obj.loss, obj.opts{:}) ; + + % If there are 3 inputs, the third input should contain sample + % specific weights + weights = ones(size(inputs{2})); + i = find(strcmpi(obj.opts,'instanceWeights')); + if numel(inputs)==3 + weights = inputs{3}; + derInputs{3} = []; + elseif ~isempty(i) + weights = obj.opts{i+1}; + derInputs{3} = []; + end + + derInputs{1} = vl_nnloss(inputs{1}, inputs{2}, derOutputs{1}, 'loss', obj.loss, 'InstanceWeights', weights, obj.opts{:}) ; derInputs{2} = [] ; derParams = {} ; end