Style Transfer Using PyTorch

Edna Figueira Fernandes
4 min readMay 27, 2020

--

In neural networks, style transfer refers to getting the content of one image and the style of another image and merging them to acquire the desired output. For example, using the photo of a dog to get the content (the dog) and using an image of a painting to get the style (the colors). The output would be an image of a dog with the colors of the painting.

This blog post shows a simple style transfer project using PyTorch. For the project, the pre-trained model VGG19 is going to be used. This model has two sections: the features (convolutional and pooling layers) and the classifier. For this exercise, only the features’ section is going to be used.

The get started, in the cell below, the necessary libraries are imported, the model’s features are initialized to the variable vgg and the model is passed to cuda, so that it trains faster. To check the model’s architecture, one can just print vgg. Also, it is important to freeze the model’s features since the goal is not to optimize the network’s weights.

import numpy as np 
import torch
import matplotlib.pyplot as plt
%matplotlib inline
from PIL import Image
import matplotlib.image as mpimg
from torchvision import transforms, models
import torch.optim as optim
# Extracting the model's features
vgg = models.vgg19(pretrained=True).features
# Freezing the parameters
for param in vgg.parameters():
param.requires_grad_(False)
# Passing the model to cuda
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vgg.to(device)

The function load_image is defined to load the image into the notebook. In the function, the module Image from the PIL library (Python Imaging Library) is used to open the image. The PIL library allows images to be opened, manipulated, or saved using different formats.

Once the image is opened, it is resized; this step ensures that both the content and style images have the same size. The image is then transformed to tensor (this is the data structure that the CNN expects) and normalized (for faster convergence).

Last, since the CNN expects a tensor with the shape batch_size x depth x width x height, the tensor is being unsqueezed so that it adds a batch_size of 1.

def load_image(img_path, size=(400, 592)):

image = Image.open(img_path)

transform = transforms.Compose([transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize((0.5,),(0.5,))])

# CNN expects a tensor with the shape (batch_size, d, w, h)
image = transform(image).unsqueeze(0)

return image
#loading and passing the images to cuda
content = load_image('dog.jpg').to(device)
style = load_image('painting.jpg').to(device)

The convolutional layers from the vgg model that are going to be used for content and style representations are defined in the layers dictionary. Essentially, as the network gets deeper, the input image is transformed into feature maps that focus more on the content rather than on the spatial information. Therefore, the layer ‘conv4_2’ is chosen for content representation. All the other layers are used for style representation.

Now that the layers have been defined, and the images have been transformed into tensors, the content and style features can be extracted using the get_features function.

layers = {'0': 'conv1_1',
'5': 'conv2_1',
'10': 'conv3_1',
'19': 'conv4_1',
'21': 'conv4_2',
'28': 'conv5_1'}
def get_features (image, model):features = {} for key, layer in vgg._modules.items():
image = layer(image)
if key in layers:
features[layers[key]] = image

return features

# extracting the content and style features
content_features = get_features(content, vgg)
style_features = get_features(style, vgg)

The content representation is achieved once the content features are extracted. The style representation, on the other hand, is concerned with the correlations of the features extracted in each layer. These correlations can be determined by calculating the gram matrix, which involves flattening the tensors for each layer and multiplying the flattened tensor by its transpose.

def gram_matrix(tensor):

batch_size, d, h, w = tensor.size()
tensor = tensor.view(d, h*w)
gram = torch.mm(tensor, tensor.t())

return gram
style_grams = {}for layer, tensor in style_features.items():
gram = gram_matrix(tensor)
style_grams[layer] = gram

Now that all the functions have been created, it is time to iterate through the style and content images in order to generate the new image. The cell below shows some of the parameters that were defined in order to run the loop. These parameters can be adjusted depending on the output that the user is trying to achieve.

# here you can try different weights for the desired output
style_weights = {'conv1_1': 0.8,
'conv2_1' : 0.8,
'conv3_1': 0.6,
'conv4_1': 0.6,
'conv5_1' : 0.6}
content_weight = 1
style_weight = 1e6
show_every = 1000
steps = 3000
optimizer = optim.Adam([target], lr=0.003)

Finally, at each iteration, both the content and style losses are estimated. The content loss helps ensure that features of the content and generated images are similar, while the style loss helps ensure that the correlations of the features of the style and the generated images are similar. At each step, the loss is minimized using the Adam optimization function.

target = content.clone().requires_grad_(True).to(device)for i in range(steps):

target_features = get_features(target, vgg)
content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)


style_loss = 0
for layer in style_weights:
target_feature = target_features[layer]
batch_size, d, h, w = target_feature.shape

target_gram = gram_matrix(target_feature)

style_gram = style_grams[layer]

layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)

style_loss += layer_style_loss/ (d*h*w)
total_loss = content_weight * content_loss + style_weight * style_loss


# updating the target image
optimizer.zero_grad()
total_loss.backward()
optimizer.step()

if i % show_every == 0:
print('Total loss: ', total_loss)
image = target.to('cpu').clone().detach()
image = image.numpy().squeeze()
image = image.transpose(1,2,0)
image = image*np.array((0.5,0.5,0.5))+ np.array((0.5,0.5,0.5))
plt.imshow(image)
plt.show()

References:

--

--

No responses yet