Fast Neural Style Transfer β Starry Night
This repository contains weights for a Fast Neural Style Transfer network based on Johnson et al. It is trained on the COCO val2017 dataset to instantly apply Vincent van Gogh's The Starry Night style to any input image.
Style Transfer Preview
| Content Image | Stylized Output |
|---|---|
![]() |
![]() |
How to Use Programmatically
You can run inference using the official huggingface_hub utility library. The script automatically downloads your weights file directly from the cloud and applies the necessary ImageNet normalization matching the training routine.
Dependencies
Ensure you have the required packages installed:
pip install torch torchvision pillow huggingface_hub
Inference Script (inference.py)
Save the following code as inference.py. You can run it via terminal with python inference.py your_image.jpg.
import sys
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image
from huggingface_hub import hf_hub_download
# ββ CONFIG βββββββββββββββββββββββββββββββββββββββββββββββββββ
REPO_ID = "Rohanify/Brawnz-StyleTransferSN"
FILENAME = "pytorch_model.bin"
IMG_SIZE = 512
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# ββ NATIVE PYTORCH NETWORK DEFINITION ββββββββββββββββββββββββ
def conv_bn_relu(in_c, out_c, k, stride=1, pad=0):
return nn.Sequential(
nn.ReflectionPad2d(pad),
nn.Conv2d(in_c, out_c, k, stride),
nn.InstanceNorm2d(out_c),
nn.ReLU(inplace=True),
)
class ResBlock(nn.Module):
def __init__(self, c):
super().__init__()
self.block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(c, c, 3),
nn.InstanceNorm2d(c),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(c, c, 3),
nn.InstanceNorm2d(c),
)
def forward(self, x):
return x + self.block(x)
class TransformNet(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
conv_bn_relu(3, 32, 9, pad=4),
conv_bn_relu(32, 64, 3, stride=2, pad=1),
conv_bn_relu(64, 128, 3, stride=2, pad=1),
ResBlock(128), ResBlock(128), ResBlock(128),
ResBlock(128), ResBlock(128),
nn.Upsample(scale_factor=2, mode="nearest"),
conv_bn_relu(128, 64, 3, pad=1),
nn.Upsample(scale_factor=2, mode="nearest"),
conv_bn_relu(64, 32, 3, pad=1),
nn.ReflectionPad2d(4),
nn.Conv2d(32, 3, 9),
nn.Tanh(),
)
def forward(self, x):
return self.net(x)
# ββ LOAD INPUT IMAGE βββββββββββββββββββββββββββββββββββββββββ
if len(sys.argv) < 2:
print("Usage: python inference.py path_to_input_image.jpg")
sys.exit(1)
input_path = sys.argv[1]
output_path = "output_styled.jpg"
transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
img = Image.open(input_path).convert("RGB")
x = transform(img).unsqueeze(0).to(DEVICE)
# ββ SECURE FILE DOWNLOAD & STATE LOAD ββββββββββββββββββββββββ
print("Downloading weights from Hugging Face Hub...")
weights_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
model = TransformNet().to(DEVICE)
model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
model.eval()
print(f"Weights successfully loaded on: {DEVICE}")
# ββ RUN INFERENCE ββββββββββββββββββββββββββββββββββββββββββββ
print("Processing style transfer...")
with torch.no_grad():
out = model(x)
save_image(out[0] * 0.5 + 0.5, output_path)
print(f"Success! Styled image saved to: {output_path}")
- Downloads last month
- 87

