Track 1: Why Keras Exists and Why It Matters
Understand Keras's origin, philosophy, and the multi-backend revolution that makes Keras 3 unique.
✓The Origin Story
François Chollet created Keras in March 2015 while working at Google. His vision was radical for the time: deep learning should be accessible to everyone, not just PhD researchers with years of framework experience.
Before Keras, building a neural network required extensive boilerplate code — configuring computation graphs, manually managing sessions, and wrestling with low-level tensor operations. Chollet saw this as an unnecessary barrier and set out to create an API that was "designed for human beings, not machines."
His guiding principle: "Being able to go from idea to result with the least possible delay is key to doing good research." This idea — that developer experience is a research accelerator, not a luxury — became the foundation of Keras's design philosophy.
✓Keras Philosophy
Keras is built on four core principles that have guided its development for a decade:
- User-friendliness — Designed for humans, not machines. The API minimizes the number of user actions required for common use cases and provides clear error messages. As Chollet puts it: "Keras follows best practices for reducing cognitive load."
- Modularity — A model is a sequence or graph of standalone, fully configurable modules. Layers, optimizers, loss functions, and regularizers are all independent building blocks that can be combined with as few restrictions as possible.
- Extensibility — New modules are easy to add (as new classes and functions), and existing modules provide ample examples. The ability to easily create new components makes Keras suitable for advanced research.
- Minimalism — Just enough to get the job done, no more. Each module is kept small and simple. Source code is readable. No unnecessary complexity lurking beneath the surface.
These principles explain why Keras consistently ranks as one of the most-loved ML frameworks in developer surveys. The API is predictable: once you learn the patterns, everything else follows logically.
✓The Journey
Keras has gone through three major eras:
Era 1: Standalone Library (2015–2017) — Keras started as an independent library that could run on top of Theano, TensorFlow, or CNTK. This multi-backend design was Keras's original superpower, letting researchers switch engines without changing code.
Era 2: TensorFlow Integration (2017–2023) — Keras was integrated into TensorFlow as tf.keras, becoming TF's official high-level API. This brought massive adoption but created tight coupling. Keras became synonymous with TensorFlow in many developers' minds, and running Keras on PyTorch or JAX wasn't possible.
Era 3: Keras 3 Multi-Backend Rewrite (2023–present) — Chollet and the team performed a complete rewrite, restoring the original multi-backend vision but at a much deeper level. Keras 3 runs on TensorFlow, PyTorch, JAX, and OpenVINO. The split back to standalone was necessary because TensorFlow coupling limited Keras's reach — researchers using PyTorch or JAX were locked out.
# Keras 3 — works on any backend import keras from keras import layers # Same code runs on TF, PyTorch, or JAX model = keras.Sequential([ layers.Dense(128, activation="relu"), layers.Dense(10, activation="softmax"), ])
✓Keras 3's Breakthrough
One API, three backends: TensorFlow, PyTorch, and JAX (plus OpenVINO for inference-only). The latest stable version is Keras 3.11.3.
The breakthrough is the keras.ops namespace — a universal operation layer that implements the full NumPy API plus neural network-specific functions. When you write keras.ops.matmul(x, w), Keras dispatches the call to jax.numpy.matmul, tf.matmul, or torch.matmul depending on the active backend, producing numerically identical results.
"Backend-agnostic" in practice means:
- Write your model once, train it on any framework
- Save a model trained on JAX, load it in TensorFlow for deployment
- Use PyTorch's ecosystem for research, then export for TF Serving in production
- Every built-in Keras layer, metric, loss, and optimizer uses
keras.opsinternally — they're all portable
# keras.ops — the universal operation layer import keras # These work identically on TF, PyTorch, or JAX: x = keras.ops.matmul(a, b) y = keras.ops.nn.softmax(logits) z = keras.ops.image.resize(img, size=(224, 224))
model.predict() using Intel's optimized inference engine.✓When to Choose Keras
Choose Keras when you need:
- Rapid prototyping — Keras provides the fastest path from idea to working model. Build a CNN in 10 lines, not 100.
- Education — The cleanest API for learning deep learning concepts. The code reads almost like pseudocode.
- Production — Battle-tested at Google scale. Export to TF Serving, TFLite, ONNX, or OpenVINO.
- Research flexibility — Multi-backend means you can train on JAX for speed, prototype on PyTorch for debugging, and deploy on TF for serving.
When NOT to choose Keras:
- Extremely custom training loops that need raw framework-specific features (e.g., JAX's
jax.pmapfor fine-grained parallelism) - Tight integration with framework-specific tooling that doesn't have Keras equivalents
- Projects already deeply embedded in a single framework's ecosystem
That said, Keras 3's subclassing API and custom training steps cover most "advanced" needs. You'll rarely hit Keras's ceiling before reaching production-grade results.
✓The Ecosystem
Keras is more than a single library — it's a full ecosystem of tools:
- KerasCV — Computer vision: pretrained backbones, object detection (YOLOv8), segmentation, advanced augmentation (CutMix, MixUp, RandAugment)
- KerasNLP → KerasHub — NLP and beyond. Renamed to KerasHub in September 2024 as a unified model hub. Access BERT, GPT-2, Gemma, Llama, Mistral with
from_preset() - KerasTuner — Hyperparameter search with RandomSearch, BayesianOptimization, and Hyperband strategies
- keras.io — Official site with hundreds of code examples, guides, and API documentation
Real-world users:
| Organization | Use Case |
|---|---|
| Gemma model family, internal ML infrastructure | |
| CERN | Large Hadron Collider particle physics research |
| NASA | Heliophysics and space weather prediction |
| Waymo | Self-driving car perception models |
| YouTube | Recommendation systems |
✓Installation & Setup
Getting started with Keras 3 is straightforward:
# Install Keras pip install keras # Install your preferred backend pip install tensorflow # TensorFlow backend pip install jax jaxlib # JAX backend pip install torch # PyTorch backend
Set your backend before importing Keras:
import os os.environ["KERAS_BACKEND"] = "jax" # Must be set BEFORE import keras import keras # Or check/set programmatically (Keras 3.5+): keras.config.set_backend("torch") print(keras.backend.backend()) # → "torch"
~/.keras/keras.json config file. The default backend is TensorFlow.🧠 Track 1 Quiz
Track 2: Backend Selection and Multi-Backend Thinking
Master Keras 3's multi-backend architecture. Learn how to choose, configure, and leverage TensorFlow, PyTorch, and JAX backends.
✓How Keras 3 Decouples API from Engine
The key insight behind Keras 3 is separating model description from execution. When you define layers, connect them, and specify training — you're describing what to compute. The backend decides how to compute it.
This separation works through keras.ops, which dispatches every operation to the active backend. Under the hood:
keras.ops.matmul(a, b)→jax.numpy.matmul/tf.matmul/torch.matmulkeras.ops.nn.softmax(x)→jax.nn.softmax/tf.nn.softmax/torch.nn.functional.softmax
KerasTensors are backend-agnostic symbolic tensors used during model construction. They describe the computation graph without committing to any framework. When you save a .keras file, it contains no backend-specific operations.
✓Setting the Backend
There are three methods to set the Keras backend, with a clear priority order:
1. Environment variable (highest priority):
import os os.environ["KERAS_BACKEND"] = "jax" import keras # Must come AFTER setting the env var
2. Config file (~/.keras/keras.json):
{{
"backend": "jax",
"floatx": "float32",
"epsilon": 1e-07,
"image_data_format": "channels_last"
}}3. Programmatic (Keras 3.5+):
import keras keras.config.set_backend("torch") # Change backend after import
KERAS_BACKEND="jax", the config file and programmatic calls are ignored at initial import time.✓TensorFlow Backend
TensorFlow is the default backend for Keras 3, requiring TF 2.16.1+. It's the most mature option with the broadest deployment ecosystem.
Strengths:
- Best TPU support for training at scale
- Deployment via TF Serving (production-grade HTTP/gRPC serving)
- TFLite for mobile/edge devices
- TF.js for browser-based inference
- Most extensive documentation and community resources
Good for: Production deployment pipelines, mobile/edge deployment, teams already using TF infrastructure.
os.environ["KERAS_BACKEND"] = "tensorflow" import keras model = keras.Sequential([...]) model.compile(optimizer="adam", loss="sparse_categorical_crossentropy") model.fit(x_train, y_train) # Export for TF Serving model.export("serving_model", format="tf_saved_model")
✓PyTorch Backend
The PyTorch backend requires PyTorch 2.1.0+ and provides seamless integration with the massive PyTorch ecosystem.
Key detail: When using the PyTorch backend, Keras layers become torch.nn.Module instances. This means you can use Keras models inside PyTorch code and vice versa.
Strengths:
- Full interoperability with Hugging Face Transformers, torchvision, torchaudio
- Eager mode debugging with standard Python debugging tools
- Massive research community and pretrained model ecosystem
os.environ["KERAS_BACKEND"] = "torch" import keras from keras import layers model = keras.Sequential([ layers.Dense(128, activation="relu"), layers.Dense(10, activation="softmax"), ]) # Keras model IS a torch.nn.Module print(isinstance(model, torch.nn.Module)) # True
TorchModuleWrapper to wrap existing PyTorch modules for use as Keras layers: keras_layer = keras.utils.TorchModuleWrapper(torch_module)✓JAX Backend
The JAX backend requires JAX 0.4.20+ and is often the fastest option for GPU and TPU training thanks to JIT compilation and XLA.
Strengths:
- JIT compilation via XLA for maximum performance
- Functional transforms:
jax.vmap,jax.grad,jax.pmap - Best support for
keras.distribution(multi-GPU/TPU training) - Stateless API for pure functional programming patterns
os.environ["KERAS_BACKEND"] = "jax" import keras model = keras.Sequential([...]) model.compile(optimizer="adam", loss="mse") # JAX stateless API for functional purity variables = model.variables outputs = model.stateless_call(variables, inputs)
layer.stateless_call(), optimizer.stateless_apply(), and metric.stateless_update_state().✓keras.ops: The Universal Layer
keras.ops implements the full NumPy API plus neural network-specific operations. It's the backbone of Keras 3's multi-backend architecture.
| Namespace | Examples |
|---|---|
keras.ops (NumPy) | matmul, sum, mean, reshape, concatenate, stack, einsum, arange, clip |
keras.ops.nn | softmax, sigmoid, relu, conv, depthwise_conv, binary_crossentropy |
keras.ops.image | resize, crop_images, pad_images, rgb_to_grayscale |
keras.random | normal, uniform, dropout (with SeedGenerator) |
TF → Keras ops migration:
| TensorFlow | Keras 3 |
|---|---|
tf.reduce_sum | keras.ops.sum |
tf.concat | keras.ops.concatenate |
tf.range | keras.ops.arange |
tf.reduce_mean | keras.ops.mean |
tf.gather | keras.ops.take |
tf.clip_by_value | keras.ops.clip |
✓What Changes Across Backends
Understanding what stays the same and what changes across backends is crucial:
What stays the SAME:
- Model definition (Sequential, Functional, Subclassing)
- Training loop (
model.compile()+model.fit()) - Callbacks (EarlyStopping, ModelCheckpoint, etc.)
- Saving format (
.keras) - All
keras.opsoperations
What CHANGES:
| Aspect | TensorFlow | PyTorch | JAX |
|---|---|---|---|
| Tensor type | tf.Tensor | torch.Tensor | jax.Array |
| Device mgmt | tf.device | .to(device) | jax.devices() |
| Layers are | Keras layers | torch.nn.Module | Keras layers |
| JIT compilation | tf.function | torch.compile | jax.jit (default) |
✓Practical: Same Model on All Three Backends
Here's the power of Keras 3 — the exact same code runs identically on any backend:
import keras from keras import layers # This model definition is 100% backend-agnostic def build_model(): inputs = keras.Input(shape=(784,)) x = layers.Dense(256, activation="relu")(inputs) x = layers.Dropout(0.3)(x) x = layers.Dense(128, activation="relu")(x) x = layers.Dropout(0.3)(x) outputs = layers.Dense(10, activation="softmax")(x) return keras.Model(inputs, outputs) model = build_model() model.compile( optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"], ) # Train on any backend — same result model.fit(x_train, y_train, epochs=10, validation_split=0.2) # Save — cross-backend compatible! model.save("my_model.keras") # Load this on ANY backend — TF, PyTorch, JAX
The .keras format is the key to cross-backend portability. Save from one backend, load in another — the model, weights, and optimizer state all transfer.
🧠 Track 2 Quiz
Track 3: Your First Keras Model — Sequential API
Build your first neural network with Keras's simplest model type: the Sequential API.
✓What is Sequential?
The Sequential model is the simplest way to build a neural network in Keras: a linear stack of layers where data flows from first to last, one step at a time.
When it works: Single input tensor, single output tensor, no branching, no skip connections, no layer sharing. If your model is a straight pipeline, Sequential is your friend.
import keras from keras import layers # Method 1: Pass a list of layers model = keras.Sequential([ keras.Input(shape=(784,)), layers.Dense(128, activation="relu"), layers.Dense(64, activation="relu"), layers.Dense(10, activation="softmax"), ]) # Method 2: Incremental add() model = keras.Sequential() model.add(keras.Input(shape=(784,))) model.add(layers.Dense(128, activation="relu")) model.add(layers.Dense(10, activation="softmax"))
✓Adding Layers
The most common layers in a Sequential model:
Dense(units, activation)— Fully connected layer. The workhorse of neural networks.Dropout(rate)— Randomly sets input units to 0 during training to prevent overfitting.Flatten()— Reshapes multi-dimensional input into 1D (e.g., 28×28 image → 784 vector).Activation(fn)— Applies an activation function. Usually specified inline:Dense(64, activation="relu").
Critical: The first layer needs shape information. Use keras.Input(shape=(...)) as the first element to tell Keras the expected input dimensions.
model = keras.Sequential([
keras.Input(shape=(28, 28, 1)), # 28x28 grayscale images
layers.Flatten(), # → (784,)
layers.Dense(128, activation="relu"), # → (128,)
layers.Dropout(0.2), # Regularization
layers.Dense(10, activation="softmax"), # → (10,) probabilities
])✓Understanding Layer Arguments
Let's break down the most important arguments for Dense and other layers:
| Argument | Type | Description |
|---|---|---|
units | int | Output dimensionality — how many neurons in this layer |
activation | str/fn | Activation function: 'relu', 'sigmoid', 'softmax', 'tanh', None |
input_shape | tuple | Shape of input (only needed for first layer, prefer keras.Input) |
use_bias | bool | Whether to add a bias term (default: True) |
kernel_initializer | str | Weight initialization: 'glorot_uniform' (default), 'he_normal', etc. |
Common activations and when to use them:
- ReLU — Default for hidden layers. Simple, fast, works well in practice.
- Sigmoid — Output layer for binary classification (0 or 1).
- Softmax — Output layer for multi-class classification (probabilities sum to 1).
- Tanh — Output range [-1, 1]. Occasionally used in hidden layers.
- None / Linear — Output layer for regression (no transformation).
✓model.summary()
model.summary() prints the model architecture — it's your primary debugging tool for understanding what your model looks like:
model.summary() # Output: # ┏━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ # ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ # ┡━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ # │ dense (Dense) │ (None, 128) │ 100,480 │ # │ dropout (Dropout) │ (None, 128) │ 0 │ # │ dense_1 (Dense) │ (None, 10) │ 1,290 │ # └───────────────────────┴──────────────────┴────────────┘
How parameters are calculated: A Dense layer with input size n and output size m has n × m + m parameters (weights + biases). For example, Dense(128) with input shape (784,) = 784 × 128 + 128 = 100,480 params.
✓Building a Classifier: MNIST
Let's build a complete MNIST digit classifier from scratch — the "Hello World" of deep learning:
import keras from keras import layers # 1. Load data (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() # 2. Normalize pixel values to [0, 1] x_train = x_train.astype("float32") / 255.0 x_test = x_test.astype("float32") / 255.0 # 3. Build model model = keras.Sequential([ keras.Input(shape=(28, 28)), layers.Flatten(), layers.Dense(128, activation="relu"), layers.Dropout(0.2), layers.Dense(10, activation="softmax"), ]) # 4. Compile model.compile( optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"], ) # 5. Train model.fit(x_train, y_train, epochs=5, validation_split=0.1) # 6. Evaluate test_loss, test_acc = model.evaluate(x_test, y_test) print(f"Test accuracy: {{test_acc:.4f}}") # ~97.5%
✓Common Mistakes
Watch out for these frequent Sequential model pitfalls:
- Forgetting input shape — Without
keras.Input(shape=...), the model won't build until it sees data. Add it explicitly for clarity. - Wrong output activation — Using
sigmoidfor multi-class (10 classes) instead ofsoftmax. Sigmoid is for binary classification only. - Mismatched loss function —
sparse_categorical_crossentropyexpects integer labels (0, 1, 2...).categorical_crossentropyexpects one-hot encoded labels ([0,0,1,0...]). - Not normalizing input data — Raw pixel values (0–255) make training unstable. Always normalize to [0, 1] or standardize to zero mean / unit variance.
- Too many or too few layers — Start simple. A 2-3 layer Dense network handles most tabular/image classification. Add complexity only when needed.
✓When to Graduate Beyond Sequential
Sequential is great for learning and simple models, but real-world architectures often need more. It's time to move on when you need:
- Multiple inputs — Text + image, metadata + time series
- Multiple outputs — Classification + regression heads
- Skip connections — ResNet-style residual blocks
- Shared layers — Same weights applied to different inputs (Siamese networks)
- Non-linear topology — Any architecture that isn't a straight line
The answer is the Functional API (Track 4) — Keras's most versatile model-building approach.
🧠 Track 3 Quiz
Track 4: Functional API — Flexible Model Architectures
Build complex model architectures with multiple inputs, outputs, skip connections, and shared layers.
✓Why Functional API?
Real-world models aren't linear pipelines. They branch, merge, and loop. The Functional API builds directed acyclic graphs (DAGs) of layers, supporting:
- Multiple inputs — Combine text + image + metadata
- Multiple outputs — Predict class label + confidence score simultaneously
- Skip connections — ResNet's residual blocks that skip over layers
- Shared layers — Same weights applied to different inputs (Siamese networks for similarity)
- Nested models — Use one model as a layer inside another
The Functional API is the most commonly used model-building approach in production Keras code. It's more flexible than Sequential while remaining declarative and serializable.
✓The Pattern
The Functional API follows a simple three-step pattern:
- Define inputs with
keras.Input(shape=...) - Chain layers by calling them on tensors:
x = Layer()(x) - Create model with
keras.Model(inputs, outputs)
import keras from keras import layers # Step 1: Define input inputs = keras.Input(shape=(784,)) # Step 2: Chain layers x = layers.Dense(256, activation="relu")(inputs) x = layers.Dropout(0.3)(x) x = layers.Dense(128, activation="relu")(x) outputs = layers.Dense(10, activation="softmax")(x) # Step 3: Create model model = keras.Model(inputs=inputs, outputs=outputs, name="my_classifier") model.summary()
Notice the pattern: each layer is called on the previous output. The parentheses are doing double duty — first creating the layer, then calling it on the input.
✓Multiple Inputs and Outputs
One of the Functional API's most powerful features: combining different data types and producing multiple predictions.
# Multi-input, multi-output ticket classification title_input = keras.Input(shape=(100,), name="title") body_input = keras.Input(shape=(500,), name="body") tags_input = keras.Input(shape=(12,), name="tags") # Process each branch title_features = layers.Dense(64, activation="relu")(title_input) body_features = layers.Dense(128, activation="relu")(body_input) tags_features = layers.Dense(32, activation="relu")(tags_input) # Merge branches x = layers.concatenate([title_features, body_features, tags_features]) x = layers.Dense(128, activation="relu")(x) # Multiple outputs priority = layers.Dense(3, activation="softmax", name="priority")(x) department = layers.Dense(5, activation="softmax", name="department")(x) model = keras.Model( inputs=[title_input, body_input, tags_input], outputs=[priority, department], )
✓Skip Connections and Residual Blocks
Residual connections (skip connections) are the pattern behind ResNet — one of the most influential architectures in deep learning. The idea: let information skip over layers, making it easier for very deep networks to train.
# A simple residual block def residual_block(x, filters): shortcut = x # Save the input x = layers.Dense(filters, activation="relu")(x) x = layers.Dense(filters)(x) # Add the skip connection x = layers.add([x, shortcut]) x = layers.Activation("relu")(x) return x # Use it in a model inputs = keras.Input(shape=(128,)) x = layers.Dense(128, activation="relu")(inputs) x = residual_block(x, 128) x = residual_block(x, 128) outputs = layers.Dense(10, activation="softmax")(x) model = keras.Model(inputs, outputs)
The layers.add([x, shortcut]) performs element-wise addition. This requires both tensors to have the same shape — if they don't, add a 1×1 convolution or Dense layer to the shortcut to match dimensions.
✓Shared Layers and Nested Models
Shared layers apply the same weights to different inputs. This is essential for Siamese networks (comparing similarity between two inputs) and encoder-decoder architectures.
# Shared embedding for a Siamese network shared_embedding = layers.Dense(64, activation="relu", name="shared_embed") input_a = keras.Input(shape=(128,)) input_b = keras.Input(shape=(128,)) # Same weights used for both inputs encoded_a = shared_embedding(input_a) encoded_b = shared_embedding(input_b) # Compute distance distance = layers.Lambda( lambda x: keras.ops.abs(x[0] - x[1]) )([encoded_a, encoded_b]) output = layers.Dense(1, activation="sigmoid")(distance) model = keras.Model(inputs=[input_a, input_b], outputs=output)
Nested models: Any Keras Model can be called like a layer inside another model:
# Use an existing model as a layer encoder = keras.Model(encoder_inputs, encoded, name="encoder") decoder = keras.Model(decoder_inputs, decoded, name="decoder") # Nest them inputs = keras.Input(shape=(784,)) z = encoder(inputs) # Model called like a layer outputs = decoder(z) autoencoder = keras.Model(inputs, outputs)
✓Inspecting the Graph
The Functional API creates a visible graph you can inspect and debug:
# Print architecture model.summary() # Generate a diagram (requires pydot/graphviz) keras.utils.plot_model(model, show_shapes=True, show_layer_names=True) # Inspect intermediate outputs for debugging intermediate_model = keras.Model( inputs=model.input, outputs=model.get_layer("dense_2").output, ) intermediate_output = intermediate_model.predict(sample_data)
✓Practical: Multi-Input Image + Metadata Model
Combine a CNN branch for images with a Dense branch for metadata:
# Image branch image_input = keras.Input(shape=(64, 64, 3), name="image") x = layers.Conv2D(32, 3, activation="relu")(image_input) x = layers.MaxPooling2D(2)(x) x = layers.Conv2D(64, 3, activation="relu")(x) x = layers.GlobalAveragePooling2D()(x) image_features = layers.Dense(64, activation="relu")(x) # Metadata branch meta_input = keras.Input(shape=(10,), name="metadata") meta_features = layers.Dense(32, activation="relu")(meta_input) # Merge and predict combined = layers.concatenate([image_features, meta_features]) x = layers.Dense(64, activation="relu")(combined) x = layers.Dropout(0.3)(x) output = layers.Dense(5, activation="softmax")(x) model = keras.Model( inputs=[image_input, meta_input], outputs=output, )
🧠 Track 4 Quiz
Track 5: Model Subclassing — Full Control
Take full control of your model's forward pass with subclassing, then combine it with the Functional API for the best of both worlds.
✓When to Subclass
Model subclassing gives you full control over the forward pass. Use it when:
- The forward pass depends on input data dynamically (e.g., tree-structured networks)
- You need Python control flow (if/else, loops) in the forward pass
- Research experiments that push beyond standard architectures
- Custom forward logic that can't be expressed as a static DAG
Important tradeoff: Subclassed models lose some Functional API benefits — they can't be serialized from config alone, inspected as graphs, or have intermediate outputs extracted as easily. The recommended pattern is to create custom layers via subclassing, then compose them using the Functional API.
✓Subclassing keras.Model
Create a custom model by inheriting from keras.Model:
import keras from keras import layers class MyModel(keras.Model): def __init__(self): super().__init__() self.dense1 = layers.Dense(128, activation="relu") self.dropout = layers.Dropout(0.3) self.dense2 = layers.Dense(10, activation="softmax") def call(self, inputs, training=False): x = self.dense1(inputs) x = self.dropout(x, training=training) return self.dense2(x) model = MyModel() model.compile(optimizer="adam", loss="sparse_categorical_crossentropy") model.fit(x_train, y_train, epochs=5)
__init__() creates the layers. call() defines the forward pass. The training argument controls behavior of Dropout and BatchNormalization.
✓Subclassing keras.layers.Layer
Custom layers are the more common and recommended use of subclassing. Override build() for lazy weight creation and call() for the forward pass.
class MyDense(keras.layers.Layer): def __init__(self, units, **kwargs): super().__init__(**kwargs) self.units = units def build(self, input_shape): # Lazy weight creation — called on first use self.w = self.add_weight( shape=(input_shape[-1], self.units), initializer="glorot_uniform", trainable=True, name="kernel", ) self.b = self.add_weight( shape=(self.units,), initializer="zeros", trainable=True, name="bias", ) def call(self, inputs): return keras.ops.matmul(inputs, self.w) + self.b def get_config(self): config = super().get_config() config.update({{"units": self.units}}) return config
build() is called once when the layer first sees data — it allows lazy weight creation based on actual input shapes. get_config() enables serialization.
✓The call() Method
The call() method is where the computation happens. Key arguments:
training— Boolean that controls behavior of Dropout (active during training, inactive during inference) and BatchNormalization (uses batch vs. moving statistics).mask— Optional boolean tensor for masking timesteps in sequences.
class ConditionalLayer(keras.layers.Layer): def call(self, inputs, training=False): # Use keras.ops for backend-agnostic code! x = keras.ops.relu(inputs) if training: x = keras.ops.nn.dropout(x, rate=0.5) return x
keras.ops instead of raw framework operations (tf.nn.relu, torch.relu, etc.) inside custom layers to maintain backend portability.✓Mixing Functional and Subclass
The recommended pattern: create custom layers via subclassing, then compose them using the Functional API. This gives you the best of both worlds.
# Custom layer class ResidualBlock(keras.layers.Layer): def __init__(self, filters, **kwargs): super().__init__(**kwargs) self.dense1 = layers.Dense(filters, activation="relu") self.dense2 = layers.Dense(filters) self.add_layer = layers.Add() self.activation = layers.Activation("relu") def call(self, inputs): x = self.dense1(inputs) x = self.dense2(x) x = self.add_layer([x, inputs]) return self.activation(x) # Use custom layer in Functional API inputs = keras.Input(shape=(64,)) x = layers.Dense(64, activation="relu")(inputs) x = ResidualBlock(64)(x) # Custom layer used like any built-in x = ResidualBlock(64)(x) outputs = layers.Dense(10, activation="softmax")(x) model = keras.Model(inputs, outputs)
✓Custom Training Step
Override train_step() for custom training logic while still using model.fit():
class CustomModel(keras.Model): def train_step(self, data): x, y = data # Forward pass with gradient tracking y_pred = self(x, training=True) loss = self.compute_loss(y=y, y_pred=y_pred) # Compute and apply gradients gradients = self.optimizer.compute_gradients(loss, self.trainable_variables) self.optimizer.apply(gradients) # Update metrics for metric in self.metrics: if metric.name == "loss": metric.update_state(loss) else: metric.update_state(y, y_pred) return {{m.name: m.result() for m in self.metrics}}
This keeps model.fit()'s progress bar, callbacks, and validation — while giving you full control over what happens each step.
✓Practical: Custom Transformer Block
Build a Transformer block as a subclassed Layer, then use it in a Functional model:
class TransformerBlock(keras.layers.Layer): def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1, **kwargs): super().__init__(**kwargs) self.att = layers.MultiHeadAttention( num_heads=num_heads, key_dim=embed_dim ) self.ffn = keras.Sequential([ layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim), ]) self.norm1 = layers.LayerNormalization(epsilon=1e-6) self.norm2 = layers.LayerNormalization(epsilon=1e-6) self.dropout1 = layers.Dropout(rate) self.dropout2 = layers.Dropout(rate) def call(self, inputs, training=False): attn_output = self.att(inputs, inputs) attn_output = self.dropout1(attn_output, training=training) out1 = self.norm1(inputs + attn_output) ffn_output = self.ffn(out1) ffn_output = self.dropout2(ffn_output, training=training) return self.norm2(out1 + ffn_output)
🧠 Track 5 Quiz
Track 6: Layers — The Building Blocks
Explore Keras's rich layer library: core, convolutional, recurrent, attention, normalization, and preprocessing layers.
✓Core Layers
These are the fundamental building blocks you'll use in almost every model:
| Layer | Purpose | Key Args |
|---|---|---|
Dense | Fully connected | units, activation |
Embedding | Integer → dense vector lookup | input_dim, output_dim, mask_zero |
Masking | Marks timesteps as padding | mask_value=0.0 |
Lambda | Wraps arbitrary operation | function |
Input | Placeholder tensor | shape, dtype, name |
EinsumDense | Generalized dense via einsum | equation, output_shape |
# Embedding layer: integer tokens → dense vectors embedding = layers.Embedding( input_dim=10000, # Vocabulary size output_dim=128, # Embedding dimension mask_zero=True, # Treat 0 as padding ) # Input shape: (batch, seq_len) of integers # Output shape: (batch, seq_len, 128) of floats
✓Convolutional Layers
Convolutional layers detect local patterns in spatial/temporal data:
| Layer | Use Case | Key Detail |
|---|---|---|
Conv1D | Temporal/text data | 1D sliding window |
Conv2D | Images | 2D sliding window |
Conv3D | Video/volumetric | 3D sliding window |
SeparableConv2D | Efficient images | Depthwise + pointwise (fewer params) |
DepthwiseConv2D | Per-channel conv | No cross-channel mixing |
# Standard Conv2D: 32 filters of 3×3 layers.Conv2D(32, kernel_size=3, activation="relu", padding="same") # SeparableConv2D: same API, ~8x fewer params layers.SeparableConv2D(32, kernel_size=3, activation="relu", padding="same")
Padding: 'valid' = no padding (output shrinks), 'same' = pad to keep output size equal to input. Strides: step size of the sliding window (strides=2 halves spatial dimensions).
✓Pooling and Recurrent Layers
Pooling layers downsample spatial data:
MaxPooling2D(2)— Takes the maximum value in each 2×2 windowAveragePooling2D(2)— Takes the meanGlobalAveragePooling2D()— Averages across all spatial dimensions → single vector per channel. Very common in transfer learning heads.
Recurrent layers process sequences:
# LSTM: Long Short-Term Memory layers.LSTM(64, return_sequences=True) # Output at every timestep layers.LSTM(64) # Output at last timestep only # GRU: Gated Recurrent Unit (simpler, often comparable) layers.GRU(64) # Bidirectional wrapper: processes sequence forward AND backward layers.Bidirectional(layers.LSTM(64))
Set return_sequences=True when stacking RNNs — intermediate layers need the full sequence.
✓Attention and Normalization
MultiHeadAttention is the core of Transformer architectures:
# Multi-head attention (Transformer key component) layers.MultiHeadAttention( num_heads=8, # Number of attention heads key_dim=64, # Dimension of each head flash_attention=True, # Enable Flash Attention (3.7+) )
Normalization layers stabilize training:
| Layer | Normalizes | Best For |
|---|---|---|
BatchNormalization | Per batch (across samples) | CNNs, general training |
LayerNormalization | Per sample (across features) | Transformers, RNNs |
GroupNormalization | Per group of channels | Small batch sizes |
training=False on frozen BatchNorm layers to use the learned moving statistics instead of batch statistics.✓Regularization and Reshaping
Regularization prevents overfitting:
Dropout(0.3)— Randomly zeros out 30% of units during trainingSpatialDropout1D/SpatialDropout2D— Drops entire feature maps (better for conv layers)
Reshaping layers:
Flatten()— Reshapes to 1D (e.g., for Dense layer after Conv layers)Reshape(target_shape)— Arbitrary reshapePermute(dims)— Reorder dimensions
Merging layers for combining branches:
# Merge two branches layers.Concatenate()([branch_a, branch_b]) # Concatenate along axis layers.Add()([branch_a, branch_b]) # Element-wise addition layers.Multiply()([branch_a, branch_b]) # Element-wise multiplication layers.Average()([branch_a, branch_b]) # Element-wise average
✓Preprocessing Layers
Keras includes preprocessing layers that can be placed inside the model — making preprocessing portable with the model itself:
# Numeric preprocessing layers.Rescaling(1.0/255) # Normalize pixel values layers.Normalization() # Standardize features (.adapt() first) # Text preprocessing layers.TextVectorization( max_tokens=10000, output_mode="int", output_sequence_length=200, ) # Category encoding layers.CategoryEncoding(num_tokens=10) layers.StringLookup(vocabulary=["cat", "dog", "bird"])
Preprocessing layers that use .adapt() (like Normalization and TextVectorization) need to see a sample of the data first to learn statistics like mean, variance, or vocabulary.
✓Practical: Custom Attention Layer
Build a simplified attention mechanism from scratch:
class SimpleAttention(keras.layers.Layer): def __init__(self, units, **kwargs): super().__init__(**kwargs) self.units = units def build(self, input_shape): self.W_q = self.add_weight( shape=(input_shape[-1], self.units), name="query_weight" ) self.W_k = self.add_weight( shape=(input_shape[-1], self.units), name="key_weight" ) self.W_v = self.add_weight( shape=(input_shape[-1], self.units), name="value_weight" ) def call(self, inputs): q = keras.ops.matmul(inputs, self.W_q) k = keras.ops.matmul(inputs, self.W_k) v = keras.ops.matmul(inputs, self.W_v) # Scaled dot-product attention scale = keras.ops.sqrt( keras.ops.cast(self.units, dtype="float32") ) scores = keras.ops.matmul(q, keras.ops.transpose(k)) / scale weights = keras.ops.nn.softmax(scores) return keras.ops.matmul(weights, v)
🧠 Track 6 Quiz
Track 7: Training — compile, fit, and the Training Loop
Master Keras's training pipeline: optimizers, loss functions, metrics, callbacks, and monitoring.
✓model.compile()
model.compile() configures the model for training by wiring together three components:
model.compile( optimizer="adam", # or keras.optimizers.Adam(1e-3) loss="sparse_categorical_crossentropy", # or Loss instance metrics=["accuracy"], # list of metric names/instances ) # For regression: model.compile( optimizer=keras.optimizers.Adam(learning_rate=1e-3), loss="mse", metrics=["mae"], )
You can pass string shortcuts ("adam", "mse") or class instances for more control. Always call compile() before fit().
✓Optimizers
The optimizer controls how weights are updated based on gradients:
| Optimizer | Key Feature | When to Use |
|---|---|---|
Adam | Adaptive learning rates | Default starting point |
AdamW | Adam + weight decay | Fine-tuning pretrained models |
SGD | Simple + momentum | Still competitive, good with schedules |
RMSprop | Adaptive per-parameter | RNNs, non-stationary objectives |
Muon | Newton-style updates | New in Keras 3.10, experimental |
Learning rate schedules adjust the LR during training:
# Cosine decay with warmup lr_schedule = keras.optimizers.schedules.CosineDecay( initial_learning_rate=1e-3, decay_steps=10000, alpha=1e-6, # Minimum learning rate warmup_target=1e-3, warmup_steps=1000, ) optimizer = keras.optimizers.Adam(learning_rate=lr_schedule)
✓Loss Functions
The loss function measures how wrong the model's predictions are:
| Loss | Task | Labels Format |
|---|---|---|
CategoricalCrossentropy | Multi-class classification | One-hot: [0,0,1,0] |
SparseCategoricalCrossentropy | Multi-class classification | Integer: 2 |
BinaryCrossentropy | Binary / multi-label | 0 or 1 |
MeanSquaredError | Regression | Continuous values |
Huber | Robust regression | Continuous (outlier-resistant) |
from_logits=True — Use this when your model output is raw scores (no softmax/sigmoid). It's numerically more stable because the loss function applies the activation internally:
# Option A: softmax output + standard loss model.compile(loss="sparse_categorical_crossentropy") # Option B (preferred): no output activation + from_logits model.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True) )
✓Metrics
Metrics track performance during training without affecting the gradient:
model.compile( optimizer="adam", loss="sparse_categorical_crossentropy", metrics=[ "accuracy", keras.metrics.SparseCategoricalAccuracy(name="acc"), keras.metrics.AUC(name="auc"), keras.metrics.Precision(name="precision"), keras.metrics.Recall(name="recall"), keras.metrics.F1Score(name="f1"), ], )
Custom metrics are created by subclassing keras.metrics.Metric and implementing update_state(), result(), and reset_state().
✓model.fit()
model.fit() is the main training method. Here are the key arguments:
history = model.fit( x_train, y_train, epochs=20, # Number of passes through the data batch_size=32, # Samples per gradient update validation_split=0.2, # Use 20% of training data for validation # OR: validation_data=(x_val, y_val), class_weight={{0: 1.0, 1: 5.0}}, # For imbalanced data callbacks=[...], # List of callback instances ) # history.history contains loss/metric values per epoch print(history.history["loss"]) # [0.83, 0.54, 0.41, ...] print(history.history["val_accuracy"]) # [0.72, 0.81, 0.85, ...]
The History object returned by fit() contains the training and validation loss/metrics for each epoch — use it to plot learning curves and diagnose training issues.
✓Callbacks
Callbacks are hooks that execute at various points during training:
callbacks = [
# Stop when val_loss hasn't improved for 5 epochs
keras.callbacks.EarlyStopping(
monitor="val_loss",
patience=5,
restore_best_weights=True, # Go back to best model
),
# Save the best model
keras.callbacks.ModelCheckpoint(
"best_model.keras",
monitor="val_loss",
save_best_only=True,
),
# Reduce LR when stuck
keras.callbacks.ReduceLROnPlateau(
monitor="val_loss",
factor=0.5, # Halve the LR
patience=3,
min_lr=1e-6,
),
# Log to CSV
keras.callbacks.CSVLogger("training_log.csv"),
]
model.fit(x_train, y_train, epochs=50, callbacks=callbacks)Custom callbacks override methods like on_epoch_begin, on_batch_end, on_train_end for custom logic.
✓model.evaluate() and model.predict()
After training, evaluate performance and make predictions:
# Evaluate on test data test_loss, test_acc = model.evaluate(x_test, y_test) print(f"Test accuracy: {{test_acc:.4f}}") # Get predictions (probabilities) predictions = model.predict(x_test) # predictions.shape = (num_samples, num_classes) # Convert probabilities to class labels predicted_classes = keras.ops.argmax(predictions, axis=1)
✓Practical: Full Training Pipeline
Complete example with callbacks, learning curves, and best practices:
import keras from keras import layers # Build model model = keras.Sequential([ keras.Input(shape=(784,)), layers.Dense(256, activation="relu"), layers.Dropout(0.3), layers.Dense(128, activation="relu"), layers.Dropout(0.2), layers.Dense(10, activation="softmax"), ]) # Compile model.compile( optimizer=keras.optimizers.Adam(1e-3), loss="sparse_categorical_crossentropy", metrics=["accuracy"], ) # Callbacks callbacks = [ keras.callbacks.EarlyStopping( monitor="val_loss", patience=5, restore_best_weights=True, ), keras.callbacks.ModelCheckpoint( "best.keras", save_best_only=True, ), keras.callbacks.ReduceLROnPlateau( factor=0.5, patience=3, ), ] # Train history = model.fit( x_train, y_train, epochs=50, batch_size=32, validation_split=0.2, callbacks=callbacks, )
🧠 Track 7 Quiz
Track 8: Custom Training Loops
Go beyond model.fit() with custom training steps, manual loops, and GAN training patterns.
✓Why Custom Loops?
Most of the time, model.fit() is all you need. But certain scenarios require custom training logic:
- GANs — Alternating generator and discriminator training steps
- Multi-task learning — Custom loss weighting or gradients for different tasks
- Research experiments — Gradient manipulation, custom regularization, meta-learning
- Gradient accumulation — Simulating larger batch sizes on limited GPU memory
Keras provides three levels of customization: (1) override train_step() for custom logic inside fit(), (2) write a full manual loop for maximum control, (3) use backend-specific APIs for extreme flexibility.
✓Overriding train_step()
The sweet spot: custom logic while keeping fit()'s progress bar, callbacks, and validation.
class CustomModel(keras.Model): def train_step(self, data): x, y = data # Forward pass + loss y_pred = self(x, training=True) loss = self.compute_loss(y=y, y_pred=y_pred) # Compute and apply gradients grads = self.optimizer.compute_gradients( loss, self.trainable_variables ) self.optimizer.apply(grads) # Update and return metrics for metric in self.metrics: if metric.name == "loss": metric.update_state(loss) else: metric.update_state(y, y_pred) return {{m.name: m.result() for m in self.metrics}} # Still use fit()! model = CustomModel(...) model.compile(optimizer="adam", loss="mse") model.fit(x_train, y_train, epochs=10)
✓Full Manual Loop
For maximum control, write the training loop from scratch. Use keras.ops for backend-agnostic code:
# Backend-agnostic manual training loop optimizer = keras.optimizers.Adam(1e-3) loss_fn = keras.losses.SparseCategoricalCrossentropy() for epoch in range(10): for step, (x_batch, y_batch) in enumerate(train_dataset): # Forward pass with gradient tracking y_pred = model(x_batch, training=True) loss = loss_fn(y_batch, y_pred) # Compute gradients and update weights grads = optimizer.compute_gradients( loss, model.trainable_variables ) optimizer.apply(grads) print(f"Epoch {{epoch}}, Loss: {{loss:.4f}}")
optimizer.compute_gradients() method in Keras 3 abstracts this away.✓GAN Training Pattern
GANs need alternating training of two models. The train_step() override pattern is ideal:
class GAN(keras.Model): def __init__(self, generator, discriminator, latent_dim): super().__init__() self.generator = generator self.discriminator = discriminator self.latent_dim = latent_dim def train_step(self, real_images): batch_size = keras.ops.shape(real_images)[0] noise = keras.random.normal( shape=(batch_size, self.latent_dim) ) # Train discriminator fake_images = self.generator(noise) combined = keras.ops.concatenate([real_images, fake_images]) labels = keras.ops.concatenate([ keras.ops.ones((batch_size, 1)), keras.ops.zeros((batch_size, 1)), ]) d_loss = self._train_discriminator(combined, labels) # Train generator noise = keras.random.normal(shape=(batch_size, self.latent_dim)) misleading_labels = keras.ops.ones((batch_size, 1)) g_loss = self._train_generator(noise, misleading_labels) return {{"d_loss": d_loss, "g_loss": g_loss}}
✓Gradient Accumulation
When GPU memory limits your batch size, accumulate gradients over N steps before applying them:
accumulation_steps = 4 # Effective batch = batch_size × 4 for step, (x, y) in enumerate(dataset): # Compute gradients y_pred = model(x, training=True) loss = loss_fn(y, y_pred) / accumulation_steps grads = optimizer.compute_gradients( loss, model.trainable_variables ) # Accumulate if step % accumulation_steps == 0: accumulated_grads = grads else: accumulated_grads = [ (a + g) for a, g in zip(accumulated_grads, grads) ] # Apply every N steps if (step + 1) % accumulation_steps == 0: optimizer.apply(accumulated_grads)
✓Practical: Custom GAN Training Loop
The full GAN pattern uses train_step() override, keeping callbacks and progress bars. The generator learns to produce realistic images while the discriminator learns to tell real from fake — a minimax game.
Key insights:
- Train discriminator on real + fake images together
- Train generator by trying to fool the discriminator
- Use separate optimizers for each network
- Monitor both losses — they should oscillate, not converge to zero
🧠 Track 8 Quiz
Track 9: Data Loading and Preprocessing
Build efficient data pipelines with preprocessing layers, augmentation, and multiple data source formats.
✓Keras Preprocessing Layers
Keras preprocessing layers can be placed inside the model, making preprocessing portable:
# Normalization: learns mean/variance from data normalizer = layers.Normalization() normalizer.adapt(training_data) # Must see data first # Rescaling: simple linear transform rescaler = layers.Rescaling(1.0 / 255.0) # TextVectorization: text → integer tokens vectorizer = layers.TextVectorization( max_tokens=20000, output_sequence_length=200, ) vectorizer.adapt(text_data) # Put inside the model model = keras.Sequential([ rescaler, # Preprocessing layers.Conv2D(32, 3, activation="relu"), # Model layers layers.GlobalAveragePooling2D(), layers.Dense(10, activation="softmax"), ])
✓Image Augmentation
Augmentation layers apply random transformations during training only, expanding effective dataset size:
# Built-in Keras augmentation layers data_augmentation = keras.Sequential([ layers.RandomFlip("horizontal"), layers.RandomRotation(0.1), # ±10% of full rotation layers.RandomZoom(0.1), # ±10% zoom layers.RandomContrast(0.1), # ±10% contrast ]) # Applied only during training (training=True) # At inference time, augmentation is a no-op
CutMix, MixUp, RandAugment, and MosaicAugmentation for more advanced training strategies.✓Data Pipeline Options
Keras accepts multiple data formats in model.fit():
| Format | Backend | Best For |
|---|---|---|
| NumPy arrays | All | Small datasets that fit in memory |
tf.data.Dataset | All | Large datasets, complex pipelines |
PyTorch DataLoader | All | PyTorch ecosystem integration |
keras.utils.PyDataset | All | Custom data generators |
| Pandas DataFrame | All | Tabular data |
# All of these work with model.fit(): model.fit(numpy_x, numpy_y) # NumPy arrays model.fit(tf_dataset) # tf.data.Dataset model.fit(torch_dataloader) # PyTorch DataLoader model.fit(keras_pydataset) # keras.utils.PyDataset
✓Image Loading
keras.utils.image_dataset_from_directory creates a dataset from a folder structure:
# Folder structure: # data/ # cats/ # cat001.jpg # cat002.jpg # dogs/ # dog001.jpg # dog002.jpg train_ds = keras.utils.image_dataset_from_directory( "data/", image_size=(224, 224), # Resize all images batch_size=32, validation_split=0.2, subset="training", seed=42, label_mode="categorical", # or "int" for integer labels )
This automatically infers labels from subdirectory names and handles batching, shuffling, and resizing.
✓Handling Class Imbalance
When classes are imbalanced (e.g., 95% negative, 5% positive), models tend to just predict the majority class. Three strategies:
# 1. class_weight in model.fit() model.fit(x, y, class_weight={{0: 1.0, 1: 10.0}}) # 2. sample_weight array (per-sample importance) weights = np.where(y_train == 1, 10.0, 1.0) model.fit(x, y, sample_weight=weights) # 3. Oversample minority class in the pipeline # (duplicate minority samples or use SMOTE)
✓Practical: Image Pipeline with Augmentation
Complete pipeline from directory loading to training:
import keras from keras import layers # Data augmentation layers (only active during training) augmentation = keras.Sequential([ layers.RandomFlip("horizontal"), layers.RandomRotation(0.1), layers.RandomZoom(0.1), ]) # Model with preprocessing built in model = keras.Sequential([ keras.Input(shape=(224, 224, 3)), layers.Rescaling(1.0 / 255.0), # Normalize augmentation, # Augment layers.Conv2D(32, 3, activation="relu"), layers.MaxPooling2D(2), layers.Conv2D(64, 3, activation="relu"), layers.MaxPooling2D(2), layers.GlobalAveragePooling2D(), layers.Dense(128, activation="relu"), layers.Dense(5, activation="softmax"), ]) model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"]) model.fit(train_ds, validation_data=val_ds, epochs=20)
🧠 Track 9 Quiz
Track 10: Transfer Learning and Pretrained Models
Leverage pretrained models with feature extraction, fine-tuning, and the KerasHub model hub.
✓What is Transfer Learning?
Transfer learning leverages knowledge from models trained on massive datasets (like ImageNet with 14M+ images) and adapts it to your specific task — even with limited data.
Why it works: Early layers in deep networks learn universal features — edges, textures, color patterns. Later layers learn task-specific features. By reusing the early layers, you start with a foundation of visual understanding rather than learning from scratch.
The result: Instead of needing millions of images and days of training, you can achieve strong results with hundreds of images and minutes of training.
✓Keras Applications
Keras includes 37+ pretrained models out of the box:
| Model | Params | Top-1 Acc | Best For |
|---|---|---|---|
| EfficientNetV2 | 6M–119M | Up to 85.7% | Best accuracy/efficiency tradeoff |
| ResNet50 | 25M | 76.0% | Widely used, well understood |
| MobileNetV3 | 2.5M–5.4M | 75.2% | Mobile/edge deployment |
| ConvNeXt | 28M–350M | Up to 88.5% | Modern CNN architecture |
| VGG16 | 138M | 71.3% | Simple, educational |
# Load pretrained model (without classification head) base_model = keras.applications.EfficientNetV2S( weights="imagenet", include_top=False, # Remove classification layers input_shape=(224, 224, 3), )
✓Feature Extraction
Step 1: Freeze the base and train only the new head.
# Freeze the pretrained weights base_model.trainable = False # Add a custom classification head inputs = keras.Input(shape=(224, 224, 3)) x = base_model(inputs, training=False) # Keep BN in inference mode x = layers.GlobalAveragePooling2D()(x) x = layers.Dropout(0.2)(x) outputs = layers.Dense(5, activation="softmax")(x) model = keras.Model(inputs, outputs) model.compile( optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"], ) # Train only the head (~5-10 epochs) model.fit(train_ds, epochs=10, validation_data=val_ds)
✓Fine-Tuning
Step 2: Unfreeze top layers and retrain with a very low learning rate.
# Unfreeze the top layers of the base model base_model.trainable = True # Freeze everything except the top 20 layers for layer in base_model.layers[:-20]: layer.trainable = False # CRITICAL: Re-compile with a very low learning rate model.compile( optimizer=keras.optimizers.Adam(1e-5), # 100x lower! loss="categorical_crossentropy", metrics=["accuracy"], ) # Continue training (~10-20 more epochs) model.fit(train_ds, epochs=20, validation_data=val_ds)
model.compile() after changing .trainable properties. The optimizer needs to know which variables to update.✓KerasHub
KerasHub is the unified pretrained model hub, providing access to NLP, vision, and audio models:
import keras_hub # Load a pretrained NLP model classifier = keras_hub.models.BertClassifier.from_preset( "bert_base_en", num_classes=4, ) # Available models include: # BERT, GPT-2, Gemma (1/2/3/4), Llama 2/3, # Mistral, T5, Whisper, Qwen, and many more # Fine-tune with LoRA for efficiency classifier.backbone.enable_lora(rank=4) classifier.compile(optimizer="adam", loss="sparse_categorical_crossentropy") classifier.fit(train_ds, epochs=3)
✓The Fine-Tuning Recipe
The proven two-phase approach that prevents catastrophic forgetting:
- Phase 1: Feature extraction — Freeze base model, train only the new head for ~5-10 epochs with normal LR (1e-3)
- Phase 2: Fine-tuning — Unfreeze top layers of base model, retrain everything with very low LR (1e-5) for ~10-20 more epochs
Why two phases? If you fine-tune immediately with a high learning rate, the random initialization of the new head will generate large gradients that destroy the pretrained features. By training the head first, you get it to a reasonable state before touching the base model.
✓Practical: Fine-Tuning EfficientNet
Complete workflow:
import keras from keras import layers # Load pretrained EfficientNetV2S base = keras.applications.EfficientNetV2S( weights="imagenet", include_top=False, input_shape=(224, 224, 3), ) base.trainable = False # Build model inputs = keras.Input(shape=(224, 224, 3)) x = base(inputs, training=False) x = layers.GlobalAveragePooling2D()(x) x = layers.Dense(128, activation="relu")(x) x = layers.Dropout(0.3)(x) outputs = layers.Dense(5, activation="softmax")(x) model = keras.Model(inputs, outputs) # Phase 1: Feature extraction model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"]) model.fit(train_ds, epochs=10, validation_data=val_ds) # Phase 2: Fine-tuning base.trainable = True for layer in base.layers[:-20]: layer.trainable = False model.compile(optimizer=keras.optimizers.Adam(1e-5), loss="categorical_crossentropy", metrics=["accuracy"]) model.fit(train_ds, epochs=20, validation_data=val_ds)
🧠 Track 10 Quiz
Track 11: Saving, Loading, and Exporting
Master Keras's cross-backend saving, model export, and deployment formats.
✓The .keras Format
The .keras format is Keras 3's native and recommended save format:
# Save the complete model model.save("my_model.keras") # Load it back loaded_model = keras.models.load_model("my_model.keras")
What's saved: Architecture (config.json), weights, optimizer state. The config is human-readable JSON, making the format inspectable and debuggable.
✓Legacy Formats
Keras still supports older formats for backwards compatibility:
| Format | Extension | Notes |
|---|---|---|
| .keras | .keras | Recommended. Cross-backend, architecture + weights + optimizer |
| HDF5 | .h5 | Legacy. Weights only or full model. Still widely used |
| SavedModel | directory | TF-specific. Good for TF Serving |
Use .keras for all new projects. Legacy formats are supported for loading old models but shouldn't be the target for new work.
✓Saving/Loading Weights Only
When you only need the trained weights (not the architecture):
# Save weights model.save_weights("weights.weights.h5") # Load weights into an identical architecture new_model = build_model() # Must match original architecture new_model.load_weights("weights.weights.h5") # Weight sharding for large models (Keras 3.10+) model.save_weights("large_model.weights.h5", max_shard_size="2GB")
✓Cross-Backend Saving
This is a Keras 3 superpower: save from one backend, load in another.
# Train on JAX (fast JIT compilation) os.environ["KERAS_BACKEND"] = "jax" model.fit(x_train, y_train, epochs=10) model.save("trained_on_jax.keras") # Later: load on TensorFlow for deployment os.environ["KERAS_BACKEND"] = "tensorflow" model = keras.models.load_model("trained_on_jax.keras") model.export("tf_serving_model", format="tf_saved_model")
The .keras format stores no backend-specific operations — it's truly portable. This enables workflows like training on JAX for speed, then deploying via TF Serving.
✓Exporting for Serving
Keras 3 supports multiple export formats for production deployment:
# TF SavedModel (for TF Serving) model.export("model_dir", format="tf_saved_model") # ONNX (cross-platform inference) model.export("model.onnx", format="onnx") # OpenVINO (Intel optimized inference) model.export("model_ov", format="openvino") # TFLite (mobile/edge) model.export("model.tflite", format="litert")
✓Custom Object Registration
If your model uses custom layers, losses, or metrics, register them for serialization:
@keras.saving.register_keras_serializable(package="my_package") class MyCustomLayer(keras.layers.Layer): def __init__(self, units, **kwargs): super().__init__(**kwargs) self.units = units def get_config(self): config = super().get_config() config.update({{"units": self.units}}) return config # Now model.save() / load_model() works with custom objects
✓Practical: Save → Cross-Backend Load → Export
Full cross-backend workflow:
# Step 1: Train on JAX os.environ["KERAS_BACKEND"] = "jax" import keras model = build_model() model.compile(optimizer="adam", loss="sparse_categorical_crossentropy") model.fit(x_train, y_train, epochs=10) model.save("model.keras") # Step 2: Load on TensorFlow # (In a new script with KERAS_BACKEND="tensorflow") model = keras.models.load_model("model.keras") # Step 3: Export for production model.export("production_model", format="tf_saved_model") # Deploy with TF Serving, TFLite, or TF.js
🧠 Track 11 Quiz
Track 12: KerasCV — Computer Vision
Explore KerasCV's pretrained backbones, YOLOV8 object detection, segmentation, and advanced augmentation.
✓What is KerasCV?
KerasCV is the industry-strength computer vision library built on Keras 3. It provides pretrained backbones, object detection models, segmentation, and advanced augmentation — all supporting multiple backends.
Important: KerasCV is currently migrating into KerasHub. Existing code works, but new models are being added to KerasHub. Install with pip install keras-cv.
✓Pretrained Backbones
KerasCV offers a rich selection of pretrained vision backbones:
| Backbone | Style | Sizes |
|---|---|---|
| EfficientNetV1/V2 | Efficient CNN | B0–B7, S/M/L |
| ResNet | Classic residual | 18/34/50/101/152 |
| MiT (Mix Transformer) | ViT-style | B0–B5 |
| MobileNetV3 | Mobile-optimized | Small/Large |
| DenseNet | Dense connections | 121/169/201 |
| CSPDarkNet | YOLO backbone | Tiny/S/M/L/XL |
import keras_cv # Load a pretrained backbone backbone = keras_cv.models.EfficientNetV2Backbone.from_preset( "efficientnetv2_s_imagenet" )
✓Object Detection: YOLOV8
KerasCV includes YOLOV8 — one of the most popular real-time object detection models:
# Load YOLOV8 with COCO pretrained weights detector = keras_cv.models.YOLOV8Detector.from_preset( "yolo_v8_m_coco", # xs, s, m, l, xl sizes ) # Run inference predictions = detector.predict(images) # predictions["boxes"] → bounding box coordinates # predictions["classes"] → class IDs # predictions["confidence"] → confidence scores
YOLOV8 comes in 5 sizes: xs (extra small, fastest), s (small), m (medium), l (large), xl (extra large, most accurate).
✓Image Segmentation
KerasCV provides segmentation models for pixel-level classification:
- DeepLabV3Plus — Semantic segmentation. Assigns a class label to every pixel. Uses atrous (dilated) convolution for multi-scale feature extraction.
- BASNet — Boundary-aware salient object detection. Focuses on cleanly separating foreground from background.
Both models support pretrained backbones and multi-backend operation.
✓Data Augmentation Layers
KerasCV's augmentation layers go far beyond basic flips and rotations:
| Augmentation | Description |
|---|---|
CutMix | Cuts a random patch from one image and pastes it onto another, mixing labels proportionally |
MixUp | Linear interpolation between two images and their labels |
RandAugment | Applies a random policy of N transforms from a predefined set |
MosaicAugmentation | Combines 4 images into a mosaic (used in YOLO training) |
# CutMix and MixUp augmenter = keras_cv.layers.CutMix(alpha=1.0) augmenter = keras_cv.layers.MixUp(alpha=0.2) # RandAugment augmenter = keras_cv.layers.RandAugment( value_range=(0, 255), augmentations_per_image=3 )
✓Practical: Object Detection Pipeline
Complete YOLOV8 inference pipeline:
import keras_cv # Load pretrained YOLOV8 model = keras_cv.models.YOLOV8Detector.from_preset( "yolo_v8_m_coco", ) # Prepare image image = keras.utils.load_img("photo.jpg", target_size=(640, 640)) image_array = keras.utils.img_to_array(image) image_batch = keras.ops.expand_dims(image_array, axis=0) # Run detection predictions = model.predict(image_batch) # Process bounding boxes, classes, and confidence scores
🧠 Track 12 Quiz
Track 13: KerasNLP — Natural Language Processing
Access pretrained NLP models, tokenizers, text classification, generation, and efficient fine-tuning with LoRA.
✓What is KerasNLP/KerasHub?
KerasNLP was the modular NLP library for Keras. In September 2024, it was renamed to KerasHub as the team unified all pretrained model access under one hub.
The keras_nlp import still works for backwards compatibility, but new code should use keras_hub:
✓Pretrained Models
KerasHub provides access to a wide range of pretrained models:
| Model | Type | Use Case |
|---|---|---|
| BERT | Encoder | Classification, NER, embeddings |
| GPT-2 | Causal LM | Text generation |
| Gemma (1/2/3/4) | Causal LM | Google's open model family |
| Llama 2/3 | Causal LM | Meta's open models |
| Mistral | Causal LM | Efficient open model |
| T5 | Encoder-Decoder | Translation, summarization |
| Whisper | Audio | Speech recognition |
import keras_hub # Load a pretrained classifier classifier = keras_hub.models.BertClassifier.from_preset( "bert_base_en", num_classes=4, ) # Load a text generator gpt2 = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
✓Tokenizers
KerasHub includes built-in tokenizers matched to each model:
# Tokenizers are included with model presets tokenizer = keras_hub.tokenizers.WordPieceTokenizer( vocabulary=vocab_data, ) # Or use the tokenizer that comes with a preset preprocessor = keras_hub.models.BertPreprocessor.from_preset( "bert_base_en", sequence_length=128, )
Available tokenizers: WordPieceTokenizer (BERT), BytePairTokenizer (GPT), SentencePieceTokenizer (T5, Llama). All are TF-free in the latest versions.
✓Text Classification
Fine-tune a pretrained model for text classification:
import keras_hub # Load BERT for classification classifier = keras_hub.models.BertClassifier.from_preset( "bert_base_en", num_classes=2, # Binary sentiment ) # Compile and train classifier.compile( optimizer=keras.optimizers.Adam(5e-5), loss="sparse_categorical_crossentropy", metrics=["accuracy"], ) classifier.fit(train_ds, validation_data=val_ds, epochs=3)
✓Text Generation
Generate text with causal language models:
# Load GPT-2 for text generation gpt2 = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en") # Generate with various sampling strategies output = gpt2.generate( "The future of AI is", max_length=100, ) print(output) # Control generation with temperature, top_k, top_p gpt2.compile(sampler=keras_hub.samplers.TopKSampler(k=50, temperature=0.7))
✓Fine-Tuning with LoRA
LoRA (Low-Rank Adaptation) enables parameter-efficient fine-tuning by adding small trainable adapter layers:
# Enable LoRA on the model backbone classifier.backbone.enable_lora(rank=4) # Check trainable params — dramatically reduced! print(classifier.summary()) # Total params: 110M, Trainable: ~300K (0.3%!) # QLoRA: Quantize first, then LoRA classifier.backbone.quantize("int8") # Reduce model size classifier.backbone.enable_lora(rank=4) # Add adapters # Even less memory, nearly same quality
QLoRA combines quantization with LoRA: first reduce the model to int8 (or int4), then add LoRA adapters. This dramatically cuts memory while maintaining quality.
✓Practical: Fine-Tuning for Classification
Complete text classification workflow with KerasHub:
import keras import keras_hub # Load pretrained classifier classifier = keras_hub.models.BertClassifier.from_preset( "bert_base_en", num_classes=3, ) # Enable LoRA for efficient fine-tuning classifier.backbone.enable_lora(rank=4) # Compile classifier.compile( optimizer=keras.optimizers.AdamW(5e-5), loss="sparse_categorical_crossentropy", metrics=["accuracy"], ) # Train classifier.fit( train_ds, validation_data=val_ds, epochs=3, ) # Save the fine-tuned model classifier.save("my_classifier.keras")
🧠 Track 13 Quiz
Track 14: Advanced Patterns and Best Practices
Master hyperparameter tuning, mixed precision, multi-GPU training, quantization, distillation, and debugging.
✓Hyperparameter Tuning
KerasTuner automates the search for optimal hyperparameters:
import keras_tuner def build_model(hp): model = keras.Sequential([ keras.Input(shape=(784,)), layers.Dense( units=hp.Int("units", min_value=32, max_value=512, step=32), activation="relu", ), layers.Dropout(hp.Float("dropout", 0.0, 0.5, step=0.1)), layers.Dense(10, activation="softmax"), ]) model.compile( optimizer=keras.optimizers.Adam( hp.Float("lr", 1e-4, 1e-2, sampling="log") ), loss="sparse_categorical_crossentropy", metrics=["accuracy"], ) return model # Search strategies tuner = keras_tuner.BayesianOptimization( build_model, objective="val_accuracy", max_trials=20, ) tuner.search(x_train, y_train, epochs=10, validation_split=0.2) best_model = tuner.get_best_models()[0]
Strategies: RandomSearch (simple sampling), BayesianOptimization (Gaussian process guided), Hyperband (bandit-based early stopping — efficient for large search spaces).
✓Mixed Precision Training
Mixed precision uses float16 for computation and float32 for accumulation — roughly 2× speedup on modern GPUs:
# Enable mixed precision globally keras.mixed_precision.set_global_policy("mixed_float16") # Build and train normally — Keras handles the casting model = keras.Sequential([ keras.Input(shape=(784,)), layers.Dense(256, activation="relu"), # Computes in float16 layers.Dense(10), # Raw logits (no activation) ]) # Use from_logits for numerical stability with mixed precision model.compile( optimizer="adam", loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), )
mixed_bfloat16 for TPUs.✓Multi-GPU Training
Keras 3's keras.distribution API makes multi-GPU training straightforward:
# Data parallelism: same model on each GPU, split data devices = keras.distribution.list_devices("gpu") data_parallel = keras.distribution.DataParallel(devices=devices) # Set distribution before building the model keras.distribution.set_distribution(data_parallel) model = build_model() model.compile(optimizer="adam", loss="mse") model.fit(x_train, y_train) # Automatically distributed! # Model parallelism: split model across GPUs device_mesh = keras.distribution.DeviceMesh( shape=(2,), axis_names=["model"], devices=devices ) layout_map = keras.distribution.LayoutMap(device_mesh) layout_map["dense/kernel"] = keras.distribution.TensorLayout(["model", None])
keras.distribution support. TensorFlow and PyTorch support is also available.✓Quantization
Reduce model size and speed up inference with post-training quantization:
# Int8 quantization (4x smaller, faster inference) model.quantize("int8") # Int4 quantization (8x smaller, Keras 3.11+) model.quantize("int4") # Selective quantization (exclude specific layers) model.quantize("int8", type_filter=["Dense"]) # Only Dense layers
Tradeoffs: int8 typically has negligible accuracy loss (<0.5%). int4 may lose 1-3% accuracy but provides maximum compression. Always benchmark on your specific task.
✓Knowledge Distillation
Train a small "student" model to mimic a large "teacher" model's predictions:
class Distiller(keras.Model): def __init__(self, student, teacher, temperature=3.0, alpha=0.1): super().__init__() self.student = student self.teacher = teacher self.temperature = temperature self.alpha = alpha def train_step(self, data): x, y = data # Teacher predictions (soft labels) teacher_pred = self.teacher(x, training=False) # Student predictions student_pred = self.student(x, training=True) # Distillation loss (soft) + student loss (hard) loss = ( self.alpha * hard_loss(y, student_pred) + (1 - self.alpha) * soft_loss(teacher_pred, student_pred) ) ...
The student learns the teacher's soft probability distributions, which contain more information than hard labels (e.g., "this 7 looks a bit like a 1").
✓Debugging Strategies
Common problems and their solutions:
| Problem | Symptom | Solution |
|---|---|---|
| Shape mismatch | ValueError about incompatible shapes | Print shapes at each layer; check model.summary() |
| NaN loss | Loss becomes NaN during training | Lower LR, check data for NaN/inf, add gradient clipping |
| Overfitting | Train acc high, val acc low | More data, add dropout, L2 regularization, augmentation |
| Underfitting | Both train and val acc low | Bigger model, more epochs, lower regularization |
| Slow training | Epochs take too long | Check GPU utilization, prefetch data, mixed precision |
# Gradient clipping to prevent NaN optimizer = keras.optimizers.Adam( learning_rate=1e-3, clipnorm=1.0, # Clip gradient norm )
✓The Keras Philosophy in Practice
After 14 tracks, here's the distilled wisdom:
- Start simple, add complexity only when needed. Sequential → Functional → Subclassing. Don't reach for the complex tool first.
- Prefer Functional API over subclassing. Custom layers (subclassed) + Functional composition = best of both worlds.
- Use built-in components before custom ones. Keras's built-in layers, losses, and metrics are battle-tested and optimized.
- Profile before optimizing. Don't guess where the bottleneck is — measure it.
- Leverage the multi-backend superpower. Train on JAX for speed, deploy on TF for serving, use PyTorch for ecosystem access.
Where to go next: keras.io/examples (hundreds of runnable examples), the Keras GitHub repo (source code is readable!), and the Keras community on GitHub Discussions.