Use Keras Pretrained Models With Tensorflow

In my last post (the Simpsons Detector) I’ve
used Keras as my deep-learning package to train and run CNN models. Since Keras is just
an API on top of TensorFlow I wanted to play with the underlying layer and therefore implemented
image-style-transfer
with TF.

Image-style-transfer requires calculation of VGG19’s output on the given images and
since I was familiar with the nice API of Keras and keras.applications, I expected that to work easily.

Well, that’s not quite the case… While I could ‘get things to work’, I was always
confused by inconsistent behavior, weird occasional errors and messy graphs that made me
shamefully admit that I don’t really understand what’s going on.

After spending some time on that, here are 4 tips that I think will make your life
easier if you plan to use Keras pretrained models in your TensorFlow graphs.
I also created my own wrapper to VGG19 to demonstrate that. Feel free to use as it is
or adjust to your needs.

Keras Pretrained Models

Keras comes with some built-in models that implement famous widely-used applications with
their pretrained weights (on common datasets). This allows you to get results pretty fast and easy:

vgg19 = keras.applications.VGG19(weights='imagenet', include_top=False)
imgs = ...  # load images
imgs = ...  # apply VGG preprocessing
keras_output = vgg19.predict(imgs)

The first section in this notebook runs this code on a sample
image I took a couple of years ago in New-Zealand. I’m using the mean() of the
activation map on the last VGG19 layer as a hash for the calculation results. We’ll
compare that later with a second more TF-ish implementation.

keras_output.shape, keras_output.mean()
# => ((1, 6, 9, 512), 1.5227494)

Problems With Keras-TensorFlow Integration

Why would I even want to take a model from one package and run it in another?
I guess there could be many reasons for that, including some psychotic disorders,
but my use-case is much simpler - I wanted to implement an image-style-transfer
model and for that I needed to compute VGG19 outputs on 3 images.
The model I needed is not a straight-forward fit/predict model, so I can’t build
it with Keras only, but on the other hand, I don’t really want to start building
in TF the full network of VGG and having to deal with loading weights.

