Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix Wint8 lora #9781

Open
wants to merge 4 commits into
base: release/3.0-beta2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 91 additions & 48 deletions paddlenlp/peft/lora/lora_quantization_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
import paddle
from paddle import nn
from paddle.distributed.fleet.layers.mpu import mp_ops
from paddle.nn.quant import weight_dequantize, weight_only_linear, weight_quantize
from paddle.nn.quant import weight_dequantize, weight_quantize

from ...quantization.qlora import qlora_weight_dequantize, qlora_weight_quantize
from ...quantization.quantization_linear import (
ColumnParallelQuantizationLinear,
QuantizationLinear,
RowParallelQuantizationLinear,
)
from .lora_quick_quant_layers import quick_lora
from .utils import rng_ctx


Expand Down Expand Up @@ -108,6 +109,8 @@ def __init__(
self.weight = None
self.scaling = self.lora_alpha / self.r
self.disable_lora = False
if self.bias is not None:
self.bias.stop_gradient = True

def dequantize_weight(self):
if self.quant_algo in ["fp4", "nf4"]:
Expand Down Expand Up @@ -176,10 +179,21 @@ def merge(self):
self.merged = True

def forward(self, x: paddle.Tensor):
result = super().forward(x)
if not self.merged and not self.disable_lora:
result += (self.lora_dropout(x) @ self.lora_A @ self.lora_B) * self.scaling
return result
# result = super().forward(x)
# if not self.merged and not self.disable_lora:
# result += (self.lora_dropout(x) @ self.lora_A @ self.lora_B) * self.scaling
# return result
return quick_lora(
x,
self.lora_A,
self.lora_B,
self.quant_weight,
self.quant_scale,
self.quant_algo,
self._dtype,
self.bias,
self.scaling,
)


class ColumnParallelQuantizationLoRALinear(ColumnParallelQuantizationLinear):
Expand Down Expand Up @@ -223,9 +237,6 @@ def __init__(
raise ValueError("Lora rank r should be a positive integer")
if self.quant_algo == "llm.int8":
raise NotImplementedError("llm.int8 not yet support lora strategy.")
if self.quant_algo in ["fp4", "nf4"]:
raise NotImplementedError(f"{self.quant_algo} not yet support tensor parallelism.")

self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
Expand Down Expand Up @@ -258,19 +269,33 @@ def __init__(

def forward(self, x):

result_mp = super().forward(x)

if not self.disable_lora or not self.merged:
input_a = self.lora_dropout(x) @ self.lora_A
input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group)
delta_mp = (input_a_mp @ self.lora_B) * self.scaling
result_mp += delta_mp

if self.gather_output and self.is_mp:
result = mp_ops._c_concat(result_mp, group=self.model_parallel_group)
else:
result = result_mp
return result
# result_mp = super().forward(x)

# if not self.disable_lora or not self.merged:
# input_a = self.lora_dropout(x) @ self.lora_A
# input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group)
# delta_mp = (input_a_mp @ self.lora_B) * self.scaling
# result_mp += delta_mp

# if self.gather_output and self.is_mp:
# result = mp_ops._c_concat(result_mp, group=self.model_parallel_group)
# else:
# result = result_mp
# return result
return quick_lora(
x,
self.lora_A,
self.lora_B,
self.quant_weight,
self.quant_scale,
self.quant_algo,
self._dtype,
self.bias,
self.scaling,
is_column=True,
group=self.model_parallel_group,
world_size=self.world_size,
)

def dequantize_weight(self):
if self.quant_algo in ["fp4", "nf4"]:
Expand Down Expand Up @@ -377,8 +402,6 @@ def __init__(
raise ValueError("Lora rank r should be a positive integer")
if self.quant_algo == "llm.int8":
raise NotImplementedError("llm.int8 not yet support lora strategy.")
if self.quant_algo in ["fp4", "nf4"]:
raise NotImplementedError(f"{self.quant_algo} not yet support tensor parallelism.")
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
Expand Down Expand Up @@ -411,36 +434,56 @@ def __init__(
self.merged = False

def forward(self, x: paddle.Tensor):
if not self.input_is_parallel:
input_mp = mp_ops._c_split(x, group=self.model_parallel_group)
else:
input_mp = x

# x @ W : [bz, in_f / ws] ===> [bz, out_f]
with paddle.amp.auto_cast(enable=False):
result_mp = weight_only_linear(input_mp, self.quant_weight, None, self.quant_scale, self.quant_dtype)

output = mp_ops._mp_allreduce(
# if not self.input_is_parallel:
# input_mp = mp_ops._c_split(x, group=self.model_parallel_group)
# else:
# input_mp = x

# # x @ W : [bz, in_f / ws] ===> [bz, out_f]
# with paddle.amp.auto_cast(enable=False):
# result_mp = weight_only_linear(input_mp, self.quant_weight, None, self.quant_scale, self.quant_dtype)

# output = mp_ops._mp_allreduce(
# result_mp,
# group=self.model_parallel_group,
# use_calc_stream=True,
# use_model_parallel=True,
# )
# if not self.disable_lora or not self.merged:
# # x @ A: [bz, in_f/ ws] ===> [bz, r]
# input_mp = self.lora_dropout(input_mp) @ self.lora_A
# # all reduce to keep Lora B's gradient on different gpu consistent
# input_dup = mp_ops._mp_allreduce(
# input_mp,
# group=self.model_parallel_group,
# use_calc_stream=True,
# use_model_parallel=True,
# )
# # @ B: [bz, r] ===> [bz, out_f]
# delta_mp = (input_dup @ self.lora_B) * self.scaling
# output += delta_mp
# output = output + self.bias if self.bias is not None else output
# return output
result_mp = quick_lora(
x,
self.lora_A,
self.lora_B,
self.quant_weight,
self.quant_scale,
self.quant_algo,
self._dtype,
self.bias,
self.scaling,
is_row=True,
group=self.model_parallel_group,
world_size=self.world_size,
)
return mp_ops._mp_allreduce(
result_mp,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)
if not self.disable_lora or not self.merged:
# x @ A: [bz, in_f/ ws] ===> [bz, r]
input_mp = self.lora_dropout(input_mp) @ self.lora_A
# all reduce to keep Lora B's gradient on different gpu consistent
input_dup = mp_ops._mp_allreduce(
input_mp,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)
# @ B: [bz, r] ===> [bz, out_f]
delta_mp = (input_dup @ self.lora_B) * self.scaling
output += delta_mp
output = output + self.bias if self.bias is not None else output
return output

def dequantize_weight(self):
if self.quant_algo in ["fp4", "nf4"]:
Expand Down
Loading
Loading