There's lots of innovation out there building better machine learning models with new neural net structures, regularization methods, etc. Groups like fast.ai are training complex models quickly on commodity hardware by relying more on "algorithmic creativity" than on overwhelming hardware power, which is good news for those of us without data centers full of hardware. Rather than add to your stash of creative algorithms, this post takes a different (but compatible) approach to better performance. We'll step through measuring and optimizing model training from a systems perspective, which applies regardless of what algorithm you're using. As we'll see, there are nice speedups to be had with little effort. While keeping the same model structure, hyperparameters, etc, training speed improves by 26% by simply moving some work to another CPU core. You might find that you have easy performance gains waiting for you whether you're using an SVM or a neural network in your project.
We'll be using the fast neural style example in PyTorch's example projects as the system to optimize. If you just want to see the few code changes needed, check out this branch. Otherwise, to see how to get there step-by-step so that you can replicate this process on your own projects, read on.
If you want to follow along, start by downloading the 2017 COCO training dataset (18GiB). You'll also need a Linux system with a recent kernel and a GPU (an nVidia one, if you want to use the provided commands as-is). My hardware for this experiment is an i7-6850K with 2x GTX 1070 Ti, though we'll only be using one GPU this time.
If you're a
virtualenv user, you'll probably want to have a virtualenv with the necessary packages installed:
1 2 3
Clone the pytorch/examples repo and go into the
fast_neural_style directory, then start training a model. The batch size is left at the default (4) so it will be easier to replicate these results on smaller hardware, but of course feel free to increase the batch size if you have the hardware. We run
date first so we have a timestamp to compare later timestamps with, and pass
--cuda 1 so that it will use a GPU. The directory to pass to
--dataset should be the one containing the
train2017 directory, not the path to
1 2 3 4
While that's running for the next few hours, let's dig in to its performance characteristics.
vmstat 2 isn't fancy, but it's a good place to start. Unsurprisingly, we have negligible
wa (i/o wait) CPU usage and little block i/o, and we're seeing normal user CPU usage for one busy process on a 6-core, 12-thread system (10-12%):
1 2 3 4 5 6 7
Moving on to the GPU, we'll use
nvidia-smi dmon -s u -i 0.
dmon periodically outputs GPU info, and we're limiting it to utilization (
-s u) and we want the first GPU device (
1 2 3 4 5 6 7 8 9 10
Now this is more interesting. GPU utilization as low as 30%? This workload should basically be waiting on the GPU the entire time, so failing to keep the GPU busy is a problem.
To see if there's something seriously wrong,
perf stat is a simple way to get a high-level view of what's going on. Just using 100% of a CPU core doesn't mean much; it could be spending all of its time waiting for memory access or pipeline flushes. We can attach to a running process (that's the
-p <pid>) and aggregate performance counters. Note that if you're using a virtual machine, you may not have access to performance counters, and even my physical hardware doesn't support all the counters
perf stat looks for, as you can see below.
After letting that run for a couple of minutes, stopping it with
^C prints a summary:
1 2 3 4 5 6 7 8 9 10 11 12
- 1.25 instructions per cycle isn't awful, but it's not great either. In ideal circumstances, CPUs can retire many instructions per cycle, so if that was more like 3 or 4 then that would be a signal that the CPU was already being kept quite busy.
- 0.7% branch mispredicts is higher than I'd like, but it isn't catastrophic.
- Hundreds of context switches per second is not suitable for realtime systems, but is unlikely to affect batch workloads like this one.
Overall, there's nothing obvious like a 5% branch miss rate or 0.5 IPC, so this isn't telling us anything particularly compelling.
Capturing performance counter information with
perf stat shows that there's some room to improve, but it's not providing any details on what to do about it.
perf record, on the other hand, samples the program's stack while it's running, so it will tell us more about what specifically the CPU is spending its time doing.
After letting that run for a few minutes, that will have written a
perf.data file. It's not human readable, but
perf annotate will output disassembly with the percentage of samples where the CPU was executing that instruction. For my run, the output starts with a disassembly of
syscall_return_via_sysret indicating that most of the time there is spent on
pop. That's not particularly useful knowledge at the moment (the process makes a good number of syscalls, so we do expect to see some time spent there), so let's keep looking. The next item is for
jpeg_idct_islow@@LIBJPEG_9.0, part of PIL (Python Imaging Library, aka Pillow). The output starts with a summary like this:
1 2 3 4 5 6
That continues for about 50 lines or so. That tells us that rather than having a few hot instructions in this function, the cost is smeared out across many instructions (2.52% at offset 36635, 1.93% at 3624d, etc). Paging down to the disassembly, we find lots of this kind of info:
1 2 3 4 5 6 7 8 9 10 11 12
You can see the percentage of samples in the left column (sometimes prefixed with the symbol name for extra busy samples) and offset. This snippet is telling us that 0.72% of cycles were spent on
shlq (shift left), 0.82% on
addq (integer addition), etc. Note that due to a phenomenon called skid the usage per instruction may be incorrect by several or even dozens of instructions, so these percentage numbers should not be taken as gospel. In this case, for instance, it's unlikely that the two
shlq instructions at 35ec6 and 35ecb are actually 5x different.
The next section of
perf annotate output is of
__vdso_clock_gettime@@LINUX_2.6. VDSO is a way to speed up certain syscalls, notably
gettimeofday. Not much to see here, other than to note that maybe we shouldn't be calling
gettimeofday(2) as much.
The next section is of
_imaging.cpython-35m-x86_64-linux-gnu.so, which has a large block of fairly hot instructions like this:
1 2 3 4 5
That's a pretty hot
movzbl (zero-expand 1 byte into 4 bytes). At this point, we have a hypothesis: we're spending a lot of time decoding and scaling images.
To get a clearer view of what paths through the program are the most relevant, we'll use a flame graph. There are other things we could do with
perf report, but flame graphs are easier to understand in my experience. If you haven't worked with flame graphs before, the general idea is that width = percentage of samples and height = call stack. As an example, a tall, skinny column means a deep call stack that doesn't use much CPU time, while a wide, short column means a shallow call stack that uses a lot of CPU.
Clone the Pyflame repo and follow their compile instructions. There's no need to install it anywhere (the
make install step) -- just having a compiled
pyflame binary in the build output is sufficient.
As with the other tools, we attach the compiled
pyflame binary to the running process and let it run for 10 minutes to get good fidelity:
In the meantime, clone FlameGraph so we can render
pyflame's output as an SVG. From the FlameGraph repo:
The resulting flamegraph looks like this, which you'll probably want to open in a separate tab and zoom in on (the SVG has helpful mouseover info):
Most of the time is idle, which isn't interesting in this case. Back to pyflame, this time with
-x to exclude idle time:
That's much easier to see. If you spend some time mousing around the left third or so of the graph, you'll find that a significant amount of execution time is spent on image decoding and manipulation, starting with
neural_style.py:67, which is this:
Put another way, it's spending enough time decoding images that it's a significant part of the overall execution, and it's all clumped together in one place rather than being spread across the whole program. In a way, that's good news, because that's a problem we can pretty easily do something about. There's still the other 2/3rds of the graph that we'd love to optimize away too, but that's mostly running the PyTorch model, and improving that will require much more surgery. There is a surprising amount of time spent in
vgg.py's usage of
namedtuple, though -- that might be a good thing to investigate another time. So, let's work on that first third of the graph: loading the training data.
Adding some parallelism
By now that
python command has probably been running for a good long while. It helpfully outputs timestamps every 2000 iterations, and comparing the 20,000 iteration timestamp with the start timestamp I get 21m 40s elapsed. We'll use this duration later.
It's unlikely that we'll make huge strides in JPEG decoding, as that code is already written in a low level language and reasonably well optimized. What we can do, though, is move the CPU-intensive work of image decoding onto another core. Normally this would be a good place to use threads, but Python as a language and CPython as a runtime are not very well suited to multithreading. We can use another
Process to avoid the GIL (Global Interpreter Lock) and lack of memory model, though, and even though we'll have more overhead between processes than we would between threads, it should still be a net win. Conveniently, the work we want to execute concurrently is already small and fairly isolated, so it should be easy to move to another process. The training data
DataLoader is set up in
and used with
So, all we need to do is move the loading to another process. We can do this with a Queue (actually, one of PyTorch's wrappers). Instead of enumerating
train_loader in the main process, we'll have another process do that so that all the image decoding can happen on another core, and we'll receive the decoded data in the main process. To make it easy to enumerate a
Queue, we'll start with a helper to make a
Queue be iterable:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
We're making the simple assumption that if we pop
None off the queue, the iteration is complete. Next, we need a function to invoke in another process that will populate the queue. This is made a bit more complicated by the fact that we can't spawn a process at the intuitive point in the training
for loop. If we create a new
Process where we enumerate over
train_loader (as in the code snippet above), the child process hangs immediately and never is able to populate the queue. The PyTorch docs warn that about such issues, but unfortunately using
torch.multiprocessing's wrappers or
SimpleQueue did not help. Getting to the root cause of that problem will be a task for another day, but it's simple enough to rearrange the code to avoid the problem: fork a worker process earlier, and re-use it across multiple iterations. To do this, we use another
Queue as a simple communication mechanism, which is
control_queue below. The usage is pretty basic: sending
control_queue tells the worker to enumerate the loader and populate
batch_queue, finishing with a
None to signal completion to the
QueueIterator on the other end, while sending
False tells the worker that its job is done and it can end its loop (and therefore exit).
1 2 3 4 5 6 7 8 9
Now we have everything we need to wire it all together. Before the training loop, make some queues and fork a process:
1 2 3
And then iterate in the training loop with a
After stopping the previous training invocation and starting a new one, we can immediately see a good change in
1 2 3 4 5 6 7 8 9 10 11 12 13
The GPU is staying fully utilized. Here's the master process:
What was previously the largest chunk of work in the flame graph is now the small peak on the left. Training logic dominates the graph, instead of image decoding, and
namedtuple looks like increasingly low-hanging fruit at 13% of samples...
And the worker that loads training images:
It's spending most of its time doing image decoding. There's some CPU spent in Python's Queue implementation, but the worker sits at about 40% CPU usage total anyway, so the inter-process communication isn't a major bottleneck in this case.
More importantly, for our "time until 20,000 iterations" measurement, that improves from 21m40s to 17m9s, or about a 26% improvement in iterations/sec (15.4 to 19.4). Not bad for just a few lines of straightforward code.