forked from muellerzr/pippy-device-map-playground
-
Notifications
You must be signed in to change notification settings - Fork 0
/
t5.py
68 lines (59 loc) · 1.82 KB
/
t5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import time
import torch
from accelerate import prepare_pippy
from accelerate.utils import set_seed
from transformers import AutoModelForSeq2SeqLM
# Set the random seed to have reproducable outputs
set_seed(42)
# Create an example model
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
model.eval()
# Input configs
# Create example inputs for the model
input = torch.randint(
low=0,
high=model.config.vocab_size,
size=(2, 1024), # bs x seq_len
device="cpu",
dtype=torch.int64,
requires_grad=False,
)
example_inputs = {"input_ids": input, "decoder_input_ids": input}
# Create a pipeline stage from the model
# Using `auto` is equivalent to letting `device_map="auto"` figure
# out device mapping and will also split the model according to the
# number of total GPUs available if it fits on one GPU
model = prepare_pippy(
model,
no_split_module_classes=["T5Block"],
example_kwargs=example_inputs,
)
# The model expects a tuple during real inference
# with the data on the first device
args = (
example_inputs["input_ids"].to("cuda:0"),
example_inputs["decoder_input_ids"].to("cuda:0")
)
# Take an average of 5 times
# Measure first batch
torch.cuda.synchronize()
start_time = time.time()
with torch.no_grad():
output = model(*args)
torch.cuda.synchronize()
end_time = time.time()
first_batch = end_time - start_time
# Now that CUDA is init, measure after
torch.cuda.synchronize()
start_time = time.time()
for i in range(5):
with torch.no_grad():
output = model(*args)
torch.cuda.synchronize()
end_time = time.time()
# First `n` values in output are the model outputs
# which will be located on the last device
if output is not None:
output = torch.stack(tuple(output[0]))
print(f"Time of first pass: {first_batch}")
print(f"Average time per batch: {(end_time - start_time)/5}")