diff --git a/pytext/task/accelerator_lowering.py b/pytext/task/accelerator_lowering.py index 969f8bab3..c5d71ae78 100644 --- a/pytext/task/accelerator_lowering.py +++ b/pytext/task/accelerator_lowering.py @@ -243,20 +243,6 @@ def forward( return rep, (new_hidden, new_cell) -# Swap a transformer for only RoBERTaEncoder encoders -def swap_modules_for_accelerator(model): - if hasattr(model, "encoder") and isinstance(model.encoder, RoBERTaEncoder): - old_transformer = model.encoder.encoder.transformer - model.encoder.encoder.transformer = AcceleratorTransformer(old_transformer) - return model - elif hasattr(model, "representation") and isinstance(model.representation, BiLSTM): - old_biLSTM = model.representation - model.representation = AcceleratorBiLSTM(old_biLSTM) - return model - else: - return model - - def lower_modules_to_accelerator( model: nn.Module, trace, export_options: ExportConfig, throughput_optimize=False ): @@ -321,3 +307,13 @@ def lower_modules_to_accelerator( return trace else: return trace + + +def nnpi_rewrite_roberta_transformer(model): + model.encoder.encoder.transformer = AcceleratorTransformer( + model.encoder.encoder.transformer + ) + + +def nnpi_rewrite_bilstm(model): + model.representation = AcceleratorBiLSTM(model.representation) diff --git a/pytext/task/cuda_lowering.py b/pytext/task/cuda_lowering.py index 58796e126..24761b15e 100644 --- a/pytext/task/cuda_lowering.py +++ b/pytext/task/cuda_lowering.py @@ -189,11 +189,7 @@ def forward(self, tokens: Tensor) -> List[Tensor]: return states -# Swap a transformer for only RoBERTaEncoder encoders -def swap_modules_for_faster_transformer(model): - if hasattr(model, "encoder") and isinstance(model.encoder, RoBERTaEncoder): - old_transformer = model.encoder.encoder.transformer - model.encoder.encoder.transformer = NVFasterTransformerEncoder(old_transformer) - return model - else: - return model +def cuda_rewrite_roberta_transformer(model): + model.encoder.encoder.transformer = NVFasterTransformerEncoder( + model.encoder.encoder.transformer + ) diff --git a/pytext/task/new_task.py b/pytext/task/new_task.py index 276024d26..d2547ce63 100644 --- a/pytext/task/new_task.py +++ b/pytext/task/new_task.py @@ -13,6 +13,8 @@ from pytext.data.tensorizers import Tensorizer from pytext.metric_reporters import MetricReporter from pytext.models.model import BaseModel +from pytext.models.representations.bilstm import BiLSTM +from pytext.models.roberta import RoBERTaEncoder from pytext.trainers import TaskTrainer, TrainingState from pytext.utils import cuda, onnx, precision from pytext.utils.file_io import PathManager @@ -25,10 +27,11 @@ from .accelerator_lowering import ( lower_modules_to_accelerator, - swap_modules_for_accelerator, + nnpi_rewrite_roberta_transformer, + nnpi_rewrite_bilstm, ) from .cuda_lowering import ( - swap_modules_for_faster_transformer, + cuda_rewrite_roberta_transformer, ) from .quantize import ( quantize_statically, @@ -44,6 +47,57 @@ ) +MODULE_TO_REWRITER = { + "nnpi": { + RoBERTaEncoder: nnpi_rewrite_roberta_transformer, + BiLSTM: nnpi_rewrite_bilstm, + }, + "cuda": { + RoBERTaEncoder: cuda_rewrite_roberta_transformer, + }, +} + + +def find_module_instances(model, module_type, cur_path): + """ + Finds all module instances of the specified type and returns the paths to get to each of + those instances + """ + if isinstance(model, module_type): + yield list(cur_path) # copy the list since cur_path is a shared list + for attr in dir(model): + if ( + attr[0] == "_" or len(cur_path) > 4 + ): # avoids infinite recursion and exploring unnecessary paths + continue + cur_path.append(attr) + # recursively yield + yield from find_module_instances(getattr(model, attr), module_type, cur_path) + cur_path.pop() + + +def rewrite_transformer(model, module_path, rewriter): + """ + Descends model hierarchy according to module_path and calls the rewriter at the end + """ + for prefix in module_path[:-1]: + model = getattr(model, prefix) + rewriter(model) + + +def swap_modules(model, module_to_rewriter): + """ + Finds modules within a model that can be rewritten and rewrites them with predefined + rewrite functions + """ + for module in module_to_rewriter: + instance_paths = find_module_instances(model, module, []) + rewriter = module_to_rewriter[module] + for path in instance_paths: + rewrite_transformer(model, path, rewriter) + return model + + def create_schema( tensorizers: Dict[str, Tensorizer], extra_schema: Optional[Dict[str, Type]] = None ) -> Schema: @@ -346,7 +400,7 @@ def torchscript_export( optimizer.pre_export(model) if use_nnpi or use_fx_quantize: - model = swap_modules_for_accelerator(model) + model = swap_modules(model, MODULE_TO_REWRITER["nnpi"]) # Trace needs eval mode, to disable dropout etc model.eval() @@ -392,7 +446,7 @@ def torchscript_export( precision.FP16_ENABLED = True cuda.CUDA_ENABLED = True - model = swap_modules_for_faster_transformer(model) + model = swap_modules(model, MODULE_TO_REWRITER["cuda"]) model.eval() model.half().cuda() # obtain new inputs with cuda/fp16 enabled.