How It Works
This project uses a Deep Learning approach to tag movie posters with multiple genres simultaneously. It is built as part of the CCAI 9028 course.
- Architecture: We utilize a pre-trained ResNet18 model from torchvision.
- Fine-Tuning: The final classification layer was replaced to output probabilities for our specific set of genres.
- Loss Function: Since a movie can have multiple genres (e.g., Action AND Sci-Fi), we use
BCEWithLogitsLossrather than standard CrossEntropy. - Deployment: The model has been extracted and is currently being served via a lightweight FastAPI Python backend inside a Docker container.
Training Script (ccai9028.py)
Below is the complete source code used to train the model inside our Colab environment.
"""ccai9028.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1lXwjdySi2hIId-JvrWzqz7OOAe-kkNwN
# CCAI 9028 -- group 21 group project
## Building an AI model tagging to never seen movies
"""
import torch
import torch.nn as nn
from torchvision import models, transforms
import datasets
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import gc
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
train_transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
val_transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
dataset_name = "skvarre/movie-posters"
dataset = datasets.load_dataset(dataset_name)
print(dataset)
num_samples = 5
random_samples = dataset["train"].shuffle(seed=123).select(range(num_samples))
fig, axes = plt.subplots(1, num_samples, figsize=(20, 8))
for i, example in enumerate(random_samples):
img = example["image"].convert("RGB")
title = example["title"]
raw_genres = example["genres"]
if raw_genres and isinstance(raw_genres[0], dict):
genre_list = [g["name"] for g in raw_genres]
elif raw_genres:
genre_list = raw_genres
else:
genre_list = ["No Genre"]
genres_str = ", ".join(genre_list)
resized_img = img.resize((224, 224))
axes[i].imshow(resized_img)
axes[i].set_title(f"{title}\n[{genres_str}]", fontsize=14, pad=20)
axes[i].axis("off")
plt.tight_layout()
plt.show()
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm # Imported tqdm
all_genres = set()
for item in tqdm(dataset["train"], desc="Extracting Genres"):
genres = item["genres"]
if genres:
if isinstance(genres[0], dict):
for g in genres:
all_genres.add(g["name"])
else:
for g in genres:
all_genres.add(g)
genre_list = sorted(list(all_genres))
genre_to_idx = {genre: i for i, genre in enumerate(genre_list)}
num_classes = len(genre_list)
print(f"Total unique genres: {num_classes}")
print(f"Genres: {genre_list}")
class MoviePosterDataset(Dataset):
def __init__(self, hf_dataset, genre_mapping, transform=None):
self.dataset = hf_dataset
self.genre_mapping = genre_mapping
self.transform = transform
self.num_classes = len(genre_mapping)
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
image = item["image"].convert("RGB")
if self.transform:
image = self.transform(image)
label = torch.zeros(self.num_classes)
raw_genres = item["genres"]
if raw_genres:
for g in raw_genres:
genre_name = g["name"] if isinstance(g, dict) else g
if genre_name in self.genre_mapping:
label[self.genre_mapping[genre_name]] = 1.0
return image, label
train_test_split = dataset["train"].train_test_split(test_size=0.2, seed=42)
train_ds = MoviePosterDataset(
train_test_split["train"], genre_to_idx, transform=train_transform
)
val_ds = MoviePosterDataset(
train_test_split["test"], genre_to_idx, transform=val_transform
)
optimal_batch_size = 256
train_loader = DataLoader(
train_ds, batch_size=optimal_batch_size, shuffle=True, num_workers=0
)
val_loader = DataLoader(
val_ds, batch_size=optimal_batch_size, shuffle=False, num_workers=0
)
print(
f"DataLoaders initialized: {len(train_loader)} training batches, {len(val_loader)} validation batches."
)
import torch
import torch.nn as nn
from torchvision import models
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print("Phase 2 Complete: Model architecture set up successfully!")
print(
f"Final layer structure: {num_ftrs} input features -> {num_classes} output genres."
)
print(f"Model is running on: {device}")
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.auto import tqdm
# ==========================================
# PHASE 3: Loss Function & Optimizer
# ==========================================
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.1, patience=2, verbose=True
)
# ==========================================
# PHASE 4: Training & Validation Loop
# ==========================================
num_epochs = 15 # Increased to 15 for better convergence
best_val_loss = float("inf")
model_save_path = "movie_genre_model.pth"
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Starting training on {device} for {num_epochs} epochs...")
for epoch in range(num_epochs):
# --- Training Phase ---
model.train()
train_loss = 0.0
train_loop = tqdm(
train_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [Train]", leave=False
)
for images, labels in train_loop:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item() * images.size(0)
train_loop.set_postfix(loss=loss.item())
avg_train_loss = train_loss / len(train_loader.dataset)
# --- Validation Phase ---
model.eval()
val_loss = 0.0
with torch.no_grad():
val_loop = tqdm(
val_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [Val]", leave=False
)
for images, labels in val_loop:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item() * images.size(0)
val_loop.set_postfix(loss=loss.item())
avg_val_loss = val_loss / len(val_loader.dataset)
print(
f"Epoch {epoch + 1}/{num_epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}"
)
scheduler.step(avg_val_loss)
if avg_val_loss < best_val_loss:
print(
f"Validation Loss decreased ({best_val_loss:.4f} --> {avg_val_loss:.4f}). Saving model to {model_save_path}..."
)
best_val_loss = avg_val_loss
torch.save(model.state_dict(), model_save_path)
print("Training Complete!")
# ==========================================
# PHASE 5: Inference & Tagging
# ==========================================
def predict_movie_genres(image_path, model, transform, genre_list, threshold=0.5):
"""
Takes an image path, runs it through the trained ResNet,
and returns genres that pass the probability threshold.
"""
try:
img = Image.open(image_path).convert("RGB")
except FileNotFoundError:
print(f"Error: Could not find '{image_path}'. Did you upload it to Colab?")
return
input_tensor = transform(img)
input_batch = input_tensor.unsqueeze(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_batch = input_batch.to(device)
model.eval() # Crucial: puts layers like Dropout/BatchNorm in test mode
with torch.no_grad():
raw_logits = model(input_batch)
probabilities = torch.sigmoid(raw_logits)[0]
predicted_genres = []
for i, prob in enumerate(probabilities):
if prob.item() >= threshold:
predicted_genres.append(f"{genre_list[i]} ({prob.item() * 100:.1f}%)")
plt.figure(figsize=(6, 8))
plt.imshow(img)
plt.axis("off")
if predicted_genres:
title_text = "Predicted Genres:\n" + "\n".join(predicted_genres)
else:
title_text = f"No confidence scores above {threshold * 100}% threshold."
plt.title(title_text, fontsize=12, loc="left", pad=10)
plt.tight_layout()
plt.show()
return predicted_genres
# --- EXECUTION ---
my_poster = "avengers.jpg"
tags = predict_movie_genres(my_poster, model, val_transform, genre_list, threshold=0.5)