Transfer Learning With Keras(Resnet-50)

Neural Networks Jul 16, 2021

In continuation to our computer vision blogs, in this tutorial we'll explore the phenomenon of transfer learning and apply it to image classification problems.

By the end of this tutorial you will be able to implement transfer learning to train state of the art deep learning models, on your custom data set.

Github source code and YouTube links for this tutorial is available at the end of the blog. For now, let's start from the basics.

So, What is Transfer Learning?

Simply put, transfer learning is the phenomenon which allows you to transfer what state-of-the-art machine learning models have learnt, and you use it for your custom problem.

The following flow chart represents how transfer learning works in practice. Deep learning model 1 transfers the knowledge it learnt(weights, biases) which can be then used by deep learning model 2.

Flow chart for transfer learning

Okay, but why do we need transfer learning. Why can't we train our own neural networks from scratch?

Most deep learning networks that are used for real world problems, learn around millions of parameters to give a high accuracy.

Now let's say i am working on the problem of face detection. Sure I can train my deep learning model from scratch for this.

But here's the catch:

Any image classification network will learn very similar weights and features in order to identify a face.

So why not use what state of the art models have learnt and use it for our problem?

Models that have been trained on extensive data sets like COCO or ImageNet, are in general good at recognizing objects in images. So it only makes sense to directly use what they have learnt and fine tune it to our purpose.

So in short, transfer learning allows us to reduce massive time and space complexity by using what other state-of-the-art models have learnt.

With the basics out of the way, let's start with implementing the Resnet-50 model to solve an image classification problem.

Step 1: Import all the required libraries

import matplotlib.pyplot as plt
import numpy as np
import PIL
import tensorflow as tf
from tensorflow.keras import layers,Dense,Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam

Step 2: Import your dataset

You can use the same approach for any data you want. For the purpose of this tutorial, we'll be working on the Tensorflow flower classification problem. The data set contains of 5 classes of flowers, for which we will try to build a classifier.

Sample images of 5 classes from tensorflow data set

Execute the following lines to import the data on your workspace.

import pathlib

dataset_url = ""

data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)

data_dir = pathlib.Path(data_dir)

Step 3: Split Your Data

We are going to split our image data into training and validation. With each epoch, our model will get trained on the training subset, while it checks its performance on the validation data at each epoch.

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  image_size=(img_height, img_width),

In the above snippet of code, some things to note are:

  1. We have reformatted the input to a dimension of 180,180. This ensures uniformity across all images. You can change this according to your custom problem.
  2. We mention the validation split as 0.2. This means that 80% of data will reserved for training while 20% will be used for validation.
  3. We are keeping the batch size as 32. If you are working on a system with lower ram configuration, you can reduce the batch size further.
  4. We mention subset as training, which means we are first creating the training subset.

Now, we run the same code again to produce the validation data. The only change here is in the subset attribute.

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  image_size=(img_height, img_width),

Step 4: Visualize your data

We are going to be using the matplotlib library to visualize 6 images in our data.

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for images, labels in trainds.take(1):
  for i in range(6):
    ax = plt.subplot(3, 3, i + 1)

Here is the output that we get after executing the above lines:

6 sample images from the flower classification data set

Step 5: Import your Pre-trained Model

Now, here is where the power of transfer learning comes in. From the keras applications, we can pick and select any of the state-of-the-art models and use it for our problem.

List of available models in keras

For now, we are going to be using ResNet-50, but the same procedure can be used for any other model as well. The following lines of code, allow us to import the model into our workspace

resnet_model = Sequential()

pretrained_model= tf.keras.applications.ResNet50(include_top=False,
for layer in pretrained_model.layers:


Here's few things to note about the above snippet of code:

  1. While importing the ResNet50 class, we mention include_top=False. This ensures that we can add our own custom input and output layers according to our data.
  2. We mention the weights='imagenet'. This means that the Resnet50 model will use the weights it learnt while being trained on the imagenet data.
  3. Finally, we mention layer.trainable= False in the pretrained model. This ensures that the model does not learn the weights again, saving us a lot of time and space complexity.

Now, that we have imported a pre-trained model, we will also add a fully connected and output layer where actual learning can take place.

resnet_model.add(Dense(512, activation='relu'))
resnet_model.add(Dense(5, activation='softmax'))

In the output layer we use the softmax activation function and we have 5 output neurons corresponding to the 5 classes in our data.

Now, to look at your model architecture, just call the summary attribute as shown below


Here is how your model architecture should look like:

Model Summary for Resnet-50
The key point to note over here is that the total number of parameters in the Resnet50 model is 24 million. But the trainable parameters are only 1 million.

That is precisely how transfer learning saves us massive time,space and computational complexity.

Now that our model is ready we simply compile it and train it over 10 epochs for now.


history =, validation_data=val_ds, epochs=10)

Once your model is trained, we move over to the next step which is evaluating the model

Step 6: Model Evaluation

For now we will us the matplotlib library to plot the train and validation accuracy with respect to each epoch. These logs had been stored in the history variable during the time of training.

fig1 = plt.gcf()
plt.title('Model Accuracy')
plt.legend(['train', 'validation'])

You should get a plot, similar to the one shown below:

Training vs validation accuracy

The model does seem to have overfit a little bit, but will talk about measures to prevent over fitting in some other blog. For now since the validation accuracy is good enough( around 90%), we will proceed with the final step of making predictions with our model.

Step 7: Model Inference

To make predictions on any image, we simply have to run a few pre-processing steps to ensure that the images dimensions are same as the one that our model was trained on. We use the opencv library for this.

For now we will run predictions on a sample image of rose from our data.

import cv2
image_resized= cv2.resize(image, (img_height,img_width))

Now to make predictions, we simply call the predict method:


However, when you try to print the predictions, you will receive an array of 5 numbers, since we used the softmax classifier. To get an output label prediction we execute the following code:

print("The predicted class is", output_class)

We get the following output:

Output prediction for sample rose image


In this tutorial, we were able to see the power of transfer learning and how we can use state-of-the-art deep learning models and train it on our custom data set. Stay tuned for more such blogs!

Code Links:
1. You can find source code for this project in my github repo.

2. You can find the video explanation for this tutorial on my YouTube channel:


Nachiketa Hebbar

Computer Vision Engineer at Awiros| YouTuber|

Great! You've successfully subscribed.
Great! Next, complete checkout for full access.
Welcome back! You've successfully signed in.
Success! Your account is fully activated, you now have access to all content.