diff --git a/intel_pytorch_extension_py/optim/__init__.py b/intel_pytorch_extension_py/optim/__init__.py index b58308cc4..5265f3556 100644 --- a/intel_pytorch_extension_py/optim/__init__.py +++ b/intel_pytorch_extension_py/optim/__init__.py @@ -1,2 +1,3 @@ from .split_sgd import is_available from .split_sgd import SplitSGD +from .lamb import Lamb diff --git a/intel_pytorch_extension_py/optim/lamb.py b/intel_pytorch_extension_py/optim/lamb.py new file mode 100644 index 000000000..96bf827fd --- /dev/null +++ b/intel_pytorch_extension_py/optim/lamb.py @@ -0,0 +1,127 @@ +"""Lamb optimizer.""" + +import collections +import math + +import torch +from tensorboardX import SummaryWriter +from torch.optim import Optimizer +from _torch_ipex import lamb_fused_step_ + +def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int): + """Log a histogram of trust ratio scalars in across layers.""" + results = collections.defaultdict(list) + for group in optimizer.param_groups: + for p in group['params']: + state = optimizer.state[p] + for i in ('weight_norm', 'adam_norm', 'trust_ratio'): + if i in state: + results[i].append(state[i]) + + for k, v in results.items(): + event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count) + +class Lamb(Optimizer): + r"""Implements Lamb algorithm. + + It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + adam (bool, optional): always use trust ratio = 1, which turns this into + Adam. Useful for comparison purposes. + + .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, + weight_decay=0, adam=False, bf16=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay) + self.adam = adam + self.bf16 = bf16 + super(Lamb, self).__init__(params, defaults) + + def set_bf16(self, bf16=False): + self.bf16 = bf16 + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32) + if self.bf16: + # additional fp32 version of master weights + state['bot_half'] = torch.zeros_like(p.data, dtype=torch.bfloat16, device=p.data.device) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + + state['step'] += 1 + if self.bf16: + lamb_fused_step_(p, p.grad, state['bot_half'], exp_avg, exp_avg_sq, state['step'], group['lr'], beta1, beta2, group['weight_decay'], group['eps']) + else: + step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1 + # m_t + exp_avg.mul_(beta1).add_(1 - beta1, grad) + # v_t + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + weight_norm = data_fp32.pow(2).sum().sqrt().clamp(0, 10) + adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) + adam_norm = adam_step.pow(2).sum().sqrt() + + if group['weight_decay'] != 0: + adam_step.add_(group['weight_decay'], p.data) + + if weight_norm == 0 or adam_norm == 0: + trust_ratio = 1 + else: + trust_ratio = weight_norm / adam_norm + state['weight_norm'] = weight_norm + state['adam_norm'] = adam_norm + state['trust_ratio'] = trust_ratio + if self.adam: + trust_ratio = 1 + p.data.add_(-step_size * trust_ratio, adam_step) + + return loss diff --git a/torch_ipex/csrc/cpu/ExtendOPs.cpp b/torch_ipex/csrc/cpu/ExtendOPs.cpp index bb11a869f..c51e24148 100644 --- a/torch_ipex/csrc/cpu/ExtendOPs.cpp +++ b/torch_ipex/csrc/cpu/ExtendOPs.cpp @@ -12,6 +12,105 @@ #include "DevOPs.h" namespace torch_ipex { +inline float pack_bfloat16_float(at::BFloat16 a, at::BFloat16 b) { + uint16_t* ap = reinterpret_cast(&a); + uint16_t* bp = reinterpret_cast(&b); + uint32_t hi = static_cast(*ap); + uint32_t lo = static_cast(*bp); + uint32_t out = (hi << 16) + lo; + float* outp = reinterpret_cast(&out); + return *outp; +} + +inline std::tuple unpack_float_bfloat16(float a) { + uint32_t* ap = reinterpret_cast(&a); + uint16_t hi = static_cast((*ap) >> 16); + uint16_t lo = static_cast((*ap)); + at::BFloat16* hip = reinterpret_cast(&hi); + at::BFloat16* lop = reinterpret_cast(&lo); + return std::make_tuple(*hip, *lop); +} + +void AtenIpexTypeExt::lamb_fused_step_(at::Tensor & param, at::Tensor & grad, at::Tensor & param2, at::Tensor & exp_avg, at::Tensor & exp_avg_sq, int64_t step, float lr, float beta1, float beta2, float weight_decay, float eps){ + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grad.scalar_type() == + at::ScalarType::BFloat16); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(param.scalar_type() == + at::ScalarType::BFloat16); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(param2.scalar_type() == + at::ScalarType::BFloat16); + RECORD_FUNCTION("ipex::lamb_fused_step", std::vector({param, param2, grad}), torch::autograd::Node::peek_at_next_sequence_nr()); + at::BFloat16* param_data = param.data_ptr(); + float* exp_avg_data = exp_avg.data_ptr(); + float* exp_avg_sq_data = exp_avg_sq.data_ptr(); + at::BFloat16* grad_data = grad.data_ptr(); + at::BFloat16* param2_data = param2.data_ptr(); + int num_threads = at::get_num_threads(); + float param_norm_acc[num_threads]; + float rtw_norm_acc[num_threads]; + std::fill_n(¶m_norm_acc[0], num_threads, float(0)); + std::fill_n(&rtw_norm_acc[0], num_threads, float(0)); + int64_t numel = param.numel(); + at::Tensor workspace = at::empty({numel}, exp_avg.options()); + float* workspace_data = workspace.data_ptr(); + int64_t grain_size = 512; + + at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) { + int tid = at::get_thread_num(); + // local pointers + at::BFloat16* param_ptr = param_data + begin; + float* exp_avg_ptr = exp_avg_data + begin; + float* exp_avg_sq_ptr = exp_avg_sq_data + begin; + at::BFloat16* grad_ptr = grad_data + begin; + at::BFloat16* param2_ptr = param2_data + begin; + float* workspace_ptr = workspace_data + begin; + const int64_t size = end - begin; + float sum1_val = float(0); + float sum2_val = float(0); + int64_t d = 0; + for (; d < size; d++) { + float grad_val = float(grad_ptr[d]); + exp_avg_ptr[d] = exp_avg_ptr[d] * beta1 + grad_val * (1 - beta1); + exp_avg_sq_ptr[d] = exp_avg_sq_ptr[d] * beta2 + grad_val * grad_val * (1 - beta2); + float adam_step_val = exp_avg_ptr[d] / (std::sqrt(exp_avg_sq_ptr[d]) + eps); + + float param_val = pack_bfloat16_float(param_ptr[d], param2_ptr[d]); + //adam_step_val += param_val * weight_decay; + workspace_ptr[d] = adam_step_val; + + sum1_val += param_val * param_val; + sum2_val += adam_step_val * adam_step_val; + } + param_norm_acc[tid] = sum1_val; + rtw_norm_acc[tid] = sum2_val; + }); + //std::cout<< "grad: " < & input); static std::vector interaction_backward(const at::Tensor & grad_out, const std::vector & input); diff --git a/torch_ipex/csrc/init_python_bindings.cpp b/torch_ipex/csrc/init_python_bindings.cpp index 6ff39b68b..57cad7779 100644 --- a/torch_ipex/csrc/init_python_bindings.cpp +++ b/torch_ipex/csrc/init_python_bindings.cpp @@ -65,6 +65,11 @@ void InitIpexModuleBindings(py::module m) { m.def("enable_pure_bf16", []() { AutoOptConfig::singleton().set_pure_bf16(true); }); m.def("disable_pure_bf16", []() { AutoOptConfig::singleton().set_pure_bf16(false); }); m.def("get_pure_bf16", []() { return AutoOptConfig::singleton().get_pure_bf16(); }); + m.def("lamb_fused_step_", + [](at::Tensor ¶m, at::Tensor &grad, at::Tensor & param2, at::Tensor & exp_avg, at::Tensor & exp_avg_sq, int64_t step, float lr, float beta1, float beta2, float weight_decay, float eps) { + AtenIpexTypeExt::lamb_fused_step_(param, grad, param2, exp_avg, exp_avg_sq, step, lr, beta1, beta2, weight_decay, eps); + }); + m.def("packed_add_", [](at::Tensor &top_half, at::Tensor &bot_half, const at::Tensor &grad, float alpha) {