Distributed Training: Scaling Across GPUs

Single-GPU training hits limits at: dataset size (100GB+ doesn't fit in GPU memory), model size (LLMs with billions of parameters don't fit on one GPU), and time (acceptable training time for business iteration — waiting 72 hours for each experiment means 2 experiments per week instead of 2 per day). Distributed training solves all three by spreading the workload across multiple GPUs — on one machine (multi-GPU) or across machines (multi-node). Two fundamental approaches: data parallelism (same model on every GPU, different data slices) and model parallelism (model split across GPUs, same data).

Distributed training isn't about having more GPUs — it's about using multiple GPUs efficiently. Poorly configured distributed training runs slower than single-GPU because of communication overhead. The engineering is in minimizing that overhead.

Data Parallelism: The Standard Approach

Data parallelism: each GPU holds a complete copy of the model. The training batch is split across GPUs — each GPU processes its slice, computes gradients, and the gradients are averaged across all GPUs (allreduce operation) before updating the model weights. Example: 8 GPUs, batch size 256 → each GPU processes 32 samples. Training time: approximately 1/8th of single-GPU (minus communication overhead for gradient synchronization). Communication overhead: 5-15% for well-configured setups, 30-50% for poorly configured (slow interconnect, oversized gradients). Implementation: PyTorch DistributedDataParallel (DDP) — the standard for multi-GPU training. Configuration: set world_size (total GPU count), rank (GPU index), and backend (NCCL for GPU communication). DDP handles gradient synchronization automatically. Scaling efficiency: Near-linear up to 8 GPUs on one machine (NVLink interconnect). 70-85% efficiency for 16-64 GPUs across machines (network becomes the bottleneck). Above 64 GPUs: diminishing returns unless using advanced techniques (gradient compression, asynchronous training).

Model Parallelism: When Models Don't Fit on One GPU

Model parallelism splits the model across GPUs — each GPU holds a portion of the model's layers. Required when: the model's parameters exceed one GPU's memory (a 70B parameter LLM requires ~140GB in FP16 — no single GPU has 140GB). Types: tensor parallelism (individual layers split across GPUs — a large matrix multiplication is divided and computed in parallel), pipeline parallelism (layers assigned to different GPUs — GPU 1 runs layers 1-10, GPU 2 runs layers 11-20, input flows through the pipeline), and fully sharded data parallelism (FSDP) (model parameters, gradients, and optimizer states sharded across GPUs — each GPU holds only a fraction of the full state, gathering on demand). FSDP (PyTorch) and DeepSpeed ZeRO (Microsoft) are the standard approaches for training large models. For fine-tuning LLMs: FSDP with 4-8 GPUs handles 7-13B parameter models. For training from scratch: pipeline parallelism with 64-256+ GPUs for 70B+ models.

Model Optimization for Inference

