forked from pytorch/benchmark
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_sweep.py
162 lines (154 loc) · 7.22 KB
/
run_sweep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""
Run a config of benchmarking with a list of models.
If unspecified, run a sweep of all models.
"""
import argparse
import json
import os
import sys
import numpy
import sys
import torch
import time
import pathlib
import dataclasses
import itertools
import torch
from typing import List, Optional, Dict, Any, Tuple
from torchbenchmark import ModelTask
WARMUP_ROUNDS = 3
WORKER_TIMEOUT = 600 # seconds
MODEL_DIR = ['torchbenchmark', 'models']
NANOSECONDS_PER_MILLISECONDS = 1_000_000.0
def run_one_step(func, device: str, nwarmup=WARMUP_ROUNDS, num_iter=10) -> Tuple[float, Optional[Tuple[torch.Tensor]]]:
"Run one step of the model, and return the latency in milliseconds."
# Warm-up `nwarmup` rounds
for _i in range(nwarmup):
func()
result_summary = []
for _i in range(num_iter):
if device == "cuda":
torch.cuda.synchronize()
# Collect time_ns() instead of time() which does not provide better precision than 1
# second according to https://docs.python.org/3/library/time.html#time.time.
t0 = time.time_ns()
func()
torch.cuda.synchronize() # Wait for the events to be recorded!
t1 = time.time_ns()
else:
t0 = time.time_ns()
func()
t1 = time.time_ns()
result_summary.append((t1 - t0) / NANOSECONDS_PER_MILLISECONDS)
wall_latency = numpy.median(result_summary)
return wall_latency
@dataclasses.dataclass
class ModelTestResult:
name: str
test: str
device: str
extra_args: List[str]
status: str
batch_size: Optional[int]
precision: str
results: Dict[str, Any]
def _list_model_paths(models: List[str]) -> List[str]:
p = pathlib.Path(__file__).parent.joinpath(*MODEL_DIR)
model_paths = sorted(child for child in p.iterdir() if child.is_dir())
valid_model_paths = sorted(filter(lambda x: x.joinpath("__init__.py").exists(), model_paths))
if models:
valid_model_paths = sorted(filter(lambda x: x.name in models, valid_model_paths))
return valid_model_paths
def _validate_tests(tests: str) -> List[str]:
tests_list = list(map(lambda x: x.strip(), tests.split(",")))
valid_tests = ['train', 'eval']
for t in tests_list:
if t not in valid_tests:
raise ValueError(f'Invalid test {t} passed into --tests. Expected tests: {valid_tests}.')
return tests_list
def _validate_devices(devices: str) -> List[str]:
devices_list = list(map(lambda x: x.strip(), devices.split(",")))
valid_devices = ['cpu', 'cuda']
for d in devices_list:
if d not in valid_devices:
raise ValueError(f'Invalid device {d} passed into --devices. Expected devices: {valid_devices}.')
return devices_list
def _run_model_test(model_path: pathlib.Path, test: str, device: str, jit: bool, batch_size: Optional[int], extra_args: List[str]) -> ModelTestResult:
assert test == "train" or test == "eval", f"Test must be either 'train' or 'eval', but get {test}."
result = ModelTestResult(name=model_path.name, test=test, device=device, extra_args=extra_args, batch_size=None, precision="fp32",
status="OK", results={})
# Run the benchmark test in a separate process
print(f"Running model {model_path.name} ... ", end='', flush=True)
status: str = "OK"
bs_name = "batch_size"
correctness_name = "correctness"
error_message: Optional[str] = None
try:
task = ModelTask(os.path.basename(model_path), timeout=WORKER_TIMEOUT)
if not task.model_details.exists:
status = "NotExist"
return
task.make_model_instance(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args)
# Check the batch size in the model matches the specified value
result.batch_size = task.get_model_attribute(bs_name)
result.precision = task.get_model_attribute("dargs", "precision")
if batch_size and (not result.batch_size == batch_size):
raise ValueError(f"User specify batch size {batch_size}, but model {result.name} runs with batch size {result.batch_size}. Please report a bug.")
result.results["latency_ms"] = run_one_step(task.invoke, device)
# if NUM_BATCHES is set, update to per-batch latencies
num_batches = task.get_model_attribute("NUM_BATCHES")
if num_batches:
result.results["latency_ms"] = result.results["latency_ms"] / num_batches
# if the model provides eager eval result, save it for cosine similarity
correctness = task.get_model_attribute(correctness_name)
if correctness is not None:
result.results[correctness_name] = str(correctness)
except NotImplementedError as e:
status = "NotImplemented"
error_message = str(e)
except TypeError as e: # TypeError is raised when the model doesn't support variable batch sizes
status = "TypeError"
error_message = str(e)
except KeyboardInterrupt as e:
status = "UserInterrupted"
error_message = str(e)
except Exception as e:
status = f"{type(e).__name__}"
error_message = str(e)
finally:
print(f"[ {status} ]")
result.status = status
if error_message:
result.results["error_message"] = error_message
if status == "UserInterrupted":
sys.exit(1)
return result
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--models", nargs='+', default=[],
help="Specify one or more models to run. If not set, trigger a sweep-run on all models.")
parser.add_argument("-t", "--tests", required=True, type=_validate_tests, help="Specify tests, choice of train, or eval.")
parser.add_argument("-d", "--devices", required=True, type=_validate_devices, help="Specify devices, choice of cpu, or cuda.")
parser.add_argument("-b", "--bs", type=int, help="Specify batch size.")
parser.add_argument("--jit", action='store_true', help="Turn on torchscript.")
parser.add_argument("-o", "--output", type=str, default="tb-output.json", help="The default output json file.")
parser.add_argument("--proper-bs", action='store_true', help="Find the best batch_size for current devices.")
args, extra_args = parser.parse_known_args()
args.models = _list_model_paths(args.models)
results = []
for element in itertools.product(*[args.models, args.tests, args.devices]):
model_path, test, device = element
if args.proper_bs:
if test != 'eval':
print("Error: Only batch size of eval test is tunable.")
sys.exit(1)
from scripts.proper_bs import _run_model_test_proper_bs
r = _run_model_test_proper_bs(model_path, test, device, args.jit, batch_size=args.bs, extra_args=extra_args)
else:
r = _run_model_test(model_path, test, device, args.jit, batch_size=args.bs, extra_args=extra_args)
results.append(r)
results_to_export = list(map(lambda x: dataclasses.asdict(x), results))
parent_dir = pathlib.Path(args.output).parent
parent_dir.mkdir(exist_ok=True, parents=True)
with open(args.output, "w") as outfile:
json.dump(results_to_export, outfile, indent=4)