There's lots of innovation out there building better machine learning models with new neural net structures, regularization methods, etc. Groups like fast.ai are training complex models quickly on commodity hardware by relying more on "algorithmic creativity" than on overwhelming hardware power, which is good news for those of us without data centers full of hardware. Rather than add to your stash of creative algorithms, this post takes a different (but compatible) approach to better performance. We'll step through measuring and optimizing model training from a systems perspective, which applies regardless of what algorithm you're using. As we'll see, there are nice speedups to be had with little effort. While keeping the same model structure, hyperparameters, etc, training speed improves by 26% by simply moving some work to another CPU core. You might find that you have easy performance gains waiting for you whether you're using an SVM or a neural network in your project.
We'll be using the fast neural style example in PyTorch's example projects as the system to optimize. If you just want to see the few code changes needed, check out this branch. Otherwise, to see how to get there step-by-step so that you can replicate this process on your own projects, read on.
If you want to follow along, start by downloading the 2017 COCO training dataset (18GiB). You'll also need a Linux system with a recent kernel and a GPU (an nVidia one, if you want to use the provided commands as-is). My hardware for this experiment is an i7-6850K with 2x GTX 1070 Ti, though we'll only be using one GPU this time.
Getting started
If you're a virtualenv
user, you'll probably want to have a virtualenv with the necessary packages installed:
1 2 3 |
|
Clone the pytorch/examples repo and go into the fast_neural_style
directory, then start training a model. The batch size is left at the default (4) so it will be easier to replicate these results on smaller hardware, but of course feel free to increase the batch size if you have the hardware. We run date
first so we have a timestamp to compare later timestamps with, and pass --cuda 1
so that it will use a GPU. The directory to pass to --dataset
should be the one containing the train2017
directory, not the path to train2017
itself.
1 2 3 4 |
|
While that's running for the next few hours, let's dig in to its performance characteristics. vmstat 2
isn't fancy, but it's a good place to start. Unsurprisingly, we have negligible wa
(i/o wait) CPU usage and little block i/o, and we're seeing normal user CPU usage for one busy process on a 6-core, 12-thread system (10-12%):
1 2 3 4 5 6 7 |
|
Moving on to the GPU, we'll use nvidia-smi dmon -s u -i 0
. dmon
periodically outputs GPU info, and we're limiting it to utilization (-s u
) and we want the first GPU device (-i 0
).
1 2 3 4 5 6 7 8 9 10 |
|
Now this is more interesting. GPU utilization as low as 30%? This workload should basically be waiting on the GPU the entire time, so failing to keep the GPU busy is a problem.
To see if there's something seriously wrong, perf stat
is a simple way to get a high-level view of what's going on. Just using 100% of a CPU core doesn't mean much; it could be spending all of its time waiting for memory access or pipeline flushes. We can attach to a running process (that's the -p <pid>
) and aggregate performance counters. Note that if you're using a virtual machine, you may not have access to performance counters, and even my physical hardware doesn't support all the counters perf stat
looks for, as you can see below.
1
|
|
After letting that run for a couple of minutes, stopping it with ^C
prints a summary:
1 2 3 4 5 6 7 8 9 10 11 12 |
|
Key points:
- 1.25 instructions per cycle isn't awful, but it's not great either. In ideal circumstances, CPUs can retire many instructions per cycle, so if that was more like 3 or 4 then that would be a signal that the CPU was already being kept quite busy.
- 0.7% branch mispredicts is higher than I'd like, but it isn't catastrophic.
- Hundreds of context switches per second is not suitable for realtime systems, but is unlikely to affect batch workloads like this one.
Overall, there's nothing obvious like a 5% branch miss rate or 0.5 IPC, so this isn't telling us anything particularly compelling.
Profiling
Capturing performance counter information with perf stat
shows that there's some room to improve, but it's not providing any details on what to do about it. perf record
, on the other hand, samples the program's stack while it's running, so it will tell us more about what specifically the CPU is spending its time doing.
1
|
|
After letting that run for a few minutes, that will have written a perf.data
file. It's not human readable, but perf annotate
will output disassembly with the percentage of samples where the CPU was executing that instruction. For my run, the output starts with a disassembly of syscall_return_via_sysret
indicating that most of the time there is spent on pop
. That's not particularly useful knowledge at the moment (the process makes a good number of syscalls, so we do expect to see some time spent there), so let's keep looking. The next item is for jpeg_idct_islow@@LIBJPEG_9.0
, part of PIL (Python Imaging Library, aka Pillow). The output starts with a summary like this:
1 2 3 4 5 6 |
|
That continues for about 50 lines or so. That tells us that rather than having a few hot instructions in this function, the cost is smeared out across many instructions (2.52% at offset 36635, 1.93% at 3624d, etc). Paging down to the disassembly, we find lots of this kind of info:
1 2 3 4 5 6 7 8 9 10 11 12 |
|
You can see the percentage of samples in the left column (sometimes prefixed with the symbol name for extra busy samples) and offset. This snippet is telling us that 0.72% of cycles were spent on shlq
(shift left), 0.82% on addq
(integer addition), etc. Note that due to a phenomenon called skid the usage per instruction may be incorrect by several or even dozens of instructions, so these percentage numbers should not be taken as gospel. In this case, for instance, it's unlikely that the two shlq
instructions at 35ec6 and 35ecb are actually 5x different.
The next section of perf annotate
output is of __vdso_clock_gettime@@LINUX_2.6
. VDSO is a way to speed up certain syscalls, notably gettimeofday
. Not much to see here, other than to note that maybe we shouldn't be calling gettimeofday(2)
as much.
The next section is of ImagingResampleHorizontal_8bpc@@Base
from _imaging.cpython-35m-x86_64-linux-gnu.so
, which has a large block of fairly hot instructions like this:
1 2 3 4 5 |
|
That's a pretty hot movzbl
(zero-expand 1 byte into 4 bytes). At this point, we have a hypothesis: we're spending a lot of time decoding and scaling images.
Flame graphs
To get a clearer view of what paths through the program are the most relevant, we'll use a flame graph. There are other things we could do with perf
like perf report
, but flame graphs are easier to understand in my experience. If you haven't worked with flame graphs before, the general idea is that width = percentage of samples and height = call stack. As an example, a tall, skinny column means a deep call stack that doesn't use much CPU time, while a wide, short column means a shallow call stack that uses a lot of CPU.
Clone the Pyflame repo and follow their compile instructions. There's no need to install it anywhere (the make install
step) -- just having a compiled pyflame
binary in the build output is sufficient.
As with the other tools, we attach the compiled pyflame
binary to the running process and let it run for 10 minutes to get good fidelity:
1
|
|
In the meantime, clone FlameGraph so we can render pyflame
's output as an SVG. From the FlameGraph repo:
1
|
|
The resulting flamegraph looks like this, which you'll probably want to open in a separate tab and zoom in on (the SVG has helpful mouseover info):
Most of the time is idle, which isn't interesting in this case. Back to pyflame, this time with -x
to exclude idle time:
1
|
|
That's much easier to see. If you spend some time mousing around the left third or so of the graph, you'll find that a significant amount of execution time is spent on image decoding and manipulation, starting with neural_style.py:67
, which is this:
1
|
|
Put another way, it's spending enough time decoding images that it's a significant part of the overall execution, and it's all clumped together in one place rather than being spread across the whole program. In a way, that's good news, because that's a problem we can pretty easily do something about. There's still the other 2/3rds of the graph that we'd love to optimize away too, but that's mostly running the PyTorch model, and improving that will require much more surgery. There is a surprising amount of time spent in vgg.py
's usage of namedtuple
, though -- that might be a good thing to investigate another time. So, let's work on that first third of the graph: loading the training data.
Adding some parallelism
By now that python
command has probably been running for a good long while. It helpfully outputs timestamps every 2000 iterations, and comparing the 20,000 iteration timestamp with the start timestamp I get 21m 40s elapsed. We'll use this duration later.
It's unlikely that we'll make huge strides in JPEG decoding, as that code is already written in a low level language and reasonably well optimized. What we can do, though, is move the CPU-intensive work of image decoding onto another core. Normally this would be a good place to use threads, but Python as a language and CPython as a runtime are not very well suited to multithreading. We can use another Process
to avoid the GIL (Global Interpreter Lock) and lack of memory model, though, and even though we'll have more overhead between processes than we would between threads, it should still be a net win. Conveniently, the work we want to execute concurrently is already small and fairly isolated, so it should be easy to move to another process. The training data DataLoader
is set up in neural_style.py
with
1 2 |
|
and used with
1
|
|
So, all we need to do is move the loading to another process. We can do this with a Queue (actually, one of PyTorch's wrappers). Instead of enumerating train_loader
in the main process, we'll have another process do that so that all the image decoding can happen on another core, and we'll receive the decoded data in the main process. To make it easy to enumerate a Queue
, we'll start with a helper to make a Queue
be iterable:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
|
We're making the simple assumption that if we pop None
off the queue, the iteration is complete. Next, we need a function to invoke in another process that will populate the queue. This is made a bit more complicated by the fact that we can't spawn a process at the intuitive point in the training for
loop. If we create a new Process
where we enumerate over train_loader
(as in the code snippet above), the child process hangs immediately and never is able to populate the queue. The PyTorch docs warn that about such issues, but unfortunately using torch.multiprocessing
's wrappers or SimpleQueue
did not help. Getting to the root cause of that problem will be a task for another day, but it's simple enough to rearrange the code to avoid the problem: fork a worker process earlier, and re-use it across multiple iterations. To do this, we use another Queue
as a simple communication mechanism, which is control_queue
below. The usage is pretty basic: sending True
through control_queue
tells the worker to enumerate the loader and populate batch_queue
, finishing with a None
to signal completion to the QueueIterator
on the other end, while sending False
tells the worker that its job is done and it can end its loop (and therefore exit).
1 2 3 4 5 6 7 8 9 |
|
Now we have everything we need to wire it all together. Before the training loop, make some queues and fork a process:
1 2 3 |
|
And then iterate in the training loop with a QueueIterator
:
1
|
|
Measuring
After stopping the previous training invocation and starting a new one, we can immediately see a good change in nvidia-smi dmon
:
1 2 3 4 5 6 7 8 9 10 11 12 13 |
|
The GPU is staying fully utilized. Here's the master process:
What was previously the largest chunk of work in the flame graph is now the small peak on the left. Training logic dominates the graph, instead of image decoding, and namedtuple
looks like increasingly low-hanging fruit at 13% of samples...
And the worker that loads training images:
It's spending most of its time doing image decoding. There's some CPU spent in Python's Queue implementation, but the worker sits at about 40% CPU usage total anyway, so the inter-process communication isn't a major bottleneck in this case.
More importantly, for our "time until 20,000 iterations" measurement, that improves from 21m40s to 17m9s, or about a 26% improvement in iterations/sec (15.4 to 19.4). Not bad for just a few lines of straightforward code.