Upload 29 files
Browse files- .gitattributes +10 -0
- codes/batch_lcm_eval.py +155 -0
- codes/config.py +43 -0
- codes/dataset.py +74 -0
- codes/datasetSegmentation.py +45 -0
- codes/lcm_train.py +180 -0
- codes/lungSegmentation.ipynb +0 -0
- codes/metrics.py +140 -0
- codes/model.py +37 -0
- codes/modules/__init__.py +3 -0
- codes/modules/diffusion_model_unet.py +2099 -0
- codes/modules/fp16_util.py +237 -0
- codes/modules/logger.py +495 -0
- codes/modules/nn.py +170 -0
- codes/modules/unet.py +894 -0
- codes/modules/unet_2d.py +334 -0
- codes/pytorch_msssim.py +156 -0
- codes/transform.py +28 -0
- codes/vq-gan_eval.py +60 -0
- data/BS/0.png +3 -0
- data/BS/1.png +3 -0
- data/BS/2.png +3 -0
- data/CXR/0.png +3 -0
- data/CXR/1.png +3 -0
- data/CXR/2.png +3 -0
- images/GL-LCM_gif.gif +3 -0
- images/ablation.png +3 -0
- images/comparison.png +3 -0
- images/framework.png +3 -0
- requirements.txt +20 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
data/BS/0.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
data/BS/1.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
data/BS/2.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
data/CXR/0.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
data/CXR/1.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
data/CXR/2.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
images/ablation.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
images/comparison.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
images/framework.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
images/GL-LCM_gif.gif filter=lfs diff=lfs merge=lfs -text
|
codes/batch_lcm_eval.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from config import config
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from dataset import myC2BDataset
|
| 5 |
+
from transform import myTransform
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from diffusers import LCMScheduler
|
| 8 |
+
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import cv2 as cv
|
| 11 |
+
import torch
|
| 12 |
+
import time
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
from monai.utils import set_determinism
|
| 16 |
+
|
| 17 |
+
set_determinism(42)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def eval():
|
| 21 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 设置运行环境
|
| 22 |
+
output_path = os.path.join("lcm_output_bs", "BS")
|
| 23 |
+
masked_output_path = os.path.join("lcm_output_bs", "Masked_BS")
|
| 24 |
+
fusion_output_path = os.path.join("lcm_output_bs", "Fusion_BS")
|
| 25 |
+
|
| 26 |
+
cxr_path = os.path.join("SZCH-X-Rays", "CXR")
|
| 27 |
+
masked_cxr_path = os.path.join("SZCH-X-Rays", "Masked_CXR")
|
| 28 |
+
mask_path = os.path.join("SZCH-X-Rays", "Mask")
|
| 29 |
+
|
| 30 |
+
model = torch.load("masked_lcm-600-2024-12-19-myModel.pth").to(device).eval()
|
| 31 |
+
VQGAN = torch.load("2024-12-12-Mask-SZCH-VQGAN.pth").to(device).eval()
|
| 32 |
+
testset_list = "SZCH.txt"
|
| 33 |
+
myTestSet = myC2BDataset(testset_list, cxr_path, masked_cxr_path, myTransform['testTransform'])
|
| 34 |
+
myTestLoader = DataLoader(myTestSet, batch_size=1, shuffle=False)
|
| 35 |
+
# 设置噪声调度器
|
| 36 |
+
noise_scheduler = LCMScheduler(num_train_timesteps=config.num_train_timesteps,
|
| 37 |
+
clip_sample=config.clip_sample,
|
| 38 |
+
clip_sample_range=config.initial_clip_sample_range_g)
|
| 39 |
+
noise_scheduler.set_timesteps(config.num_infer_timesteps)
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
progress_bar = tqdm(enumerate(myTestLoader), total=len(myTestLoader), ncols=100)
|
| 42 |
+
total_start = time.time()
|
| 43 |
+
for step, batch in progress_bar:
|
| 44 |
+
cxr = batch[0].to(device=device, non_blocking=True).float()
|
| 45 |
+
masked_cxr = batch[1].to(device=device, non_blocking=True).float()
|
| 46 |
+
filename = batch[2][0]
|
| 47 |
+
cxr_copy = np.array(cxr.detach().to("cpu"))
|
| 48 |
+
cxr_copy = np.squeeze(cxr_copy) # HW
|
| 49 |
+
cxr_copy = cxr_copy * 0.5 + 0.5
|
| 50 |
+
cxr_copy *= 255
|
| 51 |
+
cxr_copy = cxr_copy.astype(np.int8)
|
| 52 |
+
|
| 53 |
+
cxr = VQGAN.encode_stage_2_inputs(cxr)
|
| 54 |
+
masked_cxr = VQGAN.encode_stage_2_inputs(masked_cxr)
|
| 55 |
+
|
| 56 |
+
noise = torch.randn_like(cxr).to(device)
|
| 57 |
+
sample = torch.cat((noise, cxr), dim=1).to(device) # BCHW
|
| 58 |
+
masked_sample = torch.cat((noise, masked_cxr), dim=1).to(device) # BCHW
|
| 59 |
+
|
| 60 |
+
for j, t in tqdm(enumerate(noise_scheduler.timesteps)):
|
| 61 |
+
residual = model(sample, torch.Tensor((t,)).to(device).long()).to(device)
|
| 62 |
+
masked_residual = model(masked_sample, torch.Tensor((t,)).to(device).long()).to(device)
|
| 63 |
+
# masked_residual = (1 - config.alpha) * residual + config.alpha * masked_residual
|
| 64 |
+
masked_residual = config.alpha * masked_residual + (1 - config.alpha) * torch.randn_like(
|
| 65 |
+
masked_residual).to(device) / torch.std(masked_residual)
|
| 66 |
+
|
| 67 |
+
noise_scheduler = LCMScheduler(num_train_timesteps=config.num_train_timesteps,
|
| 68 |
+
clip_sample=config.clip_sample,
|
| 69 |
+
clip_sample_range=
|
| 70 |
+
config.initial_clip_sample_range_g
|
| 71 |
+
+ config.clip_rate * j
|
| 72 |
+
)
|
| 73 |
+
noise_scheduler.set_timesteps(config.num_infer_timesteps)
|
| 74 |
+
sample = noise_scheduler.step(residual, t, sample).prev_sample
|
| 75 |
+
|
| 76 |
+
noise_scheduler = LCMScheduler(num_train_timesteps=config.num_train_timesteps,
|
| 77 |
+
clip_sample=config.clip_sample,
|
| 78 |
+
clip_sample_range=
|
| 79 |
+
config.initial_clip_sample_range_l
|
| 80 |
+
+ config.clip_rate * j
|
| 81 |
+
)
|
| 82 |
+
noise_scheduler.set_timesteps(config.num_infer_timesteps)
|
| 83 |
+
masked_sample = noise_scheduler.step(masked_residual, t, masked_sample).prev_sample
|
| 84 |
+
|
| 85 |
+
sample = torch.cat((sample[:, :4], cxr), dim=1) # BCHW
|
| 86 |
+
masked_sample = torch.cat((masked_sample[:, :4], masked_cxr), dim=1).to(device) # BCHW
|
| 87 |
+
if config.output_feature_map:
|
| 88 |
+
bs_show = np.array(sample[:, 0].detach().to("cpu"))
|
| 89 |
+
bs_show = np.squeeze(bs_show) # HW
|
| 90 |
+
bs_show = bs_show * 0.5 + 0.5
|
| 91 |
+
bs_show = np.clip(bs_show, 0, 1)
|
| 92 |
+
|
| 93 |
+
masked_bs_show = np.array(masked_sample[:, 0].detach().to("cpu"))
|
| 94 |
+
masked_bs_show = np.squeeze(masked_bs_show) # HW
|
| 95 |
+
masked_bs_show = masked_bs_show * 0.5 + 0.5
|
| 96 |
+
masked_bs_show = np.clip(masked_bs_show, 0, 1)
|
| 97 |
+
|
| 98 |
+
if not config.use_server:
|
| 99 |
+
cv.imshow("win1", bs_show)
|
| 100 |
+
cv.imshow("win2", masked_bs_show)
|
| 101 |
+
cv.waitKey(1)
|
| 102 |
+
|
| 103 |
+
mask = cv.imread(os.path.join(mask_path, filename), 0)
|
| 104 |
+
mask[mask < 255] = 0
|
| 105 |
+
|
| 106 |
+
bs = VQGAN.decode((sample[:, :4]))
|
| 107 |
+
bs = np.array(bs.detach().to("cpu"))
|
| 108 |
+
bs = np.squeeze(bs) # HW
|
| 109 |
+
bs = bs * 0.5 + 0.5
|
| 110 |
+
bs[cxr_copy == 0] = 0
|
| 111 |
+
|
| 112 |
+
masked_bs = VQGAN.decode((masked_sample[:, :4]))
|
| 113 |
+
masked_bs = np.array(masked_bs.detach().to("cpu"))
|
| 114 |
+
masked_bs = np.squeeze(masked_bs) # HW
|
| 115 |
+
masked_bs = masked_bs * 0.5 + 0.5
|
| 116 |
+
masked_bs[mask > 0] = masked_bs[mask > 0] + np.mean(bs[mask > 0]) - np.mean(masked_bs[mask > 0])
|
| 117 |
+
masked_bs[cxr_copy == 0] = 0
|
| 118 |
+
if not config.use_server:
|
| 119 |
+
cv.imshow("win3", bs)
|
| 120 |
+
cv.imshow("win4", masked_bs)
|
| 121 |
+
cv.waitKey(1)
|
| 122 |
+
|
| 123 |
+
bs *= 255
|
| 124 |
+
cv.imwrite(os.path.join(output_path, filename), bs)
|
| 125 |
+
masked_bs *= 255
|
| 126 |
+
cv.imwrite(os.path.join(masked_output_path, filename), masked_bs)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
num_labels, labels, stats, _ = cv.connectedComponentsWithStats(mask)
|
| 130 |
+
min_area = 100
|
| 131 |
+
for i in range(1, num_labels):
|
| 132 |
+
if stats[i, cv.CC_STAT_AREA] < min_area:
|
| 133 |
+
labels[labels == i] = 0
|
| 134 |
+
mask[labels == 0] = 0
|
| 135 |
+
|
| 136 |
+
br = cv.boundingRect(mask)
|
| 137 |
+
p = (br[0] + br[2] // 2, br[1] + br[3] // 2)
|
| 138 |
+
|
| 139 |
+
masked_bs = np.clip(masked_bs, 0, 255)
|
| 140 |
+
masked_bs = cv.cvtColor(masked_bs, cv.COLOR_GRAY2BGR).astype(np.uint8)
|
| 141 |
+
bs = np.clip(bs, 0, 255)
|
| 142 |
+
bs = cv.cvtColor(bs, cv.COLOR_GRAY2BGR).astype(np.uint8)
|
| 143 |
+
|
| 144 |
+
fusion_bs = cv.seamlessClone(masked_bs, bs, mask, p, cv.MONOCHROME_TRANSFER)
|
| 145 |
+
# cv.rectangle(fusion_bs, br, (0, 255, 0), 2)
|
| 146 |
+
# fusion_bs[mask==255]=(255, 0, 0)
|
| 147 |
+
|
| 148 |
+
cv.imwrite(os.path.join(fusion_output_path, filename), fusion_bs)
|
| 149 |
+
|
| 150 |
+
total_time = time.time() - total_start
|
| 151 |
+
print(f"Total time: {total_time}.")
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
if __name__ == "__main__":
|
| 155 |
+
eval()
|
codes/config.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@dataclass
|
| 5 |
+
class config():
|
| 6 |
+
use_server = True
|
| 7 |
+
test_epoch_interval = 10
|
| 8 |
+
image_size = 1024
|
| 9 |
+
r = 8
|
| 10 |
+
|
| 11 |
+
# VQGAN
|
| 12 |
+
vae_epoch_number = 300
|
| 13 |
+
vae_batch_size = 4
|
| 14 |
+
milestones_g = [200]
|
| 15 |
+
milestones_d = [200]
|
| 16 |
+
initial_learning_rate_g = 1e-4
|
| 17 |
+
initial_learning_rate_d = 5e-4
|
| 18 |
+
|
| 19 |
+
# lcm
|
| 20 |
+
batch_size = 4
|
| 21 |
+
epoch_number = 600
|
| 22 |
+
initial_learning_rate = 1e-4
|
| 23 |
+
milestones = [300, 400, 500]
|
| 24 |
+
num_train_timesteps = 1000
|
| 25 |
+
beta_start = 0.00085
|
| 26 |
+
beta_end = 0.012
|
| 27 |
+
offset_noise = True
|
| 28 |
+
offset_noise_coefficient = 0.1
|
| 29 |
+
output_feature_map = True
|
| 30 |
+
clip_sample = True
|
| 31 |
+
num_infer_timesteps = 50
|
| 32 |
+
alpha = 3
|
| 33 |
+
video = False
|
| 34 |
+
|
| 35 |
+
# SZCH
|
| 36 |
+
clip_rate = 0.025
|
| 37 |
+
initial_clip_sample_range_g = 2
|
| 38 |
+
initial_clip_sample_range_l = 3.5
|
| 39 |
+
# JSRT
|
| 40 |
+
# clip_rate = 0.025
|
| 41 |
+
# initial_clip_sample_range_g = 1.7
|
| 42 |
+
# initial_clip_sample_range_l = 3
|
| 43 |
+
|
codes/dataset.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import Dataset
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import cv2 as cv
|
| 4 |
+
import os
|
| 5 |
+
from config import config
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class mySingleDataset(Dataset): # 定义数据集类
|
| 11 |
+
def __init__(self, filelist, img_dir, transform=None): # 传入参数(标签路径,图像路径,图像预处理方式,标签预处理方式)
|
| 12 |
+
self.img_dir = img_dir # 读取图像路径
|
| 13 |
+
self.transform = transform # 读取图像预处理方式
|
| 14 |
+
self.filelist = pd.read_csv(filelist, sep="\t", header=None) # 读取文件名列表
|
| 15 |
+
|
| 16 |
+
def __len__(self):
|
| 17 |
+
return len(self.filelist) # 读取文件名数量作为数据集长度
|
| 18 |
+
|
| 19 |
+
def __getitem__(self, idx): # 从数据集中取出数据
|
| 20 |
+
img_path = self.img_dir # 读取图片文件夹路径
|
| 21 |
+
|
| 22 |
+
file = self.filelist.iloc[idx, 0] # 读取文件名
|
| 23 |
+
image = cv.imread(os.path.join(img_path, file)) # 用openCV的imread函数读取图像
|
| 24 |
+
|
| 25 |
+
if self.transform:
|
| 26 |
+
image = self.transform(image) # 图像预处理
|
| 27 |
+
return image, file # 返回图像和名称
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class myDataset(Dataset): # 定义数据集类
|
| 31 |
+
def __init__(self, filelist, cxr_dir, bs_dir,
|
| 32 |
+
transform=None): # 传入参数(标签路径,图像路径,图像预处理方式,标签预处理方式)
|
| 33 |
+
self.cxr_dir = cxr_dir # 读取图像路径
|
| 34 |
+
self.bs_dir = bs_dir # 读取图像路径
|
| 35 |
+
|
| 36 |
+
self.transform = transform # 读取图像预处理方式
|
| 37 |
+
self.filelist = pd.read_csv(filelist, sep="\t", header=None) # 读取文件名列表
|
| 38 |
+
|
| 39 |
+
def __len__(self):
|
| 40 |
+
return len(self.filelist) # 读取文件名数量作为数据集长度
|
| 41 |
+
|
| 42 |
+
def __getitem__(self, idx): # 从数据集中取出数据
|
| 43 |
+
file = self.filelist.iloc[idx, 0] # 读取文件名
|
| 44 |
+
cxr = cv.imread(os.path.join(self.cxr_dir, file)) # 用openCV的imread函数读取图像
|
| 45 |
+
bs = cv.imread(os.path.join(self.bs_dir, file)) # 用openCV的imread函数读取图像
|
| 46 |
+
|
| 47 |
+
if self.transform:
|
| 48 |
+
cxr = self.transform(cxr) # 图像预处理
|
| 49 |
+
bs = self.transform(bs) # 图像预处理
|
| 50 |
+
|
| 51 |
+
return cxr, bs, file # 返回图像和标签
|
| 52 |
+
|
| 53 |
+
class myC2BDataset(Dataset): # 定义数据集类
|
| 54 |
+
def __init__(self, filelist, cxr_dir, masked_cxr_dir,
|
| 55 |
+
transform=None): # 传入参数(标签路径,图像路径,图像预处理方式,标签预处理方式)
|
| 56 |
+
self.cxr_dir = cxr_dir # 读取图像路径
|
| 57 |
+
self.masked_cxr_dir = masked_cxr_dir # 读取图像路径
|
| 58 |
+
|
| 59 |
+
self.transform = transform # 读取图像预处理方式
|
| 60 |
+
self.filelist = pd.read_csv(filelist, sep="\t", header=None) # 读取文件名列表
|
| 61 |
+
|
| 62 |
+
def __len__(self):
|
| 63 |
+
return len(self.filelist) # 读取文件名数量作为数据集长度
|
| 64 |
+
|
| 65 |
+
def __getitem__(self, idx): # 从数据集中取出数据
|
| 66 |
+
file = self.filelist.iloc[idx, 0] # 读取文件名
|
| 67 |
+
cxr = cv.imread(os.path.join(self.cxr_dir, file)) # 用openCV的imread函数读取图像
|
| 68 |
+
masked_cxr = cv.imread(os.path.join(self.masked_cxr_dir, file)) # 用openCV的imread函数读取图像
|
| 69 |
+
|
| 70 |
+
if self.transform:
|
| 71 |
+
cxr = self.transform(cxr) # 图像预处理
|
| 72 |
+
masked_cxr = self.transform(masked_cxr) # 图像预处理
|
| 73 |
+
|
| 74 |
+
return cxr, masked_cxr, file # 返回图像和标签
|
codes/datasetSegmentation.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def traverse_directory(directory, txt_name):
|
| 6 |
+
# 创建一个空的列表用于存储文件名
|
| 7 |
+
file_names = []
|
| 8 |
+
# 遍历目录中的所有文件和子目录
|
| 9 |
+
for root, dirs, files in os.walk(directory):
|
| 10 |
+
for file in files:
|
| 11 |
+
file_names.append(file)
|
| 12 |
+
# 按照文件名排序(如果你希望的话)
|
| 13 |
+
file_names.sort()
|
| 14 |
+
# 创建一个新的txt文件,并将文件名写入该文件
|
| 15 |
+
with open(txt_name, 'w') as f:
|
| 16 |
+
for file_name in file_names:
|
| 17 |
+
f.write(file_name + '\n')
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def split_dataset(file_path, train_ratio=0.8, val_ratio=0.1, former=None):
|
| 21 |
+
# 读取数据集
|
| 22 |
+
with open(file_path, 'r') as f:
|
| 23 |
+
data = f.readlines()
|
| 24 |
+
# 随机打乱数据集
|
| 25 |
+
random.shuffle(data)
|
| 26 |
+
|
| 27 |
+
train_size = int(len(data) * train_ratio)
|
| 28 |
+
val_size = int(len(data) * val_ratio)
|
| 29 |
+
|
| 30 |
+
train_set = data[:train_size]
|
| 31 |
+
val_set = data[train_size:train_size + val_size]
|
| 32 |
+
test_set = data[train_size + val_size:]
|
| 33 |
+
with open(former + '_trainset.txt', 'w') as f:
|
| 34 |
+
f.writelines(train_set)
|
| 35 |
+
with open(former + '_valset.txt', 'w') as f:
|
| 36 |
+
f.writelines(val_set)
|
| 37 |
+
with open(former + '_testset.txt', 'w') as f:
|
| 38 |
+
f.writelines(test_set)
|
| 39 |
+
print(f"Finished.")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
traverse_directory('JSRTnew1024-241/CXR', "JSRT.txt")
|
| 44 |
+
|
| 45 |
+
split_dataset('JSRT.txt',0.8,0.1, "JSRT")
|
codes/lcm_train.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from matplotlib import pyplot as plt
|
| 2 |
+
from config import config
|
| 3 |
+
from dataset import myDataset
|
| 4 |
+
from transform import myTransform
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from model import myUnet, myVQGANModel
|
| 7 |
+
from diffusers import LCMScheduler,DDPMScheduler
|
| 8 |
+
from torch.optim.lr_scheduler import MultiStepLR
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from datetime import date
|
| 11 |
+
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import torch
|
| 14 |
+
import time
|
| 15 |
+
from monai.utils import set_determinism
|
| 16 |
+
|
| 17 |
+
set_determinism(42)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def train():
|
| 21 |
+
if config.use_server:
|
| 22 |
+
file = open('log.txt', 'w') # 保存日志位置
|
| 23 |
+
else:
|
| 24 |
+
file = None # 取消日志输出
|
| 25 |
+
|
| 26 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 设置运行环境
|
| 27 |
+
|
| 28 |
+
train_file_list = "JSRT_trainset.txt" # 存储训练集文件名的文本文件
|
| 29 |
+
test_file_list = "JSRT_valset.txt" # 存储测试集文件名的文本文件
|
| 30 |
+
|
| 31 |
+
cxr_path = "/mntcephfs/med_dataset/SYF/JSRTnew1024-241/CXR" # 图像文件夹路径
|
| 32 |
+
bs_path = "/mntcephfs/med_dataset/SYF/JSRTnew1024-241/BS" # 图像文件夹路径
|
| 33 |
+
masked_cxr_path = "/mntcephfs/med_dataset/SYF/JSRTnew1024-241/Masked_CXR" # 图像文件夹路径
|
| 34 |
+
masked_bs_path = "/mntcephfs/med_dataset/SYF/JSRTnew1024-241/Masked_BS" # 图像文件夹路径
|
| 35 |
+
|
| 36 |
+
myTrainSet = myDataset(train_file_list, cxr_path, bs_path,
|
| 37 |
+
myTransform['trainTransform']) + myDataset(train_file_list, masked_cxr_path, masked_bs_path,
|
| 38 |
+
myTransform['trainTransform'])
|
| 39 |
+
myTestSet = myDataset(test_file_list, cxr_path, bs_path,
|
| 40 |
+
myTransform['testTransform']) + myDataset(test_file_list, masked_cxr_path, masked_bs_path,
|
| 41 |
+
myTransform['testTransform'])
|
| 42 |
+
|
| 43 |
+
myTrainLoader = DataLoader(myTrainSet, batch_size=config.batch_size, shuffle=True)
|
| 44 |
+
myTestLoader = DataLoader(myTestSet, batch_size=config.batch_size, shuffle=True)
|
| 45 |
+
|
| 46 |
+
print("Number of batches in train set:", len(myTrainLoader)) # 输出训练集batch数量
|
| 47 |
+
print("Train set size:", len(myTrainSet)) # 输出训练集大小
|
| 48 |
+
print("Number of batches in test set:", len(myTestLoader)) # 输出测试集batch数量
|
| 49 |
+
print("Test set size:", len(myTestSet)) # 输出测试集大小
|
| 50 |
+
|
| 51 |
+
model = myUnet.to(device).train()
|
| 52 |
+
|
| 53 |
+
# 设置噪声调度器
|
| 54 |
+
noise_scheduler = LCMScheduler(num_train_timesteps=config.num_train_timesteps)
|
| 55 |
+
noise_scheduler.set_timesteps(config.num_infer_timesteps)
|
| 56 |
+
# 设置动态学习率
|
| 57 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=config.initial_learning_rate, eps=1e-6)
|
| 58 |
+
milestones = [x * len(myTrainLoader) for x in config.milestones]
|
| 59 |
+
optimizer_scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1)
|
| 60 |
+
|
| 61 |
+
train_losses = []
|
| 62 |
+
test_losses = []
|
| 63 |
+
plt_train_loss_epoch = []
|
| 64 |
+
plt_test_loss_epoch = []
|
| 65 |
+
train_epoch_list = list(range(0, config.epoch_number))
|
| 66 |
+
test_epoch_list = list(range(0, int(config.epoch_number / config.test_epoch_interval)))
|
| 67 |
+
|
| 68 |
+
VQGAN = torch.load("2025-02-04-Mask-JSRT-VQGAN.pth").to(device).eval()
|
| 69 |
+
print(time.strftime("%H:%M:%S", time.localtime()), "----------Begin Training----------", file=file)
|
| 70 |
+
for epoch in range(config.epoch_number):
|
| 71 |
+
model.train()
|
| 72 |
+
print(time.strftime("%H:%M:%S", time.localtime()),
|
| 73 |
+
f"Epoch:{epoch},learning rate:{optimizer.param_groups[0]['lr']}", file=file)
|
| 74 |
+
for i, batch in tqdm(enumerate(myTrainLoader)):
|
| 75 |
+
cxr_i, bs_i = batch[0].to(device), batch[1].to(device)
|
| 76 |
+
|
| 77 |
+
with torch.no_grad():
|
| 78 |
+
cxr = VQGAN.encode_stage_2_inputs(cxr_i)
|
| 79 |
+
bs = VQGAN.encode_stage_2_inputs(bs_i)
|
| 80 |
+
|
| 81 |
+
cat = torch.cat((bs, cxr), dim=-3)
|
| 82 |
+
|
| 83 |
+
# 为图片添加噪声
|
| 84 |
+
if config.offset_noise:
|
| 85 |
+
noise = torch.randn_like(cxr).to(device) + config.offset_noise_coefficient * torch.randn(
|
| 86 |
+
cxr.shape[0], cxr.shape[1], 1,
|
| 87 |
+
1).to(device)
|
| 88 |
+
else:
|
| 89 |
+
noise = torch.randn_like(cxr).to(device)
|
| 90 |
+
|
| 91 |
+
blank = torch.zeros_like(cxr).to(device)
|
| 92 |
+
noise = torch.cat((noise, blank), dim=-3)
|
| 93 |
+
|
| 94 |
+
# 为每张图片随机采样一个时间步
|
| 95 |
+
timesteps = torch.randint(0, config.num_train_timesteps, (cxr.shape[0],), device=device).long()
|
| 96 |
+
|
| 97 |
+
# 根据每个时间步的噪声幅度,向清晰的图片中添加噪声
|
| 98 |
+
noisy_images = noise_scheduler.add_noise(cat, noise, timesteps)
|
| 99 |
+
|
| 100 |
+
# 获取模型的预测结果
|
| 101 |
+
noise_pred = model(noisy_images, timesteps)
|
| 102 |
+
|
| 103 |
+
# 计算损失
|
| 104 |
+
loss = F.mse_loss(noise_pred[:, :4].float(), noise[:, :4].float())
|
| 105 |
+
|
| 106 |
+
loss.backward()
|
| 107 |
+
train_losses.append(loss.item())
|
| 108 |
+
|
| 109 |
+
# 迭代模型参数
|
| 110 |
+
optimizer.step()
|
| 111 |
+
optimizer.zero_grad()
|
| 112 |
+
optimizer_scheduler.step()
|
| 113 |
+
|
| 114 |
+
train_loss_epoch = sum(train_losses[-len(myTrainLoader):]) / len(myTrainLoader)
|
| 115 |
+
print(time.strftime("%H:%M:%S", time.localtime()), f"Epoch:{epoch},train losses:{train_loss_epoch}", file=file)
|
| 116 |
+
plt_train_loss_epoch.append(train_loss_epoch)
|
| 117 |
+
|
| 118 |
+
if (epoch + 1) % config.test_epoch_interval == 0:
|
| 119 |
+
model.eval()
|
| 120 |
+
print(time.strftime("%H:%M:%S", time.localtime()), "----------Stop Training----------", file=file)
|
| 121 |
+
print(time.strftime("%H:%M:%S", time.localtime()), "----------Begin Testing----------", file=file)
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
for i, batch in tqdm(enumerate(myTestLoader)):
|
| 124 |
+
cxr_i, bs_i = batch[0].to(device), batch[1].to(device)
|
| 125 |
+
|
| 126 |
+
with torch.no_grad():
|
| 127 |
+
cxr = VQGAN.encode_stage_2_inputs(cxr_i)
|
| 128 |
+
bs = VQGAN.encode_stage_2_inputs(bs_i)
|
| 129 |
+
|
| 130 |
+
cat = torch.cat((bs, cxr), dim=-3)
|
| 131 |
+
|
| 132 |
+
# 为图片添加噪声
|
| 133 |
+
if config.offset_noise:
|
| 134 |
+
noise = torch.randn_like(cxr).to(device) + config.offset_noise_coefficient * torch.randn(
|
| 135 |
+
cxr.shape[0],
|
| 136 |
+
cxr.shape[1], 1,
|
| 137 |
+
1).to(device)
|
| 138 |
+
else:
|
| 139 |
+
noise = torch.randn_like(cxr).to(device)
|
| 140 |
+
|
| 141 |
+
blank = torch.zeros_like(cxr).to(device)
|
| 142 |
+
noise = torch.cat((noise, blank), dim=-3)
|
| 143 |
+
|
| 144 |
+
# 为每张图片随机采样一个时间步
|
| 145 |
+
timesteps = torch.randint(0, config.num_train_timesteps, (cxr.shape[0],),
|
| 146 |
+
device=device).long()
|
| 147 |
+
|
| 148 |
+
# 根据每个时间步的噪声幅度,向清晰的图片中添加噪声
|
| 149 |
+
noisy_images = noise_scheduler.add_noise(cat, noise, timesteps)
|
| 150 |
+
|
| 151 |
+
# 获取模型的预测结果
|
| 152 |
+
noise_pred = model(noisy_images, timesteps)
|
| 153 |
+
|
| 154 |
+
# 计算损失
|
| 155 |
+
loss = F.mse_loss(noise_pred[:, :4].float(), noise[:, :4].float())
|
| 156 |
+
|
| 157 |
+
test_losses.append(loss.item())
|
| 158 |
+
|
| 159 |
+
test_loss_epoch = sum(test_losses[-len(myTestLoader):]) / len(myTestLoader)
|
| 160 |
+
print(time.strftime("%H:%M:%S", time.localtime()), f"Epoch:{epoch},test losses:{test_loss_epoch}",
|
| 161 |
+
file=file)
|
| 162 |
+
plt_test_loss_epoch.append(test_loss_epoch)
|
| 163 |
+
print(time.strftime("%H:%M:%S", time.localtime()), "----------End Validation----------", file=file)
|
| 164 |
+
print(time.strftime("%H:%M:%S", time.localtime()), "----------Continue to Train----------",
|
| 165 |
+
file=file)
|
| 166 |
+
print(time.strftime("%H:%M:%S", time.localtime()), "----------End Training Normally----------", file=file)
|
| 167 |
+
# 查看损失曲线
|
| 168 |
+
f, ([ax1, ax2]) = plt.subplots(1, 2)
|
| 169 |
+
ax1.plot(train_epoch_list, plt_train_loss_epoch, color="red") # 绘制曲线
|
| 170 |
+
ax1.set_title('Train loss') # 添加标题
|
| 171 |
+
ax2.plot(test_epoch_list, plt_test_loss_epoch, color="blue") # 绘制曲线
|
| 172 |
+
ax2.set_title('Test loss') # 添加标题
|
| 173 |
+
plt.savefig("./loss.png") # 保存损失曲线
|
| 174 |
+
if not config.use_server:
|
| 175 |
+
plt.show() # 展示损失曲线
|
| 176 |
+
torch.save(model, "masked_lcm-600JSRT-" + str(date.today()) + "-myModel.pth")
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
if __name__ == "__main__":
|
| 180 |
+
train()
|
codes/lungSegmentation.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
codes/metrics.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2 as cv
|
| 2 |
+
import lpips
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
from config import config
|
| 6 |
+
from openpyxl import Workbook
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from math import log10, sqrt
|
| 10 |
+
import torchvision.transforms as transforms
|
| 11 |
+
|
| 12 |
+
def cal_BSR(cxr_path, gt_path, bs_path):
|
| 13 |
+
cxr = cv.imread(cxr_path, 0)
|
| 14 |
+
gt = cv.imread(gt_path, 0)
|
| 15 |
+
bs = cv.imread(bs_path, 0)
|
| 16 |
+
|
| 17 |
+
cxr = cxr / 255
|
| 18 |
+
gt = gt / 255
|
| 19 |
+
bs = bs / 255
|
| 20 |
+
|
| 21 |
+
bone = cv.subtract(cxr, gt)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
gt = cv.resize(gt, (config.image_size, config.image_size))
|
| 25 |
+
bs = cv.resize(bs, (config.image_size, config.image_size))
|
| 26 |
+
bone = cv.resize(bone, (config.image_size, config.image_size))
|
| 27 |
+
|
| 28 |
+
bs += np.average(cv.subtract(gt, bs))
|
| 29 |
+
|
| 30 |
+
bias = cv.subtract(gt, bs)
|
| 31 |
+
bias[bias < 0] = 0
|
| 32 |
+
|
| 33 |
+
BSR = 1 - np.sum(bias ** 2) / np.sum(bone ** 2)
|
| 34 |
+
return BSR
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def cal_MSE(gt_path, bs_path):
|
| 38 |
+
gt = cv.imread(gt_path, 0)
|
| 39 |
+
bs = cv.imread(bs_path, 0)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
gt = cv.resize(gt, (config.image_size, config.image_size))
|
| 44 |
+
bs = cv.resize(bs, (config.image_size, config.image_size))
|
| 45 |
+
|
| 46 |
+
MSE = np.mean((gt - bs) ** 2)
|
| 47 |
+
MSE = 2*lpips.l2(gt,bs)
|
| 48 |
+
return MSE
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def cal_PSNR(gt_path, bs_path):
|
| 54 |
+
mse = cal_MSE(gt_path,bs_path)
|
| 55 |
+
max_pixel = 1
|
| 56 |
+
|
| 57 |
+
PSNR = 20 * log10(max_pixel / sqrt(mse))
|
| 58 |
+
return PSNR
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def cal_LPIPS(gt_path, bs_path):
|
| 64 |
+
lplps_model = lpips.LPIPS()
|
| 65 |
+
|
| 66 |
+
gt = cv.imread(gt_path, 0)
|
| 67 |
+
bs = cv.imread(bs_path, 0)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
gt = cv.resize(gt, (config.image_size, config.image_size))
|
| 71 |
+
bs = cv.resize(bs, (config.image_size, config.image_size))
|
| 72 |
+
|
| 73 |
+
gt = transforms.ToTensor()(gt)
|
| 74 |
+
bs = transforms.ToTensor()(bs)
|
| 75 |
+
|
| 76 |
+
gt = torch.unsqueeze(gt, dim=0)
|
| 77 |
+
bs = torch.unsqueeze(bs, dim=0)
|
| 78 |
+
|
| 79 |
+
LPIPS = lplps_model(gt, bs).item()
|
| 80 |
+
|
| 81 |
+
return LPIPS
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
wb = Workbook()
|
| 86 |
+
|
| 87 |
+
ws = wb.active
|
| 88 |
+
|
| 89 |
+
CXR_path = "SZCH-X-Rays/CXR"
|
| 90 |
+
GT_path = "SZCH-X-Rays/BS"
|
| 91 |
+
BS_path = "lcm_output_bs/Fusion_BS"
|
| 92 |
+
|
| 93 |
+
BSR_list = []
|
| 94 |
+
MSE_list = []
|
| 95 |
+
PSNR_list = []
|
| 96 |
+
LPIPS_list = []
|
| 97 |
+
ws.append(["Filename", "BSR", "MSE", "PSNR", "LPIPS"])
|
| 98 |
+
txt = 'SZCH_testset.txt'
|
| 99 |
+
with open(txt, 'r', encoding='utf-8') as file:
|
| 100 |
+
lines = file.readlines()
|
| 101 |
+
file_names = [line.strip() for line in lines]
|
| 102 |
+
|
| 103 |
+
for filename in os.listdir(BS_path):
|
| 104 |
+
if filename in file_names:
|
| 105 |
+
pass
|
| 106 |
+
else:
|
| 107 |
+
continue
|
| 108 |
+
cxr_path = os.path.join(CXR_path, filename)
|
| 109 |
+
gt_path = os.path.join(GT_path, filename)
|
| 110 |
+
bs_path = os.path.join(BS_path, filename)
|
| 111 |
+
|
| 112 |
+
BSR = cal_BSR(cxr_path, gt_path, bs_path)
|
| 113 |
+
MSE = cal_MSE(gt_path, bs_path)
|
| 114 |
+
PSNR = cal_PSNR(gt_path, bs_path)
|
| 115 |
+
LPIPS = cal_LPIPS(gt_path, bs_path)
|
| 116 |
+
|
| 117 |
+
BSR_list.append(BSR)
|
| 118 |
+
MSE_list.append(MSE)
|
| 119 |
+
PSNR_list.append(PSNR)
|
| 120 |
+
LPIPS_list.append(LPIPS)
|
| 121 |
+
print(f"{filename} BSR: {BSR} MSE: {MSE} PSNR:{PSNR} LPIPS:{LPIPS}")
|
| 122 |
+
ws.append([filename, BSR, MSE, PSNR, LPIPS])
|
| 123 |
+
|
| 124 |
+
ws.append(["Mean",
|
| 125 |
+
np.mean(np.array(BSR_list)),
|
| 126 |
+
np.mean(np.array(MSE_list)),
|
| 127 |
+
np.mean(np.array(PSNR_list)),
|
| 128 |
+
np.mean(np.array(LPIPS_list))])
|
| 129 |
+
ws.append(["Std",
|
| 130 |
+
np.std(np.array(BSR_list)),
|
| 131 |
+
np.std(np.array(MSE_list)),
|
| 132 |
+
np.std(np.array(PSNR_list)),
|
| 133 |
+
np.std(np.array(LPIPS_list))])
|
| 134 |
+
print("Average BSR:", np.mean(np.array(BSR_list)), "Std:", np.std(np.array(BSR_list)))
|
| 135 |
+
print("Average MSE:", np.mean(np.array(MSE_list)), "Std:", np.std(np.array(MSE_list)))
|
| 136 |
+
print("Average PSNR:", np.mean(np.array(PSNR_list)), "Std:", np.std(np.array(PSNR_list)))
|
| 137 |
+
print("Average LPIPS:", np.mean(np.array(LPIPS_list)), "Std:", np.std(np.array(LPIPS_list)))
|
| 138 |
+
|
| 139 |
+
# 保存工作簿到文件
|
| 140 |
+
wb.save("sample.xlsx")
|
codes/model.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from modules.unet import UNetModel
|
| 3 |
+
from generative.networks.nets import VQVAE
|
| 4 |
+
from config import config
|
| 5 |
+
|
| 6 |
+
myUnet = UNetModel(
|
| 7 |
+
image_size=config.image_size / config.r,
|
| 8 |
+
model_channels=128,
|
| 9 |
+
in_channels=8,
|
| 10 |
+
out_channels=8,
|
| 11 |
+
num_res_blocks=8,
|
| 12 |
+
num_heads=8,
|
| 13 |
+
attention_resolutions=(64, 32, 16, 8),
|
| 14 |
+
num_heads_upsample=-1,
|
| 15 |
+
num_head_channels=-1,
|
| 16 |
+
resblock_updown=True,
|
| 17 |
+
channel_mult=(1, 1, 2, 2, 4, 4),
|
| 18 |
+
use_scale_shift_norm=True,
|
| 19 |
+
use_new_attention_order=True
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
myVQGANModel = VQVAE(
|
| 23 |
+
spatial_dims=2,
|
| 24 |
+
in_channels=1,
|
| 25 |
+
out_channels=1,
|
| 26 |
+
num_channels=(128, 256, 512),
|
| 27 |
+
num_res_channels=512,
|
| 28 |
+
num_res_layers=2,
|
| 29 |
+
downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1),),
|
| 30 |
+
upsample_parameters=((2, 4, 1, 1, 0), (2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),
|
| 31 |
+
num_embeddings=1024,
|
| 32 |
+
embedding_dim=4,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
print("Number of model parameters:", sum([p.numel() for p in myUnet.parameters()]))
|
| 37 |
+
print("Number of model parameters:", sum([p.numel() for p in myVQGANModel.parameters()]))
|
codes/modules/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Codebase for "Bone Suppression via conditional diffusion model".
|
| 3 |
+
"""
|
codes/modules/diffusion_model_unet.py
ADDED
|
@@ -0,0 +1,2099 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
#
|
| 12 |
+
# =========================================================================
|
| 13 |
+
# Adapted from https://github.com/huggingface/diffusers
|
| 14 |
+
# which has the following license:
|
| 15 |
+
# https://github.com/huggingface/diffusers/blob/main/LICENSE
|
| 16 |
+
#
|
| 17 |
+
# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
|
| 18 |
+
#
|
| 19 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 20 |
+
# you may not use this file except in compliance with the License.
|
| 21 |
+
# You may obtain a copy of the License at
|
| 22 |
+
#
|
| 23 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 24 |
+
#
|
| 25 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 26 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 27 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 28 |
+
# See the License for the specific language governing permissions and
|
| 29 |
+
# limitations under the License.
|
| 30 |
+
# =========================================================================
|
| 31 |
+
|
| 32 |
+
from __future__ import annotations
|
| 33 |
+
|
| 34 |
+
import importlib.util
|
| 35 |
+
import math
|
| 36 |
+
from collections.abc import Sequence
|
| 37 |
+
|
| 38 |
+
import torch
|
| 39 |
+
import torch.nn.functional as F
|
| 40 |
+
from monai.networks.blocks import Convolution, MLPBlock
|
| 41 |
+
from monai.networks.layers.factories import Pool
|
| 42 |
+
from monai.utils import ensure_tuple_rep
|
| 43 |
+
from torch import nn
|
| 44 |
+
|
| 45 |
+
# To install xformers, use pip install xformers==0.0.16rc401
|
| 46 |
+
if importlib.util.find_spec("xformers") is not None:
|
| 47 |
+
import xformers
|
| 48 |
+
import xformers.ops
|
| 49 |
+
|
| 50 |
+
has_xformers = True
|
| 51 |
+
else:
|
| 52 |
+
xformers = None
|
| 53 |
+
has_xformers = False
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# TODO: Use MONAI's optional_import
|
| 57 |
+
# from monai.utils import optional_import
|
| 58 |
+
# xformers, has_xformers = optional_import("xformers.ops", name="xformers")
|
| 59 |
+
|
| 60 |
+
__all__ = ["DiffusionModelUNet"]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def zero_module(module: nn.Module) -> nn.Module:
|
| 64 |
+
"""
|
| 65 |
+
Zero out the parameters of a module and return it.
|
| 66 |
+
"""
|
| 67 |
+
for p in module.parameters():
|
| 68 |
+
p.detach().zero_()
|
| 69 |
+
return module
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class CrossAttention(nn.Module):
|
| 73 |
+
"""
|
| 74 |
+
A cross attention layer.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
query_dim: number of channels in the query.
|
| 78 |
+
cross_attention_dim: number of channels in the context.
|
| 79 |
+
num_attention_heads: number of heads to use for multi-head attention.
|
| 80 |
+
num_head_channels: number of channels in each head.
|
| 81 |
+
dropout: dropout probability to use.
|
| 82 |
+
upcast_attention: if True, upcast attention operations to full precision.
|
| 83 |
+
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
query_dim: int,
|
| 89 |
+
cross_attention_dim: int | None = None,
|
| 90 |
+
num_attention_heads: int = 8,
|
| 91 |
+
num_head_channels: int = 64,
|
| 92 |
+
dropout: float = 0.0,
|
| 93 |
+
upcast_attention: bool = False,
|
| 94 |
+
use_flash_attention: bool = False,
|
| 95 |
+
) -> None:
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.use_flash_attention = use_flash_attention
|
| 98 |
+
inner_dim = num_head_channels * num_attention_heads
|
| 99 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
| 100 |
+
|
| 101 |
+
self.scale = 1 / math.sqrt(num_head_channels)
|
| 102 |
+
self.num_heads = num_attention_heads
|
| 103 |
+
|
| 104 |
+
self.upcast_attention = upcast_attention
|
| 105 |
+
|
| 106 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
| 107 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
|
| 108 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)
|
| 109 |
+
|
| 110 |
+
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
| 111 |
+
|
| 112 |
+
def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor:
|
| 113 |
+
"""
|
| 114 |
+
Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch.
|
| 115 |
+
"""
|
| 116 |
+
batch_size, seq_len, dim = x.shape
|
| 117 |
+
x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads)
|
| 118 |
+
x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads)
|
| 119 |
+
return x
|
| 120 |
+
|
| 121 |
+
def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor:
|
| 122 |
+
"""Combine the output of the attention heads back into the hidden state dimension."""
|
| 123 |
+
batch_size, seq_len, dim = x.shape
|
| 124 |
+
x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim)
|
| 125 |
+
x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads)
|
| 126 |
+
return x
|
| 127 |
+
|
| 128 |
+
def _memory_efficient_attention_xformers(
|
| 129 |
+
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
| 130 |
+
) -> torch.Tensor:
|
| 131 |
+
query = query.contiguous()
|
| 132 |
+
key = key.contiguous()
|
| 133 |
+
value = value.contiguous()
|
| 134 |
+
x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
|
| 135 |
+
return x
|
| 136 |
+
|
| 137 |
+
def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
|
| 138 |
+
dtype = query.dtype
|
| 139 |
+
if self.upcast_attention:
|
| 140 |
+
query = query.float()
|
| 141 |
+
key = key.float()
|
| 142 |
+
|
| 143 |
+
attention_scores = torch.baddbmm(
|
| 144 |
+
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
| 145 |
+
query,
|
| 146 |
+
key.transpose(-1, -2),
|
| 147 |
+
beta=0,
|
| 148 |
+
alpha=self.scale,
|
| 149 |
+
)
|
| 150 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
| 151 |
+
attention_probs = attention_probs.to(dtype=dtype)
|
| 152 |
+
|
| 153 |
+
x = torch.bmm(attention_probs, value)
|
| 154 |
+
return x
|
| 155 |
+
|
| 156 |
+
def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
|
| 157 |
+
query = self.to_q(x)
|
| 158 |
+
context = context if context is not None else x
|
| 159 |
+
key = self.to_k(context)
|
| 160 |
+
value = self.to_v(context)
|
| 161 |
+
|
| 162 |
+
# Multi-Head Attention
|
| 163 |
+
query = self.reshape_heads_to_batch_dim(query)
|
| 164 |
+
key = self.reshape_heads_to_batch_dim(key)
|
| 165 |
+
value = self.reshape_heads_to_batch_dim(value)
|
| 166 |
+
|
| 167 |
+
if self.use_flash_attention:
|
| 168 |
+
x = self._memory_efficient_attention_xformers(query, key, value)
|
| 169 |
+
else:
|
| 170 |
+
x = self._attention(query, key, value)
|
| 171 |
+
|
| 172 |
+
x = self.reshape_batch_dim_to_heads(x)
|
| 173 |
+
x = x.to(query.dtype)
|
| 174 |
+
|
| 175 |
+
return self.to_out(x)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class BasicTransformerBlock(nn.Module):
|
| 179 |
+
"""
|
| 180 |
+
A basic Transformer block.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
num_channels: number of channels in the input and output.
|
| 184 |
+
num_attention_heads: number of heads to use for multi-head attention.
|
| 185 |
+
num_head_channels: number of channels in each attention head.
|
| 186 |
+
dropout: dropout probability to use.
|
| 187 |
+
cross_attention_dim: size of the context vector for cross attention.
|
| 188 |
+
upcast_attention: if True, upcast attention operations to full precision.
|
| 189 |
+
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(
|
| 193 |
+
self,
|
| 194 |
+
num_channels: int,
|
| 195 |
+
num_attention_heads: int,
|
| 196 |
+
num_head_channels: int,
|
| 197 |
+
dropout: float = 0.0,
|
| 198 |
+
cross_attention_dim: int | None = None,
|
| 199 |
+
upcast_attention: bool = False,
|
| 200 |
+
use_flash_attention: bool = False,
|
| 201 |
+
) -> None:
|
| 202 |
+
super().__init__()
|
| 203 |
+
self.attn1 = CrossAttention(
|
| 204 |
+
query_dim=num_channels,
|
| 205 |
+
num_attention_heads=num_attention_heads,
|
| 206 |
+
num_head_channels=num_head_channels,
|
| 207 |
+
dropout=dropout,
|
| 208 |
+
upcast_attention=upcast_attention,
|
| 209 |
+
use_flash_attention=use_flash_attention,
|
| 210 |
+
) # is a self-attention
|
| 211 |
+
self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout)
|
| 212 |
+
self.attn2 = CrossAttention(
|
| 213 |
+
query_dim=num_channels,
|
| 214 |
+
cross_attention_dim=cross_attention_dim,
|
| 215 |
+
num_attention_heads=num_attention_heads,
|
| 216 |
+
num_head_channels=num_head_channels,
|
| 217 |
+
dropout=dropout,
|
| 218 |
+
upcast_attention=upcast_attention,
|
| 219 |
+
use_flash_attention=use_flash_attention,
|
| 220 |
+
) # is a self-attention if context is None
|
| 221 |
+
self.norm1 = nn.LayerNorm(num_channels)
|
| 222 |
+
self.norm2 = nn.LayerNorm(num_channels)
|
| 223 |
+
self.norm3 = nn.LayerNorm(num_channels)
|
| 224 |
+
|
| 225 |
+
def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
|
| 226 |
+
# 1. Self-Attention
|
| 227 |
+
x = self.attn1(self.norm1(x)) + x
|
| 228 |
+
|
| 229 |
+
# 2. Cross-Attention
|
| 230 |
+
x = self.attn2(self.norm2(x), context=context) + x
|
| 231 |
+
|
| 232 |
+
# 3. Feed-forward
|
| 233 |
+
x = self.ff(self.norm3(x)) + x
|
| 234 |
+
return x
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class SpatialTransformer(nn.Module):
|
| 238 |
+
"""
|
| 239 |
+
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
|
| 240 |
+
standard transformer action. Finally, reshape to image.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
spatial_dims: number of spatial dimensions.
|
| 244 |
+
in_channels: number of channels in the input and output.
|
| 245 |
+
num_attention_heads: number of heads to use for multi-head attention.
|
| 246 |
+
num_head_channels: number of channels in each attention head.
|
| 247 |
+
num_layers: number of layers of Transformer blocks to use.
|
| 248 |
+
dropout: dropout probability to use.
|
| 249 |
+
norm_num_groups: number of groups for the normalization.
|
| 250 |
+
norm_eps: epsilon for the normalization.
|
| 251 |
+
cross_attention_dim: number of context dimensions to use.
|
| 252 |
+
upcast_attention: if True, upcast attention operations to full precision.
|
| 253 |
+
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
def __init__(
|
| 257 |
+
self,
|
| 258 |
+
spatial_dims: int,
|
| 259 |
+
in_channels: int,
|
| 260 |
+
num_attention_heads: int,
|
| 261 |
+
num_head_channels: int,
|
| 262 |
+
num_layers: int = 1,
|
| 263 |
+
dropout: float = 0.0,
|
| 264 |
+
norm_num_groups: int = 32,
|
| 265 |
+
norm_eps: float = 1e-6,
|
| 266 |
+
cross_attention_dim: int | None = None,
|
| 267 |
+
upcast_attention: bool = False,
|
| 268 |
+
use_flash_attention: bool = False,
|
| 269 |
+
) -> None:
|
| 270 |
+
super().__init__()
|
| 271 |
+
self.spatial_dims = spatial_dims
|
| 272 |
+
self.in_channels = in_channels
|
| 273 |
+
inner_dim = num_attention_heads * num_head_channels
|
| 274 |
+
|
| 275 |
+
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True)
|
| 276 |
+
|
| 277 |
+
self.proj_in = Convolution(
|
| 278 |
+
spatial_dims=spatial_dims,
|
| 279 |
+
in_channels=in_channels,
|
| 280 |
+
out_channels=inner_dim,
|
| 281 |
+
strides=1,
|
| 282 |
+
kernel_size=1,
|
| 283 |
+
padding=0,
|
| 284 |
+
conv_only=True,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
self.transformer_blocks = nn.ModuleList(
|
| 288 |
+
[
|
| 289 |
+
BasicTransformerBlock(
|
| 290 |
+
num_channels=inner_dim,
|
| 291 |
+
num_attention_heads=num_attention_heads,
|
| 292 |
+
num_head_channels=num_head_channels,
|
| 293 |
+
dropout=dropout,
|
| 294 |
+
cross_attention_dim=cross_attention_dim,
|
| 295 |
+
upcast_attention=upcast_attention,
|
| 296 |
+
use_flash_attention=use_flash_attention,
|
| 297 |
+
)
|
| 298 |
+
for _ in range(num_layers)
|
| 299 |
+
]
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
self.proj_out = zero_module(
|
| 303 |
+
Convolution(
|
| 304 |
+
spatial_dims=spatial_dims,
|
| 305 |
+
in_channels=inner_dim,
|
| 306 |
+
out_channels=in_channels,
|
| 307 |
+
strides=1,
|
| 308 |
+
kernel_size=1,
|
| 309 |
+
padding=0,
|
| 310 |
+
conv_only=True,
|
| 311 |
+
)
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
|
| 315 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
| 316 |
+
batch = channel = height = width = depth = -1
|
| 317 |
+
if self.spatial_dims == 2:
|
| 318 |
+
batch, channel, height, width = x.shape
|
| 319 |
+
if self.spatial_dims == 3:
|
| 320 |
+
batch, channel, height, width, depth = x.shape
|
| 321 |
+
|
| 322 |
+
residual = x
|
| 323 |
+
x = self.norm(x)
|
| 324 |
+
x = self.proj_in(x)
|
| 325 |
+
|
| 326 |
+
inner_dim = x.shape[1]
|
| 327 |
+
|
| 328 |
+
if self.spatial_dims == 2:
|
| 329 |
+
x = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
| 330 |
+
if self.spatial_dims == 3:
|
| 331 |
+
x = x.permute(0, 2, 3, 4, 1).reshape(batch, height * width * depth, inner_dim)
|
| 332 |
+
|
| 333 |
+
for block in self.transformer_blocks:
|
| 334 |
+
x = block(x, context=context)
|
| 335 |
+
|
| 336 |
+
if self.spatial_dims == 2:
|
| 337 |
+
x = x.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
| 338 |
+
if self.spatial_dims == 3:
|
| 339 |
+
x = x.reshape(batch, height, width, depth, inner_dim).permute(0, 4, 1, 2, 3).contiguous()
|
| 340 |
+
|
| 341 |
+
x = self.proj_out(x)
|
| 342 |
+
return x + residual
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
class AttentionBlock(nn.Module):
|
| 346 |
+
"""
|
| 347 |
+
An attention block that allows spatial positions to attend to each other. Uses three q, k, v linear layers to
|
| 348 |
+
compute attention.
|
| 349 |
+
|
| 350 |
+
Args:
|
| 351 |
+
spatial_dims: number of spatial dimensions.
|
| 352 |
+
num_channels: number of input channels.
|
| 353 |
+
num_head_channels: number of channels in each attention head.
|
| 354 |
+
norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of
|
| 355 |
+
channels is divisible by this number.
|
| 356 |
+
norm_eps: epsilon value to use for the normalisation.
|
| 357 |
+
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
|
| 358 |
+
"""
|
| 359 |
+
|
| 360 |
+
def __init__(
|
| 361 |
+
self,
|
| 362 |
+
spatial_dims: int,
|
| 363 |
+
num_channels: int,
|
| 364 |
+
num_head_channels: int | None = None,
|
| 365 |
+
norm_num_groups: int = 32,
|
| 366 |
+
norm_eps: float = 1e-6,
|
| 367 |
+
use_flash_attention: bool = False,
|
| 368 |
+
) -> None:
|
| 369 |
+
super().__init__()
|
| 370 |
+
self.use_flash_attention = use_flash_attention
|
| 371 |
+
self.spatial_dims = spatial_dims
|
| 372 |
+
self.num_channels = num_channels
|
| 373 |
+
|
| 374 |
+
self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1
|
| 375 |
+
self.scale = 1 / math.sqrt(num_channels / self.num_heads)
|
| 376 |
+
|
| 377 |
+
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True)
|
| 378 |
+
|
| 379 |
+
self.to_q = nn.Linear(num_channels, num_channels)
|
| 380 |
+
self.to_k = nn.Linear(num_channels, num_channels)
|
| 381 |
+
self.to_v = nn.Linear(num_channels, num_channels)
|
| 382 |
+
|
| 383 |
+
self.proj_attn = nn.Linear(num_channels, num_channels)
|
| 384 |
+
|
| 385 |
+
def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor:
|
| 386 |
+
batch_size, seq_len, dim = x.shape
|
| 387 |
+
x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads)
|
| 388 |
+
x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads)
|
| 389 |
+
return x
|
| 390 |
+
|
| 391 |
+
def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor:
|
| 392 |
+
batch_size, seq_len, dim = x.shape
|
| 393 |
+
x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim)
|
| 394 |
+
x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads)
|
| 395 |
+
return x
|
| 396 |
+
|
| 397 |
+
def _memory_efficient_attention_xformers(
|
| 398 |
+
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
| 399 |
+
) -> torch.Tensor:
|
| 400 |
+
query = query.contiguous()
|
| 401 |
+
key = key.contiguous()
|
| 402 |
+
value = value.contiguous()
|
| 403 |
+
x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
|
| 404 |
+
return x
|
| 405 |
+
|
| 406 |
+
def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
|
| 407 |
+
attention_scores = torch.baddbmm(
|
| 408 |
+
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
| 409 |
+
query,
|
| 410 |
+
key.transpose(-1, -2),
|
| 411 |
+
beta=0,
|
| 412 |
+
alpha=self.scale,
|
| 413 |
+
)
|
| 414 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
| 415 |
+
x = torch.bmm(attention_probs, value)
|
| 416 |
+
return x
|
| 417 |
+
|
| 418 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 419 |
+
residual = x
|
| 420 |
+
|
| 421 |
+
batch = channel = height = width = depth = -1
|
| 422 |
+
if self.spatial_dims == 2:
|
| 423 |
+
batch, channel, height, width = x.shape
|
| 424 |
+
if self.spatial_dims == 3:
|
| 425 |
+
batch, channel, height, width, depth = x.shape
|
| 426 |
+
|
| 427 |
+
# norm
|
| 428 |
+
x = self.norm(x)
|
| 429 |
+
|
| 430 |
+
if self.spatial_dims == 2:
|
| 431 |
+
x = x.view(batch, channel, height * width).transpose(1, 2)
|
| 432 |
+
if self.spatial_dims == 3:
|
| 433 |
+
x = x.view(batch, channel, height * width * depth).transpose(1, 2)
|
| 434 |
+
|
| 435 |
+
# proj to q, k, v
|
| 436 |
+
query = self.to_q(x)
|
| 437 |
+
key = self.to_k(x)
|
| 438 |
+
value = self.to_v(x)
|
| 439 |
+
|
| 440 |
+
# Multi-Head Attention
|
| 441 |
+
query = self.reshape_heads_to_batch_dim(query)
|
| 442 |
+
key = self.reshape_heads_to_batch_dim(key)
|
| 443 |
+
value = self.reshape_heads_to_batch_dim(value)
|
| 444 |
+
|
| 445 |
+
if self.use_flash_attention:
|
| 446 |
+
x = self._memory_efficient_attention_xformers(query, key, value)
|
| 447 |
+
else:
|
| 448 |
+
x = self._attention(query, key, value)
|
| 449 |
+
|
| 450 |
+
x = self.reshape_batch_dim_to_heads(x)
|
| 451 |
+
x = x.to(query.dtype)
|
| 452 |
+
|
| 453 |
+
if self.spatial_dims == 2:
|
| 454 |
+
x = x.transpose(-1, -2).reshape(batch, channel, height, width)
|
| 455 |
+
if self.spatial_dims == 3:
|
| 456 |
+
x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth)
|
| 457 |
+
|
| 458 |
+
return x + residual
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000) -> torch.Tensor:
|
| 462 |
+
"""
|
| 463 |
+
Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic
|
| 464 |
+
Models" https://arxiv.org/abs/2006.11239.
|
| 465 |
+
|
| 466 |
+
Args:
|
| 467 |
+
timesteps: a 1-D Tensor of N indices, one per batch element.
|
| 468 |
+
embedding_dim: the dimension of the output.
|
| 469 |
+
max_period: controls the minimum frequency of the embeddings.
|
| 470 |
+
"""
|
| 471 |
+
if timesteps.ndim != 1:
|
| 472 |
+
raise ValueError("Timesteps should be a 1d-array")
|
| 473 |
+
|
| 474 |
+
half_dim = embedding_dim // 2
|
| 475 |
+
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
| 476 |
+
freqs = torch.exp(exponent / half_dim)
|
| 477 |
+
|
| 478 |
+
args = timesteps[:, None].float() * freqs[None, :]
|
| 479 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 480 |
+
|
| 481 |
+
# zero pad
|
| 482 |
+
if embedding_dim % 2 == 1:
|
| 483 |
+
embedding = torch.nn.functional.pad(embedding, (0, 1, 0, 0))
|
| 484 |
+
|
| 485 |
+
return embedding
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
class Downsample(nn.Module):
|
| 489 |
+
"""
|
| 490 |
+
Downsampling layer.
|
| 491 |
+
|
| 492 |
+
Args:
|
| 493 |
+
spatial_dims: number of spatial dimensions.
|
| 494 |
+
num_channels: number of input channels.
|
| 495 |
+
use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is
|
| 496 |
+
False, the number of output channels must be the same as the number of input channels.
|
| 497 |
+
out_channels: number of output channels.
|
| 498 |
+
padding: controls the amount of implicit zero-paddings on both sides for padding number of points
|
| 499 |
+
for each dimension.
|
| 500 |
+
"""
|
| 501 |
+
|
| 502 |
+
def __init__(
|
| 503 |
+
self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1
|
| 504 |
+
) -> None:
|
| 505 |
+
super().__init__()
|
| 506 |
+
self.num_channels = num_channels
|
| 507 |
+
self.out_channels = out_channels or num_channels
|
| 508 |
+
self.use_conv = use_conv
|
| 509 |
+
if use_conv:
|
| 510 |
+
self.op = Convolution(
|
| 511 |
+
spatial_dims=spatial_dims,
|
| 512 |
+
in_channels=self.num_channels,
|
| 513 |
+
out_channels=self.out_channels,
|
| 514 |
+
strides=2,
|
| 515 |
+
kernel_size=3,
|
| 516 |
+
padding=padding,
|
| 517 |
+
conv_only=True,
|
| 518 |
+
)
|
| 519 |
+
else:
|
| 520 |
+
if self.num_channels != self.out_channels:
|
| 521 |
+
raise ValueError("num_channels and out_channels must be equal when use_conv=False")
|
| 522 |
+
self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2)
|
| 523 |
+
|
| 524 |
+
def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
|
| 525 |
+
del emb
|
| 526 |
+
if x.shape[1] != self.num_channels:
|
| 527 |
+
raise ValueError(
|
| 528 |
+
f"Input number of channels ({x.shape[1]}) is not equal to expected number of channels "
|
| 529 |
+
f"({self.num_channels})"
|
| 530 |
+
)
|
| 531 |
+
return self.op(x)
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
class Upsample(nn.Module):
|
| 535 |
+
"""
|
| 536 |
+
Upsampling layer with an optional convolution.
|
| 537 |
+
|
| 538 |
+
Args:
|
| 539 |
+
spatial_dims: number of spatial dimensions.
|
| 540 |
+
num_channels: number of input channels.
|
| 541 |
+
use_conv: if True uses Convolution instead of Pool average to perform downsampling.
|
| 542 |
+
out_channels: number of output channels.
|
| 543 |
+
padding: controls the amount of implicit zero-paddings on both sides for padding number of points for each
|
| 544 |
+
dimension.
|
| 545 |
+
"""
|
| 546 |
+
|
| 547 |
+
def __init__(
|
| 548 |
+
self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1
|
| 549 |
+
) -> None:
|
| 550 |
+
super().__init__()
|
| 551 |
+
self.num_channels = num_channels
|
| 552 |
+
self.out_channels = out_channels or num_channels
|
| 553 |
+
self.use_conv = use_conv
|
| 554 |
+
if use_conv:
|
| 555 |
+
self.conv = Convolution(
|
| 556 |
+
spatial_dims=spatial_dims,
|
| 557 |
+
in_channels=self.num_channels,
|
| 558 |
+
out_channels=self.out_channels,
|
| 559 |
+
strides=1,
|
| 560 |
+
kernel_size=3,
|
| 561 |
+
padding=padding,
|
| 562 |
+
conv_only=True,
|
| 563 |
+
)
|
| 564 |
+
else:
|
| 565 |
+
self.conv = None
|
| 566 |
+
|
| 567 |
+
def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
|
| 568 |
+
del emb
|
| 569 |
+
if x.shape[1] != self.num_channels:
|
| 570 |
+
raise ValueError("Input channels should be equal to num_channels")
|
| 571 |
+
|
| 572 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
| 573 |
+
# https://github.com/pytorch/pytorch/issues/86679
|
| 574 |
+
dtype = x.dtype
|
| 575 |
+
if dtype == torch.bfloat16:
|
| 576 |
+
x = x.to(torch.float32)
|
| 577 |
+
|
| 578 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 579 |
+
|
| 580 |
+
# If the input is bfloat16, we cast back to bfloat16
|
| 581 |
+
if dtype == torch.bfloat16:
|
| 582 |
+
x = x.to(dtype)
|
| 583 |
+
|
| 584 |
+
if self.use_conv:
|
| 585 |
+
x = self.conv(x)
|
| 586 |
+
return x
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
class ResnetBlock(nn.Module):
|
| 590 |
+
"""
|
| 591 |
+
Residual block with timestep conditioning.
|
| 592 |
+
|
| 593 |
+
Args:
|
| 594 |
+
spatial_dims: The number of spatial dimensions.
|
| 595 |
+
in_channels: number of input channels.
|
| 596 |
+
temb_channels: number of timestep embedding channels.
|
| 597 |
+
out_channels: number of output channels.
|
| 598 |
+
up: if True, performs upsampling.
|
| 599 |
+
down: if True, performs downsampling.
|
| 600 |
+
norm_num_groups: number of groups for the group normalization.
|
| 601 |
+
norm_eps: epsilon for the group normalization.
|
| 602 |
+
"""
|
| 603 |
+
|
| 604 |
+
def __init__(
|
| 605 |
+
self,
|
| 606 |
+
spatial_dims: int,
|
| 607 |
+
in_channels: int,
|
| 608 |
+
temb_channels: int,
|
| 609 |
+
out_channels: int | None = None,
|
| 610 |
+
up: bool = False,
|
| 611 |
+
down: bool = False,
|
| 612 |
+
norm_num_groups: int = 32,
|
| 613 |
+
norm_eps: float = 1e-6,
|
| 614 |
+
) -> None:
|
| 615 |
+
super().__init__()
|
| 616 |
+
self.spatial_dims = spatial_dims
|
| 617 |
+
self.channels = in_channels
|
| 618 |
+
self.emb_channels = temb_channels
|
| 619 |
+
self.out_channels = out_channels or in_channels
|
| 620 |
+
self.up = up
|
| 621 |
+
self.down = down
|
| 622 |
+
|
| 623 |
+
self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True)
|
| 624 |
+
self.nonlinearity = nn.SiLU()
|
| 625 |
+
self.conv1 = Convolution(
|
| 626 |
+
spatial_dims=spatial_dims,
|
| 627 |
+
in_channels=in_channels,
|
| 628 |
+
out_channels=self.out_channels,
|
| 629 |
+
strides=1,
|
| 630 |
+
kernel_size=3,
|
| 631 |
+
padding=1,
|
| 632 |
+
conv_only=True,
|
| 633 |
+
dilation=3
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
self.upsample = self.downsample = None
|
| 637 |
+
if self.up:
|
| 638 |
+
self.upsample = Upsample(spatial_dims, in_channels, use_conv=False)
|
| 639 |
+
elif down:
|
| 640 |
+
self.downsample = Downsample(spatial_dims, in_channels, use_conv=False)
|
| 641 |
+
|
| 642 |
+
self.time_emb_proj = nn.Linear(temb_channels, self.out_channels)
|
| 643 |
+
|
| 644 |
+
self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=self.out_channels, eps=norm_eps, affine=True)
|
| 645 |
+
self.conv2 = zero_module(
|
| 646 |
+
Convolution(
|
| 647 |
+
spatial_dims=spatial_dims,
|
| 648 |
+
in_channels=self.out_channels,
|
| 649 |
+
out_channels=self.out_channels,
|
| 650 |
+
strides=1,
|
| 651 |
+
kernel_size=3,
|
| 652 |
+
padding=1,
|
| 653 |
+
conv_only=True,
|
| 654 |
+
dilation=2
|
| 655 |
+
)
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
if self.out_channels == in_channels:
|
| 659 |
+
self.skip_connection = nn.Identity()
|
| 660 |
+
else:
|
| 661 |
+
self.skip_connection = Convolution(
|
| 662 |
+
spatial_dims=spatial_dims,
|
| 663 |
+
in_channels=in_channels,
|
| 664 |
+
out_channels=self.out_channels,
|
| 665 |
+
strides=1,
|
| 666 |
+
kernel_size=1,
|
| 667 |
+
padding=0,
|
| 668 |
+
conv_only=True,
|
| 669 |
+
dilation=1
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
| 673 |
+
h = x
|
| 674 |
+
h = self.norm1(h)
|
| 675 |
+
h = self.nonlinearity(h)
|
| 676 |
+
|
| 677 |
+
if self.upsample is not None:
|
| 678 |
+
if h.shape[0] >= 64:
|
| 679 |
+
x = x.contiguous()
|
| 680 |
+
h = h.contiguous()
|
| 681 |
+
x = self.upsample(x)
|
| 682 |
+
h = self.upsample(h)
|
| 683 |
+
elif self.downsample is not None:
|
| 684 |
+
x = self.downsample(x)
|
| 685 |
+
h = self.downsample(h)
|
| 686 |
+
|
| 687 |
+
h = self.conv1(h)
|
| 688 |
+
|
| 689 |
+
if self.spatial_dims == 2:
|
| 690 |
+
temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None]
|
| 691 |
+
else:
|
| 692 |
+
temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None]
|
| 693 |
+
h = h + temb
|
| 694 |
+
|
| 695 |
+
h = self.norm2(h)
|
| 696 |
+
h = self.nonlinearity(h)
|
| 697 |
+
h = self.conv2(h)
|
| 698 |
+
|
| 699 |
+
return self.skip_connection(x) + h
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
class DownBlock(nn.Module):
|
| 703 |
+
"""
|
| 704 |
+
Unet's down block containing resnet and downsamplers blocks.
|
| 705 |
+
|
| 706 |
+
Args:
|
| 707 |
+
spatial_dims: The number of spatial dimensions.
|
| 708 |
+
in_channels: number of input channels.
|
| 709 |
+
out_channels: number of output channels.
|
| 710 |
+
temb_channels: number of timestep embedding channels.
|
| 711 |
+
num_res_blocks: number of residual blocks.
|
| 712 |
+
norm_num_groups: number of groups for the group normalization.
|
| 713 |
+
norm_eps: epsilon for the group normalization.
|
| 714 |
+
add_downsample: if True add downsample block.
|
| 715 |
+
resblock_updown: if True use residual blocks for downsampling.
|
| 716 |
+
downsample_padding: padding used in the downsampling block.
|
| 717 |
+
"""
|
| 718 |
+
|
| 719 |
+
def __init__(
|
| 720 |
+
self,
|
| 721 |
+
spatial_dims: int,
|
| 722 |
+
in_channels: int,
|
| 723 |
+
out_channels: int,
|
| 724 |
+
temb_channels: int,
|
| 725 |
+
num_res_blocks: int = 1,
|
| 726 |
+
norm_num_groups: int = 32,
|
| 727 |
+
norm_eps: float = 1e-6,
|
| 728 |
+
add_downsample: bool = True,
|
| 729 |
+
resblock_updown: bool = False,
|
| 730 |
+
downsample_padding: int = 1,
|
| 731 |
+
) -> None:
|
| 732 |
+
super().__init__()
|
| 733 |
+
self.resblock_updown = resblock_updown
|
| 734 |
+
|
| 735 |
+
resnets = []
|
| 736 |
+
|
| 737 |
+
for i in range(num_res_blocks):
|
| 738 |
+
in_channels = in_channels if i == 0 else out_channels
|
| 739 |
+
resnets.append(
|
| 740 |
+
ResnetBlock(
|
| 741 |
+
spatial_dims=spatial_dims,
|
| 742 |
+
in_channels=in_channels,
|
| 743 |
+
out_channels=out_channels,
|
| 744 |
+
temb_channels=temb_channels,
|
| 745 |
+
norm_num_groups=norm_num_groups,
|
| 746 |
+
norm_eps=norm_eps,
|
| 747 |
+
)
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
self.resnets = nn.ModuleList(resnets)
|
| 751 |
+
|
| 752 |
+
if add_downsample:
|
| 753 |
+
if resblock_updown:
|
| 754 |
+
self.downsampler = ResnetBlock(
|
| 755 |
+
spatial_dims=spatial_dims,
|
| 756 |
+
in_channels=out_channels,
|
| 757 |
+
out_channels=out_channels,
|
| 758 |
+
temb_channels=temb_channels,
|
| 759 |
+
norm_num_groups=norm_num_groups,
|
| 760 |
+
norm_eps=norm_eps,
|
| 761 |
+
down=True,
|
| 762 |
+
)
|
| 763 |
+
else:
|
| 764 |
+
self.downsampler = Downsample(
|
| 765 |
+
spatial_dims=spatial_dims,
|
| 766 |
+
num_channels=out_channels,
|
| 767 |
+
use_conv=True,
|
| 768 |
+
out_channels=out_channels,
|
| 769 |
+
padding=downsample_padding,
|
| 770 |
+
)
|
| 771 |
+
else:
|
| 772 |
+
self.downsampler = None
|
| 773 |
+
|
| 774 |
+
def forward(
|
| 775 |
+
self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
|
| 776 |
+
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
| 777 |
+
del context
|
| 778 |
+
output_states = []
|
| 779 |
+
|
| 780 |
+
for resnet in self.resnets:
|
| 781 |
+
hidden_states = resnet(hidden_states, temb)
|
| 782 |
+
output_states.append(hidden_states)
|
| 783 |
+
|
| 784 |
+
if self.downsampler is not None:
|
| 785 |
+
hidden_states = self.downsampler(hidden_states, temb)
|
| 786 |
+
output_states.append(hidden_states)
|
| 787 |
+
|
| 788 |
+
return hidden_states, output_states
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
class AttnDownBlock(nn.Module):
|
| 792 |
+
"""
|
| 793 |
+
Unet's down block containing resnet, downsamplers and self-attention blocks.
|
| 794 |
+
|
| 795 |
+
Args:
|
| 796 |
+
spatial_dims: The number of spatial dimensions.
|
| 797 |
+
in_channels: number of input channels.
|
| 798 |
+
out_channels: number of output channels.
|
| 799 |
+
temb_channels: number of timestep embedding channels.
|
| 800 |
+
num_res_blocks: number of residual blocks.
|
| 801 |
+
norm_num_groups: number of groups for the group normalization.
|
| 802 |
+
norm_eps: epsilon for the group normalization.
|
| 803 |
+
add_downsample: if True add downsample block.
|
| 804 |
+
resblock_updown: if True use residual blocks for downsampling.
|
| 805 |
+
downsample_padding: padding used in the downsampling block.
|
| 806 |
+
num_head_channels: number of channels in each attention head.
|
| 807 |
+
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
|
| 808 |
+
"""
|
| 809 |
+
|
| 810 |
+
def __init__(
|
| 811 |
+
self,
|
| 812 |
+
spatial_dims: int,
|
| 813 |
+
in_channels: int,
|
| 814 |
+
out_channels: int,
|
| 815 |
+
temb_channels: int,
|
| 816 |
+
num_res_blocks: int = 1,
|
| 817 |
+
norm_num_groups: int = 32,
|
| 818 |
+
norm_eps: float = 1e-6,
|
| 819 |
+
add_downsample: bool = True,
|
| 820 |
+
resblock_updown: bool = False,
|
| 821 |
+
downsample_padding: int = 1,
|
| 822 |
+
num_head_channels: int = 1,
|
| 823 |
+
use_flash_attention: bool = False,
|
| 824 |
+
) -> None:
|
| 825 |
+
super().__init__()
|
| 826 |
+
self.resblock_updown = resblock_updown
|
| 827 |
+
|
| 828 |
+
resnets = []
|
| 829 |
+
attentions = []
|
| 830 |
+
|
| 831 |
+
for i in range(num_res_blocks):
|
| 832 |
+
in_channels = in_channels if i == 0 else out_channels
|
| 833 |
+
resnets.append(
|
| 834 |
+
ResnetBlock(
|
| 835 |
+
spatial_dims=spatial_dims,
|
| 836 |
+
in_channels=in_channels,
|
| 837 |
+
out_channels=out_channels,
|
| 838 |
+
temb_channels=temb_channels,
|
| 839 |
+
norm_num_groups=norm_num_groups,
|
| 840 |
+
norm_eps=norm_eps,
|
| 841 |
+
)
|
| 842 |
+
)
|
| 843 |
+
attentions.append(
|
| 844 |
+
AttentionBlock(
|
| 845 |
+
spatial_dims=spatial_dims,
|
| 846 |
+
num_channels=out_channels,
|
| 847 |
+
num_head_channels=num_head_channels,
|
| 848 |
+
norm_num_groups=norm_num_groups,
|
| 849 |
+
norm_eps=norm_eps,
|
| 850 |
+
use_flash_attention=use_flash_attention,
|
| 851 |
+
)
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
self.attentions = nn.ModuleList(attentions)
|
| 855 |
+
self.resnets = nn.ModuleList(resnets)
|
| 856 |
+
|
| 857 |
+
if add_downsample:
|
| 858 |
+
if resblock_updown:
|
| 859 |
+
self.downsampler = ResnetBlock(
|
| 860 |
+
spatial_dims=spatial_dims,
|
| 861 |
+
in_channels=out_channels,
|
| 862 |
+
out_channels=out_channels,
|
| 863 |
+
temb_channels=temb_channels,
|
| 864 |
+
norm_num_groups=norm_num_groups,
|
| 865 |
+
norm_eps=norm_eps,
|
| 866 |
+
down=True,
|
| 867 |
+
)
|
| 868 |
+
else:
|
| 869 |
+
self.downsampler = Downsample(
|
| 870 |
+
spatial_dims=spatial_dims,
|
| 871 |
+
num_channels=out_channels,
|
| 872 |
+
use_conv=True,
|
| 873 |
+
out_channels=out_channels,
|
| 874 |
+
padding=downsample_padding,
|
| 875 |
+
)
|
| 876 |
+
else:
|
| 877 |
+
self.downsampler = None
|
| 878 |
+
|
| 879 |
+
def forward(
|
| 880 |
+
self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
|
| 881 |
+
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
| 882 |
+
del context
|
| 883 |
+
output_states = []
|
| 884 |
+
|
| 885 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
| 886 |
+
hidden_states = resnet(hidden_states, temb)
|
| 887 |
+
hidden_states = attn(hidden_states)
|
| 888 |
+
output_states.append(hidden_states)
|
| 889 |
+
|
| 890 |
+
if self.downsampler is not None:
|
| 891 |
+
hidden_states = self.downsampler(hidden_states, temb)
|
| 892 |
+
output_states.append(hidden_states)
|
| 893 |
+
|
| 894 |
+
return hidden_states, output_states
|
| 895 |
+
|
| 896 |
+
|
| 897 |
+
class CrossAttnDownBlock(nn.Module):
|
| 898 |
+
"""
|
| 899 |
+
Unet's down block containing resnet, downsamplers and cross-attention blocks.
|
| 900 |
+
|
| 901 |
+
Args:
|
| 902 |
+
spatial_dims: number of spatial dimensions.
|
| 903 |
+
in_channels: number of input channels.
|
| 904 |
+
out_channels: number of output channels.
|
| 905 |
+
temb_channels: number of timestep embedding channels.
|
| 906 |
+
num_res_blocks: number of residual blocks.
|
| 907 |
+
norm_num_groups: number of groups for the group normalization.
|
| 908 |
+
norm_eps: epsilon for the group normalization.
|
| 909 |
+
add_downsample: if True add downsample block.
|
| 910 |
+
resblock_updown: if True use residual blocks for downsampling.
|
| 911 |
+
downsample_padding: padding used in the downsampling block.
|
| 912 |
+
num_head_channels: number of channels in each attention head.
|
| 913 |
+
transformer_num_layers: number of layers of Transformer blocks to use.
|
| 914 |
+
cross_attention_dim: number of context dimensions to use.
|
| 915 |
+
upcast_attention: if True, upcast attention operations to full precision.
|
| 916 |
+
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
|
| 917 |
+
"""
|
| 918 |
+
|
| 919 |
+
def __init__(
|
| 920 |
+
self,
|
| 921 |
+
spatial_dims: int,
|
| 922 |
+
in_channels: int,
|
| 923 |
+
out_channels: int,
|
| 924 |
+
temb_channels: int,
|
| 925 |
+
num_res_blocks: int = 1,
|
| 926 |
+
norm_num_groups: int = 32,
|
| 927 |
+
norm_eps: float = 1e-6,
|
| 928 |
+
add_downsample: bool = True,
|
| 929 |
+
resblock_updown: bool = False,
|
| 930 |
+
downsample_padding: int = 1,
|
| 931 |
+
num_head_channels: int = 1,
|
| 932 |
+
transformer_num_layers: int = 1,
|
| 933 |
+
cross_attention_dim: int | None = None,
|
| 934 |
+
upcast_attention: bool = False,
|
| 935 |
+
use_flash_attention: bool = False,
|
| 936 |
+
) -> None:
|
| 937 |
+
super().__init__()
|
| 938 |
+
self.resblock_updown = resblock_updown
|
| 939 |
+
|
| 940 |
+
resnets = []
|
| 941 |
+
attentions = []
|
| 942 |
+
|
| 943 |
+
for i in range(num_res_blocks):
|
| 944 |
+
in_channels = in_channels if i == 0 else out_channels
|
| 945 |
+
resnets.append(
|
| 946 |
+
ResnetBlock(
|
| 947 |
+
spatial_dims=spatial_dims,
|
| 948 |
+
in_channels=in_channels,
|
| 949 |
+
out_channels=out_channels,
|
| 950 |
+
temb_channels=temb_channels,
|
| 951 |
+
norm_num_groups=norm_num_groups,
|
| 952 |
+
norm_eps=norm_eps,
|
| 953 |
+
)
|
| 954 |
+
)
|
| 955 |
+
|
| 956 |
+
attentions.append(
|
| 957 |
+
SpatialTransformer(
|
| 958 |
+
spatial_dims=spatial_dims,
|
| 959 |
+
in_channels=out_channels,
|
| 960 |
+
num_attention_heads=out_channels // num_head_channels,
|
| 961 |
+
num_head_channels=num_head_channels,
|
| 962 |
+
num_layers=transformer_num_layers,
|
| 963 |
+
norm_num_groups=norm_num_groups,
|
| 964 |
+
norm_eps=norm_eps,
|
| 965 |
+
cross_attention_dim=cross_attention_dim,
|
| 966 |
+
upcast_attention=upcast_attention,
|
| 967 |
+
use_flash_attention=use_flash_attention,
|
| 968 |
+
)
|
| 969 |
+
)
|
| 970 |
+
|
| 971 |
+
self.attentions = nn.ModuleList(attentions)
|
| 972 |
+
self.resnets = nn.ModuleList(resnets)
|
| 973 |
+
|
| 974 |
+
if add_downsample:
|
| 975 |
+
if resblock_updown:
|
| 976 |
+
self.downsampler = ResnetBlock(
|
| 977 |
+
spatial_dims=spatial_dims,
|
| 978 |
+
in_channels=out_channels,
|
| 979 |
+
out_channels=out_channels,
|
| 980 |
+
temb_channels=temb_channels,
|
| 981 |
+
norm_num_groups=norm_num_groups,
|
| 982 |
+
norm_eps=norm_eps,
|
| 983 |
+
down=True,
|
| 984 |
+
)
|
| 985 |
+
else:
|
| 986 |
+
self.downsampler = Downsample(
|
| 987 |
+
spatial_dims=spatial_dims,
|
| 988 |
+
num_channels=out_channels,
|
| 989 |
+
use_conv=True,
|
| 990 |
+
out_channels=out_channels,
|
| 991 |
+
padding=downsample_padding,
|
| 992 |
+
)
|
| 993 |
+
else:
|
| 994 |
+
self.downsampler = None
|
| 995 |
+
|
| 996 |
+
def forward(
|
| 997 |
+
self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
|
| 998 |
+
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
| 999 |
+
output_states = []
|
| 1000 |
+
|
| 1001 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
| 1002 |
+
hidden_states = resnet(hidden_states, temb)
|
| 1003 |
+
hidden_states = attn(hidden_states, context=context)
|
| 1004 |
+
output_states.append(hidden_states)
|
| 1005 |
+
|
| 1006 |
+
if self.downsampler is not None:
|
| 1007 |
+
hidden_states = self.downsampler(hidden_states, temb)
|
| 1008 |
+
output_states.append(hidden_states)
|
| 1009 |
+
|
| 1010 |
+
return hidden_states, output_states
|
| 1011 |
+
|
| 1012 |
+
|
| 1013 |
+
class AttnMidBlock(nn.Module):
|
| 1014 |
+
"""
|
| 1015 |
+
Unet's mid block containing resnet and self-attention blocks.
|
| 1016 |
+
|
| 1017 |
+
Args:
|
| 1018 |
+
spatial_dims: The number of spatial dimensions.
|
| 1019 |
+
in_channels: number of input channels.
|
| 1020 |
+
temb_channels: number of timestep embedding channels.
|
| 1021 |
+
norm_num_groups: number of groups for the group normalization.
|
| 1022 |
+
norm_eps: epsilon for the group normalization.
|
| 1023 |
+
num_head_channels: number of channels in each attention head.
|
| 1024 |
+
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
|
| 1025 |
+
"""
|
| 1026 |
+
|
| 1027 |
+
def __init__(
|
| 1028 |
+
self,
|
| 1029 |
+
spatial_dims: int,
|
| 1030 |
+
in_channels: int,
|
| 1031 |
+
temb_channels: int,
|
| 1032 |
+
norm_num_groups: int = 32,
|
| 1033 |
+
norm_eps: float = 1e-6,
|
| 1034 |
+
num_head_channels: int = 1,
|
| 1035 |
+
use_flash_attention: bool = False,
|
| 1036 |
+
) -> None:
|
| 1037 |
+
super().__init__()
|
| 1038 |
+
self.attention = None
|
| 1039 |
+
|
| 1040 |
+
self.resnet_1 = ResnetBlock(
|
| 1041 |
+
spatial_dims=spatial_dims,
|
| 1042 |
+
in_channels=in_channels,
|
| 1043 |
+
out_channels=in_channels,
|
| 1044 |
+
temb_channels=temb_channels,
|
| 1045 |
+
norm_num_groups=norm_num_groups,
|
| 1046 |
+
norm_eps=norm_eps,
|
| 1047 |
+
)
|
| 1048 |
+
self.attention = AttentionBlock(
|
| 1049 |
+
spatial_dims=spatial_dims,
|
| 1050 |
+
num_channels=in_channels,
|
| 1051 |
+
num_head_channels=num_head_channels,
|
| 1052 |
+
norm_num_groups=norm_num_groups,
|
| 1053 |
+
norm_eps=norm_eps,
|
| 1054 |
+
use_flash_attention=use_flash_attention,
|
| 1055 |
+
)
|
| 1056 |
+
|
| 1057 |
+
self.resnet_2 = ResnetBlock(
|
| 1058 |
+
spatial_dims=spatial_dims,
|
| 1059 |
+
in_channels=in_channels,
|
| 1060 |
+
out_channels=in_channels,
|
| 1061 |
+
temb_channels=temb_channels,
|
| 1062 |
+
norm_num_groups=norm_num_groups,
|
| 1063 |
+
norm_eps=norm_eps,
|
| 1064 |
+
)
|
| 1065 |
+
|
| 1066 |
+
def forward(
|
| 1067 |
+
self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
|
| 1068 |
+
) -> torch.Tensor:
|
| 1069 |
+
del context
|
| 1070 |
+
hidden_states = self.resnet_1(hidden_states, temb)
|
| 1071 |
+
hidden_states = self.attention(hidden_states)
|
| 1072 |
+
hidden_states = self.resnet_2(hidden_states, temb)
|
| 1073 |
+
|
| 1074 |
+
return hidden_states
|
| 1075 |
+
|
| 1076 |
+
|
| 1077 |
+
class CrossAttnMidBlock(nn.Module):
|
| 1078 |
+
"""
|
| 1079 |
+
Unet's mid block containing resnet and cross-attention blocks.
|
| 1080 |
+
|
| 1081 |
+
Args:
|
| 1082 |
+
spatial_dims: The number of spatial dimensions.
|
| 1083 |
+
in_channels: number of input channels.
|
| 1084 |
+
temb_channels: number of timestep embedding channels
|
| 1085 |
+
norm_num_groups: number of groups for the group normalization.
|
| 1086 |
+
norm_eps: epsilon for the group normalization.
|
| 1087 |
+
num_head_channels: number of channels in each attention head.
|
| 1088 |
+
transformer_num_layers: number of layers of Transformer blocks to use.
|
| 1089 |
+
cross_attention_dim: number of context dimensions to use.
|
| 1090 |
+
upcast_attention: if True, upcast attention operations to full precision.
|
| 1091 |
+
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
|
| 1092 |
+
"""
|
| 1093 |
+
|
| 1094 |
+
def __init__(
|
| 1095 |
+
self,
|
| 1096 |
+
spatial_dims: int,
|
| 1097 |
+
in_channels: int,
|
| 1098 |
+
temb_channels: int,
|
| 1099 |
+
norm_num_groups: int = 32,
|
| 1100 |
+
norm_eps: float = 1e-6,
|
| 1101 |
+
num_head_channels: int = 1,
|
| 1102 |
+
transformer_num_layers: int = 1,
|
| 1103 |
+
cross_attention_dim: int | None = None,
|
| 1104 |
+
upcast_attention: bool = False,
|
| 1105 |
+
use_flash_attention: bool = False,
|
| 1106 |
+
) -> None:
|
| 1107 |
+
super().__init__()
|
| 1108 |
+
self.attention = None
|
| 1109 |
+
|
| 1110 |
+
self.resnet_1 = ResnetBlock(
|
| 1111 |
+
spatial_dims=spatial_dims,
|
| 1112 |
+
in_channels=in_channels,
|
| 1113 |
+
out_channels=in_channels,
|
| 1114 |
+
temb_channels=temb_channels,
|
| 1115 |
+
norm_num_groups=norm_num_groups,
|
| 1116 |
+
norm_eps=norm_eps,
|
| 1117 |
+
)
|
| 1118 |
+
self.attention = SpatialTransformer(
|
| 1119 |
+
spatial_dims=spatial_dims,
|
| 1120 |
+
in_channels=in_channels,
|
| 1121 |
+
num_attention_heads=in_channels // num_head_channels,
|
| 1122 |
+
num_head_channels=num_head_channels,
|
| 1123 |
+
num_layers=transformer_num_layers,
|
| 1124 |
+
norm_num_groups=norm_num_groups,
|
| 1125 |
+
norm_eps=norm_eps,
|
| 1126 |
+
cross_attention_dim=cross_attention_dim,
|
| 1127 |
+
upcast_attention=upcast_attention,
|
| 1128 |
+
use_flash_attention=use_flash_attention,
|
| 1129 |
+
)
|
| 1130 |
+
self.resnet_2 = ResnetBlock(
|
| 1131 |
+
spatial_dims=spatial_dims,
|
| 1132 |
+
in_channels=in_channels,
|
| 1133 |
+
out_channels=in_channels,
|
| 1134 |
+
temb_channels=temb_channels,
|
| 1135 |
+
norm_num_groups=norm_num_groups,
|
| 1136 |
+
norm_eps=norm_eps,
|
| 1137 |
+
)
|
| 1138 |
+
|
| 1139 |
+
def forward(
|
| 1140 |
+
self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
|
| 1141 |
+
) -> torch.Tensor:
|
| 1142 |
+
hidden_states = self.resnet_1(hidden_states, temb)
|
| 1143 |
+
hidden_states = self.attention(hidden_states, context=context)
|
| 1144 |
+
hidden_states = self.resnet_2(hidden_states, temb)
|
| 1145 |
+
|
| 1146 |
+
return hidden_states
|
| 1147 |
+
|
| 1148 |
+
|
| 1149 |
+
class UpBlock(nn.Module):
|
| 1150 |
+
"""
|
| 1151 |
+
Unet's up block containing resnet and upsamplers blocks.
|
| 1152 |
+
|
| 1153 |
+
Args:
|
| 1154 |
+
spatial_dims: The number of spatial dimensions.
|
| 1155 |
+
in_channels: number of input channels.
|
| 1156 |
+
prev_output_channel: number of channels from residual connection.
|
| 1157 |
+
out_channels: number of output channels.
|
| 1158 |
+
temb_channels: number of timestep embedding channels.
|
| 1159 |
+
num_res_blocks: number of residual blocks.
|
| 1160 |
+
norm_num_groups: number of groups for the group normalization.
|
| 1161 |
+
norm_eps: epsilon for the group normalization.
|
| 1162 |
+
add_upsample: if True add downsample block.
|
| 1163 |
+
resblock_updown: if True use residual blocks for upsampling.
|
| 1164 |
+
"""
|
| 1165 |
+
|
| 1166 |
+
def __init__(
|
| 1167 |
+
self,
|
| 1168 |
+
spatial_dims: int,
|
| 1169 |
+
in_channels: int,
|
| 1170 |
+
prev_output_channel: int,
|
| 1171 |
+
out_channels: int,
|
| 1172 |
+
temb_channels: int,
|
| 1173 |
+
num_res_blocks: int = 1,
|
| 1174 |
+
norm_num_groups: int = 32,
|
| 1175 |
+
norm_eps: float = 1e-6,
|
| 1176 |
+
add_upsample: bool = True,
|
| 1177 |
+
resblock_updown: bool = False,
|
| 1178 |
+
) -> None:
|
| 1179 |
+
super().__init__()
|
| 1180 |
+
self.resblock_updown = resblock_updown
|
| 1181 |
+
resnets = []
|
| 1182 |
+
|
| 1183 |
+
for i in range(num_res_blocks):
|
| 1184 |
+
res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels
|
| 1185 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
| 1186 |
+
|
| 1187 |
+
resnets.append(
|
| 1188 |
+
ResnetBlock(
|
| 1189 |
+
spatial_dims=spatial_dims,
|
| 1190 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
| 1191 |
+
out_channels=out_channels,
|
| 1192 |
+
temb_channels=temb_channels,
|
| 1193 |
+
norm_num_groups=norm_num_groups,
|
| 1194 |
+
norm_eps=norm_eps,
|
| 1195 |
+
)
|
| 1196 |
+
)
|
| 1197 |
+
|
| 1198 |
+
self.resnets = nn.ModuleList(resnets)
|
| 1199 |
+
|
| 1200 |
+
if add_upsample:
|
| 1201 |
+
if resblock_updown:
|
| 1202 |
+
self.upsampler = ResnetBlock(
|
| 1203 |
+
spatial_dims=spatial_dims,
|
| 1204 |
+
in_channels=out_channels,
|
| 1205 |
+
out_channels=out_channels,
|
| 1206 |
+
temb_channels=temb_channels,
|
| 1207 |
+
norm_num_groups=norm_num_groups,
|
| 1208 |
+
norm_eps=norm_eps,
|
| 1209 |
+
up=True,
|
| 1210 |
+
)
|
| 1211 |
+
else:
|
| 1212 |
+
self.upsampler = Upsample(
|
| 1213 |
+
spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels
|
| 1214 |
+
)
|
| 1215 |
+
else:
|
| 1216 |
+
self.upsampler = None
|
| 1217 |
+
|
| 1218 |
+
def forward(
|
| 1219 |
+
self,
|
| 1220 |
+
hidden_states: torch.Tensor,
|
| 1221 |
+
res_hidden_states_list: list[torch.Tensor],
|
| 1222 |
+
temb: torch.Tensor,
|
| 1223 |
+
context: torch.Tensor | None = None,
|
| 1224 |
+
) -> torch.Tensor:
|
| 1225 |
+
del context
|
| 1226 |
+
for resnet in self.resnets:
|
| 1227 |
+
# pop res hidden states
|
| 1228 |
+
res_hidden_states = res_hidden_states_list[-1]
|
| 1229 |
+
res_hidden_states_list = res_hidden_states_list[:-1]
|
| 1230 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 1231 |
+
|
| 1232 |
+
hidden_states = resnet(hidden_states, temb)
|
| 1233 |
+
|
| 1234 |
+
if self.upsampler is not None:
|
| 1235 |
+
hidden_states = self.upsampler(hidden_states, temb)
|
| 1236 |
+
|
| 1237 |
+
return hidden_states
|
| 1238 |
+
|
| 1239 |
+
|
| 1240 |
+
class AttnUpBlock(nn.Module):
|
| 1241 |
+
"""
|
| 1242 |
+
Unet's up block containing resnet, upsamplers, and self-attention blocks.
|
| 1243 |
+
|
| 1244 |
+
Args:
|
| 1245 |
+
spatial_dims: The number of spatial dimensions.
|
| 1246 |
+
in_channels: number of input channels.
|
| 1247 |
+
prev_output_channel: number of channels from residual connection.
|
| 1248 |
+
out_channels: number of output channels.
|
| 1249 |
+
temb_channels: number of timestep embedding channels.
|
| 1250 |
+
num_res_blocks: number of residual blocks.
|
| 1251 |
+
norm_num_groups: number of groups for the group normalization.
|
| 1252 |
+
norm_eps: epsilon for the group normalization.
|
| 1253 |
+
add_upsample: if True add downsample block.
|
| 1254 |
+
resblock_updown: if True use residual blocks for upsampling.
|
| 1255 |
+
num_head_channels: number of channels in each attention head.
|
| 1256 |
+
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
|
| 1257 |
+
"""
|
| 1258 |
+
|
| 1259 |
+
def __init__(
|
| 1260 |
+
self,
|
| 1261 |
+
spatial_dims: int,
|
| 1262 |
+
in_channels: int,
|
| 1263 |
+
prev_output_channel: int,
|
| 1264 |
+
out_channels: int,
|
| 1265 |
+
temb_channels: int,
|
| 1266 |
+
num_res_blocks: int = 1,
|
| 1267 |
+
norm_num_groups: int = 32,
|
| 1268 |
+
norm_eps: float = 1e-6,
|
| 1269 |
+
add_upsample: bool = True,
|
| 1270 |
+
resblock_updown: bool = False,
|
| 1271 |
+
num_head_channels: int = 1,
|
| 1272 |
+
use_flash_attention: bool = False,
|
| 1273 |
+
) -> None:
|
| 1274 |
+
super().__init__()
|
| 1275 |
+
self.resblock_updown = resblock_updown
|
| 1276 |
+
|
| 1277 |
+
resnets = []
|
| 1278 |
+
attentions = []
|
| 1279 |
+
|
| 1280 |
+
for i in range(num_res_blocks):
|
| 1281 |
+
res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels
|
| 1282 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
| 1283 |
+
|
| 1284 |
+
resnets.append(
|
| 1285 |
+
ResnetBlock(
|
| 1286 |
+
spatial_dims=spatial_dims,
|
| 1287 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
| 1288 |
+
out_channels=out_channels,
|
| 1289 |
+
temb_channels=temb_channels,
|
| 1290 |
+
norm_num_groups=norm_num_groups,
|
| 1291 |
+
norm_eps=norm_eps,
|
| 1292 |
+
)
|
| 1293 |
+
)
|
| 1294 |
+
attentions.append(
|
| 1295 |
+
AttentionBlock(
|
| 1296 |
+
spatial_dims=spatial_dims,
|
| 1297 |
+
num_channels=out_channels,
|
| 1298 |
+
num_head_channels=num_head_channels,
|
| 1299 |
+
norm_num_groups=norm_num_groups,
|
| 1300 |
+
norm_eps=norm_eps,
|
| 1301 |
+
use_flash_attention=use_flash_attention,
|
| 1302 |
+
)
|
| 1303 |
+
)
|
| 1304 |
+
|
| 1305 |
+
self.resnets = nn.ModuleList(resnets)
|
| 1306 |
+
self.attentions = nn.ModuleList(attentions)
|
| 1307 |
+
|
| 1308 |
+
if add_upsample:
|
| 1309 |
+
if resblock_updown:
|
| 1310 |
+
self.upsampler = ResnetBlock(
|
| 1311 |
+
spatial_dims=spatial_dims,
|
| 1312 |
+
in_channels=out_channels,
|
| 1313 |
+
out_channels=out_channels,
|
| 1314 |
+
temb_channels=temb_channels,
|
| 1315 |
+
norm_num_groups=norm_num_groups,
|
| 1316 |
+
norm_eps=norm_eps,
|
| 1317 |
+
up=True,
|
| 1318 |
+
)
|
| 1319 |
+
else:
|
| 1320 |
+
self.upsampler = Upsample(
|
| 1321 |
+
spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels
|
| 1322 |
+
)
|
| 1323 |
+
else:
|
| 1324 |
+
self.upsampler = None
|
| 1325 |
+
|
| 1326 |
+
def forward(
|
| 1327 |
+
self,
|
| 1328 |
+
hidden_states: torch.Tensor,
|
| 1329 |
+
res_hidden_states_list: list[torch.Tensor],
|
| 1330 |
+
temb: torch.Tensor,
|
| 1331 |
+
context: torch.Tensor | None = None,
|
| 1332 |
+
) -> torch.Tensor:
|
| 1333 |
+
del context
|
| 1334 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
| 1335 |
+
# pop res hidden states
|
| 1336 |
+
res_hidden_states = res_hidden_states_list[-1]
|
| 1337 |
+
res_hidden_states_list = res_hidden_states_list[:-1]
|
| 1338 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 1339 |
+
|
| 1340 |
+
hidden_states = resnet(hidden_states, temb)
|
| 1341 |
+
hidden_states = attn(hidden_states)
|
| 1342 |
+
|
| 1343 |
+
if self.upsampler is not None:
|
| 1344 |
+
hidden_states = self.upsampler(hidden_states, temb)
|
| 1345 |
+
|
| 1346 |
+
return hidden_states
|
| 1347 |
+
|
| 1348 |
+
|
| 1349 |
+
class CrossAttnUpBlock(nn.Module):
|
| 1350 |
+
"""
|
| 1351 |
+
Unet's up block containing resnet, upsamplers, and self-attention blocks.
|
| 1352 |
+
|
| 1353 |
+
Args:
|
| 1354 |
+
spatial_dims: The number of spatial dimensions.
|
| 1355 |
+
in_channels: number of input channels.
|
| 1356 |
+
prev_output_channel: number of channels from residual connection.
|
| 1357 |
+
out_channels: number of output channels.
|
| 1358 |
+
temb_channels: number of timestep embedding channels.
|
| 1359 |
+
num_res_blocks: number of residual blocks.
|
| 1360 |
+
norm_num_groups: number of groups for the group normalization.
|
| 1361 |
+
norm_eps: epsilon for the group normalization.
|
| 1362 |
+
add_upsample: if True add downsample block.
|
| 1363 |
+
resblock_updown: if True use residual blocks for upsampling.
|
| 1364 |
+
num_head_channels: number of channels in each attention head.
|
| 1365 |
+
transformer_num_layers: number of layers of Transformer blocks to use.
|
| 1366 |
+
cross_attention_dim: number of context dimensions to use.
|
| 1367 |
+
upcast_attention: if True, upcast attention operations to full precision.
|
| 1368 |
+
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
|
| 1369 |
+
"""
|
| 1370 |
+
|
| 1371 |
+
def __init__(
|
| 1372 |
+
self,
|
| 1373 |
+
spatial_dims: int,
|
| 1374 |
+
in_channels: int,
|
| 1375 |
+
prev_output_channel: int,
|
| 1376 |
+
out_channels: int,
|
| 1377 |
+
temb_channels: int,
|
| 1378 |
+
num_res_blocks: int = 1,
|
| 1379 |
+
norm_num_groups: int = 32,
|
| 1380 |
+
norm_eps: float = 1e-6,
|
| 1381 |
+
add_upsample: bool = True,
|
| 1382 |
+
resblock_updown: bool = False,
|
| 1383 |
+
num_head_channels: int = 1,
|
| 1384 |
+
transformer_num_layers: int = 1,
|
| 1385 |
+
cross_attention_dim: int | None = None,
|
| 1386 |
+
upcast_attention: bool = False,
|
| 1387 |
+
use_flash_attention: bool = False,
|
| 1388 |
+
) -> None:
|
| 1389 |
+
super().__init__()
|
| 1390 |
+
self.resblock_updown = resblock_updown
|
| 1391 |
+
|
| 1392 |
+
resnets = []
|
| 1393 |
+
attentions = []
|
| 1394 |
+
|
| 1395 |
+
for i in range(num_res_blocks):
|
| 1396 |
+
res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels
|
| 1397 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
| 1398 |
+
|
| 1399 |
+
resnets.append(
|
| 1400 |
+
ResnetBlock(
|
| 1401 |
+
spatial_dims=spatial_dims,
|
| 1402 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
| 1403 |
+
out_channels=out_channels,
|
| 1404 |
+
temb_channels=temb_channels,
|
| 1405 |
+
norm_num_groups=norm_num_groups,
|
| 1406 |
+
norm_eps=norm_eps,
|
| 1407 |
+
)
|
| 1408 |
+
)
|
| 1409 |
+
attentions.append(
|
| 1410 |
+
SpatialTransformer(
|
| 1411 |
+
spatial_dims=spatial_dims,
|
| 1412 |
+
in_channels=out_channels,
|
| 1413 |
+
num_attention_heads=out_channels // num_head_channels,
|
| 1414 |
+
num_head_channels=num_head_channels,
|
| 1415 |
+
norm_num_groups=norm_num_groups,
|
| 1416 |
+
norm_eps=norm_eps,
|
| 1417 |
+
num_layers=transformer_num_layers,
|
| 1418 |
+
cross_attention_dim=cross_attention_dim,
|
| 1419 |
+
upcast_attention=upcast_attention,
|
| 1420 |
+
use_flash_attention=use_flash_attention,
|
| 1421 |
+
)
|
| 1422 |
+
)
|
| 1423 |
+
|
| 1424 |
+
self.attentions = nn.ModuleList(attentions)
|
| 1425 |
+
self.resnets = nn.ModuleList(resnets)
|
| 1426 |
+
|
| 1427 |
+
if add_upsample:
|
| 1428 |
+
if resblock_updown:
|
| 1429 |
+
self.upsampler = ResnetBlock(
|
| 1430 |
+
spatial_dims=spatial_dims,
|
| 1431 |
+
in_channels=out_channels,
|
| 1432 |
+
out_channels=out_channels,
|
| 1433 |
+
temb_channels=temb_channels,
|
| 1434 |
+
norm_num_groups=norm_num_groups,
|
| 1435 |
+
norm_eps=norm_eps,
|
| 1436 |
+
up=True,
|
| 1437 |
+
)
|
| 1438 |
+
else:
|
| 1439 |
+
self.upsampler = Upsample(
|
| 1440 |
+
spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels
|
| 1441 |
+
)
|
| 1442 |
+
else:
|
| 1443 |
+
self.upsampler = None
|
| 1444 |
+
|
| 1445 |
+
def forward(
|
| 1446 |
+
self,
|
| 1447 |
+
hidden_states: torch.Tensor,
|
| 1448 |
+
res_hidden_states_list: list[torch.Tensor],
|
| 1449 |
+
temb: torch.Tensor,
|
| 1450 |
+
context: torch.Tensor | None = None,
|
| 1451 |
+
) -> torch.Tensor:
|
| 1452 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
| 1453 |
+
# pop res hidden states
|
| 1454 |
+
res_hidden_states = res_hidden_states_list[-1]
|
| 1455 |
+
res_hidden_states_list = res_hidden_states_list[:-1]
|
| 1456 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 1457 |
+
|
| 1458 |
+
hidden_states = resnet(hidden_states, temb)
|
| 1459 |
+
hidden_states = attn(hidden_states, context=context)
|
| 1460 |
+
|
| 1461 |
+
if self.upsampler is not None:
|
| 1462 |
+
hidden_states = self.upsampler(hidden_states, temb)
|
| 1463 |
+
|
| 1464 |
+
return hidden_states
|
| 1465 |
+
|
| 1466 |
+
|
| 1467 |
+
def get_down_block(
|
| 1468 |
+
spatial_dims: int,
|
| 1469 |
+
in_channels: int,
|
| 1470 |
+
out_channels: int,
|
| 1471 |
+
temb_channels: int,
|
| 1472 |
+
num_res_blocks: int,
|
| 1473 |
+
norm_num_groups: int,
|
| 1474 |
+
norm_eps: float,
|
| 1475 |
+
add_downsample: bool,
|
| 1476 |
+
resblock_updown: bool,
|
| 1477 |
+
with_attn: bool,
|
| 1478 |
+
with_cross_attn: bool,
|
| 1479 |
+
num_head_channels: int,
|
| 1480 |
+
transformer_num_layers: int,
|
| 1481 |
+
cross_attention_dim: int | None,
|
| 1482 |
+
upcast_attention: bool = False,
|
| 1483 |
+
use_flash_attention: bool = False,
|
| 1484 |
+
) -> nn.Module:
|
| 1485 |
+
if with_attn:
|
| 1486 |
+
return AttnDownBlock(
|
| 1487 |
+
spatial_dims=spatial_dims,
|
| 1488 |
+
in_channels=in_channels,
|
| 1489 |
+
out_channels=out_channels,
|
| 1490 |
+
temb_channels=temb_channels,
|
| 1491 |
+
num_res_blocks=num_res_blocks,
|
| 1492 |
+
norm_num_groups=norm_num_groups,
|
| 1493 |
+
norm_eps=norm_eps,
|
| 1494 |
+
add_downsample=add_downsample,
|
| 1495 |
+
resblock_updown=resblock_updown,
|
| 1496 |
+
num_head_channels=num_head_channels,
|
| 1497 |
+
use_flash_attention=use_flash_attention,
|
| 1498 |
+
)
|
| 1499 |
+
elif with_cross_attn:
|
| 1500 |
+
return CrossAttnDownBlock(
|
| 1501 |
+
spatial_dims=spatial_dims,
|
| 1502 |
+
in_channels=in_channels,
|
| 1503 |
+
out_channels=out_channels,
|
| 1504 |
+
temb_channels=temb_channels,
|
| 1505 |
+
num_res_blocks=num_res_blocks,
|
| 1506 |
+
norm_num_groups=norm_num_groups,
|
| 1507 |
+
norm_eps=norm_eps,
|
| 1508 |
+
add_downsample=add_downsample,
|
| 1509 |
+
resblock_updown=resblock_updown,
|
| 1510 |
+
num_head_channels=num_head_channels,
|
| 1511 |
+
transformer_num_layers=transformer_num_layers,
|
| 1512 |
+
cross_attention_dim=cross_attention_dim,
|
| 1513 |
+
upcast_attention=upcast_attention,
|
| 1514 |
+
use_flash_attention=use_flash_attention,
|
| 1515 |
+
)
|
| 1516 |
+
else:
|
| 1517 |
+
return DownBlock(
|
| 1518 |
+
spatial_dims=spatial_dims,
|
| 1519 |
+
in_channels=in_channels,
|
| 1520 |
+
out_channels=out_channels,
|
| 1521 |
+
temb_channels=temb_channels,
|
| 1522 |
+
num_res_blocks=num_res_blocks,
|
| 1523 |
+
norm_num_groups=norm_num_groups,
|
| 1524 |
+
norm_eps=norm_eps,
|
| 1525 |
+
add_downsample=add_downsample,
|
| 1526 |
+
resblock_updown=resblock_updown,
|
| 1527 |
+
)
|
| 1528 |
+
|
| 1529 |
+
|
| 1530 |
+
def get_mid_block(
|
| 1531 |
+
spatial_dims: int,
|
| 1532 |
+
in_channels: int,
|
| 1533 |
+
temb_channels: int,
|
| 1534 |
+
norm_num_groups: int,
|
| 1535 |
+
norm_eps: float,
|
| 1536 |
+
with_conditioning: bool,
|
| 1537 |
+
num_head_channels: int,
|
| 1538 |
+
transformer_num_layers: int,
|
| 1539 |
+
cross_attention_dim: int | None,
|
| 1540 |
+
upcast_attention: bool = False,
|
| 1541 |
+
use_flash_attention: bool = False,
|
| 1542 |
+
) -> nn.Module:
|
| 1543 |
+
if with_conditioning:
|
| 1544 |
+
return CrossAttnMidBlock(
|
| 1545 |
+
spatial_dims=spatial_dims,
|
| 1546 |
+
in_channels=in_channels,
|
| 1547 |
+
temb_channels=temb_channels,
|
| 1548 |
+
norm_num_groups=norm_num_groups,
|
| 1549 |
+
norm_eps=norm_eps,
|
| 1550 |
+
num_head_channels=num_head_channels,
|
| 1551 |
+
transformer_num_layers=transformer_num_layers,
|
| 1552 |
+
cross_attention_dim=cross_attention_dim,
|
| 1553 |
+
upcast_attention=upcast_attention,
|
| 1554 |
+
use_flash_attention=use_flash_attention,
|
| 1555 |
+
)
|
| 1556 |
+
else:
|
| 1557 |
+
return AttnMidBlock(
|
| 1558 |
+
spatial_dims=spatial_dims,
|
| 1559 |
+
in_channels=in_channels,
|
| 1560 |
+
temb_channels=temb_channels,
|
| 1561 |
+
norm_num_groups=norm_num_groups,
|
| 1562 |
+
norm_eps=norm_eps,
|
| 1563 |
+
num_head_channels=num_head_channels,
|
| 1564 |
+
use_flash_attention=use_flash_attention,
|
| 1565 |
+
)
|
| 1566 |
+
|
| 1567 |
+
|
| 1568 |
+
def get_up_block(
|
| 1569 |
+
spatial_dims: int,
|
| 1570 |
+
in_channels: int,
|
| 1571 |
+
prev_output_channel: int,
|
| 1572 |
+
out_channels: int,
|
| 1573 |
+
temb_channels: int,
|
| 1574 |
+
num_res_blocks: int,
|
| 1575 |
+
norm_num_groups: int,
|
| 1576 |
+
norm_eps: float,
|
| 1577 |
+
add_upsample: bool,
|
| 1578 |
+
resblock_updown: bool,
|
| 1579 |
+
with_attn: bool,
|
| 1580 |
+
with_cross_attn: bool,
|
| 1581 |
+
num_head_channels: int,
|
| 1582 |
+
transformer_num_layers: int,
|
| 1583 |
+
cross_attention_dim: int | None,
|
| 1584 |
+
upcast_attention: bool = False,
|
| 1585 |
+
use_flash_attention: bool = False,
|
| 1586 |
+
) -> nn.Module:
|
| 1587 |
+
if with_attn:
|
| 1588 |
+
return AttnUpBlock(
|
| 1589 |
+
spatial_dims=spatial_dims,
|
| 1590 |
+
in_channels=in_channels,
|
| 1591 |
+
prev_output_channel=prev_output_channel,
|
| 1592 |
+
out_channels=out_channels,
|
| 1593 |
+
temb_channels=temb_channels,
|
| 1594 |
+
num_res_blocks=num_res_blocks,
|
| 1595 |
+
norm_num_groups=norm_num_groups,
|
| 1596 |
+
norm_eps=norm_eps,
|
| 1597 |
+
add_upsample=add_upsample,
|
| 1598 |
+
resblock_updown=resblock_updown,
|
| 1599 |
+
num_head_channels=num_head_channels,
|
| 1600 |
+
use_flash_attention=use_flash_attention,
|
| 1601 |
+
)
|
| 1602 |
+
elif with_cross_attn:
|
| 1603 |
+
return CrossAttnUpBlock(
|
| 1604 |
+
spatial_dims=spatial_dims,
|
| 1605 |
+
in_channels=in_channels,
|
| 1606 |
+
prev_output_channel=prev_output_channel,
|
| 1607 |
+
out_channels=out_channels,
|
| 1608 |
+
temb_channels=temb_channels,
|
| 1609 |
+
num_res_blocks=num_res_blocks,
|
| 1610 |
+
norm_num_groups=norm_num_groups,
|
| 1611 |
+
norm_eps=norm_eps,
|
| 1612 |
+
add_upsample=add_upsample,
|
| 1613 |
+
resblock_updown=resblock_updown,
|
| 1614 |
+
num_head_channels=num_head_channels,
|
| 1615 |
+
transformer_num_layers=transformer_num_layers,
|
| 1616 |
+
cross_attention_dim=cross_attention_dim,
|
| 1617 |
+
upcast_attention=upcast_attention,
|
| 1618 |
+
use_flash_attention=use_flash_attention,
|
| 1619 |
+
)
|
| 1620 |
+
else:
|
| 1621 |
+
return UpBlock(
|
| 1622 |
+
spatial_dims=spatial_dims,
|
| 1623 |
+
in_channels=in_channels,
|
| 1624 |
+
prev_output_channel=prev_output_channel,
|
| 1625 |
+
out_channels=out_channels,
|
| 1626 |
+
temb_channels=temb_channels,
|
| 1627 |
+
num_res_blocks=num_res_blocks,
|
| 1628 |
+
norm_num_groups=norm_num_groups,
|
| 1629 |
+
norm_eps=norm_eps,
|
| 1630 |
+
add_upsample=add_upsample,
|
| 1631 |
+
resblock_updown=resblock_updown,
|
| 1632 |
+
)
|
| 1633 |
+
|
| 1634 |
+
|
| 1635 |
+
class DiffusionModelUNet(nn.Module):
|
| 1636 |
+
"""
|
| 1637 |
+
Unet network with timestep embedding and attention mechanisms for conditioning based on
|
| 1638 |
+
Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752
|
| 1639 |
+
and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162
|
| 1640 |
+
|
| 1641 |
+
Args:
|
| 1642 |
+
spatial_dims: number of spatial dimensions.
|
| 1643 |
+
in_channels: number of input channels.
|
| 1644 |
+
out_channels: number of output channels.
|
| 1645 |
+
num_res_blocks: number of residual blocks (see ResnetBlock) per level.
|
| 1646 |
+
num_channels: tuple of block output channels.
|
| 1647 |
+
attention_levels: list of levels to add attention.
|
| 1648 |
+
norm_num_groups: number of groups for the normalization.
|
| 1649 |
+
norm_eps: epsilon for the normalization.
|
| 1650 |
+
resblock_updown: if True use residual blocks for up/downsampling.
|
| 1651 |
+
num_head_channels: number of channels in each attention head.
|
| 1652 |
+
with_conditioning: if True add spatial transformers to perform conditioning.
|
| 1653 |
+
transformer_num_layers: number of layers of Transformer blocks to use.
|
| 1654 |
+
cross_attention_dim: number of context dimensions to use.
|
| 1655 |
+
num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
|
| 1656 |
+
classes.
|
| 1657 |
+
upcast_attention: if True, upcast attention operations to full precision.
|
| 1658 |
+
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
|
| 1659 |
+
"""
|
| 1660 |
+
|
| 1661 |
+
def __init__(
|
| 1662 |
+
self,
|
| 1663 |
+
spatial_dims: int,
|
| 1664 |
+
in_channels: int,
|
| 1665 |
+
out_channels: int,
|
| 1666 |
+
num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
|
| 1667 |
+
num_channels: Sequence[int] = (32, 64, 64, 64),
|
| 1668 |
+
attention_levels: Sequence[bool] = (False, False, True, True),
|
| 1669 |
+
norm_num_groups: int = 32,
|
| 1670 |
+
norm_eps: float = 1e-6,
|
| 1671 |
+
resblock_updown: bool = False,
|
| 1672 |
+
num_head_channels: int | Sequence[int] = 8,
|
| 1673 |
+
with_conditioning: bool = False,
|
| 1674 |
+
transformer_num_layers: int = 1,
|
| 1675 |
+
cross_attention_dim: int | None = None,
|
| 1676 |
+
num_class_embeds: int | None = None,
|
| 1677 |
+
upcast_attention: bool = False,
|
| 1678 |
+
use_flash_attention: bool = False,
|
| 1679 |
+
) -> None:
|
| 1680 |
+
super().__init__()
|
| 1681 |
+
if with_conditioning is True and cross_attention_dim is None:
|
| 1682 |
+
raise ValueError(
|
| 1683 |
+
"DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) "
|
| 1684 |
+
"when using with_conditioning."
|
| 1685 |
+
)
|
| 1686 |
+
if cross_attention_dim is not None and with_conditioning is False:
|
| 1687 |
+
raise ValueError(
|
| 1688 |
+
"DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim."
|
| 1689 |
+
)
|
| 1690 |
+
|
| 1691 |
+
# All number of channels should be multiple of num_groups
|
| 1692 |
+
if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels):
|
| 1693 |
+
raise ValueError("DiffusionModelUNet expects all num_channels being multiple of norm_num_groups")
|
| 1694 |
+
|
| 1695 |
+
if len(num_channels) != len(attention_levels):
|
| 1696 |
+
raise ValueError("DiffusionModelUNet expects num_channels being same size of attention_levels")
|
| 1697 |
+
|
| 1698 |
+
if isinstance(num_head_channels, int):
|
| 1699 |
+
num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels))
|
| 1700 |
+
|
| 1701 |
+
if len(num_head_channels) != len(attention_levels):
|
| 1702 |
+
raise ValueError(
|
| 1703 |
+
"num_head_channels should have the same length as attention_levels. For the i levels without attention,"
|
| 1704 |
+
" i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored."
|
| 1705 |
+
)
|
| 1706 |
+
|
| 1707 |
+
if isinstance(num_res_blocks, int):
|
| 1708 |
+
num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels))
|
| 1709 |
+
|
| 1710 |
+
if len(num_res_blocks) != len(num_channels):
|
| 1711 |
+
raise ValueError(
|
| 1712 |
+
"`num_res_blocks` should be a single integer or a tuple of integers with the same length as "
|
| 1713 |
+
"`num_channels`."
|
| 1714 |
+
)
|
| 1715 |
+
|
| 1716 |
+
if use_flash_attention and not has_xformers:
|
| 1717 |
+
raise ValueError("use_flash_attention is True but xformers is not installed.")
|
| 1718 |
+
|
| 1719 |
+
if use_flash_attention is True and not torch.cuda.is_available():
|
| 1720 |
+
raise ValueError(
|
| 1721 |
+
"torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU."
|
| 1722 |
+
)
|
| 1723 |
+
|
| 1724 |
+
self.in_channels = in_channels
|
| 1725 |
+
self.block_out_channels = num_channels
|
| 1726 |
+
self.out_channels = out_channels
|
| 1727 |
+
self.num_res_blocks = num_res_blocks
|
| 1728 |
+
self.attention_levels = attention_levels
|
| 1729 |
+
self.num_head_channels = num_head_channels
|
| 1730 |
+
self.with_conditioning = with_conditioning
|
| 1731 |
+
|
| 1732 |
+
# input
|
| 1733 |
+
self.conv_in = Convolution(
|
| 1734 |
+
spatial_dims=spatial_dims,
|
| 1735 |
+
in_channels=in_channels,
|
| 1736 |
+
out_channels=num_channels[0],
|
| 1737 |
+
strides=1,
|
| 1738 |
+
kernel_size=3,
|
| 1739 |
+
padding=1,
|
| 1740 |
+
conv_only=True,
|
| 1741 |
+
)
|
| 1742 |
+
|
| 1743 |
+
# time
|
| 1744 |
+
time_embed_dim = num_channels[0] * 4
|
| 1745 |
+
self.time_embed = nn.Sequential(
|
| 1746 |
+
nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim)
|
| 1747 |
+
)
|
| 1748 |
+
|
| 1749 |
+
# class embedding
|
| 1750 |
+
self.num_class_embeds = num_class_embeds
|
| 1751 |
+
if num_class_embeds is not None:
|
| 1752 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
| 1753 |
+
|
| 1754 |
+
# down
|
| 1755 |
+
self.down_blocks = nn.ModuleList([])
|
| 1756 |
+
output_channel = num_channels[0]
|
| 1757 |
+
for i in range(len(num_channels)):
|
| 1758 |
+
input_channel = output_channel
|
| 1759 |
+
output_channel = num_channels[i]
|
| 1760 |
+
is_final_block = i == len(num_channels) - 1
|
| 1761 |
+
|
| 1762 |
+
down_block = get_down_block(
|
| 1763 |
+
spatial_dims=spatial_dims,
|
| 1764 |
+
in_channels=input_channel,
|
| 1765 |
+
out_channels=output_channel,
|
| 1766 |
+
temb_channels=time_embed_dim,
|
| 1767 |
+
num_res_blocks=num_res_blocks[i],
|
| 1768 |
+
norm_num_groups=norm_num_groups,
|
| 1769 |
+
norm_eps=norm_eps,
|
| 1770 |
+
add_downsample=not is_final_block,
|
| 1771 |
+
resblock_updown=resblock_updown,
|
| 1772 |
+
with_attn=(attention_levels[i] and not with_conditioning),
|
| 1773 |
+
with_cross_attn=(attention_levels[i] and with_conditioning),
|
| 1774 |
+
num_head_channels=num_head_channels[i],
|
| 1775 |
+
transformer_num_layers=transformer_num_layers,
|
| 1776 |
+
cross_attention_dim=cross_attention_dim,
|
| 1777 |
+
upcast_attention=upcast_attention,
|
| 1778 |
+
use_flash_attention=use_flash_attention,
|
| 1779 |
+
)
|
| 1780 |
+
|
| 1781 |
+
self.down_blocks.append(down_block)
|
| 1782 |
+
|
| 1783 |
+
# mid
|
| 1784 |
+
self.middle_block = get_mid_block(
|
| 1785 |
+
spatial_dims=spatial_dims,
|
| 1786 |
+
in_channels=num_channels[-1],
|
| 1787 |
+
temb_channels=time_embed_dim,
|
| 1788 |
+
norm_num_groups=norm_num_groups,
|
| 1789 |
+
norm_eps=norm_eps,
|
| 1790 |
+
with_conditioning=with_conditioning,
|
| 1791 |
+
num_head_channels=num_head_channels[-1],
|
| 1792 |
+
transformer_num_layers=transformer_num_layers,
|
| 1793 |
+
cross_attention_dim=cross_attention_dim,
|
| 1794 |
+
upcast_attention=upcast_attention,
|
| 1795 |
+
use_flash_attention=use_flash_attention,
|
| 1796 |
+
)
|
| 1797 |
+
|
| 1798 |
+
# up
|
| 1799 |
+
self.up_blocks = nn.ModuleList([])
|
| 1800 |
+
reversed_block_out_channels = list(reversed(num_channels))
|
| 1801 |
+
reversed_num_res_blocks = list(reversed(num_res_blocks))
|
| 1802 |
+
reversed_attention_levels = list(reversed(attention_levels))
|
| 1803 |
+
reversed_num_head_channels = list(reversed(num_head_channels))
|
| 1804 |
+
output_channel = reversed_block_out_channels[0]
|
| 1805 |
+
for i in range(len(reversed_block_out_channels)):
|
| 1806 |
+
prev_output_channel = output_channel
|
| 1807 |
+
output_channel = reversed_block_out_channels[i]
|
| 1808 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(num_channels) - 1)]
|
| 1809 |
+
|
| 1810 |
+
is_final_block = i == len(num_channels) - 1
|
| 1811 |
+
|
| 1812 |
+
up_block = get_up_block(
|
| 1813 |
+
spatial_dims=spatial_dims,
|
| 1814 |
+
in_channels=input_channel,
|
| 1815 |
+
prev_output_channel=prev_output_channel,
|
| 1816 |
+
out_channels=output_channel,
|
| 1817 |
+
temb_channels=time_embed_dim,
|
| 1818 |
+
num_res_blocks=reversed_num_res_blocks[i] + 1,
|
| 1819 |
+
norm_num_groups=norm_num_groups,
|
| 1820 |
+
norm_eps=norm_eps,
|
| 1821 |
+
add_upsample=not is_final_block,
|
| 1822 |
+
resblock_updown=resblock_updown,
|
| 1823 |
+
with_attn=(reversed_attention_levels[i] and not with_conditioning),
|
| 1824 |
+
with_cross_attn=(reversed_attention_levels[i] and with_conditioning),
|
| 1825 |
+
num_head_channels=reversed_num_head_channels[i],
|
| 1826 |
+
transformer_num_layers=transformer_num_layers,
|
| 1827 |
+
cross_attention_dim=cross_attention_dim,
|
| 1828 |
+
upcast_attention=upcast_attention,
|
| 1829 |
+
use_flash_attention=use_flash_attention,
|
| 1830 |
+
)
|
| 1831 |
+
|
| 1832 |
+
self.up_blocks.append(up_block)
|
| 1833 |
+
|
| 1834 |
+
# out
|
| 1835 |
+
self.out = nn.Sequential(
|
| 1836 |
+
nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[0], eps=norm_eps, affine=True),
|
| 1837 |
+
nn.SiLU(),
|
| 1838 |
+
zero_module(
|
| 1839 |
+
Convolution(
|
| 1840 |
+
spatial_dims=spatial_dims,
|
| 1841 |
+
in_channels=num_channels[0],
|
| 1842 |
+
out_channels=out_channels,
|
| 1843 |
+
strides=1,
|
| 1844 |
+
kernel_size=3,
|
| 1845 |
+
padding=1,
|
| 1846 |
+
conv_only=True,
|
| 1847 |
+
dilation=2
|
| 1848 |
+
)
|
| 1849 |
+
),
|
| 1850 |
+
)
|
| 1851 |
+
|
| 1852 |
+
def forward(
|
| 1853 |
+
self,
|
| 1854 |
+
x: torch.Tensor,
|
| 1855 |
+
timesteps: torch.Tensor,
|
| 1856 |
+
context: torch.Tensor | None = None,
|
| 1857 |
+
class_labels: torch.Tensor | None = None,
|
| 1858 |
+
down_block_additional_residuals: tuple[torch.Tensor] | None = None,
|
| 1859 |
+
mid_block_additional_residual: torch.Tensor | None = None,
|
| 1860 |
+
) -> torch.Tensor:
|
| 1861 |
+
"""
|
| 1862 |
+
Args:
|
| 1863 |
+
x: input tensor (N, C, SpatialDims).
|
| 1864 |
+
timesteps: timestep tensor (N,).
|
| 1865 |
+
context: context tensor (N, 1, ContextDim).
|
| 1866 |
+
class_labels: context tensor (N, ).
|
| 1867 |
+
down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims).
|
| 1868 |
+
mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims).
|
| 1869 |
+
"""
|
| 1870 |
+
# 1. time
|
| 1871 |
+
t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])
|
| 1872 |
+
|
| 1873 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
| 1874 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 1875 |
+
# there might be better ways to encapsulate this.
|
| 1876 |
+
t_emb = t_emb.to(dtype=x.dtype)
|
| 1877 |
+
emb = self.time_embed(t_emb)
|
| 1878 |
+
|
| 1879 |
+
# 2. class
|
| 1880 |
+
if self.num_class_embeds is not None:
|
| 1881 |
+
if class_labels is None:
|
| 1882 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
| 1883 |
+
class_emb = self.class_embedding(class_labels)
|
| 1884 |
+
class_emb = class_emb.to(dtype=x.dtype)
|
| 1885 |
+
emb = emb + class_emb
|
| 1886 |
+
|
| 1887 |
+
# 3. initial convolution
|
| 1888 |
+
h = self.conv_in(x)
|
| 1889 |
+
|
| 1890 |
+
# 4. down
|
| 1891 |
+
if context is not None and self.with_conditioning is False:
|
| 1892 |
+
raise ValueError("model should have with_conditioning = True if context is provided")
|
| 1893 |
+
down_block_res_samples: list[torch.Tensor] = [h]
|
| 1894 |
+
for downsample_block in self.down_blocks:
|
| 1895 |
+
h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context)
|
| 1896 |
+
for residual in res_samples:
|
| 1897 |
+
down_block_res_samples.append(residual)
|
| 1898 |
+
|
| 1899 |
+
# Additional residual conections for Controlnets
|
| 1900 |
+
if down_block_additional_residuals is not None:
|
| 1901 |
+
new_down_block_res_samples = ()
|
| 1902 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
| 1903 |
+
down_block_res_samples, down_block_additional_residuals
|
| 1904 |
+
):
|
| 1905 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
| 1906 |
+
new_down_block_res_samples += (down_block_res_sample,)
|
| 1907 |
+
|
| 1908 |
+
down_block_res_samples = new_down_block_res_samples
|
| 1909 |
+
|
| 1910 |
+
# 5. mid
|
| 1911 |
+
h = self.middle_block(hidden_states=h, temb=emb, context=context)
|
| 1912 |
+
|
| 1913 |
+
# Additional residual conections for Controlnets
|
| 1914 |
+
if mid_block_additional_residual is not None:
|
| 1915 |
+
h = h + mid_block_additional_residual
|
| 1916 |
+
|
| 1917 |
+
# 6. up
|
| 1918 |
+
for upsample_block in self.up_blocks:
|
| 1919 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
| 1920 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
| 1921 |
+
h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context)
|
| 1922 |
+
|
| 1923 |
+
# 7. output block
|
| 1924 |
+
h = self.out(h)
|
| 1925 |
+
|
| 1926 |
+
return h
|
| 1927 |
+
|
| 1928 |
+
|
| 1929 |
+
class DiffusionModelEncoder(nn.Module):
|
| 1930 |
+
"""
|
| 1931 |
+
Classification Network based on the Encoder of the Diffusion Model, followed by fully connected layers. This network is based on
|
| 1932 |
+
Wolleb et al. "Diffusion Models for Medical Anomaly Detection" (https://arxiv.org/abs/2203.04306).
|
| 1933 |
+
|
| 1934 |
+
Args:
|
| 1935 |
+
spatial_dims: number of spatial dimensions.
|
| 1936 |
+
in_channels: number of input channels.
|
| 1937 |
+
out_channels: number of output channels.
|
| 1938 |
+
num_res_blocks: number of residual blocks (see ResnetBlock) per level.
|
| 1939 |
+
num_channels: tuple of block output channels.
|
| 1940 |
+
attention_levels: list of levels to add attention.
|
| 1941 |
+
norm_num_groups: number of groups for the normalization.
|
| 1942 |
+
norm_eps: epsilon for the normalization.
|
| 1943 |
+
resblock_updown: if True use residual blocks for downsampling.
|
| 1944 |
+
num_head_channels: number of channels in each attention head.
|
| 1945 |
+
with_conditioning: if True add spatial transformers to perform conditioning.
|
| 1946 |
+
transformer_num_layers: number of layers of Transformer blocks to use.
|
| 1947 |
+
cross_attention_dim: number of context dimensions to use.
|
| 1948 |
+
num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes.
|
| 1949 |
+
upcast_attention: if True, upcast attention operations to full precision.
|
| 1950 |
+
"""
|
| 1951 |
+
|
| 1952 |
+
def __init__(
|
| 1953 |
+
self,
|
| 1954 |
+
spatial_dims: int,
|
| 1955 |
+
in_channels: int,
|
| 1956 |
+
out_channels: int,
|
| 1957 |
+
num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
|
| 1958 |
+
num_channels: Sequence[int] = (32, 64, 64, 64),
|
| 1959 |
+
attention_levels: Sequence[bool] = (False, False, True, True),
|
| 1960 |
+
norm_num_groups: int = 32,
|
| 1961 |
+
norm_eps: float = 1e-6,
|
| 1962 |
+
resblock_updown: bool = False,
|
| 1963 |
+
num_head_channels: int | Sequence[int] = 8,
|
| 1964 |
+
with_conditioning: bool = False,
|
| 1965 |
+
transformer_num_layers: int = 1,
|
| 1966 |
+
cross_attention_dim: int | None = None,
|
| 1967 |
+
num_class_embeds: int | None = None,
|
| 1968 |
+
upcast_attention: bool = False,
|
| 1969 |
+
) -> None:
|
| 1970 |
+
super().__init__()
|
| 1971 |
+
if with_conditioning is True and cross_attention_dim is None:
|
| 1972 |
+
raise ValueError(
|
| 1973 |
+
"DiffusionModelEncoder expects dimension of the cross-attention conditioning (cross_attention_dim) "
|
| 1974 |
+
"when using with_conditioning."
|
| 1975 |
+
)
|
| 1976 |
+
if cross_attention_dim is not None and with_conditioning is False:
|
| 1977 |
+
raise ValueError(
|
| 1978 |
+
"DiffusionModelEncoder expects with_conditioning=True when specifying the cross_attention_dim."
|
| 1979 |
+
)
|
| 1980 |
+
|
| 1981 |
+
# All number of channels should be multiple of num_groups
|
| 1982 |
+
if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels):
|
| 1983 |
+
raise ValueError("DiffusionModelEncoder expects all num_channels being multiple of norm_num_groups")
|
| 1984 |
+
if len(num_channels) != len(attention_levels):
|
| 1985 |
+
raise ValueError("DiffusionModelEncoder expects num_channels being same size of attention_levels")
|
| 1986 |
+
|
| 1987 |
+
if isinstance(num_head_channels, int):
|
| 1988 |
+
num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels))
|
| 1989 |
+
|
| 1990 |
+
if len(num_head_channels) != len(attention_levels):
|
| 1991 |
+
raise ValueError(
|
| 1992 |
+
"num_head_channels should have the same length as attention_levels. For the i levels without attention,"
|
| 1993 |
+
" i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored."
|
| 1994 |
+
)
|
| 1995 |
+
|
| 1996 |
+
self.in_channels = in_channels
|
| 1997 |
+
self.block_out_channels = num_channels
|
| 1998 |
+
self.out_channels = out_channels
|
| 1999 |
+
self.num_res_blocks = num_res_blocks
|
| 2000 |
+
self.attention_levels = attention_levels
|
| 2001 |
+
self.num_head_channels = num_head_channels
|
| 2002 |
+
self.with_conditioning = with_conditioning
|
| 2003 |
+
|
| 2004 |
+
# input
|
| 2005 |
+
self.conv_in = Convolution(
|
| 2006 |
+
spatial_dims=spatial_dims,
|
| 2007 |
+
in_channels=in_channels,
|
| 2008 |
+
out_channels=num_channels[0],
|
| 2009 |
+
strides=1,
|
| 2010 |
+
kernel_size=3,
|
| 2011 |
+
padding=1,
|
| 2012 |
+
conv_only=True,
|
| 2013 |
+
)
|
| 2014 |
+
|
| 2015 |
+
# time
|
| 2016 |
+
time_embed_dim = num_channels[0] * 4
|
| 2017 |
+
self.time_embed = nn.Sequential(
|
| 2018 |
+
nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim)
|
| 2019 |
+
)
|
| 2020 |
+
|
| 2021 |
+
# class embedding
|
| 2022 |
+
self.num_class_embeds = num_class_embeds
|
| 2023 |
+
if num_class_embeds is not None:
|
| 2024 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
| 2025 |
+
|
| 2026 |
+
# down
|
| 2027 |
+
self.down_blocks = nn.ModuleList([])
|
| 2028 |
+
output_channel = num_channels[0]
|
| 2029 |
+
for i in range(len(num_channels)):
|
| 2030 |
+
input_channel = output_channel
|
| 2031 |
+
output_channel = num_channels[i]
|
| 2032 |
+
is_final_block = i == len(num_channels) # - 1
|
| 2033 |
+
|
| 2034 |
+
down_block = get_down_block(
|
| 2035 |
+
spatial_dims=spatial_dims,
|
| 2036 |
+
in_channels=input_channel,
|
| 2037 |
+
out_channels=output_channel,
|
| 2038 |
+
temb_channels=time_embed_dim,
|
| 2039 |
+
num_res_blocks=num_res_blocks[i],
|
| 2040 |
+
norm_num_groups=norm_num_groups,
|
| 2041 |
+
norm_eps=norm_eps,
|
| 2042 |
+
add_downsample=not is_final_block,
|
| 2043 |
+
resblock_updown=resblock_updown,
|
| 2044 |
+
with_attn=(attention_levels[i] and not with_conditioning),
|
| 2045 |
+
with_cross_attn=(attention_levels[i] and with_conditioning),
|
| 2046 |
+
num_head_channels=num_head_channels[i],
|
| 2047 |
+
transformer_num_layers=transformer_num_layers,
|
| 2048 |
+
cross_attention_dim=cross_attention_dim,
|
| 2049 |
+
upcast_attention=upcast_attention,
|
| 2050 |
+
)
|
| 2051 |
+
|
| 2052 |
+
self.down_blocks.append(down_block)
|
| 2053 |
+
|
| 2054 |
+
self.out = nn.Sequential(nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels))
|
| 2055 |
+
|
| 2056 |
+
def forward(
|
| 2057 |
+
self,
|
| 2058 |
+
x: torch.Tensor,
|
| 2059 |
+
timesteps: torch.Tensor,
|
| 2060 |
+
context: torch.Tensor | None = None,
|
| 2061 |
+
class_labels: torch.Tensor | None = None,
|
| 2062 |
+
) -> torch.Tensor:
|
| 2063 |
+
"""
|
| 2064 |
+
Args:
|
| 2065 |
+
x: input tensor (N, C, SpatialDims).
|
| 2066 |
+
timesteps: timestep tensor (N,).
|
| 2067 |
+
context: context tensor (N, 1, ContextDim).
|
| 2068 |
+
class_labels: context tensor (N, ).
|
| 2069 |
+
"""
|
| 2070 |
+
# 1. time
|
| 2071 |
+
t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])
|
| 2072 |
+
|
| 2073 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
| 2074 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 2075 |
+
# there might be better ways to encapsulate this.
|
| 2076 |
+
t_emb = t_emb.to(dtype=x.dtype)
|
| 2077 |
+
emb = self.time_embed(t_emb)
|
| 2078 |
+
|
| 2079 |
+
# 2. class
|
| 2080 |
+
if self.num_class_embeds is not None:
|
| 2081 |
+
if class_labels is None:
|
| 2082 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
| 2083 |
+
class_emb = self.class_embedding(class_labels)
|
| 2084 |
+
class_emb = class_emb.to(dtype=x.dtype)
|
| 2085 |
+
emb = emb + class_emb
|
| 2086 |
+
|
| 2087 |
+
# 3. initial convolution
|
| 2088 |
+
h = self.conv_in(x)
|
| 2089 |
+
|
| 2090 |
+
# 4. down
|
| 2091 |
+
if context is not None and self.with_conditioning is False:
|
| 2092 |
+
raise ValueError("model should have with_conditioning = True if context is provided")
|
| 2093 |
+
for downsample_block in self.down_blocks:
|
| 2094 |
+
h, _ = downsample_block(hidden_states=h, temb=emb, context=context)
|
| 2095 |
+
|
| 2096 |
+
h = h.reshape(h.shape[0], -1)
|
| 2097 |
+
output = self.out(h)
|
| 2098 |
+
|
| 2099 |
+
return output
|
codes/modules/fp16_util.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Helpers to train with 16-bit precision.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch as th
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
| 9 |
+
|
| 10 |
+
from . import logger
|
| 11 |
+
|
| 12 |
+
INITIAL_LOG_LOSS_SCALE = 20.0
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def convert_module_to_f16(l):
|
| 16 |
+
"""
|
| 17 |
+
Convert primitive modules to float16.
|
| 18 |
+
"""
|
| 19 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
| 20 |
+
l.weight.data = l.weight.data.half()
|
| 21 |
+
if l.bias is not None:
|
| 22 |
+
l.bias.data = l.bias.data.half()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def convert_module_to_f32(l):
|
| 26 |
+
"""
|
| 27 |
+
Convert primitive modules to float32, undoing convert_module_to_f16().
|
| 28 |
+
"""
|
| 29 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
| 30 |
+
l.weight.data = l.weight.data.float()
|
| 31 |
+
if l.bias is not None:
|
| 32 |
+
l.bias.data = l.bias.data.float()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def make_master_params(param_groups_and_shapes):
|
| 36 |
+
"""
|
| 37 |
+
Copy model parameters into a (differently-shaped) list of full-precision
|
| 38 |
+
parameters.
|
| 39 |
+
"""
|
| 40 |
+
master_params = []
|
| 41 |
+
for param_group, shape in param_groups_and_shapes:
|
| 42 |
+
master_param = nn.Parameter(
|
| 43 |
+
_flatten_dense_tensors(
|
| 44 |
+
[param.detach().float() for (_, param) in param_group]
|
| 45 |
+
).view(shape)
|
| 46 |
+
)
|
| 47 |
+
master_param.requires_grad = True
|
| 48 |
+
master_params.append(master_param)
|
| 49 |
+
return master_params
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def model_grads_to_master_grads(param_groups_and_shapes, master_params):
|
| 53 |
+
"""
|
| 54 |
+
Copy the gradients from the model parameters into the master parameters
|
| 55 |
+
from make_master_params().
|
| 56 |
+
"""
|
| 57 |
+
for master_param, (param_group, shape) in zip(
|
| 58 |
+
master_params, param_groups_and_shapes
|
| 59 |
+
):
|
| 60 |
+
master_param.grad = _flatten_dense_tensors(
|
| 61 |
+
[param_grad_or_zeros(param) for (_, param) in param_group]
|
| 62 |
+
).view(shape)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def master_params_to_model_params(param_groups_and_shapes, master_params):
|
| 66 |
+
"""
|
| 67 |
+
Copy the master parameter data back into the model parameters.
|
| 68 |
+
"""
|
| 69 |
+
# Without copying to a list, if a generator is passed, this will
|
| 70 |
+
# silently not copy any parameters.
|
| 71 |
+
for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
|
| 72 |
+
for (_, param), unflat_master_param in zip(
|
| 73 |
+
param_group, unflatten_master_params(param_group, master_param.view(-1))
|
| 74 |
+
):
|
| 75 |
+
param.detach().copy_(unflat_master_param)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def unflatten_master_params(param_group, master_param):
|
| 79 |
+
return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_param_groups_and_shapes(named_model_params):
|
| 83 |
+
named_model_params = list(named_model_params)
|
| 84 |
+
scalar_vector_named_params = (
|
| 85 |
+
[(n, p) for (n, p) in named_model_params if p.ndim <= 1],
|
| 86 |
+
(-1),
|
| 87 |
+
)
|
| 88 |
+
matrix_named_params = (
|
| 89 |
+
[(n, p) for (n, p) in named_model_params if p.ndim > 1],
|
| 90 |
+
(1, -1),
|
| 91 |
+
)
|
| 92 |
+
return [scalar_vector_named_params, matrix_named_params]
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def master_params_to_state_dict(
|
| 96 |
+
model, param_groups_and_shapes, master_params, use_fp16
|
| 97 |
+
):
|
| 98 |
+
if use_fp16:
|
| 99 |
+
state_dict = model.state_dict()
|
| 100 |
+
for master_param, (param_group, _) in zip(
|
| 101 |
+
master_params, param_groups_and_shapes
|
| 102 |
+
):
|
| 103 |
+
for (name, _), unflat_master_param in zip(
|
| 104 |
+
param_group, unflatten_master_params(param_group, master_param.view(-1))
|
| 105 |
+
):
|
| 106 |
+
assert name in state_dict
|
| 107 |
+
state_dict[name] = unflat_master_param
|
| 108 |
+
else:
|
| 109 |
+
state_dict = model.state_dict()
|
| 110 |
+
for i, (name, _value) in enumerate(model.named_parameters()):
|
| 111 |
+
assert name in state_dict
|
| 112 |
+
state_dict[name] = master_params[i]
|
| 113 |
+
return state_dict
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def state_dict_to_master_params(model, state_dict, use_fp16):
|
| 117 |
+
if use_fp16:
|
| 118 |
+
named_model_params = [
|
| 119 |
+
(name, state_dict[name]) for name, _ in model.named_parameters()
|
| 120 |
+
]
|
| 121 |
+
param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
|
| 122 |
+
master_params = make_master_params(param_groups_and_shapes)
|
| 123 |
+
else:
|
| 124 |
+
master_params = [state_dict[name] for name, _ in model.named_parameters()]
|
| 125 |
+
return master_params
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def zero_master_grads(master_params):
|
| 129 |
+
for param in master_params:
|
| 130 |
+
param.grad = None
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def zero_grad(model_params):
|
| 134 |
+
for param in model_params:
|
| 135 |
+
# Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
|
| 136 |
+
if param.grad is not None:
|
| 137 |
+
param.grad.detach_()
|
| 138 |
+
param.grad.zero_()
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def param_grad_or_zeros(param):
|
| 142 |
+
if param.grad is not None:
|
| 143 |
+
return param.grad.data.detach()
|
| 144 |
+
else:
|
| 145 |
+
return th.zeros_like(param)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class MixedPrecisionTrainer:
|
| 149 |
+
def __init__(
|
| 150 |
+
self,
|
| 151 |
+
*,
|
| 152 |
+
model,
|
| 153 |
+
use_fp16=False,
|
| 154 |
+
fp16_scale_growth=1e-3,
|
| 155 |
+
initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
|
| 156 |
+
):
|
| 157 |
+
self.model = model
|
| 158 |
+
self.use_fp16 = use_fp16
|
| 159 |
+
self.fp16_scale_growth = fp16_scale_growth
|
| 160 |
+
|
| 161 |
+
self.model_params = list(self.model.parameters())
|
| 162 |
+
self.master_params = self.model_params
|
| 163 |
+
self.param_groups_and_shapes = None
|
| 164 |
+
self.lg_loss_scale = initial_lg_loss_scale
|
| 165 |
+
|
| 166 |
+
if self.use_fp16:
|
| 167 |
+
self.param_groups_and_shapes = get_param_groups_and_shapes(
|
| 168 |
+
self.model.named_parameters()
|
| 169 |
+
)
|
| 170 |
+
self.master_params = make_master_params(self.param_groups_and_shapes)
|
| 171 |
+
self.model.convert_to_fp16()
|
| 172 |
+
|
| 173 |
+
def zero_grad(self):
|
| 174 |
+
zero_grad(self.model_params)
|
| 175 |
+
|
| 176 |
+
def backward(self, loss: th.Tensor):
|
| 177 |
+
if self.use_fp16:
|
| 178 |
+
loss_scale = 2 ** self.lg_loss_scale
|
| 179 |
+
(loss * loss_scale).backward()
|
| 180 |
+
else:
|
| 181 |
+
loss.backward()
|
| 182 |
+
|
| 183 |
+
def optimize(self, opt: th.optim.Optimizer):
|
| 184 |
+
if self.use_fp16:
|
| 185 |
+
return self._optimize_fp16(opt)
|
| 186 |
+
else:
|
| 187 |
+
return self._optimize_normal(opt)
|
| 188 |
+
|
| 189 |
+
def _optimize_fp16(self, opt: th.optim.Optimizer):
|
| 190 |
+
logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
|
| 191 |
+
model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
|
| 192 |
+
grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
|
| 193 |
+
if check_overflow(grad_norm):
|
| 194 |
+
self.lg_loss_scale -= 1
|
| 195 |
+
logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
|
| 196 |
+
zero_master_grads(self.master_params)
|
| 197 |
+
return False
|
| 198 |
+
|
| 199 |
+
logger.logkv_mean("grad_norm", grad_norm)
|
| 200 |
+
logger.logkv_mean("param_norm", param_norm)
|
| 201 |
+
|
| 202 |
+
for p in self.master_params:
|
| 203 |
+
p.grad.mul_(1.0 / (2 ** self.lg_loss_scale))
|
| 204 |
+
opt.step()
|
| 205 |
+
zero_master_grads(self.master_params)
|
| 206 |
+
master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
|
| 207 |
+
self.lg_loss_scale += self.fp16_scale_growth
|
| 208 |
+
return True
|
| 209 |
+
|
| 210 |
+
def _optimize_normal(self, opt: th.optim.Optimizer):
|
| 211 |
+
grad_norm, param_norm = self._compute_norms()
|
| 212 |
+
logger.logkv_mean("grad_norm", grad_norm)
|
| 213 |
+
logger.logkv_mean("param_norm", param_norm)
|
| 214 |
+
opt.step()
|
| 215 |
+
return True
|
| 216 |
+
|
| 217 |
+
def _compute_norms(self, grad_scale=1.0):
|
| 218 |
+
grad_norm = 0.0
|
| 219 |
+
param_norm = 0.0
|
| 220 |
+
for p in self.master_params:
|
| 221 |
+
with th.no_grad():
|
| 222 |
+
param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
|
| 223 |
+
if p.grad is not None:
|
| 224 |
+
grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
|
| 225 |
+
return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
|
| 226 |
+
|
| 227 |
+
def master_params_to_state_dict(self, master_params):
|
| 228 |
+
return master_params_to_state_dict(
|
| 229 |
+
self.model, self.param_groups_and_shapes, master_params, self.use_fp16
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
def state_dict_to_master_params(self, state_dict):
|
| 233 |
+
return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def check_overflow(value):
|
| 237 |
+
return (value == float("inf")) or (value == -float("inf")) or (value != value)
|
codes/modules/logger.py
ADDED
|
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
|
| 3 |
+
https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import shutil
|
| 9 |
+
import os.path as osp
|
| 10 |
+
import json
|
| 11 |
+
import time
|
| 12 |
+
import datetime
|
| 13 |
+
import tempfile
|
| 14 |
+
import warnings
|
| 15 |
+
from collections import defaultdict
|
| 16 |
+
from contextlib import contextmanager
|
| 17 |
+
|
| 18 |
+
DEBUG = 10
|
| 19 |
+
INFO = 20
|
| 20 |
+
WARN = 30
|
| 21 |
+
ERROR = 40
|
| 22 |
+
|
| 23 |
+
DISABLED = 50
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class KVWriter(object):
|
| 27 |
+
def writekvs(self, kvs):
|
| 28 |
+
raise NotImplementedError
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class SeqWriter(object):
|
| 32 |
+
def writeseq(self, seq):
|
| 33 |
+
raise NotImplementedError
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class HumanOutputFormat(KVWriter, SeqWriter):
|
| 37 |
+
def __init__(self, filename_or_file):
|
| 38 |
+
if isinstance(filename_or_file, str):
|
| 39 |
+
self.file = open(filename_or_file, "wt")
|
| 40 |
+
self.own_file = True
|
| 41 |
+
else:
|
| 42 |
+
assert hasattr(filename_or_file, "read"), (
|
| 43 |
+
"expected file or str, got %s" % filename_or_file
|
| 44 |
+
)
|
| 45 |
+
self.file = filename_or_file
|
| 46 |
+
self.own_file = False
|
| 47 |
+
|
| 48 |
+
def writekvs(self, kvs):
|
| 49 |
+
# Create strings for printing
|
| 50 |
+
key2str = {}
|
| 51 |
+
for (key, val) in sorted(kvs.items()):
|
| 52 |
+
if hasattr(val, "__float__"):
|
| 53 |
+
valstr = "%-8.3g" % val
|
| 54 |
+
else:
|
| 55 |
+
valstr = str(val)
|
| 56 |
+
key2str[self._truncate(key)] = self._truncate(valstr)
|
| 57 |
+
|
| 58 |
+
# Find max widths
|
| 59 |
+
if len(key2str) == 0:
|
| 60 |
+
print("WARNING: tried to write empty key-value dict")
|
| 61 |
+
return
|
| 62 |
+
else:
|
| 63 |
+
keywidth = max(map(len, key2str.keys()))
|
| 64 |
+
valwidth = max(map(len, key2str.values()))
|
| 65 |
+
|
| 66 |
+
# Write out the data
|
| 67 |
+
dashes = "-" * (keywidth + valwidth + 7)
|
| 68 |
+
lines = [dashes]
|
| 69 |
+
for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
|
| 70 |
+
lines.append(
|
| 71 |
+
"| %s%s | %s%s |"
|
| 72 |
+
% (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
|
| 73 |
+
)
|
| 74 |
+
lines.append(dashes)
|
| 75 |
+
self.file.write("\n".join(lines) + "\n")
|
| 76 |
+
|
| 77 |
+
# Flush the output to the file
|
| 78 |
+
self.file.flush()
|
| 79 |
+
|
| 80 |
+
def _truncate(self, s):
|
| 81 |
+
maxlen = 30
|
| 82 |
+
return s[: maxlen - 3] + "..." if len(s) > maxlen else s
|
| 83 |
+
|
| 84 |
+
def writeseq(self, seq):
|
| 85 |
+
seq = list(seq)
|
| 86 |
+
for (i, elem) in enumerate(seq):
|
| 87 |
+
self.file.write(elem)
|
| 88 |
+
if i < len(seq) - 1: # add space unless this is the last one
|
| 89 |
+
self.file.write(" ")
|
| 90 |
+
self.file.write("\n")
|
| 91 |
+
self.file.flush()
|
| 92 |
+
|
| 93 |
+
def close(self):
|
| 94 |
+
if self.own_file:
|
| 95 |
+
self.file.close()
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class JSONOutputFormat(KVWriter):
|
| 99 |
+
def __init__(self, filename):
|
| 100 |
+
self.file = open(filename, "wt")
|
| 101 |
+
|
| 102 |
+
def writekvs(self, kvs):
|
| 103 |
+
for k, v in sorted(kvs.items()):
|
| 104 |
+
if hasattr(v, "dtype"):
|
| 105 |
+
kvs[k] = float(v)
|
| 106 |
+
self.file.write(json.dumps(kvs) + "\n")
|
| 107 |
+
self.file.flush()
|
| 108 |
+
|
| 109 |
+
def close(self):
|
| 110 |
+
self.file.close()
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class CSVOutputFormat(KVWriter):
|
| 114 |
+
def __init__(self, filename):
|
| 115 |
+
self.file = open(filename, "w+t")
|
| 116 |
+
self.keys = []
|
| 117 |
+
self.sep = ","
|
| 118 |
+
|
| 119 |
+
def writekvs(self, kvs):
|
| 120 |
+
# Add our current row to the history
|
| 121 |
+
extra_keys = list(kvs.keys() - self.keys)
|
| 122 |
+
extra_keys.sort()
|
| 123 |
+
if extra_keys:
|
| 124 |
+
self.keys.extend(extra_keys)
|
| 125 |
+
self.file.seek(0)
|
| 126 |
+
lines = self.file.readlines()
|
| 127 |
+
self.file.seek(0)
|
| 128 |
+
for (i, k) in enumerate(self.keys):
|
| 129 |
+
if i > 0:
|
| 130 |
+
self.file.write(",")
|
| 131 |
+
self.file.write(k)
|
| 132 |
+
self.file.write("\n")
|
| 133 |
+
for line in lines[1:]:
|
| 134 |
+
self.file.write(line[:-1])
|
| 135 |
+
self.file.write(self.sep * len(extra_keys))
|
| 136 |
+
self.file.write("\n")
|
| 137 |
+
for (i, k) in enumerate(self.keys):
|
| 138 |
+
if i > 0:
|
| 139 |
+
self.file.write(",")
|
| 140 |
+
v = kvs.get(k)
|
| 141 |
+
if v is not None:
|
| 142 |
+
self.file.write(str(v))
|
| 143 |
+
self.file.write("\n")
|
| 144 |
+
self.file.flush()
|
| 145 |
+
|
| 146 |
+
def close(self):
|
| 147 |
+
self.file.close()
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class TensorBoardOutputFormat(KVWriter):
|
| 151 |
+
"""
|
| 152 |
+
Dumps key/value pairs into TensorBoard's numeric format.
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
def __init__(self, dir):
|
| 156 |
+
os.makedirs(dir, exist_ok=True)
|
| 157 |
+
self.dir = dir
|
| 158 |
+
self.step = 1
|
| 159 |
+
prefix = "events"
|
| 160 |
+
path = osp.join(osp.abspath(dir), prefix)
|
| 161 |
+
import tensorflow as tf
|
| 162 |
+
from tensorflow.python import pywrap_tensorflow
|
| 163 |
+
from tensorflow.core.util import event_pb2
|
| 164 |
+
from tensorflow.python.util import compat
|
| 165 |
+
|
| 166 |
+
self.tf = tf
|
| 167 |
+
self.event_pb2 = event_pb2
|
| 168 |
+
self.pywrap_tensorflow = pywrap_tensorflow
|
| 169 |
+
self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
|
| 170 |
+
|
| 171 |
+
def writekvs(self, kvs):
|
| 172 |
+
def summary_val(k, v):
|
| 173 |
+
kwargs = {"tag": k, "simple_value": float(v)}
|
| 174 |
+
return self.tf.Summary.Value(**kwargs)
|
| 175 |
+
|
| 176 |
+
summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
|
| 177 |
+
event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
|
| 178 |
+
event.step = (
|
| 179 |
+
self.step
|
| 180 |
+
) # is there any reason why you'd want to specify the step?
|
| 181 |
+
self.writer.WriteEvent(event)
|
| 182 |
+
self.writer.Flush()
|
| 183 |
+
self.step += 1
|
| 184 |
+
|
| 185 |
+
def close(self):
|
| 186 |
+
if self.writer:
|
| 187 |
+
self.writer.Close()
|
| 188 |
+
self.writer = None
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def make_output_format(format, ev_dir, log_suffix=""):
|
| 192 |
+
os.makedirs(ev_dir, exist_ok=True)
|
| 193 |
+
if format == "stdout":
|
| 194 |
+
return HumanOutputFormat(sys.stdout)
|
| 195 |
+
elif format == "log":
|
| 196 |
+
return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
|
| 197 |
+
elif format == "json":
|
| 198 |
+
return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
|
| 199 |
+
elif format == "csv":
|
| 200 |
+
return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
|
| 201 |
+
elif format == "tensorboard":
|
| 202 |
+
return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
|
| 203 |
+
else:
|
| 204 |
+
raise ValueError("Unknown format specified: %s" % (format,))
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
# ================================================================
|
| 208 |
+
# API
|
| 209 |
+
# ================================================================
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def logkv(key, val):
|
| 213 |
+
"""
|
| 214 |
+
Log a value of some diagnostic
|
| 215 |
+
Call this once for each diagnostic quantity, each iteration
|
| 216 |
+
If called many times, last value will be used.
|
| 217 |
+
"""
|
| 218 |
+
get_current().logkv(key, val)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def logkv_mean(key, val):
|
| 222 |
+
"""
|
| 223 |
+
The same as logkv(), but if called many times, values averaged.
|
| 224 |
+
"""
|
| 225 |
+
get_current().logkv_mean(key, val)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def logkvs(d):
|
| 229 |
+
"""
|
| 230 |
+
Log a dictionary of key-value pairs
|
| 231 |
+
"""
|
| 232 |
+
for (k, v) in d.items():
|
| 233 |
+
logkv(k, v)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def dumpkvs():
|
| 237 |
+
"""
|
| 238 |
+
Write all of the diagnostics from the current iteration
|
| 239 |
+
"""
|
| 240 |
+
return get_current().dumpkvs()
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def getkvs():
|
| 244 |
+
return get_current().name2val
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def log(*args, level=INFO):
|
| 248 |
+
"""
|
| 249 |
+
Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
|
| 250 |
+
"""
|
| 251 |
+
get_current().log(*args, level=level)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def debug(*args):
|
| 255 |
+
log(*args, level=DEBUG)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def info(*args):
|
| 259 |
+
log(*args, level=INFO)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def warn(*args):
|
| 263 |
+
log(*args, level=WARN)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def error(*args):
|
| 267 |
+
log(*args, level=ERROR)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def set_level(level):
|
| 271 |
+
"""
|
| 272 |
+
Set logging threshold on current logger.
|
| 273 |
+
"""
|
| 274 |
+
get_current().set_level(level)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def set_comm(comm):
|
| 278 |
+
get_current().set_comm(comm)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def get_dir():
|
| 282 |
+
"""
|
| 283 |
+
Get directory that log files are being written to.
|
| 284 |
+
will be None if there is no output directory (i.e., if you didn't call start)
|
| 285 |
+
"""
|
| 286 |
+
return get_current().get_dir()
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
record_tabular = logkv
|
| 290 |
+
dump_tabular = dumpkvs
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
@contextmanager
|
| 294 |
+
def profile_kv(scopename):
|
| 295 |
+
logkey = "wait_" + scopename
|
| 296 |
+
tstart = time.time()
|
| 297 |
+
try:
|
| 298 |
+
yield
|
| 299 |
+
finally:
|
| 300 |
+
get_current().name2val[logkey] += time.time() - tstart
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def profile(n):
|
| 304 |
+
"""
|
| 305 |
+
Usage:
|
| 306 |
+
@profile("my_func")
|
| 307 |
+
def my_func(): code
|
| 308 |
+
"""
|
| 309 |
+
|
| 310 |
+
def decorator_with_name(func):
|
| 311 |
+
def func_wrapper(*args, **kwargs):
|
| 312 |
+
with profile_kv(n):
|
| 313 |
+
return func(*args, **kwargs)
|
| 314 |
+
|
| 315 |
+
return func_wrapper
|
| 316 |
+
|
| 317 |
+
return decorator_with_name
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
# ================================================================
|
| 321 |
+
# Backend
|
| 322 |
+
# ================================================================
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def get_current():
|
| 326 |
+
if Logger.CURRENT is None:
|
| 327 |
+
_configure_default_logger()
|
| 328 |
+
|
| 329 |
+
return Logger.CURRENT
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
class Logger(object):
|
| 333 |
+
DEFAULT = None # A logger with no output files. (See right below class definition)
|
| 334 |
+
# So that you can still log to the terminal without setting up any output files
|
| 335 |
+
CURRENT = None # Current logger being used by the free functions above
|
| 336 |
+
|
| 337 |
+
def __init__(self, dir, output_formats, comm=None):
|
| 338 |
+
self.name2val = defaultdict(float) # values this iteration
|
| 339 |
+
self.name2cnt = defaultdict(int)
|
| 340 |
+
self.level = INFO
|
| 341 |
+
self.dir = dir
|
| 342 |
+
self.output_formats = output_formats
|
| 343 |
+
self.comm = comm
|
| 344 |
+
|
| 345 |
+
# Logging API, forwarded
|
| 346 |
+
# ----------------------------------------
|
| 347 |
+
def logkv(self, key, val):
|
| 348 |
+
self.name2val[key] = val
|
| 349 |
+
|
| 350 |
+
def logkv_mean(self, key, val):
|
| 351 |
+
oldval, cnt = self.name2val[key], self.name2cnt[key]
|
| 352 |
+
self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
|
| 353 |
+
self.name2cnt[key] = cnt + 1
|
| 354 |
+
|
| 355 |
+
def dumpkvs(self):
|
| 356 |
+
if self.comm is None:
|
| 357 |
+
d = self.name2val
|
| 358 |
+
else:
|
| 359 |
+
d = mpi_weighted_mean(
|
| 360 |
+
self.comm,
|
| 361 |
+
{
|
| 362 |
+
name: (val, self.name2cnt.get(name, 1))
|
| 363 |
+
for (name, val) in self.name2val.items()
|
| 364 |
+
},
|
| 365 |
+
)
|
| 366 |
+
if self.comm.rank != 0:
|
| 367 |
+
d["dummy"] = 1 # so we don't get a warning about empty dict
|
| 368 |
+
out = d.copy() # Return the dict for unit testing purposes
|
| 369 |
+
for fmt in self.output_formats:
|
| 370 |
+
if isinstance(fmt, KVWriter):
|
| 371 |
+
fmt.writekvs(d)
|
| 372 |
+
self.name2val.clear()
|
| 373 |
+
self.name2cnt.clear()
|
| 374 |
+
return out
|
| 375 |
+
|
| 376 |
+
def log(self, *args, level=INFO):
|
| 377 |
+
if self.level <= level:
|
| 378 |
+
self._do_log(args)
|
| 379 |
+
|
| 380 |
+
# Configuration
|
| 381 |
+
# ----------------------------------------
|
| 382 |
+
def set_level(self, level):
|
| 383 |
+
self.level = level
|
| 384 |
+
|
| 385 |
+
def set_comm(self, comm):
|
| 386 |
+
self.comm = comm
|
| 387 |
+
|
| 388 |
+
def get_dir(self):
|
| 389 |
+
return self.dir
|
| 390 |
+
|
| 391 |
+
def close(self):
|
| 392 |
+
for fmt in self.output_formats:
|
| 393 |
+
fmt.close()
|
| 394 |
+
|
| 395 |
+
# Misc
|
| 396 |
+
# ----------------------------------------
|
| 397 |
+
def _do_log(self, args):
|
| 398 |
+
for fmt in self.output_formats:
|
| 399 |
+
if isinstance(fmt, SeqWriter):
|
| 400 |
+
fmt.writeseq(map(str, args))
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def get_rank_without_mpi_import():
|
| 404 |
+
# check environment variables here instead of importing mpi4py
|
| 405 |
+
# to avoid calling MPI_Init() when this module is imported
|
| 406 |
+
for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
|
| 407 |
+
if varname in os.environ:
|
| 408 |
+
return int(os.environ[varname])
|
| 409 |
+
return 0
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def mpi_weighted_mean(comm, local_name2valcount):
|
| 413 |
+
"""
|
| 414 |
+
Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
|
| 415 |
+
Perform a weighted average over dicts that are each on a different node
|
| 416 |
+
Input: local_name2valcount: dict mapping key -> (value, count)
|
| 417 |
+
Returns: key -> mean
|
| 418 |
+
"""
|
| 419 |
+
all_name2valcount = comm.gather(local_name2valcount)
|
| 420 |
+
if comm.rank == 0:
|
| 421 |
+
name2sum = defaultdict(float)
|
| 422 |
+
name2count = defaultdict(float)
|
| 423 |
+
for n2vc in all_name2valcount:
|
| 424 |
+
for (name, (val, count)) in n2vc.items():
|
| 425 |
+
try:
|
| 426 |
+
val = float(val)
|
| 427 |
+
except ValueError:
|
| 428 |
+
if comm.rank == 0:
|
| 429 |
+
warnings.warn(
|
| 430 |
+
"WARNING: tried to compute mean on non-float {}={}".format(
|
| 431 |
+
name, val
|
| 432 |
+
)
|
| 433 |
+
)
|
| 434 |
+
else:
|
| 435 |
+
name2sum[name] += val * count
|
| 436 |
+
name2count[name] += count
|
| 437 |
+
return {name: name2sum[name] / name2count[name] for name in name2sum}
|
| 438 |
+
else:
|
| 439 |
+
return {}
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
|
| 443 |
+
"""
|
| 444 |
+
If comm is provided, average all numerical stats across that comm
|
| 445 |
+
"""
|
| 446 |
+
if dir is None:
|
| 447 |
+
dir = os.getenv("OPENAI_LOGDIR")
|
| 448 |
+
if dir is None:
|
| 449 |
+
dir = osp.join(
|
| 450 |
+
tempfile.gettempdir(),
|
| 451 |
+
datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
|
| 452 |
+
)
|
| 453 |
+
assert isinstance(dir, str)
|
| 454 |
+
dir = os.path.expanduser(dir)
|
| 455 |
+
os.makedirs(os.path.expanduser(dir), exist_ok=True)
|
| 456 |
+
|
| 457 |
+
rank = get_rank_without_mpi_import()
|
| 458 |
+
if rank > 0:
|
| 459 |
+
log_suffix = log_suffix + "-rank%03i" % rank
|
| 460 |
+
|
| 461 |
+
if format_strs is None:
|
| 462 |
+
if rank == 0:
|
| 463 |
+
format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
|
| 464 |
+
else:
|
| 465 |
+
format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
|
| 466 |
+
format_strs = filter(None, format_strs)
|
| 467 |
+
output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
|
| 468 |
+
|
| 469 |
+
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
|
| 470 |
+
if output_formats:
|
| 471 |
+
log("Logging to %s" % dir)
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def _configure_default_logger():
|
| 475 |
+
configure()
|
| 476 |
+
Logger.DEFAULT = Logger.CURRENT
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
def reset():
|
| 480 |
+
if Logger.CURRENT is not Logger.DEFAULT:
|
| 481 |
+
Logger.CURRENT.close()
|
| 482 |
+
Logger.CURRENT = Logger.DEFAULT
|
| 483 |
+
log("Reset logger")
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
@contextmanager
|
| 487 |
+
def scoped_configure(dir=None, format_strs=None, comm=None):
|
| 488 |
+
prevlogger = Logger.CURRENT
|
| 489 |
+
configure(dir=dir, format_strs=format_strs, comm=comm)
|
| 490 |
+
try:
|
| 491 |
+
yield
|
| 492 |
+
finally:
|
| 493 |
+
Logger.CURRENT.close()
|
| 494 |
+
Logger.CURRENT = prevlogger
|
| 495 |
+
|
codes/modules/nn.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Various utilities for neural networks.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
import torch as th
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
| 12 |
+
class SiLU(nn.Module):
|
| 13 |
+
def forward(self, x):
|
| 14 |
+
return x * th.sigmoid(x)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class GroupNorm32(nn.GroupNorm):
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
return super().forward(x.float()).type(x.dtype)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def conv_nd(dims, *args, **kwargs):
|
| 23 |
+
"""
|
| 24 |
+
Create a 1D, 2D, or 3D convolution module.
|
| 25 |
+
"""
|
| 26 |
+
if dims == 1:
|
| 27 |
+
return nn.Conv1d(*args, **kwargs)
|
| 28 |
+
elif dims == 2:
|
| 29 |
+
return nn.Conv2d(*args, **kwargs)
|
| 30 |
+
elif dims == 3:
|
| 31 |
+
return nn.Conv3d(*args, **kwargs)
|
| 32 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def linear(*args, **kwargs):
|
| 36 |
+
"""
|
| 37 |
+
Create a linear module.
|
| 38 |
+
"""
|
| 39 |
+
return nn.Linear(*args, **kwargs)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
| 43 |
+
"""
|
| 44 |
+
Create a 1D, 2D, or 3D average pooling module.
|
| 45 |
+
"""
|
| 46 |
+
if dims == 1:
|
| 47 |
+
return nn.AvgPool1d(*args, **kwargs)
|
| 48 |
+
elif dims == 2:
|
| 49 |
+
return nn.AvgPool2d(*args, **kwargs)
|
| 50 |
+
elif dims == 3:
|
| 51 |
+
return nn.AvgPool3d(*args, **kwargs)
|
| 52 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def update_ema(target_params, source_params, rate=0.99):
|
| 56 |
+
"""
|
| 57 |
+
Update target parameters to be closer to those of source parameters using
|
| 58 |
+
an exponential moving average.
|
| 59 |
+
|
| 60 |
+
:param target_params: the target parameter sequence.
|
| 61 |
+
:param source_params: the source parameter sequence.
|
| 62 |
+
:param rate: the EMA rate (closer to 1 means slower).
|
| 63 |
+
"""
|
| 64 |
+
for targ, src in zip(target_params, source_params):
|
| 65 |
+
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def zero_module(module):
|
| 69 |
+
"""
|
| 70 |
+
Zero out the parameters of a module and return it.
|
| 71 |
+
"""
|
| 72 |
+
for p in module.parameters():
|
| 73 |
+
p.detach().zero_()
|
| 74 |
+
return module
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def scale_module(module, scale):
|
| 78 |
+
"""
|
| 79 |
+
Scale the parameters of a module and return it.
|
| 80 |
+
"""
|
| 81 |
+
for p in module.parameters():
|
| 82 |
+
p.detach().mul_(scale)
|
| 83 |
+
return module
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def mean_flat(tensor):
|
| 87 |
+
"""
|
| 88 |
+
Take the mean over all non-batch dimensions.
|
| 89 |
+
"""
|
| 90 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def normalization(channels):
|
| 94 |
+
"""
|
| 95 |
+
Make a standard normalization layer.
|
| 96 |
+
|
| 97 |
+
:param channels: number of input channels.
|
| 98 |
+
:return: an nn.Module for normalization.
|
| 99 |
+
"""
|
| 100 |
+
return GroupNorm32(32, channels)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def timestep_embedding(timesteps, dim, max_period=10000):
|
| 104 |
+
"""
|
| 105 |
+
Create sinusoidal timestep embeddings.
|
| 106 |
+
|
| 107 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
| 108 |
+
These may be fractional.
|
| 109 |
+
:param dim: the dimension of the output.
|
| 110 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 111 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
| 112 |
+
"""
|
| 113 |
+
half = dim // 2
|
| 114 |
+
freqs = th.exp(
|
| 115 |
+
-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
|
| 116 |
+
).to(device=timesteps.device)
|
| 117 |
+
args = timesteps[:, None].float() * freqs[None]
|
| 118 |
+
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
|
| 119 |
+
if dim % 2:
|
| 120 |
+
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
|
| 121 |
+
return embedding
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def checkpoint(func, inputs, params, flag):
|
| 125 |
+
"""
|
| 126 |
+
Evaluate a function without caching intermediate activations, allowing for
|
| 127 |
+
reduced memory at the expense of extra compute in the backward pass.
|
| 128 |
+
|
| 129 |
+
:param func: the function to evaluate.
|
| 130 |
+
:param inputs: the argument sequence to pass to `func`.
|
| 131 |
+
:param params: a sequence of parameters `func` depends on but does not
|
| 132 |
+
explicitly take as arguments.
|
| 133 |
+
:param flag: if False, disable gradient checkpointing.
|
| 134 |
+
"""
|
| 135 |
+
if flag:
|
| 136 |
+
args = tuple(inputs) + tuple(params)
|
| 137 |
+
return CheckpointFunction.apply(func, len(inputs), *args)
|
| 138 |
+
else:
|
| 139 |
+
return func(*inputs)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class CheckpointFunction(th.autograd.Function):
|
| 143 |
+
@staticmethod
|
| 144 |
+
def forward(ctx, run_function, length, *args):
|
| 145 |
+
ctx.run_function = run_function
|
| 146 |
+
ctx.input_tensors = list(args[:length])
|
| 147 |
+
ctx.input_params = list(args[length:])
|
| 148 |
+
with th.no_grad():
|
| 149 |
+
output_tensors = ctx.run_function(*ctx.input_tensors)
|
| 150 |
+
return output_tensors
|
| 151 |
+
|
| 152 |
+
@staticmethod
|
| 153 |
+
def backward(ctx, *output_grads):
|
| 154 |
+
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
| 155 |
+
with th.enable_grad():
|
| 156 |
+
# Fixes a bug where the first op in run_function modifies the
|
| 157 |
+
# Tensor storage in place, which is not allowed for detach()'d
|
| 158 |
+
# Tensors.
|
| 159 |
+
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
| 160 |
+
output_tensors = ctx.run_function(*shallow_copies)
|
| 161 |
+
input_grads = th.autograd.grad(
|
| 162 |
+
output_tensors,
|
| 163 |
+
ctx.input_tensors + ctx.input_params,
|
| 164 |
+
output_grads,
|
| 165 |
+
allow_unused=True,
|
| 166 |
+
)
|
| 167 |
+
del ctx.input_tensors
|
| 168 |
+
del ctx.input_params
|
| 169 |
+
del output_tensors
|
| 170 |
+
return (None, None) + input_grads
|
codes/modules/unet.py
ADDED
|
@@ -0,0 +1,894 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import abstractmethod
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch as th
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from .fp16_util import convert_module_to_f16, convert_module_to_f32
|
| 11 |
+
from .nn import (
|
| 12 |
+
checkpoint,
|
| 13 |
+
conv_nd,
|
| 14 |
+
linear,
|
| 15 |
+
avg_pool_nd,
|
| 16 |
+
zero_module,
|
| 17 |
+
normalization,
|
| 18 |
+
timestep_embedding,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class AttentionPool2d(nn.Module):
|
| 23 |
+
"""
|
| 24 |
+
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
spacial_dim: int,
|
| 30 |
+
embed_dim: int,
|
| 31 |
+
num_heads_channels: int,
|
| 32 |
+
output_dim: int = None,
|
| 33 |
+
):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.positional_embedding = nn.Parameter(
|
| 36 |
+
th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
|
| 37 |
+
)
|
| 38 |
+
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
| 39 |
+
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
| 40 |
+
self.num_heads = embed_dim // num_heads_channels
|
| 41 |
+
self.attention = QKVAttention(self.num_heads)
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
b, c, *_spatial = x.shape
|
| 45 |
+
x = x.reshape(b, c, -1) # NC(HW)
|
| 46 |
+
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
| 47 |
+
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
| 48 |
+
x = self.qkv_proj(x)
|
| 49 |
+
x = self.attention(x)
|
| 50 |
+
x = self.c_proj(x)
|
| 51 |
+
return x[:, :, 0]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class TimestepBlock(nn.Module):
|
| 55 |
+
"""
|
| 56 |
+
Any module where forward() takes timestep embeddings as a second argument.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
@abstractmethod
|
| 60 |
+
def forward(self, x, emb):
|
| 61 |
+
"""
|
| 62 |
+
Apply the module to `x` given `emb` timestep embeddings.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
| 67 |
+
"""
|
| 68 |
+
A sequential module that passes timestep embeddings to the children that
|
| 69 |
+
support it as an extra input.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def forward(self, x, emb):
|
| 73 |
+
for layer in self:
|
| 74 |
+
if isinstance(layer, TimestepBlock):
|
| 75 |
+
x = layer(x, emb)
|
| 76 |
+
else:
|
| 77 |
+
x = layer(x)
|
| 78 |
+
return x
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class Upsample(nn.Module):
|
| 82 |
+
"""
|
| 83 |
+
An upsampling layer with an optional convolution.
|
| 84 |
+
|
| 85 |
+
:param channels: channels in the inputs and outputs.
|
| 86 |
+
:param use_conv: a bool determining if a convolution is applied.
|
| 87 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
| 88 |
+
upsampling occurs in the inner-two dimensions.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.channels = channels
|
| 94 |
+
self.out_channels = out_channels or channels
|
| 95 |
+
self.use_conv = use_conv
|
| 96 |
+
self.dims = dims
|
| 97 |
+
if use_conv:
|
| 98 |
+
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
| 99 |
+
|
| 100 |
+
def forward(self, x):
|
| 101 |
+
assert x.shape[1] == self.channels
|
| 102 |
+
if self.dims == 3:
|
| 103 |
+
x = F.interpolate(
|
| 104 |
+
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
| 105 |
+
)
|
| 106 |
+
else:
|
| 107 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
| 108 |
+
if self.use_conv:
|
| 109 |
+
x = self.conv(x)
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class Downsample(nn.Module):
|
| 114 |
+
"""
|
| 115 |
+
A downsampling layer with an optional convolution.
|
| 116 |
+
|
| 117 |
+
:param channels: channels in the inputs and outputs.
|
| 118 |
+
:param use_conv: a bool determining if a convolution is applied.
|
| 119 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
| 120 |
+
downsampling occurs in the inner-two dimensions.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.channels = channels
|
| 126 |
+
self.out_channels = out_channels or channels
|
| 127 |
+
self.use_conv = use_conv
|
| 128 |
+
self.dims = dims
|
| 129 |
+
stride = 2 if dims != 3 else (1, 2, 2)
|
| 130 |
+
if use_conv:
|
| 131 |
+
self.op = conv_nd(
|
| 132 |
+
dims, self.channels, self.out_channels, 3, stride=stride, padding=1
|
| 133 |
+
)
|
| 134 |
+
else:
|
| 135 |
+
assert self.channels == self.out_channels
|
| 136 |
+
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
| 137 |
+
|
| 138 |
+
def forward(self, x):
|
| 139 |
+
assert x.shape[1] == self.channels
|
| 140 |
+
return self.op(x)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class ResBlock(TimestepBlock):
|
| 144 |
+
"""
|
| 145 |
+
A residual block that can optionally change the number of channels.
|
| 146 |
+
|
| 147 |
+
:param channels: the number of input channels.
|
| 148 |
+
:param emb_channels: the number of timestep embedding channels.
|
| 149 |
+
:param dropout: the rate of dropout.
|
| 150 |
+
:param out_channels: if specified, the number of out channels.
|
| 151 |
+
:param use_conv: if True and out_channels is specified, use a spatial
|
| 152 |
+
convolution instead of a smaller 1x1 convolution to change the
|
| 153 |
+
channels in the skip connection.
|
| 154 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
| 155 |
+
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
| 156 |
+
:param up: if True, use this block for upsampling.
|
| 157 |
+
:param down: if True, use this block for downsampling.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
def __init__(
|
| 161 |
+
self,
|
| 162 |
+
channels,
|
| 163 |
+
emb_channels,
|
| 164 |
+
dropout,
|
| 165 |
+
out_channels=None,
|
| 166 |
+
use_conv=False,
|
| 167 |
+
use_scale_shift_norm=False,
|
| 168 |
+
dims=2,
|
| 169 |
+
use_checkpoint=False,
|
| 170 |
+
up=False,
|
| 171 |
+
down=False,
|
| 172 |
+
):
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.channels = channels
|
| 175 |
+
self.emb_channels = emb_channels
|
| 176 |
+
self.dropout = dropout
|
| 177 |
+
self.out_channels = out_channels or channels
|
| 178 |
+
self.use_conv = use_conv
|
| 179 |
+
self.use_checkpoint = use_checkpoint
|
| 180 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
| 181 |
+
|
| 182 |
+
self.in_layers = nn.Sequential(
|
| 183 |
+
normalization(channels),
|
| 184 |
+
nn.SiLU(),
|
| 185 |
+
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
self.updown = up or down
|
| 189 |
+
|
| 190 |
+
if up:
|
| 191 |
+
self.h_upd = Upsample(channels, False, dims)
|
| 192 |
+
self.x_upd = Upsample(channels, False, dims)
|
| 193 |
+
elif down:
|
| 194 |
+
self.h_upd = Downsample(channels, False, dims)
|
| 195 |
+
self.x_upd = Downsample(channels, False, dims)
|
| 196 |
+
else:
|
| 197 |
+
self.h_upd = self.x_upd = nn.Identity()
|
| 198 |
+
|
| 199 |
+
self.emb_layers = nn.Sequential(
|
| 200 |
+
nn.SiLU(),
|
| 201 |
+
linear(
|
| 202 |
+
emb_channels,
|
| 203 |
+
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
| 204 |
+
),
|
| 205 |
+
)
|
| 206 |
+
self.out_layers = nn.Sequential(
|
| 207 |
+
normalization(self.out_channels),
|
| 208 |
+
nn.SiLU(),
|
| 209 |
+
nn.Dropout(p=dropout),
|
| 210 |
+
zero_module(
|
| 211 |
+
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
| 212 |
+
),
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
if self.out_channels == channels:
|
| 216 |
+
self.skip_connection = nn.Identity()
|
| 217 |
+
elif use_conv:
|
| 218 |
+
self.skip_connection = conv_nd(
|
| 219 |
+
dims, channels, self.out_channels, 3, padding=1
|
| 220 |
+
)
|
| 221 |
+
else:
|
| 222 |
+
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
| 223 |
+
|
| 224 |
+
def forward(self, x, emb):
|
| 225 |
+
"""
|
| 226 |
+
Apply the block to a Tensor, conditioned on a timestep embedding.
|
| 227 |
+
|
| 228 |
+
:param x: an [N x C x ...] Tensor of features.
|
| 229 |
+
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
| 230 |
+
:return: an [N x C x ...] Tensor of outputs.
|
| 231 |
+
"""
|
| 232 |
+
return checkpoint(
|
| 233 |
+
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
def _forward(self, x, emb):
|
| 237 |
+
if self.updown:
|
| 238 |
+
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
| 239 |
+
h = in_rest(x)
|
| 240 |
+
h = self.h_upd(h)
|
| 241 |
+
x = self.x_upd(x)
|
| 242 |
+
h = in_conv(h)
|
| 243 |
+
else:
|
| 244 |
+
h = self.in_layers(x)
|
| 245 |
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
| 246 |
+
while len(emb_out.shape) < len(h.shape):
|
| 247 |
+
emb_out = emb_out[..., None]
|
| 248 |
+
if self.use_scale_shift_norm:
|
| 249 |
+
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
| 250 |
+
scale, shift = th.chunk(emb_out, 2, dim=1)
|
| 251 |
+
h = out_norm(h) * (1 + scale) + shift
|
| 252 |
+
h = out_rest(h)
|
| 253 |
+
else:
|
| 254 |
+
h = h + emb_out
|
| 255 |
+
h = self.out_layers(h)
|
| 256 |
+
return self.skip_connection(x) + h
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class AttentionBlock(nn.Module):
|
| 260 |
+
"""
|
| 261 |
+
An attention block that allows spatial positions to attend to each other.
|
| 262 |
+
|
| 263 |
+
Originally ported from here, but adapted to the N-d case.
|
| 264 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
def __init__(
|
| 268 |
+
self,
|
| 269 |
+
channels,
|
| 270 |
+
num_heads=1,
|
| 271 |
+
num_head_channels=-1,
|
| 272 |
+
use_checkpoint=False,
|
| 273 |
+
use_new_attention_order=False,
|
| 274 |
+
):
|
| 275 |
+
super().__init__()
|
| 276 |
+
self.channels = channels
|
| 277 |
+
if num_head_channels == -1:
|
| 278 |
+
self.num_heads = num_heads
|
| 279 |
+
else:
|
| 280 |
+
assert (
|
| 281 |
+
channels % num_head_channels == 0
|
| 282 |
+
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
| 283 |
+
self.num_heads = channels // num_head_channels
|
| 284 |
+
self.use_checkpoint = use_checkpoint
|
| 285 |
+
self.norm = normalization(channels)
|
| 286 |
+
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
| 287 |
+
if use_new_attention_order:
|
| 288 |
+
# split qkv before split heads
|
| 289 |
+
self.attention = QKVAttention(self.num_heads)
|
| 290 |
+
else:
|
| 291 |
+
# split heads before split qkv
|
| 292 |
+
self.attention = QKVAttentionLegacy(self.num_heads)
|
| 293 |
+
|
| 294 |
+
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
| 295 |
+
|
| 296 |
+
def forward(self, x):
|
| 297 |
+
return checkpoint(self._forward, (x,), self.parameters(), True)
|
| 298 |
+
|
| 299 |
+
def _forward(self, x):
|
| 300 |
+
b, c, *spatial = x.shape
|
| 301 |
+
x = x.reshape(b, c, -1)
|
| 302 |
+
qkv = self.qkv(self.norm(x))
|
| 303 |
+
h = self.attention(qkv)
|
| 304 |
+
h = self.proj_out(h)
|
| 305 |
+
return (x + h).reshape(b, c, *spatial)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def count_flops_attn(model, _x, y):
|
| 309 |
+
"""
|
| 310 |
+
A counter for the `thop` package to count the operations in an
|
| 311 |
+
attention operation.
|
| 312 |
+
Meant to be used like:
|
| 313 |
+
macs, params = thop.profile(
|
| 314 |
+
model,
|
| 315 |
+
inputs=(inputs, timestamps),
|
| 316 |
+
custom_ops={QKVAttention: QKVAttention.count_flops},
|
| 317 |
+
)
|
| 318 |
+
"""
|
| 319 |
+
b, c, *spatial = y[0].shape
|
| 320 |
+
num_spatial = int(np.prod(spatial))
|
| 321 |
+
# We perform two matmuls with the same number of ops.
|
| 322 |
+
# The first computes the weight matrix, the second computes
|
| 323 |
+
# the combination of the value vectors.
|
| 324 |
+
matmul_ops = 2 * b * (num_spatial ** 2) * c
|
| 325 |
+
model.total_ops += th.DoubleTensor([matmul_ops])
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class QKVAttentionLegacy(nn.Module):
|
| 329 |
+
"""
|
| 330 |
+
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
| 331 |
+
"""
|
| 332 |
+
|
| 333 |
+
def __init__(self, n_heads):
|
| 334 |
+
super().__init__()
|
| 335 |
+
self.n_heads = n_heads
|
| 336 |
+
|
| 337 |
+
def forward(self, qkv):
|
| 338 |
+
"""
|
| 339 |
+
Apply QKV attention.
|
| 340 |
+
|
| 341 |
+
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
| 342 |
+
:return: an [N x (H * C) x T] tensor after attention.
|
| 343 |
+
"""
|
| 344 |
+
bs, width, length = qkv.shape
|
| 345 |
+
assert width % (3 * self.n_heads) == 0
|
| 346 |
+
ch = width // (3 * self.n_heads)
|
| 347 |
+
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
| 348 |
+
scale = 1 / math.sqrt(math.sqrt(ch))
|
| 349 |
+
weight = th.einsum(
|
| 350 |
+
"bct,bcs->bts", q * scale, k * scale
|
| 351 |
+
) # More stable with f16 than dividing afterwards
|
| 352 |
+
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
| 353 |
+
a = th.einsum("bts,bcs->bct", weight, v)
|
| 354 |
+
return a.reshape(bs, -1, length)
|
| 355 |
+
|
| 356 |
+
@staticmethod
|
| 357 |
+
def count_flops(model, _x, y):
|
| 358 |
+
return count_flops_attn(model, _x, y)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
class QKVAttention(nn.Module):
|
| 362 |
+
"""
|
| 363 |
+
A module which performs QKV attention and splits in a different order.
|
| 364 |
+
"""
|
| 365 |
+
|
| 366 |
+
def __init__(self, n_heads):
|
| 367 |
+
super().__init__()
|
| 368 |
+
self.n_heads = n_heads
|
| 369 |
+
|
| 370 |
+
def forward(self, qkv):
|
| 371 |
+
"""
|
| 372 |
+
Apply QKV attention.
|
| 373 |
+
|
| 374 |
+
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
| 375 |
+
:return: an [N x (H * C) x T] tensor after attention.
|
| 376 |
+
"""
|
| 377 |
+
bs, width, length = qkv.shape
|
| 378 |
+
assert width % (3 * self.n_heads) == 0
|
| 379 |
+
ch = width // (3 * self.n_heads)
|
| 380 |
+
q, k, v = qkv.chunk(3, dim=1)
|
| 381 |
+
scale = 1 / math.sqrt(math.sqrt(ch))
|
| 382 |
+
weight = th.einsum(
|
| 383 |
+
"bct,bcs->bts",
|
| 384 |
+
(q * scale).view(bs * self.n_heads, ch, length),
|
| 385 |
+
(k * scale).view(bs * self.n_heads, ch, length),
|
| 386 |
+
) # More stable with f16 than dividing afterwards
|
| 387 |
+
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
| 388 |
+
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
| 389 |
+
return a.reshape(bs, -1, length)
|
| 390 |
+
|
| 391 |
+
@staticmethod
|
| 392 |
+
def count_flops(model, _x, y):
|
| 393 |
+
return count_flops_attn(model, _x, y)
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
class UNetModel(nn.Module):
|
| 397 |
+
"""
|
| 398 |
+
The full UNet model with attention and timestep embedding.
|
| 399 |
+
|
| 400 |
+
:param in_channels: channels in the input Tensor.
|
| 401 |
+
:param model_channels: base channel count for the model.
|
| 402 |
+
:param out_channels: channels in the output Tensor.
|
| 403 |
+
:param num_res_blocks: number of residual blocks per downsample.
|
| 404 |
+
:param attention_resolutions: a collection of downsample rates at which
|
| 405 |
+
attention will take place. May be a set, list, or tuple.
|
| 406 |
+
For example, if this contains 4, then at 4x downsampling, attention
|
| 407 |
+
will be used.
|
| 408 |
+
:param dropout: the dropout probability.
|
| 409 |
+
:param channel_mult: channel multiplier for each level of the UNet.
|
| 410 |
+
:param conv_resample: if True, use learned convolutions for upsampling and
|
| 411 |
+
downsampling.
|
| 412 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
| 413 |
+
:param num_classes: if specified (as an int), then this model will be
|
| 414 |
+
class-conditional with `num_classes` classes.
|
| 415 |
+
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
| 416 |
+
:param num_heads: the number of attention heads in each attention layer.
|
| 417 |
+
:param num_heads_channels: if specified, ignore num_heads and instead use
|
| 418 |
+
a fixed channel width per attention head.
|
| 419 |
+
:param num_heads_upsample: works with num_heads to set a different number
|
| 420 |
+
of heads for upsampling. Deprecated.
|
| 421 |
+
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
| 422 |
+
:param resblock_updown: use residual blocks for up/downsampling.
|
| 423 |
+
:param use_new_attention_order: use a different attention pattern for potentially
|
| 424 |
+
increased efficiency.
|
| 425 |
+
"""
|
| 426 |
+
|
| 427 |
+
def __init__(
|
| 428 |
+
self,
|
| 429 |
+
image_size,
|
| 430 |
+
in_channels,
|
| 431 |
+
model_channels,
|
| 432 |
+
out_channels,
|
| 433 |
+
num_res_blocks,
|
| 434 |
+
attention_resolutions,
|
| 435 |
+
dropout=0,
|
| 436 |
+
channel_mult=(1, 2, 4, 8),
|
| 437 |
+
conv_resample=True,
|
| 438 |
+
dims=2,
|
| 439 |
+
num_classes=None,
|
| 440 |
+
use_checkpoint=False,
|
| 441 |
+
use_fp16=False,
|
| 442 |
+
num_heads=1,
|
| 443 |
+
num_head_channels=-1,
|
| 444 |
+
num_heads_upsample=-1,
|
| 445 |
+
use_scale_shift_norm=False,
|
| 446 |
+
resblock_updown=False,
|
| 447 |
+
use_new_attention_order=False,
|
| 448 |
+
):
|
| 449 |
+
super().__init__()
|
| 450 |
+
|
| 451 |
+
if num_heads_upsample == -1:
|
| 452 |
+
num_heads_upsample = num_heads
|
| 453 |
+
|
| 454 |
+
self.image_size = image_size
|
| 455 |
+
self.in_channels = in_channels
|
| 456 |
+
self.model_channels = model_channels
|
| 457 |
+
self.out_channels = out_channels
|
| 458 |
+
self.num_res_blocks = num_res_blocks
|
| 459 |
+
self.attention_resolutions = attention_resolutions
|
| 460 |
+
self.dropout = dropout
|
| 461 |
+
self.channel_mult = channel_mult
|
| 462 |
+
self.conv_resample = conv_resample
|
| 463 |
+
self.num_classes = num_classes
|
| 464 |
+
self.use_checkpoint = use_checkpoint
|
| 465 |
+
self.dtype = th.float16 if use_fp16 else th.float32
|
| 466 |
+
self.num_heads = num_heads
|
| 467 |
+
self.num_head_channels = num_head_channels
|
| 468 |
+
self.num_heads_upsample = num_heads_upsample
|
| 469 |
+
|
| 470 |
+
time_embed_dim = model_channels * 4
|
| 471 |
+
self.time_embed = nn.Sequential(
|
| 472 |
+
linear(model_channels, time_embed_dim),
|
| 473 |
+
nn.SiLU(),
|
| 474 |
+
linear(time_embed_dim, time_embed_dim),
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
if self.num_classes is not None:
|
| 478 |
+
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
| 479 |
+
|
| 480 |
+
ch = input_ch = int(channel_mult[0] * model_channels)
|
| 481 |
+
self.input_blocks = nn.ModuleList(
|
| 482 |
+
[TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
|
| 483 |
+
)
|
| 484 |
+
self._feature_size = ch
|
| 485 |
+
input_block_chans = [ch]
|
| 486 |
+
ds = 1
|
| 487 |
+
for level, mult in enumerate(channel_mult):
|
| 488 |
+
for _ in range(num_res_blocks):
|
| 489 |
+
layers = [
|
| 490 |
+
ResBlock(
|
| 491 |
+
ch,
|
| 492 |
+
time_embed_dim,
|
| 493 |
+
dropout,
|
| 494 |
+
out_channels=int(mult * model_channels),
|
| 495 |
+
dims=dims,
|
| 496 |
+
use_checkpoint=use_checkpoint,
|
| 497 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 498 |
+
)
|
| 499 |
+
]
|
| 500 |
+
ch = int(mult * model_channels)
|
| 501 |
+
if ds in attention_resolutions:
|
| 502 |
+
layers.append(
|
| 503 |
+
AttentionBlock(
|
| 504 |
+
ch,
|
| 505 |
+
use_checkpoint=use_checkpoint,
|
| 506 |
+
num_heads=num_heads,
|
| 507 |
+
num_head_channels=num_head_channels,
|
| 508 |
+
use_new_attention_order=use_new_attention_order,
|
| 509 |
+
)
|
| 510 |
+
)
|
| 511 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
| 512 |
+
self._feature_size += ch
|
| 513 |
+
input_block_chans.append(ch)
|
| 514 |
+
if level != len(channel_mult) - 1:
|
| 515 |
+
out_ch = ch
|
| 516 |
+
self.input_blocks.append(
|
| 517 |
+
TimestepEmbedSequential(
|
| 518 |
+
ResBlock(
|
| 519 |
+
ch,
|
| 520 |
+
time_embed_dim,
|
| 521 |
+
dropout,
|
| 522 |
+
out_channels=out_ch,
|
| 523 |
+
dims=dims,
|
| 524 |
+
use_checkpoint=use_checkpoint,
|
| 525 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 526 |
+
down=True,
|
| 527 |
+
)
|
| 528 |
+
if resblock_updown
|
| 529 |
+
else Downsample(
|
| 530 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
| 531 |
+
)
|
| 532 |
+
)
|
| 533 |
+
)
|
| 534 |
+
ch = out_ch
|
| 535 |
+
input_block_chans.append(ch)
|
| 536 |
+
ds *= 2
|
| 537 |
+
self._feature_size += ch
|
| 538 |
+
|
| 539 |
+
self.middle_block = TimestepEmbedSequential(
|
| 540 |
+
ResBlock(
|
| 541 |
+
ch,
|
| 542 |
+
time_embed_dim,
|
| 543 |
+
dropout,
|
| 544 |
+
dims=dims,
|
| 545 |
+
use_checkpoint=use_checkpoint,
|
| 546 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 547 |
+
),
|
| 548 |
+
AttentionBlock(
|
| 549 |
+
ch,
|
| 550 |
+
use_checkpoint=use_checkpoint,
|
| 551 |
+
num_heads=num_heads,
|
| 552 |
+
num_head_channels=num_head_channels,
|
| 553 |
+
use_new_attention_order=use_new_attention_order,
|
| 554 |
+
),
|
| 555 |
+
ResBlock(
|
| 556 |
+
ch,
|
| 557 |
+
time_embed_dim,
|
| 558 |
+
dropout,
|
| 559 |
+
dims=dims,
|
| 560 |
+
use_checkpoint=use_checkpoint,
|
| 561 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 562 |
+
),
|
| 563 |
+
)
|
| 564 |
+
self._feature_size += ch
|
| 565 |
+
|
| 566 |
+
self.output_blocks = nn.ModuleList([])
|
| 567 |
+
for level, mult in list(enumerate(channel_mult))[::-1]:
|
| 568 |
+
for i in range(num_res_blocks + 1):
|
| 569 |
+
ich = input_block_chans.pop()
|
| 570 |
+
layers = [
|
| 571 |
+
ResBlock(
|
| 572 |
+
ch + ich,
|
| 573 |
+
time_embed_dim,
|
| 574 |
+
dropout,
|
| 575 |
+
out_channels=int(model_channels * mult),
|
| 576 |
+
dims=dims,
|
| 577 |
+
use_checkpoint=use_checkpoint,
|
| 578 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 579 |
+
)
|
| 580 |
+
]
|
| 581 |
+
ch = int(model_channels * mult)
|
| 582 |
+
if ds in attention_resolutions:
|
| 583 |
+
layers.append(
|
| 584 |
+
AttentionBlock(
|
| 585 |
+
ch,
|
| 586 |
+
use_checkpoint=use_checkpoint,
|
| 587 |
+
num_heads=num_heads_upsample,
|
| 588 |
+
num_head_channels=num_head_channels,
|
| 589 |
+
use_new_attention_order=use_new_attention_order,
|
| 590 |
+
)
|
| 591 |
+
)
|
| 592 |
+
if level and i == num_res_blocks:
|
| 593 |
+
out_ch = ch
|
| 594 |
+
layers.append(
|
| 595 |
+
ResBlock(
|
| 596 |
+
ch,
|
| 597 |
+
time_embed_dim,
|
| 598 |
+
dropout,
|
| 599 |
+
out_channels=out_ch,
|
| 600 |
+
dims=dims,
|
| 601 |
+
use_checkpoint=use_checkpoint,
|
| 602 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 603 |
+
up=True,
|
| 604 |
+
)
|
| 605 |
+
if resblock_updown
|
| 606 |
+
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
| 607 |
+
)
|
| 608 |
+
ds //= 2
|
| 609 |
+
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
| 610 |
+
self._feature_size += ch
|
| 611 |
+
|
| 612 |
+
self.out = nn.Sequential(
|
| 613 |
+
normalization(ch),
|
| 614 |
+
nn.SiLU(),
|
| 615 |
+
zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
def convert_to_fp16(self):
|
| 619 |
+
"""
|
| 620 |
+
Convert the torso of the model to float16.
|
| 621 |
+
"""
|
| 622 |
+
self.input_blocks.apply(convert_module_to_f16)
|
| 623 |
+
self.middle_block.apply(convert_module_to_f16)
|
| 624 |
+
self.output_blocks.apply(convert_module_to_f16)
|
| 625 |
+
|
| 626 |
+
def convert_to_fp32(self):
|
| 627 |
+
"""
|
| 628 |
+
Convert the torso of the model to float32.
|
| 629 |
+
"""
|
| 630 |
+
self.input_blocks.apply(convert_module_to_f32)
|
| 631 |
+
self.middle_block.apply(convert_module_to_f32)
|
| 632 |
+
self.output_blocks.apply(convert_module_to_f32)
|
| 633 |
+
|
| 634 |
+
def forward(self, x, timesteps, y=None):
|
| 635 |
+
"""
|
| 636 |
+
Apply the model to an input batch.
|
| 637 |
+
|
| 638 |
+
:param x: an [N x C x ...] Tensor of inputs.
|
| 639 |
+
:param timesteps: a 1-D batch of timesteps.
|
| 640 |
+
:param y: an [N] Tensor of labels, if class-conditional.
|
| 641 |
+
:return: an [N x C x ...] Tensor of outputs.
|
| 642 |
+
"""
|
| 643 |
+
assert (y is not None) == (
|
| 644 |
+
self.num_classes is not None
|
| 645 |
+
), "must specify y if and only if the model is class-conditional"
|
| 646 |
+
|
| 647 |
+
hs = []
|
| 648 |
+
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
| 649 |
+
|
| 650 |
+
if self.num_classes is not None:
|
| 651 |
+
assert y.shape == (x.shape[0],)
|
| 652 |
+
emb = emb + self.label_emb(y)
|
| 653 |
+
|
| 654 |
+
h = x.type(self.dtype)
|
| 655 |
+
for module in self.input_blocks:
|
| 656 |
+
h = module(h, emb)
|
| 657 |
+
hs.append(h)
|
| 658 |
+
h = self.middle_block(h, emb)
|
| 659 |
+
for module in self.output_blocks:
|
| 660 |
+
h = th.cat([h, hs.pop()], dim=1)
|
| 661 |
+
h = module(h, emb)
|
| 662 |
+
h = h.type(x.dtype)
|
| 663 |
+
return self.out(h)
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
class SuperResModel(UNetModel):
|
| 667 |
+
"""
|
| 668 |
+
A UNetModel that performs super-resolution.
|
| 669 |
+
|
| 670 |
+
Expects an extra kwarg `low_res` to condition on a low-resolution image.
|
| 671 |
+
"""
|
| 672 |
+
|
| 673 |
+
def __init__(self, image_size, in_channels, *args, **kwargs):
|
| 674 |
+
super().__init__(image_size, in_channels * 2, *args, **kwargs)
|
| 675 |
+
|
| 676 |
+
def forward(self, x, timesteps, low_res=None, **kwargs):
|
| 677 |
+
_, _, new_height, new_width = x.shape
|
| 678 |
+
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
|
| 679 |
+
x = th.cat([x, upsampled], dim=1)
|
| 680 |
+
return super().forward(x, timesteps, **kwargs)
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
class EncoderUNetModel(nn.Module):
|
| 684 |
+
"""
|
| 685 |
+
The half UNet model with attention and timestep embedding.
|
| 686 |
+
|
| 687 |
+
For usage, see UNet.
|
| 688 |
+
"""
|
| 689 |
+
|
| 690 |
+
def __init__(
|
| 691 |
+
self,
|
| 692 |
+
image_size,
|
| 693 |
+
in_channels,
|
| 694 |
+
model_channels,
|
| 695 |
+
out_channels,
|
| 696 |
+
num_res_blocks,
|
| 697 |
+
attention_resolutions,
|
| 698 |
+
dropout=0,
|
| 699 |
+
channel_mult=(1, 2, 4, 8),
|
| 700 |
+
conv_resample=True,
|
| 701 |
+
dims=2,
|
| 702 |
+
use_checkpoint=False,
|
| 703 |
+
use_fp16=False,
|
| 704 |
+
num_heads=1,
|
| 705 |
+
num_head_channels=-1,
|
| 706 |
+
num_heads_upsample=-1,
|
| 707 |
+
use_scale_shift_norm=False,
|
| 708 |
+
resblock_updown=False,
|
| 709 |
+
use_new_attention_order=False,
|
| 710 |
+
pool="adaptive",
|
| 711 |
+
):
|
| 712 |
+
super().__init__()
|
| 713 |
+
|
| 714 |
+
if num_heads_upsample == -1:
|
| 715 |
+
num_heads_upsample = num_heads
|
| 716 |
+
|
| 717 |
+
self.in_channels = in_channels
|
| 718 |
+
self.model_channels = model_channels
|
| 719 |
+
self.out_channels = out_channels
|
| 720 |
+
self.num_res_blocks = num_res_blocks
|
| 721 |
+
self.attention_resolutions = attention_resolutions
|
| 722 |
+
self.dropout = dropout
|
| 723 |
+
self.channel_mult = channel_mult
|
| 724 |
+
self.conv_resample = conv_resample
|
| 725 |
+
self.use_checkpoint = use_checkpoint
|
| 726 |
+
self.dtype = th.float16 if use_fp16 else th.float32
|
| 727 |
+
self.num_heads = num_heads
|
| 728 |
+
self.num_head_channels = num_head_channels
|
| 729 |
+
self.num_heads_upsample = num_heads_upsample
|
| 730 |
+
|
| 731 |
+
time_embed_dim = model_channels * 4
|
| 732 |
+
self.time_embed = nn.Sequential(
|
| 733 |
+
linear(model_channels, time_embed_dim),
|
| 734 |
+
nn.SiLU(),
|
| 735 |
+
linear(time_embed_dim, time_embed_dim),
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
ch = int(channel_mult[0] * model_channels)
|
| 739 |
+
self.input_blocks = nn.ModuleList(
|
| 740 |
+
[TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
|
| 741 |
+
)
|
| 742 |
+
self._feature_size = ch
|
| 743 |
+
input_block_chans = [ch]
|
| 744 |
+
ds = 1
|
| 745 |
+
for level, mult in enumerate(channel_mult):
|
| 746 |
+
for _ in range(num_res_blocks):
|
| 747 |
+
layers = [
|
| 748 |
+
ResBlock(
|
| 749 |
+
ch,
|
| 750 |
+
time_embed_dim,
|
| 751 |
+
dropout,
|
| 752 |
+
out_channels=int(mult * model_channels),
|
| 753 |
+
dims=dims,
|
| 754 |
+
use_checkpoint=use_checkpoint,
|
| 755 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 756 |
+
)
|
| 757 |
+
]
|
| 758 |
+
ch = int(mult * model_channels)
|
| 759 |
+
if ds in attention_resolutions:
|
| 760 |
+
layers.append(
|
| 761 |
+
AttentionBlock(
|
| 762 |
+
ch,
|
| 763 |
+
use_checkpoint=use_checkpoint,
|
| 764 |
+
num_heads=num_heads,
|
| 765 |
+
num_head_channels=num_head_channels,
|
| 766 |
+
use_new_attention_order=use_new_attention_order,
|
| 767 |
+
)
|
| 768 |
+
)
|
| 769 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
| 770 |
+
self._feature_size += ch
|
| 771 |
+
input_block_chans.append(ch)
|
| 772 |
+
if level != len(channel_mult) - 1:
|
| 773 |
+
out_ch = ch
|
| 774 |
+
self.input_blocks.append(
|
| 775 |
+
TimestepEmbedSequential(
|
| 776 |
+
ResBlock(
|
| 777 |
+
ch,
|
| 778 |
+
time_embed_dim,
|
| 779 |
+
dropout,
|
| 780 |
+
out_channels=out_ch,
|
| 781 |
+
dims=dims,
|
| 782 |
+
use_checkpoint=use_checkpoint,
|
| 783 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 784 |
+
down=True,
|
| 785 |
+
)
|
| 786 |
+
if resblock_updown
|
| 787 |
+
else Downsample(
|
| 788 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
| 789 |
+
)
|
| 790 |
+
)
|
| 791 |
+
)
|
| 792 |
+
ch = out_ch
|
| 793 |
+
input_block_chans.append(ch)
|
| 794 |
+
ds *= 2
|
| 795 |
+
self._feature_size += ch
|
| 796 |
+
|
| 797 |
+
self.middle_block = TimestepEmbedSequential(
|
| 798 |
+
ResBlock(
|
| 799 |
+
ch,
|
| 800 |
+
time_embed_dim,
|
| 801 |
+
dropout,
|
| 802 |
+
dims=dims,
|
| 803 |
+
use_checkpoint=use_checkpoint,
|
| 804 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 805 |
+
),
|
| 806 |
+
AttentionBlock(
|
| 807 |
+
ch,
|
| 808 |
+
use_checkpoint=use_checkpoint,
|
| 809 |
+
num_heads=num_heads,
|
| 810 |
+
num_head_channels=num_head_channels,
|
| 811 |
+
use_new_attention_order=use_new_attention_order,
|
| 812 |
+
),
|
| 813 |
+
ResBlock(
|
| 814 |
+
ch,
|
| 815 |
+
time_embed_dim,
|
| 816 |
+
dropout,
|
| 817 |
+
dims=dims,
|
| 818 |
+
use_checkpoint=use_checkpoint,
|
| 819 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 820 |
+
),
|
| 821 |
+
)
|
| 822 |
+
self._feature_size += ch
|
| 823 |
+
self.pool = pool
|
| 824 |
+
if pool == "adaptive":
|
| 825 |
+
self.out = nn.Sequential(
|
| 826 |
+
normalization(ch),
|
| 827 |
+
nn.SiLU(),
|
| 828 |
+
nn.AdaptiveAvgPool2d((1, 1)),
|
| 829 |
+
zero_module(conv_nd(dims, ch, out_channels, 1)),
|
| 830 |
+
nn.Flatten(),
|
| 831 |
+
)
|
| 832 |
+
elif pool == "attention":
|
| 833 |
+
assert num_head_channels != -1
|
| 834 |
+
self.out = nn.Sequential(
|
| 835 |
+
normalization(ch),
|
| 836 |
+
nn.SiLU(),
|
| 837 |
+
AttentionPool2d(
|
| 838 |
+
(image_size // ds), ch, num_head_channels, out_channels
|
| 839 |
+
),
|
| 840 |
+
)
|
| 841 |
+
elif pool == "spatial":
|
| 842 |
+
self.out = nn.Sequential(
|
| 843 |
+
nn.Linear(self._feature_size, 2048),
|
| 844 |
+
nn.ReLU(),
|
| 845 |
+
nn.Linear(2048, self.out_channels),
|
| 846 |
+
)
|
| 847 |
+
elif pool == "spatial_v2":
|
| 848 |
+
self.out = nn.Sequential(
|
| 849 |
+
nn.Linear(self._feature_size, 2048),
|
| 850 |
+
normalization(2048),
|
| 851 |
+
nn.SiLU(),
|
| 852 |
+
nn.Linear(2048, self.out_channels),
|
| 853 |
+
)
|
| 854 |
+
else:
|
| 855 |
+
raise NotImplementedError(f"Unexpected {pool} pooling")
|
| 856 |
+
|
| 857 |
+
def convert_to_fp16(self):
|
| 858 |
+
"""
|
| 859 |
+
Convert the torso of the model to float16.
|
| 860 |
+
"""
|
| 861 |
+
self.input_blocks.apply(convert_module_to_f16)
|
| 862 |
+
self.middle_block.apply(convert_module_to_f16)
|
| 863 |
+
|
| 864 |
+
def convert_to_fp32(self):
|
| 865 |
+
"""
|
| 866 |
+
Convert the torso of the model to float32.
|
| 867 |
+
"""
|
| 868 |
+
self.input_blocks.apply(convert_module_to_f32)
|
| 869 |
+
self.middle_block.apply(convert_module_to_f32)
|
| 870 |
+
|
| 871 |
+
def forward(self, x, timesteps):
|
| 872 |
+
"""
|
| 873 |
+
Apply the model to an input batch.
|
| 874 |
+
|
| 875 |
+
:param x: an [N x C x ...] Tensor of inputs.
|
| 876 |
+
:param timesteps: a 1-D batch of timesteps.
|
| 877 |
+
:return: an [N x K] Tensor of outputs.
|
| 878 |
+
"""
|
| 879 |
+
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
| 880 |
+
|
| 881 |
+
results = []
|
| 882 |
+
h = x.type(self.dtype)
|
| 883 |
+
for module in self.input_blocks:
|
| 884 |
+
h = module(h, emb)
|
| 885 |
+
if self.pool.startswith("spatial"):
|
| 886 |
+
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
| 887 |
+
h = self.middle_block(h, emb)
|
| 888 |
+
if self.pool.startswith("spatial"):
|
| 889 |
+
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
| 890 |
+
h = th.cat(results, axis=-1)
|
| 891 |
+
return self.out(h)
|
| 892 |
+
else:
|
| 893 |
+
h = h.type(x.dtype)
|
| 894 |
+
return self.out(h)
|
codes/modules/unet_2d.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from typing import Optional, Tuple, Union
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
|
| 20 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
| 21 |
+
from ..utils import BaseOutput
|
| 22 |
+
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
| 23 |
+
from .modeling_utils import ModelMixin
|
| 24 |
+
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class UNet2DOutput(BaseOutput):
|
| 29 |
+
"""
|
| 30 |
+
The output of [`UNet2DModel`].
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 34 |
+
The hidden states output from the last layer of the model.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
sample: torch.FloatTensor
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class UNet2DModel(ModelMixin, ConfigMixin):
|
| 41 |
+
r"""
|
| 42 |
+
A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
| 43 |
+
|
| 44 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
| 45 |
+
for all models (such as downloading or saving).
|
| 46 |
+
|
| 47 |
+
Parameters:
|
| 48 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
| 49 |
+
Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
|
| 50 |
+
1)`.
|
| 51 |
+
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
|
| 52 |
+
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
|
| 53 |
+
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
| 54 |
+
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
|
| 55 |
+
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
|
| 56 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
| 57 |
+
Whether to flip sin to cos for Fourier time embedding.
|
| 58 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
|
| 59 |
+
Tuple of downsample block types.
|
| 60 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
|
| 61 |
+
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
|
| 62 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
|
| 63 |
+
Tuple of upsample block types.
|
| 64 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
|
| 65 |
+
Tuple of block output channels.
|
| 66 |
+
layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
|
| 67 |
+
mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
|
| 68 |
+
downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
|
| 69 |
+
downsample_type (`str`, *optional*, defaults to `conv`):
|
| 70 |
+
The downsample type for downsampling layers. Choose between "conv" and "resnet"
|
| 71 |
+
upsample_type (`str`, *optional*, defaults to `conv`):
|
| 72 |
+
The upsample type for upsampling layers. Choose between "conv" and "resnet"
|
| 73 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 74 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
| 75 |
+
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
|
| 76 |
+
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
|
| 77 |
+
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
|
| 78 |
+
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
| 79 |
+
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
| 80 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
| 81 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
| 82 |
+
`"timestep"`, or `"identity"`.
|
| 83 |
+
num_class_embeds (`int`, *optional*, defaults to `None`):
|
| 84 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
|
| 85 |
+
conditioning with `class_embed_type` equal to `None`.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
@register_to_config
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
|
| 92 |
+
in_channels: int = 3,
|
| 93 |
+
out_channels: int = 3,
|
| 94 |
+
center_input_sample: bool = False,
|
| 95 |
+
time_embedding_type: str = "positional",
|
| 96 |
+
freq_shift: int = 0,
|
| 97 |
+
flip_sin_to_cos: bool = True,
|
| 98 |
+
down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
|
| 99 |
+
up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
|
| 100 |
+
block_out_channels: Tuple[int] = (224, 448, 672, 896),
|
| 101 |
+
layers_per_block: int = 2,
|
| 102 |
+
mid_block_scale_factor: float = 1,
|
| 103 |
+
downsample_padding: int = 1,
|
| 104 |
+
downsample_type: str = "conv",
|
| 105 |
+
upsample_type: str = "conv",
|
| 106 |
+
dropout: float = 0.0,
|
| 107 |
+
act_fn: str = "silu",
|
| 108 |
+
attention_head_dim: Optional[int] = 8,
|
| 109 |
+
norm_num_groups: int = 32,
|
| 110 |
+
norm_eps: float = 1e-5,
|
| 111 |
+
resnet_time_scale_shift: str = "default",
|
| 112 |
+
add_attention: bool = True,
|
| 113 |
+
class_embed_type: Optional[str] = None,
|
| 114 |
+
num_class_embeds: Optional[int] = None,
|
| 115 |
+
):
|
| 116 |
+
super().__init__()
|
| 117 |
+
|
| 118 |
+
self.sample_size = sample_size
|
| 119 |
+
time_embed_dim = block_out_channels[0] * 4
|
| 120 |
+
|
| 121 |
+
# Check inputs
|
| 122 |
+
if len(down_block_types) != len(up_block_types):
|
| 123 |
+
raise ValueError(
|
| 124 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
if len(block_out_channels) != len(down_block_types):
|
| 128 |
+
raise ValueError(
|
| 129 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# input
|
| 133 |
+
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
|
| 134 |
+
|
| 135 |
+
# time
|
| 136 |
+
if time_embedding_type == "fourier":
|
| 137 |
+
self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
|
| 138 |
+
timestep_input_dim = 2 * block_out_channels[0]
|
| 139 |
+
elif time_embedding_type == "positional":
|
| 140 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
| 141 |
+
timestep_input_dim = block_out_channels[0]
|
| 142 |
+
|
| 143 |
+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
| 144 |
+
|
| 145 |
+
# class embedding
|
| 146 |
+
if class_embed_type is None and num_class_embeds is not None:
|
| 147 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
| 148 |
+
elif class_embed_type == "timestep":
|
| 149 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
| 150 |
+
elif class_embed_type == "identity":
|
| 151 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
| 152 |
+
else:
|
| 153 |
+
self.class_embedding = None
|
| 154 |
+
|
| 155 |
+
self.down_blocks = nn.ModuleList([])
|
| 156 |
+
self.mid_block = None
|
| 157 |
+
self.up_blocks = nn.ModuleList([])
|
| 158 |
+
|
| 159 |
+
# down
|
| 160 |
+
output_channel = block_out_channels[0]
|
| 161 |
+
for i, down_block_type in enumerate(down_block_types):
|
| 162 |
+
input_channel = output_channel
|
| 163 |
+
output_channel = block_out_channels[i]
|
| 164 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 165 |
+
|
| 166 |
+
down_block = get_down_block(
|
| 167 |
+
down_block_type,
|
| 168 |
+
num_layers=layers_per_block,
|
| 169 |
+
in_channels=input_channel,
|
| 170 |
+
out_channels=output_channel,
|
| 171 |
+
temb_channels=time_embed_dim,
|
| 172 |
+
add_downsample=not is_final_block,
|
| 173 |
+
resnet_eps=norm_eps,
|
| 174 |
+
resnet_act_fn=act_fn,
|
| 175 |
+
resnet_groups=norm_num_groups,
|
| 176 |
+
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
|
| 177 |
+
downsample_padding=downsample_padding,
|
| 178 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 179 |
+
downsample_type=downsample_type,
|
| 180 |
+
dropout=dropout,
|
| 181 |
+
)
|
| 182 |
+
self.down_blocks.append(down_block)
|
| 183 |
+
|
| 184 |
+
# mid
|
| 185 |
+
self.mid_block = UNetMidBlock2D(
|
| 186 |
+
in_channels=block_out_channels[-1],
|
| 187 |
+
temb_channels=time_embed_dim,
|
| 188 |
+
dropout=dropout,
|
| 189 |
+
resnet_eps=norm_eps,
|
| 190 |
+
resnet_act_fn=act_fn,
|
| 191 |
+
output_scale_factor=mid_block_scale_factor,
|
| 192 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 193 |
+
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
|
| 194 |
+
resnet_groups=norm_num_groups,
|
| 195 |
+
add_attention=add_attention,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# up
|
| 199 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
| 200 |
+
output_channel = reversed_block_out_channels[0]
|
| 201 |
+
for i, up_block_type in enumerate(up_block_types):
|
| 202 |
+
prev_output_channel = output_channel
|
| 203 |
+
output_channel = reversed_block_out_channels[i]
|
| 204 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
| 205 |
+
|
| 206 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 207 |
+
|
| 208 |
+
up_block = get_up_block(
|
| 209 |
+
up_block_type,
|
| 210 |
+
num_layers=layers_per_block + 1,
|
| 211 |
+
in_channels=input_channel,
|
| 212 |
+
out_channels=output_channel,
|
| 213 |
+
prev_output_channel=prev_output_channel,
|
| 214 |
+
temb_channels=time_embed_dim,
|
| 215 |
+
add_upsample=not is_final_block,
|
| 216 |
+
resnet_eps=norm_eps,
|
| 217 |
+
resnet_act_fn=act_fn,
|
| 218 |
+
resnet_groups=norm_num_groups,
|
| 219 |
+
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
|
| 220 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 221 |
+
upsample_type=upsample_type,
|
| 222 |
+
dropout=dropout,
|
| 223 |
+
)
|
| 224 |
+
self.up_blocks.append(up_block)
|
| 225 |
+
prev_output_channel = output_channel
|
| 226 |
+
|
| 227 |
+
# out
|
| 228 |
+
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
|
| 229 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
|
| 230 |
+
self.conv_act = nn.SiLU()
|
| 231 |
+
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
| 232 |
+
|
| 233 |
+
def forward(
|
| 234 |
+
self,
|
| 235 |
+
sample: torch.FloatTensor,
|
| 236 |
+
timestep: Union[torch.Tensor, float, int],
|
| 237 |
+
class_labels: Optional[torch.Tensor] = None,
|
| 238 |
+
return_dict: bool = True,
|
| 239 |
+
) -> Union[UNet2DOutput, Tuple]:
|
| 240 |
+
r"""
|
| 241 |
+
The [`UNet2DModel`] forward method.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
sample (`torch.FloatTensor`):
|
| 245 |
+
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
| 246 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
| 247 |
+
class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
|
| 248 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
| 249 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 250 |
+
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
[`~models.unet_2d.UNet2DOutput`] or `tuple`:
|
| 254 |
+
If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
|
| 255 |
+
returned where the first element is the sample tensor.
|
| 256 |
+
"""
|
| 257 |
+
# 0. center input if necessary
|
| 258 |
+
if self.config.center_input_sample:
|
| 259 |
+
sample = 2 * sample - 1.0
|
| 260 |
+
|
| 261 |
+
# 1. time
|
| 262 |
+
timesteps = timestep
|
| 263 |
+
if not torch.is_tensor(timesteps):
|
| 264 |
+
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
| 265 |
+
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
| 266 |
+
timesteps = timesteps[None].to(sample.device)
|
| 267 |
+
|
| 268 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 269 |
+
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
|
| 270 |
+
|
| 271 |
+
t_emb = self.time_proj(timesteps)
|
| 272 |
+
|
| 273 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
| 274 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 275 |
+
# there might be better ways to encapsulate this.
|
| 276 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
| 277 |
+
emb = self.time_embedding(t_emb)
|
| 278 |
+
|
| 279 |
+
if self.class_embedding is not None:
|
| 280 |
+
if class_labels is None:
|
| 281 |
+
raise ValueError("class_labels should be provided when doing class conditioning")
|
| 282 |
+
|
| 283 |
+
if self.config.class_embed_type == "timestep":
|
| 284 |
+
class_labels = self.time_proj(class_labels)
|
| 285 |
+
|
| 286 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
| 287 |
+
emb = emb + class_emb
|
| 288 |
+
|
| 289 |
+
# 2. pre-process
|
| 290 |
+
skip_sample = sample
|
| 291 |
+
sample = self.conv_in(sample)
|
| 292 |
+
|
| 293 |
+
# 3. down
|
| 294 |
+
down_block_res_samples = (sample,)
|
| 295 |
+
for downsample_block in self.down_blocks:
|
| 296 |
+
if hasattr(downsample_block, "skip_conv"):
|
| 297 |
+
sample, res_samples, skip_sample = downsample_block(
|
| 298 |
+
hidden_states=sample, temb=emb, skip_sample=skip_sample
|
| 299 |
+
)
|
| 300 |
+
else:
|
| 301 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
| 302 |
+
|
| 303 |
+
down_block_res_samples += res_samples
|
| 304 |
+
|
| 305 |
+
# 4. mid
|
| 306 |
+
sample = self.mid_block(sample, emb)
|
| 307 |
+
|
| 308 |
+
# 5. up
|
| 309 |
+
skip_sample = None
|
| 310 |
+
for upsample_block in self.up_blocks:
|
| 311 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
| 312 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
| 313 |
+
|
| 314 |
+
if hasattr(upsample_block, "skip_conv"):
|
| 315 |
+
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
|
| 316 |
+
else:
|
| 317 |
+
sample = upsample_block(sample, res_samples, emb)
|
| 318 |
+
|
| 319 |
+
# 6. post-process
|
| 320 |
+
sample = self.conv_norm_out(sample)
|
| 321 |
+
sample = self.conv_act(sample)
|
| 322 |
+
sample = self.conv_out(sample)
|
| 323 |
+
|
| 324 |
+
if skip_sample is not None:
|
| 325 |
+
sample += skip_sample
|
| 326 |
+
|
| 327 |
+
if self.config.time_embedding_type == "fourier":
|
| 328 |
+
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
|
| 329 |
+
sample = sample / timesteps
|
| 330 |
+
|
| 331 |
+
if not return_dict:
|
| 332 |
+
return (sample,)
|
| 333 |
+
|
| 334 |
+
return UNet2DOutput(sample=sample)
|
codes/pytorch_msssim.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from math import exp
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
CODE SOURCE:
|
| 8 |
+
https://github.com/jorge-pessoa/pytorch-msssim/blob/master/pytorch_msssim/__init__.py
|
| 9 |
+
License:
|
| 10 |
+
MIT
|
| 11 |
+
|
| 12 |
+
Original Paper:
|
| 13 |
+
DOI: 10.1109/ACSSC.2003.1292216
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def gaussian(window_size, sigma):
|
| 18 |
+
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
|
| 19 |
+
return gauss / gauss.sum()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def create_window(window_size, channel=1):
|
| 23 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
| 24 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
| 25 |
+
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
|
| 26 |
+
return window
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
|
| 30 |
+
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
|
| 31 |
+
if val_range is None:
|
| 32 |
+
if torch.max(img1) > 128:
|
| 33 |
+
max_val = 255
|
| 34 |
+
else:
|
| 35 |
+
max_val = 1
|
| 36 |
+
|
| 37 |
+
if torch.min(img1) < -0.5:
|
| 38 |
+
min_val = -1
|
| 39 |
+
else:
|
| 40 |
+
min_val = 0
|
| 41 |
+
L = max_val - min_val
|
| 42 |
+
else:
|
| 43 |
+
L = val_range
|
| 44 |
+
|
| 45 |
+
padd = 0
|
| 46 |
+
(_, channel, height, width) = img1.size()
|
| 47 |
+
if window is None:
|
| 48 |
+
real_size = min(window_size, height, width)
|
| 49 |
+
window = create_window(real_size, channel=channel).to(img1.device)
|
| 50 |
+
|
| 51 |
+
mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
|
| 52 |
+
mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
|
| 53 |
+
|
| 54 |
+
mu1_sq = mu1.pow(2)
|
| 55 |
+
mu2_sq = mu2.pow(2)
|
| 56 |
+
mu1_mu2 = mu1 * mu2
|
| 57 |
+
|
| 58 |
+
sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
|
| 59 |
+
sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
|
| 60 |
+
sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
|
| 61 |
+
|
| 62 |
+
C1 = (0.01 * L) ** 2
|
| 63 |
+
C2 = (0.03 * L) ** 2
|
| 64 |
+
|
| 65 |
+
v1 = 2.0 * sigma12 + C2
|
| 66 |
+
v2 = sigma1_sq + sigma2_sq + C2
|
| 67 |
+
cs = v1 / v2 # contrast sensitivity
|
| 68 |
+
|
| 69 |
+
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
|
| 70 |
+
|
| 71 |
+
if size_average:
|
| 72 |
+
cs = cs.mean()
|
| 73 |
+
ret = ssim_map.mean()
|
| 74 |
+
else:
|
| 75 |
+
cs = cs.mean(1).mean(1).mean(1)
|
| 76 |
+
ret = ssim_map.mean(1).mean(1).mean(1)
|
| 77 |
+
|
| 78 |
+
if full:
|
| 79 |
+
return ret, cs
|
| 80 |
+
return ret
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=None):
|
| 84 |
+
device = img1.device
|
| 85 |
+
weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
|
| 86 |
+
levels = weights.size()[0]
|
| 87 |
+
ssims = []
|
| 88 |
+
mcs = []
|
| 89 |
+
for _ in range(levels):
|
| 90 |
+
sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
|
| 91 |
+
|
| 92 |
+
# Relu normalize (not compliant with original definition)
|
| 93 |
+
if normalize == "relu":
|
| 94 |
+
ssims.append(torch.relu(sim))
|
| 95 |
+
mcs.append(torch.relu(cs))
|
| 96 |
+
else:
|
| 97 |
+
ssims.append(sim)
|
| 98 |
+
mcs.append(cs)
|
| 99 |
+
|
| 100 |
+
img1 = F.avg_pool2d(img1, (2, 2))
|
| 101 |
+
img2 = F.avg_pool2d(img2, (2, 2))
|
| 102 |
+
|
| 103 |
+
ssims = torch.stack(ssims)
|
| 104 |
+
mcs = torch.stack(mcs)
|
| 105 |
+
|
| 106 |
+
# Simple normalize (not compliant with original definition)
|
| 107 |
+
# TODO: remove support for normalize == True (kept for backward support)
|
| 108 |
+
if normalize == "simple" or normalize == True:
|
| 109 |
+
ssims = (ssims + 1) / 2
|
| 110 |
+
mcs = (mcs + 1) / 2
|
| 111 |
+
|
| 112 |
+
pow1 = mcs ** weights
|
| 113 |
+
pow2 = ssims ** weights
|
| 114 |
+
|
| 115 |
+
# From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
|
| 116 |
+
output = torch.prod(pow1[:-1]) * pow2[-1]
|
| 117 |
+
return output
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# Classes to re-use window
|
| 121 |
+
class SSIM(torch.nn.Module):
|
| 122 |
+
def __init__(self, window_size=11, size_average=True, val_range=None):
|
| 123 |
+
super(SSIM, self).__init__()
|
| 124 |
+
self.window_size = window_size
|
| 125 |
+
self.size_average = size_average
|
| 126 |
+
self.val_range = val_range
|
| 127 |
+
|
| 128 |
+
# Assume 1 channel for SSIM
|
| 129 |
+
self.channel = 1
|
| 130 |
+
self.window = create_window(window_size)
|
| 131 |
+
|
| 132 |
+
def forward(self, img1, img2):
|
| 133 |
+
(_, channel, _, _) = img1.size()
|
| 134 |
+
|
| 135 |
+
if channel == self.channel and self.window.dtype == img1.dtype:
|
| 136 |
+
window = self.window
|
| 137 |
+
else:
|
| 138 |
+
window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
|
| 139 |
+
self.window = window
|
| 140 |
+
self.channel = channel
|
| 141 |
+
|
| 142 |
+
return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class MSSSIM(torch.nn.Module):
|
| 146 |
+
def __init__(self, window_size=11, size_average=True, channel=3, normalize=None):
|
| 147 |
+
super(MSSSIM, self).__init__()
|
| 148 |
+
self.window_size = window_size
|
| 149 |
+
self.size_average = size_average
|
| 150 |
+
self.channel = channel
|
| 151 |
+
self.normalize = normalize
|
| 152 |
+
|
| 153 |
+
def forward(self, img1, img2):
|
| 154 |
+
# TODO: store window between calls if possible
|
| 155 |
+
return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average,
|
| 156 |
+
normalize=self.normalize)
|
codes/transform.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from config import config
|
| 2 |
+
from torchvision import transforms
|
| 3 |
+
import cv2 as cv
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class myTransformMethod():
|
| 8 |
+
def __call__(self, img):
|
| 9 |
+
|
| 10 |
+
img = cv.resize(img, (config.image_size, config.image_size))
|
| 11 |
+
if img.shape[-1] == 3: # HWC
|
| 12 |
+
img = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
|
| 13 |
+
return img
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
myTransform = {
|
| 17 |
+
'trainTransform': transforms.Compose([
|
| 18 |
+
myTransformMethod(),
|
| 19 |
+
transforms.ToTensor(),
|
| 20 |
+
transforms.Normalize([0.5], [0.5])
|
| 21 |
+
]),
|
| 22 |
+
'testTransform': transforms.Compose([
|
| 23 |
+
myTransformMethod(),
|
| 24 |
+
transforms.ToTensor(),
|
| 25 |
+
transforms.Normalize([0.5], [0.5])
|
| 26 |
+
]),
|
| 27 |
+
|
| 28 |
+
}
|
codes/vq-gan_eval.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from config import config
|
| 2 |
+
from transform import myTransform
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
import cv2 as cv
|
| 7 |
+
import numpy as np
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def eval():
|
| 12 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 设置运行环境
|
| 13 |
+
|
| 14 |
+
source_path = "./test_eval" # 图像文件夹路径
|
| 15 |
+
recon_output_path = "./vq-gan_recon"
|
| 16 |
+
compress_output_path = "./vq-gan_compress"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
model = torch.load("2025-02-04-Mask-JSRT-VQGAN.pth").to(device).eval()
|
| 20 |
+
|
| 21 |
+
with torch.no_grad():
|
| 22 |
+
for filename in tqdm(os.listdir(source_path)):
|
| 23 |
+
img_path = os.path.join(source_path, filename)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
img = cv.imread(img_path, 0)
|
| 27 |
+
|
| 28 |
+
img = myTransform["testTransform"](img).to(device) # CHW
|
| 29 |
+
img = torch.unsqueeze(img, dim=0).to(device) # BCHW
|
| 30 |
+
|
| 31 |
+
recon, _ = model(img)
|
| 32 |
+
|
| 33 |
+
recon = np.array(recon.detach().to("cpu")) # BCHW
|
| 34 |
+
recon = np.squeeze(recon) # HW
|
| 35 |
+
recon = recon * 0.5 + 0.5
|
| 36 |
+
recon = np.clip(recon, 0, 1)
|
| 37 |
+
|
| 38 |
+
if not config.use_server:
|
| 39 |
+
cv.imshow("win", recon)
|
| 40 |
+
cv.waitKey(0)
|
| 41 |
+
|
| 42 |
+
recon *= 255
|
| 43 |
+
cv.imwrite(os.path.join(recon_output_path, filename), recon)
|
| 44 |
+
|
| 45 |
+
if config.output_feature_map:
|
| 46 |
+
compress = model.encode_stage_2_inputs(img).cpu().detach().numpy()
|
| 47 |
+
compress = np.transpose(np.squeeze(compress)[1:], (1, 2, 0))
|
| 48 |
+
compress = compress * 0.5 + 0.5
|
| 49 |
+
compress = np.clip(compress, 0, 1)
|
| 50 |
+
if not config.use_server:
|
| 51 |
+
cv.imshow("win", compress)
|
| 52 |
+
cv.waitKey(0)
|
| 53 |
+
|
| 54 |
+
compress *= 255
|
| 55 |
+
cv.imwrite(os.path.join(compress_output_path, filename), compress)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
if __name__ == "__main__":
|
| 60 |
+
eval()
|
data/BS/0.png
ADDED
|
Git LFS Details
|
data/BS/1.png
ADDED
|
Git LFS Details
|
data/BS/2.png
ADDED
|
Git LFS Details
|
data/CXR/0.png
ADDED
|
Git LFS Details
|
data/CXR/1.png
ADDED
|
Git LFS Details
|
data/CXR/2.png
ADDED
|
Git LFS Details
|
images/GL-LCM_gif.gif
ADDED
|
Git LFS Details
|
images/ablation.png
ADDED
|
Git LFS Details
|
images/comparison.png
ADDED
|
Git LFS Details
|
images/framework.png
ADDED
|
Git LFS Details
|
requirements.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
diffusers~=0.27.2
|
| 2 |
+
lpips~=0.1.4
|
| 3 |
+
matplotlib~=3.7.2
|
| 4 |
+
matplotlib-inline~=0.1.6
|
| 5 |
+
monai~=1.2.0
|
| 6 |
+
monai-generative~=0.2.2
|
| 7 |
+
numpy~=1.26.4
|
| 8 |
+
opencv-python~=4.8.1.78
|
| 9 |
+
pandas~=2.0.3
|
| 10 |
+
Pillow~=9.3.0
|
| 11 |
+
scikit-image~=0.22.0
|
| 12 |
+
scikit-learn~=1.3.1
|
| 13 |
+
scipy~=1.11.3
|
| 14 |
+
seaborn~=0.13.0
|
| 15 |
+
torch~=2.0.1+cu117
|
| 16 |
+
torch-ema=~0.3
|
| 17 |
+
torchaudio~=2.0.2+cu117
|
| 18 |
+
torchsummary~=1.5.1
|
| 19 |
+
torchvision~=0.15.2+cu117
|
| 20 |
+
tqdm~=4.66.1
|