A model optimized for training accuracy is not optimized for serving speed. Inference optimization techniques: quantization (reduce numerical precision: FP32 → FP16 → INT8 → INT4. Each reduction: ~2x speedup with minimal accuracy loss. INT8 quantization: 4x smaller model, 2-4x faster inference, typically under 1% accuracy degradation), pruning (remove model weights that contribute minimally — 30-60% of weights can be pruned with under 2% accuracy loss, producing a smaller, faster model), distillation (train a small "student" model to mimic a large "teacher" — the student achieves 90-95% of teacher accuracy at 5-10x the speed), ONNX conversion (convert from PyTorch/TensorFlow to ONNX format — hardware-agnostic inference with runtime optimizations), and TensorRT optimization (NVIDIA's inference optimizer: kernel fusion, memory optimization, and precision calibration — 2-5x speedup on NVIDIA GPUs). The optimization pipeline for production: train in FP32 → quantize to INT8 → convert to ONNX → optimize with TensorRT → benchmark latency and accuracy → deploy.

Production Serving Architecture

ML model serving in production: batch serving (model runs on all records periodically — nightly churn scores, weekly demand forecasts. Architecture: Spark job or scheduled pipeline that loads model, scores all records, writes results to the data warehouse. Simple, cost-effective, no uptime requirement), real-time serving (model serves predictions per-request via REST API — fraud detection, recommendations, dynamic pricing. Architecture: model deployed in a serving framework (TorchServe, Triton, TensorFlow Serving) behind a load balancer on Kubernetes. Requirements: sub-100ms latency, 99.9% uptime, auto-scaling for traffic spikes), and streaming serving (model scores events from a Kafka/Event Hubs stream — IoT sensor scoring, real-time anomaly detection. Architecture: model embedded in a Spark Structured Streaming or Flink application, scoring each event as it arrives). Serving framework selection: Triton Inference Server (NVIDIA — best performance on GPU, supports multiple model formats), TorchServe (PyTorch-native — simplest for PyTorch models), TensorFlow Serving (TF-native — production-proven for TensorFlow models), and Azure ML Managed Endpoints (fully managed — no infrastructure to manage, auto-scaling built in).

Auto-Scaling for ML Inference

ML inference traffic is often spiky — the recommendation model serves 100 requests/second during business hours and 5/second at night. The fraud model handles 50 requests/second normally but 500 during a flash sale. Auto-scaling configuration: metric-based (scale based on: request queue depth, GPU utilization, or request latency — when P95 latency exceeds 100ms, add a replica), schedule-based (known traffic patterns: scale up at 8 AM, scale down at 8 PM — cheaper than reactive scaling for predictable patterns), and scale-to-zero (models that serve infrequent requests scale to 0 replicas between requests — cold start of 5-30 seconds is acceptable for low-frequency predictions). On Kubernetes: Horizontal Pod Autoscaler (HPA) scales replicas based on CPU/memory/custom metrics. KEDA scales based on event-driven triggers (Kafka queue depth, HTTP request count). GPU-aware scheduling ensures pods are scheduled on nodes with available GPU capacity.

Cost Optimization for Training and Serving

OptimizationSavingsApplicable To
Spot/preemptible instances60-80%Training (with checkpointing)
INT8 quantization50-75% (smaller GPU needed)Inference serving
Auto-scaling to zero70-90% (off-hours)Low-traffic inference endpoints
Model distillation80-90% (smaller model)Inference serving (if accuracy acceptable)
Batch inference (vs real-time)50-70%Predictions that don't need real-time

GPU Selection Guide for ML Workloads

GPUMemoryBest ForAzure VMCost/hr
T416GBFine-tuning small models, inferenceNC4as_T4_v3$0.53
A1024GBFine-tuning medium models, inferenceNC24ads_A10_v4$1.80
A100 40GB40GBTraining medium models, fast inferenceNC24ads_A100_v4$3.67
A100 80GB80GBTraining large models, multi-GPUND96asr_v4$27.20 (8x)
H10080GBLarge model training, highest performanceND96isr_H100_v5$32.77 (8x)

Selection rule: Use the cheapest GPU that fits your workload. Fine-tuning a 7B parameter model: T4 or A10 (16-24GB sufficient). Fine-tuning a 13B model: A100 40GB (needs 30GB+ for model + optimizer state). Training from scratch or fine-tuning 70B+: A100 80GB or H100 with multi-GPU parallelism. For inference: quantized models run on T4 (INT8 70B model fits in 16GB) — 10x cheaper than serving on A100. The GPU cost difference between T4 ($0.53/hr) and H100 ($4.10/hr per GPU) is 8x — right-sizing the GPU is the highest-impact cost optimization in ML engineering.

Inference Latency Optimization: From 2 Seconds to 200 Milliseconds

Latency optimization techniques in order of impact: model quantization (FP32 → INT8: 2-4x speedup, ~10ms saved per inference), batch inference (group multiple requests and process simultaneously — throughput 3-5x higher at same latency), continuous batching (vLLM and TensorRT-LLM dynamically batch incoming requests — better than fixed-size batching for variable-length inputs), speculative decoding (small "draft" model generates candidate tokens, large model verifies — 2-3x speedup for autoregressive generation), KV cache optimization (for LLMs: PagedAttention manages the key-value cache efficiently — reducing memory waste 50-70%, enabling more concurrent requests per GPU), and model compilation (torch.compile or TensorRT compile the model graph into optimized kernels — 20-40% speedup for compiled models). For production LLM serving: vLLM (open-source, PagedAttention, continuous batching) or TensorRT-LLM (NVIDIA, highest performance) provide the serving runtime. Both achieve: 10-50x higher throughput than naive PyTorch serving.

Model Serving Benchmarks: What Performance to Expect

Production serving benchmarks for common model types: tabular ML (XGBoost/LightGBM) — inference latency: 1-5ms per prediction. Throughput: 10,000+ predictions/second on CPU. No GPU needed. Serving: Flask/FastAPI + ONNX Runtime. Computer vision (YOLOv8/EfficientNet) — inference latency: 10-50ms per image on GPU (T4). Throughput: 20-100 images/second per GPU. Serving: Triton Inference Server + TensorRT. NLP classification (BERT) — inference latency: 5-20ms per text on GPU. Throughput: 50-200 predictions/second per GPU. Serving: TorchServe or Triton. LLM inference (GPT-4o-mini equivalent) — inference latency: 200-2000ms per response (varies with output length). Throughput: 20-50 concurrent requests per GPU (A100). Serving: vLLM or TensorRT-LLM. These benchmarks guide infrastructure sizing: if the business requires 1,000 fraud scoring requests/second → XGBoost on 1 CPU server handles it. If the business requires 100 concurrent LLM conversations → 2-4 A100 GPUs with vLLM. Over-provisioning wastes money; under-provisioning degrades the user experience.

Mixed Precision Training: Speed Without Sacrifice

Mixed precision training uses FP16 (half precision) for most operations and FP32 (full precision) only where needed (gradient accumulation, loss scaling). Benefits: 2x faster training (FP16 operations are twice as fast on modern GPUs), 50% less GPU memory (FP16 weights are half the size — enabling larger batch sizes or larger models on the same GPU), and negligible accuracy impact (loss scaling prevents the underflow issues that naive FP16 would cause). Implementation: PyTorch AMP (Automatic Mixed Precision) — a single context manager wraps the training loop. Enable with 3 lines of code. Mixed precision is the default for all modern training — there's no reason to train in FP32 unless the model specifically requires it (rare for enterprise ML). For fine-tuning LLMs: QLoRA (Quantized Low-Rank Adaptation) combines: 4-bit quantization of the base model + LoRA adapters trained in BF16. A 70B parameter model fine-tunable on a single A100 GPU — instead of requiring 4x A100s without QLoRA. Cost reduction: 4x for the same training job.

Practical Example: Distributed Training for Enterprise Churn Model

An enterprise with 50M customer records and 200 features trains a churn prediction model. Single GPU (A100 80GB): training time 18 hours. Data parallelism across 4 A100 GPUs: training time 5 hours (3.6x speedup, 80% efficiency after communication overhead). Cost comparison: single GPU 18 hours at $5/hour = $90. Four GPUs 5 hours at $20/hour = $100. The 4-GPU approach costs 11% more but delivers results 13 hours sooner — enabling the data scientist to iterate 3-4x per day instead of once. For hyperparameter tuning (50 experiments): single GPU 900 hours ($4,500). Four GPUs with parallel experiments: 125 hours ($2,500 using spot instances at $12/hour). Distributed training reduces both time and cost when combined with spot instances and efficient experiment scheduling.

Inference Cost Modeling: When to Invest in Optimization

The decision to invest in inference optimization depends on volume: under 1,000 predictions/day — do not optimize. Use the model as-is on a T4 instance. The cost is negligible ($15-30/month). Engineering time spent optimizing costs more than the compute saved. 1,000-100,000 predictions/day — basic optimization: INT8 quantization (2-4x speedup, 2 days engineering effort), ONNX Runtime conversion (1.5-2x speedup, 1 day effort). Total savings: $200-2,000/month. 100,000+ predictions/day — full optimization: quantization plus pruning plus distillation plus TensorRT compilation. Dedicated serving infrastructure with auto-scaling. Total savings: $5,000-50,000/month. The optimization investment pays for itself when monthly compute savings exceed (engineering hours multiplied by hourly rate) divided by 6 months.

The Xylity Approach

We engineer ML systems with the train-optimize-serve methodology — distributed training for speed, quantization + distillation for efficiency, and auto-scaling serving for cost-effective production deployment. Our ML engineers and AI architects build the infrastructure that makes models fast to train and efficient to serve — because a model that takes 72 hours to train and costs $50/hour to serve isn't production-viable.

Continue building your understanding with these related resources from our consulting practice.

Train Faster, Serve Smarter, Spend Less

Distributed training, inference optimization, auto-scaling serving. ML engineering that makes models production-viable.

Start Your ML Engineering →