Your AI Code Can Be Elegant Too
As the Machine Learning Engineering Manager at an AI-powered SaaS company, I get to peek into the machine learning (ML) code written by data scientists on the team. Oh, and when I’m not doing that, I dabble in the occasional Kaggle competition — although I must admit, I’m more of a casual competitor than a podium contender.
Unfortunately, many data scientists didn’t cut their teeth on coding. Instead, they often come from mathematical or statistical backgrounds, where Python is adopted as an upgrade from R. Some might have an engineering background, but not necessarily in computer science or software engineering. It’s no wonder then that even the brightest minds winning top places in Kaggle competitions sometimes produce effective code yet about as elegant as a spork.
Why do I Care?
My coding journey began as a Computer Engineering student at one of the top universities in Egypt back in 2003. I was lucky enough to start with C++ (not C), which naturally inclined me towards a clean, object-oriented programming (OOP) style.
Fast forward to my grad school days at Virginia Tech (2014–2017), I delved into the realms of machine learning and deep learning, tinkering with Pandas and TensorFlow (yes, even versions 0.12 and 1.0— Google, I’ll take that apology now)
Coming from such a background, I’ve always had a soft spot for clean code. For me, writing clean code has always been like a piece of art that I proudly sign. I know it may sound like a cliché, but I do mean it. And yes, that includes my machine-learning code. At one point, I even took a stab at rewriting DL4J (Deep Learning for Java) in a more object-oriented manner.
What Exactly is a Clean Code Anyway?
Clean code always looks like it was written by someone who cares
― Robert Martin, Clean Code: A Handbook of Agile Software Craftsmanship
When you write code, you are writing for an audience. This audience can be yourself in the future, your team, or both. Clean code ensures readability and maintainability for whoever needs to work with it next.
Clean code becomes more crucial when used in today’s machine learning production platforms like MLFlow, TFX, and Sagemaker, where poorly written code can have significant consequences if it ends up in a production system.
Now, I won’t delve into the nitty-gritty of writing clean code here — there are many resources for that. If you’re keen to dive deeper, I recommend books like Clean Code by Robert Martin and The Pragmatic Programmer by David Thomas and Andrew Hunt. This post focuses on making machine learning code cleaner, more readable, and more enjoyable to maintain.
General Guidelines
Learn your tools
Libraries like NumPy, Pandas, and TensorFlow are powerful enough to do magic with your data in a few lines of code. Always take the time to study the API references and guides to master your tools before trying to do some data gymnastics using for loops on Python lists.
Adhere to Python naming conventions
I know X is a matrix while y is a vector, but Python doesn’t care. From a Python point of view, both are variables. Variables in Python use snake_case (lowercase with underscores) as per PEP-8 (the style guide for Python):
Variable names follow the same convention as function names.
Function names should be lowercase, with words separated by underscores as necessary to improve readability.
Give your variables descriptive names
Instead of using names likedf
, x
, y
, train_ds, val_ds
, and test_ds
, I prefer to call my variables what they are:
df
is thedata
x
andy
are thefeatures
andlabels
train_ds, val_ds
, andtest_ds
are thetraining_data,
validation_data
andtest_data
respectively
Be Precise with your imports
Why import numpy
if all I need is the mean
function? And why import pandas
when all I need is the read_csv
function? Instead, import only what you need.
For instance, if I am creating a Pandas DataFrame from a CSV file, instead of this:
import pandas as pd
df = pd.read_csv(file)
I can be more surgical with my imports:
from pandas import DataFrame, read_csv
data: DataFrame = read_csv(file)
Clean, elegant, and to the point. So, dear data scientists, next time you’re crafting a Pandas data frame from your CSV data, embrace the simplicity.
Unless I am using functions with the same name from different libraries, e.g., load
from json
and load
from pickle
, I don’t need to import the whole package, and even then, I can do something like:
from json import load as load_json
from pickle import load as load_pickle
Use an Integrated Development Environment (IDE)
While Databricks and Google Colab could be popular, they are not the best options for writing production-ready code. Did you know that Visual Studio Code supports Jupyter notebooks too? That means I can write my machine-learning code in a powerful IDE that also offers features like:
- Version control with GitHub
- Code formatting and linting with tools like black, isort, and flake8
- Auto-completion with AI tools like IntelliCode, Gemini, Blackbox, and GitHub Copilot
- Integration with Google Cloud for running code on TPUs
Example: Cats versus Dogs
Now, let’s use as an example this code snippet of a transfer learning example from the TensorFlow documentation (markdown cells and explanatory comments are omitted for clarity):
import numpy as np
import keras
from keras import layers
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
tfds.disable_progress_bar()
train_ds, validation_ds, test_ds = tfds.load(
"cats_vs_dogs",
# Reserve 10% for validation and 10% for test
split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
as_supervised=True, # Include labels
)
print(f"Number of training samples: {train_ds.cardinality()}")
print(f"Number of validation samples: {validation_ds.cardinality()}")
print(f"Number of test samples: {test_ds.cardinality()}")
plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(image)
plt.title(int(label))
plt.axis("off")
resize_fn = keras.layers.Resizing(150, 150)
train_ds = train_ds.map(lambda x, y: (resize_fn(x), y))
validation_ds = validation_ds.map(lambda x, y: (resize_fn(x), y))
test_ds = test_ds.map(lambda x, y: (resize_fn(x), y))
augmentation_layers = [
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
]
def data_augmentation(x):
for layer in augmentation_layers:
x = layer(x)
return x
train_ds = train_ds.map(lambda x, y: (data_augmentation(x), y))
from tensorflow import data as tf_data
batch_size = 64
train_ds = train_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache()
validation_ds = validation_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache()
test_ds = test_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache()
for images, labels in train_ds.take(1):
plt.figure(figsize=(10, 10))
first_image = images[0]
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
augmented_image = data_augmentation(np.expand_dims(first_image, 0))
plt.imshow(np.array(augmented_image[0]).astype("int32"))
plt.title(int(labels[0]))
plt.axis("off")
base_model = keras.applications.Xception(
weights="imagenet", # Load weights pre-trained on ImageNet.
input_shape=(150, 150, 3),
include_top=False,
) # Do not include the ImageNet classifier at the top.
# Freeze the base_model
base_model.trainable = False
inputs = keras.Input(shape=(150, 150, 3))
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(inputs)
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x) # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)
model.summary(show_trainable=True)
model.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()],
)
epochs = 2
print("Fitting the top layer of the model")
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
base_model.trainable = True
model.summary(show_trainable=True)
model.compile(
optimizer=keras.optimizers.Adam(1e-5), # Low learning rate
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()],
)
epochs = 1
print("Fitting the end-to-end model")
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
print("Test dataset evaluation")
model.evaluate(test_ds)
Here are a few areas where we can clean up that code:
- The code imports
numpy
only to useexpand_dims
andarray
- The code imports
matplotlib.pyplot
asplt
, then uses only four functions ofplt
- The code imports
layers
fromkeras
, yet only uses a few layers from it - That relaxed import results in a redundant
keras.layers
prefix for each used layer - The same goes for
keras.optimizers
,keras.losses
, andkeras.metrics
- Variable names are
train_ds
,validation_ds
, andtest_ds
can betraining_data
,validation_data
, andtest_data
respectively - The
batch_size
is a constant, yet defined as a variable - The
epochs
variable is not necessary, since it’s only used in one place - The code can be formatted (
black
) and imports can be sorted (isort
)
Rewriting this code with only those minor improvements, here is how it looks like:
from keras import Model
from keras.applications import Xception
from keras.layers import (
Input,
Dense,
Dropout,
GlobalAveragePooling2D,
RandomFlip,
RandomRotation,
Rescaling,
Resizing,
)
from keras.losses import BinaryCrossentropy
from keras.metrics import BinaryAccuracy
from keras.optimizers import Adam
from matplotlib.pyplot import axis, figure, imshow, subplot, title
from numpy import array, expand_dims
from tensorflow.data import AUTOTUNE
from tensorflow_datasets import disable_progress_bar, load
# Load and split the dataset
disable_progress_bar()
training_data, validation_data, test_data = load(
"cats_vs_dogs",
split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
as_supervised=True, # Include labels
)
print(f"Number of training samples: {training_data.cardinality()}")
print(f"Number of validation samples: {validation_data.cardinality()}")
print(f"Number of test samples: {test_data.cardinality()}")
# Show training data samples
figure(figsize=(10, 10))
for i, (image, label) in enumerate(training_data.take(9)):
ax = subplot(3, 3, i + 1)
imshow(image)
title(int(label))
axis("off")
# Resize images to 150x150
resize = Resizing(150, 150)
training_data = training_data.map(lambda x, y: (resize(x), y))
validation_data = validation_data.map(lambda x, y: (resize(x), y))
test_data = test_data.map(lambda x, y: (resize(x), y))
# Augment training data using RandomFlip and RandomRotation layers
augmentation_layers = [RandomFlip("horizontal"), RandomRotation(0.1)]
def data_augmentation(images):
for layer in augmentation_layers:
images = layer(images)
return images
training_data = training_data.map(lambda x, y: (data_augmentation(x), y))
# Batch the data
BATCH_SIZE = 64
training_data = training_data.batch(BATCH_SIZE).prefetch(AUTOTUNE).cache()
validation_data = validation_data.batch(BATCH_SIZE).prefetch(AUTOTUNE).cache()
test_data = test_data.batch(BATCH_SIZE).prefetch(AUTOTUNE).cache()
# Show the first batch
for images, labels in training_data.take(1):
figure(figsize=(10, 10))
first_image = images[0]
for i in range(9):
ax = subplot(3, 3, i + 1)
augmented_image = data_augmentation(expand_dims(first_image, 0))
imshow(array(augmented_image[0]).astype("int32"))
title(int(labels[0]))
axis("off")
# Prepare the base model
base_model = Xception(
weights="imagenet", # Load weights pre-trained on ImageNet
input_shape=(150, 150, 3),
include_top=False, # Do not include the ImageNet classifier at the top
)
base_model.trainable = False # Freeze the base model
# Create a new model on top of the base model
inputs = Input(shape=(150, 150, 3), name="input")
x = Rescaling(scale=1 / 127.5, offset=-1)(inputs) # scale input from (0, 255) to (-1, 1)
x = base_model(x, training=False)
x = GlobalAveragePooling2D()(x)
x = Dropout(0.2)(x) # Regularize with dropout
outputs = Dense(1)(x)
model = Model(inputs, outputs)
model.summary(show_trainable=True)
# Train the top layer of the model
model.compile(
optimizer=Adam(),
loss=BinaryCrossentropy(from_logits=True),
metrics=[BinaryAccuracy()],
)
model.fit(training_data, epochs=2, validation_data=validation_data)
# Fine-tune the model
base_model.trainable = True # Unfreeze the base model
model.summary(show_trainable=True)
model.compile(
optimizer=Adam(1e-5), # Low learning rate
loss=BinaryCrossentropy(from_logits=True),
metrics=[BinaryAccuracy()],
)
model.fit(training_data, epochs=1, validation_data=validation_data)
# Evaluate the model using test data
model.evaluate(test_data)
In addition to code clarity, one can now look at the import block at the beginning to get a glance at what model is being trained for the problem at hand:
from keras import Model
from keras.applications import Xception
from keras.layers import (
Input,
Dense,
Dropout,
GlobalAveragePooling2D,
RandomFlip,
RandomRotation,
Rescaling,
)
from keras.losses import BinaryCrossentropy
from keras.metrics import BinaryAccuracy
from keras.optimizers import Adam
That is a pre-trained Xception
model with Input
and Rescaling
layers used for input scaling, and GlobalAveragePooling2D
, Dropout
, and Dense
layers used for hidden and output layers on top of the pre-trained model. We also see that the data was augmented using RandomFlip
and RandomRotation
layers and the model was built using an Adam
optimizer, BinaryCrossentropy
loss function, and BinaryAccuracy
metric.
Testing the updated code gives us the same results as before using a more human-readable and maintainable code. This code is ready to run in your Python/Jupyter environment. Try it yourself!
Summary
Incorporating clean coding practices into machine learning projects enhances the code's quality and readability and contributes to more reliable and maintainable systems. Strive for clean, efficient, and elegant code in all your projects, whether small-scale experiments or large production systems. Your future self and your team will thank you.