I was naive at first, and expected something similar to [the functional API of Keras]
(https://keras.io/getting-started/functional-api-guide/)
to just work.
THIS DOESN’T WORK:

input_img = tf.placeholder(tf.float32, (1,200,300,3), name='input_img')
vgg19 = tf.keras.applications.VGG19(weights='imagenet', include_top=False)
output = vgg19(input_img)

img = ... # load and preprocess image
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    output_val = sess.run(output, {input_img: img})

output_val.shape, output_val.mean()
# => ((1, 6, 9, 512), 0.067449108)

There are a few problems with this code but most eye-catching one is the fact that
the mean() of the activation map is not the same like in the ‘pure’ Keras code
from before.

Here are the obvious and hidden problems with just ‘plain-integrating’
Keras models into TensorFlow code:

1. Using the model in a new session

Apparently, as anyone would notice after the first couple of minutes of playing
with this code, after we create the VGG model, we can’t use it in a different
session (like in with tf.Session() as sess: ...). Here is a code to demonstrate
that:

# Let's create the model like before:
vgg19 = tf.keras.applications.VGG19(weights='imagenet', include_top=False)
print(vgg19.get_layer('block1_conv1').get_weights()[0].shape)
print("Everything was loaded as we would expect")

# And now, let's use it in a new session
with tf.Session().as_default():
    try:
        print(vgg19.get_layer('block1_conv1').get_weights()[0].shape)
    except Exception as ex:
        print("EXCEPTION: ", ex)

# => (3, 3, 3, 64)
# => Everything was loaded as we would expect
# => EXCEPTION:  Attempting to use uninitialized value block1_conv1_4/bias 	[[Node: _retval_block1_conv1_4/bias_0_0 = _Retval[T=DT_FLOAT, index=0, _device="/job:localhost/replica:0/task:0/cpu:0"](block1_conv1_4/bias)]]

It’s pretty common to create a graph once and run it in many sessions, but here,
even with a simple use-case we get a weird error. When Keras loads our model with
pretrained weights, it actually runs an tf.assign operation to set the values to
all the weights in the graph. Once we use a new session, this initialization is
gone and TensorFlow is left with uninitialized nodes.

A possible solution would be to create the model in the same session that we’re
using it in (or pass a reference to that session), but that is not always possible.
Another solution is to use model.load_weights(...) in the new session.

My wrapper for VGG (shown at the end) uses something similar to the load_weights()
approach.

2. tf.global_variables_initializer() will destroy pretrained weights

Although implied from the previous section, it’s important to understand that
your weights are variables and will be randomly initialized when calling the
global initializer. So even if you kept the session, but then called tf.global_variables_initializer()
to initialize your other variables - congratulations! you now have a random VGG
model.

The notebook that follows this post shows exactly that. I won’t bring the code
here to keep it shorter.

3. Graphs are created multiple times

Things might work after you understand the first 2 issues, but when you open
tensorboard and look on the graph, you’ll see it’s not as nice as you’d expect.
In the following example, I’m using VGG once to compute output and threfore
expect to see only one ‘VGG block’ in my graph. Instead it looks duplicated:

img = tf.placeholder(tf.float32, (1,224,224,3), name='input_image')
vgg19 = tf.keras.applications.VGG19(weights='imagenet', include_top=False)
output = vgg19(img)

TensorFlow duplicate graph

The cause here is completely my fault, but a one I believe is easy to miss given the
Keras functional API. When I’m instantiating VGG19, it builds a graph. Then, when
I’m applying it on the input tensor, it builds another graph that is connected to
that input. The first graph was never used and therefore is not connected to anything
(Keras created a new input tensor for it). It’s basically just some garbage in the graph.

The solution is to use input_tensor=input parameter to the VGG constructor instead
of the (confusing) Keras way of vgg19(input).

4. Model weights are trainable

Another one that is implied from before but easy to miss due to Keras API is the
fact model weights will also be trained (unless specifically excluded).
Notice that the trainable attribute of the Keras Model has no effect as we’re
not compiling the model with Keras.

Like in previous sections, the notebook shows an example
that ‘proves’ this. I’ve used the sum of a specific layer weights and the sum of
the image variable as indicators to whether they’re changing or not.

In order to handle this, I’ve added to my Keras wrapper the model_weights_tensors
attribute that returns a set of the VGG weights tensors so you can exclude them
from training. A full example is in the notebook, but basically you have to use
optimizer.minimize(..., var_list=VARS_TO_TRAIN).

My VGG19 Wrapper

In order to address all these, and have a re-usable component that I can actually
work with, I’ve wrapped VGG19 with my own short class.
Feel free to use or adjust to your needs.

Code is available here
and also attached to the notebook.

Here is what it basically does:

  • Can be initialized with an input_tensor (otherwise, a placeholder will be created and stored in self.input_tensor)
  • Deals with VGG preprocessing (subtract VGG_MEAN and flips RGB to BGR)
  • Creates a clean graph. Different parts has different name scopes
  • Saves a checkpoint from the session used when loading the model with the
    pretrained weights. Exposes a load_weights() method to restore weights from
    checkpoint
  • Expose all layers’ outputs with __getitem__ access (vgg['block5_pool'] for
    example)

And here is a short example (also demonstrated in the notebook)
and the TensorFlow graph it generates:

from image_style_transfer import VGG19
IMAGE_SHAPE = (1,200,300,3)

my_img = tf.placeholder(tf.float32, IMAGE_SHAPE, name='my_original_image')
vgg = VGG19(image_shape=IMAGE_SHAPE, input_tensor=my_img)

output = tf.identity(vgg['block5_pool'], name='my_output')  # just to create an 'output' node in the graph

My VGG19 Graph

Just for comparison, we can calculate the mean output of block5_pool and compare
to the ‘pure’ Keras approach:

imgs = ...    # load images
with tf.Session() as sess:
    vgg.load_weights()

    output_val = sess.run(output, { my_img: imgs })

output_val.shape, output_val.mean()
# => (1, 6, 9, 512) 1.52275

Exactly the same!

comments powered by Disqus