diaoquesang commited on
Commit
6434535
·
verified ·
1 Parent(s): c2101b4

Upload 29 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/BS/0.png filter=lfs diff=lfs merge=lfs -text
37
+ data/BS/1.png filter=lfs diff=lfs merge=lfs -text
38
+ data/BS/2.png filter=lfs diff=lfs merge=lfs -text
39
+ data/CXR/0.png filter=lfs diff=lfs merge=lfs -text
40
+ data/CXR/1.png filter=lfs diff=lfs merge=lfs -text
41
+ data/CXR/2.png filter=lfs diff=lfs merge=lfs -text
42
+ images/ablation.png filter=lfs diff=lfs merge=lfs -text
43
+ images/comparison.png filter=lfs diff=lfs merge=lfs -text
44
+ images/framework.png filter=lfs diff=lfs merge=lfs -text
45
+ images/GL-LCM_gif.gif filter=lfs diff=lfs merge=lfs -text
codes/batch_lcm_eval.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from config import config
2
+ import numpy as np
3
+
4
+ from dataset import myC2BDataset
5
+ from transform import myTransform
6
+ from torch.utils.data import DataLoader
7
+ from diffusers import LCMScheduler
8
+
9
+ from tqdm import tqdm
10
+ import cv2 as cv
11
+ import torch
12
+ import time
13
+ import os
14
+
15
+ from monai.utils import set_determinism
16
+
17
+ set_determinism(42)
18
+
19
+
20
+ def eval():
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 设置运行环境
22
+ output_path = os.path.join("lcm_output_bs", "BS")
23
+ masked_output_path = os.path.join("lcm_output_bs", "Masked_BS")
24
+ fusion_output_path = os.path.join("lcm_output_bs", "Fusion_BS")
25
+
26
+ cxr_path = os.path.join("SZCH-X-Rays", "CXR")
27
+ masked_cxr_path = os.path.join("SZCH-X-Rays", "Masked_CXR")
28
+ mask_path = os.path.join("SZCH-X-Rays", "Mask")
29
+
30
+ model = torch.load("masked_lcm-600-2024-12-19-myModel.pth").to(device).eval()
31
+ VQGAN = torch.load("2024-12-12-Mask-SZCH-VQGAN.pth").to(device).eval()
32
+ testset_list = "SZCH.txt"
33
+ myTestSet = myC2BDataset(testset_list, cxr_path, masked_cxr_path, myTransform['testTransform'])
34
+ myTestLoader = DataLoader(myTestSet, batch_size=1, shuffle=False)
35
+ # 设置噪声调度器
36
+ noise_scheduler = LCMScheduler(num_train_timesteps=config.num_train_timesteps,
37
+ clip_sample=config.clip_sample,
38
+ clip_sample_range=config.initial_clip_sample_range_g)
39
+ noise_scheduler.set_timesteps(config.num_infer_timesteps)
40
+ with torch.no_grad():
41
+ progress_bar = tqdm(enumerate(myTestLoader), total=len(myTestLoader), ncols=100)
42
+ total_start = time.time()
43
+ for step, batch in progress_bar:
44
+ cxr = batch[0].to(device=device, non_blocking=True).float()
45
+ masked_cxr = batch[1].to(device=device, non_blocking=True).float()
46
+ filename = batch[2][0]
47
+ cxr_copy = np.array(cxr.detach().to("cpu"))
48
+ cxr_copy = np.squeeze(cxr_copy) # HW
49
+ cxr_copy = cxr_copy * 0.5 + 0.5
50
+ cxr_copy *= 255
51
+ cxr_copy = cxr_copy.astype(np.int8)
52
+
53
+ cxr = VQGAN.encode_stage_2_inputs(cxr)
54
+ masked_cxr = VQGAN.encode_stage_2_inputs(masked_cxr)
55
+
56
+ noise = torch.randn_like(cxr).to(device)
57
+ sample = torch.cat((noise, cxr), dim=1).to(device) # BCHW
58
+ masked_sample = torch.cat((noise, masked_cxr), dim=1).to(device) # BCHW
59
+
60
+ for j, t in tqdm(enumerate(noise_scheduler.timesteps)):
61
+ residual = model(sample, torch.Tensor((t,)).to(device).long()).to(device)
62
+ masked_residual = model(masked_sample, torch.Tensor((t,)).to(device).long()).to(device)
63
+ # masked_residual = (1 - config.alpha) * residual + config.alpha * masked_residual
64
+ masked_residual = config.alpha * masked_residual + (1 - config.alpha) * torch.randn_like(
65
+ masked_residual).to(device) / torch.std(masked_residual)
66
+
67
+ noise_scheduler = LCMScheduler(num_train_timesteps=config.num_train_timesteps,
68
+ clip_sample=config.clip_sample,
69
+ clip_sample_range=
70
+ config.initial_clip_sample_range_g
71
+ + config.clip_rate * j
72
+ )
73
+ noise_scheduler.set_timesteps(config.num_infer_timesteps)
74
+ sample = noise_scheduler.step(residual, t, sample).prev_sample
75
+
76
+ noise_scheduler = LCMScheduler(num_train_timesteps=config.num_train_timesteps,
77
+ clip_sample=config.clip_sample,
78
+ clip_sample_range=
79
+ config.initial_clip_sample_range_l
80
+ + config.clip_rate * j
81
+ )
82
+ noise_scheduler.set_timesteps(config.num_infer_timesteps)
83
+ masked_sample = noise_scheduler.step(masked_residual, t, masked_sample).prev_sample
84
+
85
+ sample = torch.cat((sample[:, :4], cxr), dim=1) # BCHW
86
+ masked_sample = torch.cat((masked_sample[:, :4], masked_cxr), dim=1).to(device) # BCHW
87
+ if config.output_feature_map:
88
+ bs_show = np.array(sample[:, 0].detach().to("cpu"))
89
+ bs_show = np.squeeze(bs_show) # HW
90
+ bs_show = bs_show * 0.5 + 0.5
91
+ bs_show = np.clip(bs_show, 0, 1)
92
+
93
+ masked_bs_show = np.array(masked_sample[:, 0].detach().to("cpu"))
94
+ masked_bs_show = np.squeeze(masked_bs_show) # HW
95
+ masked_bs_show = masked_bs_show * 0.5 + 0.5
96
+ masked_bs_show = np.clip(masked_bs_show, 0, 1)
97
+
98
+ if not config.use_server:
99
+ cv.imshow("win1", bs_show)
100
+ cv.imshow("win2", masked_bs_show)
101
+ cv.waitKey(1)
102
+
103
+ mask = cv.imread(os.path.join(mask_path, filename), 0)
104
+ mask[mask < 255] = 0
105
+
106
+ bs = VQGAN.decode((sample[:, :4]))
107
+ bs = np.array(bs.detach().to("cpu"))
108
+ bs = np.squeeze(bs) # HW
109
+ bs = bs * 0.5 + 0.5
110
+ bs[cxr_copy == 0] = 0
111
+
112
+ masked_bs = VQGAN.decode((masked_sample[:, :4]))
113
+ masked_bs = np.array(masked_bs.detach().to("cpu"))
114
+ masked_bs = np.squeeze(masked_bs) # HW
115
+ masked_bs = masked_bs * 0.5 + 0.5
116
+ masked_bs[mask > 0] = masked_bs[mask > 0] + np.mean(bs[mask > 0]) - np.mean(masked_bs[mask > 0])
117
+ masked_bs[cxr_copy == 0] = 0
118
+ if not config.use_server:
119
+ cv.imshow("win3", bs)
120
+ cv.imshow("win4", masked_bs)
121
+ cv.waitKey(1)
122
+
123
+ bs *= 255
124
+ cv.imwrite(os.path.join(output_path, filename), bs)
125
+ masked_bs *= 255
126
+ cv.imwrite(os.path.join(masked_output_path, filename), masked_bs)
127
+
128
+
129
+ num_labels, labels, stats, _ = cv.connectedComponentsWithStats(mask)
130
+ min_area = 100
131
+ for i in range(1, num_labels):
132
+ if stats[i, cv.CC_STAT_AREA] < min_area:
133
+ labels[labels == i] = 0
134
+ mask[labels == 0] = 0
135
+
136
+ br = cv.boundingRect(mask)
137
+ p = (br[0] + br[2] // 2, br[1] + br[3] // 2)
138
+
139
+ masked_bs = np.clip(masked_bs, 0, 255)
140
+ masked_bs = cv.cvtColor(masked_bs, cv.COLOR_GRAY2BGR).astype(np.uint8)
141
+ bs = np.clip(bs, 0, 255)
142
+ bs = cv.cvtColor(bs, cv.COLOR_GRAY2BGR).astype(np.uint8)
143
+
144
+ fusion_bs = cv.seamlessClone(masked_bs, bs, mask, p, cv.MONOCHROME_TRANSFER)
145
+ # cv.rectangle(fusion_bs, br, (0, 255, 0), 2)
146
+ # fusion_bs[mask==255]=(255, 0, 0)
147
+
148
+ cv.imwrite(os.path.join(fusion_output_path, filename), fusion_bs)
149
+
150
+ total_time = time.time() - total_start
151
+ print(f"Total time: {total_time}.")
152
+
153
+
154
+ if __name__ == "__main__":
155
+ eval()
codes/config.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class config():
6
+ use_server = True
7
+ test_epoch_interval = 10
8
+ image_size = 1024
9
+ r = 8
10
+
11
+ # VQGAN
12
+ vae_epoch_number = 300
13
+ vae_batch_size = 4
14
+ milestones_g = [200]
15
+ milestones_d = [200]
16
+ initial_learning_rate_g = 1e-4
17
+ initial_learning_rate_d = 5e-4
18
+
19
+ # lcm
20
+ batch_size = 4
21
+ epoch_number = 600
22
+ initial_learning_rate = 1e-4
23
+ milestones = [300, 400, 500]
24
+ num_train_timesteps = 1000
25
+ beta_start = 0.00085
26
+ beta_end = 0.012
27
+ offset_noise = True
28
+ offset_noise_coefficient = 0.1
29
+ output_feature_map = True
30
+ clip_sample = True
31
+ num_infer_timesteps = 50
32
+ alpha = 3
33
+ video = False
34
+
35
+ # SZCH
36
+ clip_rate = 0.025
37
+ initial_clip_sample_range_g = 2
38
+ initial_clip_sample_range_l = 3.5
39
+ # JSRT
40
+ # clip_rate = 0.025
41
+ # initial_clip_sample_range_g = 1.7
42
+ # initial_clip_sample_range_l = 3
43
+
codes/dataset.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ import pandas as pd
3
+ import cv2 as cv
4
+ import os
5
+ from config import config
6
+ import torch
7
+ import numpy as np
8
+
9
+
10
+ class mySingleDataset(Dataset): # 定义数据集类
11
+ def __init__(self, filelist, img_dir, transform=None): # 传入参数(标签路径,图像路径,图像预处理方式,标签预处理方式)
12
+ self.img_dir = img_dir # 读取图像路径
13
+ self.transform = transform # 读取图像预处理方式
14
+ self.filelist = pd.read_csv(filelist, sep="\t", header=None) # 读取文件名列表
15
+
16
+ def __len__(self):
17
+ return len(self.filelist) # 读取文件名数量作为数据集长度
18
+
19
+ def __getitem__(self, idx): # 从数据集中取出数据
20
+ img_path = self.img_dir # 读取图片文件夹路径
21
+
22
+ file = self.filelist.iloc[idx, 0] # 读取文件名
23
+ image = cv.imread(os.path.join(img_path, file)) # 用openCV的imread函数读取图像
24
+
25
+ if self.transform:
26
+ image = self.transform(image) # 图像预处理
27
+ return image, file # 返回图像和名称
28
+
29
+
30
+ class myDataset(Dataset): # 定义数据集类
31
+ def __init__(self, filelist, cxr_dir, bs_dir,
32
+ transform=None): # 传入参数(标签路径,图像路径,图像预处理方式,标签预处理方式)
33
+ self.cxr_dir = cxr_dir # 读取图像路径
34
+ self.bs_dir = bs_dir # 读取图像路径
35
+
36
+ self.transform = transform # 读取图像预处理方式
37
+ self.filelist = pd.read_csv(filelist, sep="\t", header=None) # 读取文件名列表
38
+
39
+ def __len__(self):
40
+ return len(self.filelist) # 读取文件名数量作为数据集长度
41
+
42
+ def __getitem__(self, idx): # 从数据集中取出数据
43
+ file = self.filelist.iloc[idx, 0] # 读取文件名
44
+ cxr = cv.imread(os.path.join(self.cxr_dir, file)) # 用openCV的imread函数读取图像
45
+ bs = cv.imread(os.path.join(self.bs_dir, file)) # 用openCV的imread函数读取图像
46
+
47
+ if self.transform:
48
+ cxr = self.transform(cxr) # 图像预处理
49
+ bs = self.transform(bs) # 图像预处理
50
+
51
+ return cxr, bs, file # 返回图像和标签
52
+
53
+ class myC2BDataset(Dataset): # 定义数据集类
54
+ def __init__(self, filelist, cxr_dir, masked_cxr_dir,
55
+ transform=None): # 传入参数(标签路径,图像路径,图像预处理方式,标签预处理方式)
56
+ self.cxr_dir = cxr_dir # 读取图像路径
57
+ self.masked_cxr_dir = masked_cxr_dir # 读取图像路径
58
+
59
+ self.transform = transform # 读取图像预处理方式
60
+ self.filelist = pd.read_csv(filelist, sep="\t", header=None) # 读取文件名列表
61
+
62
+ def __len__(self):
63
+ return len(self.filelist) # 读取文件名数量作为数据集长度
64
+
65
+ def __getitem__(self, idx): # 从数据集中取出数据
66
+ file = self.filelist.iloc[idx, 0] # 读取文件名
67
+ cxr = cv.imread(os.path.join(self.cxr_dir, file)) # 用openCV的imread函数读取图像
68
+ masked_cxr = cv.imread(os.path.join(self.masked_cxr_dir, file)) # 用openCV的imread函数读取图像
69
+
70
+ if self.transform:
71
+ cxr = self.transform(cxr) # 图像预处理
72
+ masked_cxr = self.transform(masked_cxr) # 图像预处理
73
+
74
+ return cxr, masked_cxr, file # 返回图像和标签
codes/datasetSegmentation.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import os
3
+
4
+
5
+ def traverse_directory(directory, txt_name):
6
+ # 创建一个空的列表用于存储文件名
7
+ file_names = []
8
+ # 遍历目录中的所有文件和子目录
9
+ for root, dirs, files in os.walk(directory):
10
+ for file in files:
11
+ file_names.append(file)
12
+ # 按照文件名排序(如果你希望的话)
13
+ file_names.sort()
14
+ # 创建一个新的txt文件,并将文件名写入该文件
15
+ with open(txt_name, 'w') as f:
16
+ for file_name in file_names:
17
+ f.write(file_name + '\n')
18
+
19
+
20
+ def split_dataset(file_path, train_ratio=0.8, val_ratio=0.1, former=None):
21
+ # 读取数据集
22
+ with open(file_path, 'r') as f:
23
+ data = f.readlines()
24
+ # 随机打乱数据集
25
+ random.shuffle(data)
26
+
27
+ train_size = int(len(data) * train_ratio)
28
+ val_size = int(len(data) * val_ratio)
29
+
30
+ train_set = data[:train_size]
31
+ val_set = data[train_size:train_size + val_size]
32
+ test_set = data[train_size + val_size:]
33
+ with open(former + '_trainset.txt', 'w') as f:
34
+ f.writelines(train_set)
35
+ with open(former + '_valset.txt', 'w') as f:
36
+ f.writelines(val_set)
37
+ with open(former + '_testset.txt', 'w') as f:
38
+ f.writelines(test_set)
39
+ print(f"Finished.")
40
+
41
+
42
+ if __name__ == "__main__":
43
+ traverse_directory('JSRTnew1024-241/CXR', "JSRT.txt")
44
+
45
+ split_dataset('JSRT.txt',0.8,0.1, "JSRT")
codes/lcm_train.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from matplotlib import pyplot as plt
2
+ from config import config
3
+ from dataset import myDataset
4
+ from transform import myTransform
5
+ from torch.utils.data import DataLoader
6
+ from model import myUnet, myVQGANModel
7
+ from diffusers import LCMScheduler,DDPMScheduler
8
+ from torch.optim.lr_scheduler import MultiStepLR
9
+ from tqdm import tqdm
10
+ from datetime import date
11
+
12
+ import torch.nn.functional as F
13
+ import torch
14
+ import time
15
+ from monai.utils import set_determinism
16
+
17
+ set_determinism(42)
18
+
19
+
20
+ def train():
21
+ if config.use_server:
22
+ file = open('log.txt', 'w') # 保存日志位置
23
+ else:
24
+ file = None # 取消日志输出
25
+
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 设置运行环境
27
+
28
+ train_file_list = "JSRT_trainset.txt" # 存储训练集文件名的文本文件
29
+ test_file_list = "JSRT_valset.txt" # 存储测试集文件名的文本文件
30
+
31
+ cxr_path = "/mntcephfs/med_dataset/SYF/JSRTnew1024-241/CXR" # 图像文件夹路径
32
+ bs_path = "/mntcephfs/med_dataset/SYF/JSRTnew1024-241/BS" # 图像文件夹路径
33
+ masked_cxr_path = "/mntcephfs/med_dataset/SYF/JSRTnew1024-241/Masked_CXR" # 图像文件夹路径
34
+ masked_bs_path = "/mntcephfs/med_dataset/SYF/JSRTnew1024-241/Masked_BS" # 图像文件夹路径
35
+
36
+ myTrainSet = myDataset(train_file_list, cxr_path, bs_path,
37
+ myTransform['trainTransform']) + myDataset(train_file_list, masked_cxr_path, masked_bs_path,
38
+ myTransform['trainTransform'])
39
+ myTestSet = myDataset(test_file_list, cxr_path, bs_path,
40
+ myTransform['testTransform']) + myDataset(test_file_list, masked_cxr_path, masked_bs_path,
41
+ myTransform['testTransform'])
42
+
43
+ myTrainLoader = DataLoader(myTrainSet, batch_size=config.batch_size, shuffle=True)
44
+ myTestLoader = DataLoader(myTestSet, batch_size=config.batch_size, shuffle=True)
45
+
46
+ print("Number of batches in train set:", len(myTrainLoader)) # 输出训练集batch数量
47
+ print("Train set size:", len(myTrainSet)) # 输出训练集大小
48
+ print("Number of batches in test set:", len(myTestLoader)) # 输出测试集batch数量
49
+ print("Test set size:", len(myTestSet)) # 输出测试集大小
50
+
51
+ model = myUnet.to(device).train()
52
+
53
+ # 设置噪声调度器
54
+ noise_scheduler = LCMScheduler(num_train_timesteps=config.num_train_timesteps)
55
+ noise_scheduler.set_timesteps(config.num_infer_timesteps)
56
+ # 设置动态学习率
57
+ optimizer = torch.optim.AdamW(model.parameters(), lr=config.initial_learning_rate, eps=1e-6)
58
+ milestones = [x * len(myTrainLoader) for x in config.milestones]
59
+ optimizer_scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1)
60
+
61
+ train_losses = []
62
+ test_losses = []
63
+ plt_train_loss_epoch = []
64
+ plt_test_loss_epoch = []
65
+ train_epoch_list = list(range(0, config.epoch_number))
66
+ test_epoch_list = list(range(0, int(config.epoch_number / config.test_epoch_interval)))
67
+
68
+ VQGAN = torch.load("2025-02-04-Mask-JSRT-VQGAN.pth").to(device).eval()
69
+ print(time.strftime("%H:%M:%S", time.localtime()), "----------Begin Training----------", file=file)
70
+ for epoch in range(config.epoch_number):
71
+ model.train()
72
+ print(time.strftime("%H:%M:%S", time.localtime()),
73
+ f"Epoch:{epoch},learning rate:{optimizer.param_groups[0]['lr']}", file=file)
74
+ for i, batch in tqdm(enumerate(myTrainLoader)):
75
+ cxr_i, bs_i = batch[0].to(device), batch[1].to(device)
76
+
77
+ with torch.no_grad():
78
+ cxr = VQGAN.encode_stage_2_inputs(cxr_i)
79
+ bs = VQGAN.encode_stage_2_inputs(bs_i)
80
+
81
+ cat = torch.cat((bs, cxr), dim=-3)
82
+
83
+ # 为图片添加噪声
84
+ if config.offset_noise:
85
+ noise = torch.randn_like(cxr).to(device) + config.offset_noise_coefficient * torch.randn(
86
+ cxr.shape[0], cxr.shape[1], 1,
87
+ 1).to(device)
88
+ else:
89
+ noise = torch.randn_like(cxr).to(device)
90
+
91
+ blank = torch.zeros_like(cxr).to(device)
92
+ noise = torch.cat((noise, blank), dim=-3)
93
+
94
+ # 为每张图片随机采样一个时间步
95
+ timesteps = torch.randint(0, config.num_train_timesteps, (cxr.shape[0],), device=device).long()
96
+
97
+ # 根据每个时间步的噪声幅度,向清晰的图片中添加噪声
98
+ noisy_images = noise_scheduler.add_noise(cat, noise, timesteps)
99
+
100
+ # 获取模型的预测结果
101
+ noise_pred = model(noisy_images, timesteps)
102
+
103
+ # 计算损失
104
+ loss = F.mse_loss(noise_pred[:, :4].float(), noise[:, :4].float())
105
+
106
+ loss.backward()
107
+ train_losses.append(loss.item())
108
+
109
+ # 迭代模型参数
110
+ optimizer.step()
111
+ optimizer.zero_grad()
112
+ optimizer_scheduler.step()
113
+
114
+ train_loss_epoch = sum(train_losses[-len(myTrainLoader):]) / len(myTrainLoader)
115
+ print(time.strftime("%H:%M:%S", time.localtime()), f"Epoch:{epoch},train losses:{train_loss_epoch}", file=file)
116
+ plt_train_loss_epoch.append(train_loss_epoch)
117
+
118
+ if (epoch + 1) % config.test_epoch_interval == 0:
119
+ model.eval()
120
+ print(time.strftime("%H:%M:%S", time.localtime()), "----------Stop Training----------", file=file)
121
+ print(time.strftime("%H:%M:%S", time.localtime()), "----------Begin Testing----------", file=file)
122
+ with torch.no_grad():
123
+ for i, batch in tqdm(enumerate(myTestLoader)):
124
+ cxr_i, bs_i = batch[0].to(device), batch[1].to(device)
125
+
126
+ with torch.no_grad():
127
+ cxr = VQGAN.encode_stage_2_inputs(cxr_i)
128
+ bs = VQGAN.encode_stage_2_inputs(bs_i)
129
+
130
+ cat = torch.cat((bs, cxr), dim=-3)
131
+
132
+ # 为图片添加噪声
133
+ if config.offset_noise:
134
+ noise = torch.randn_like(cxr).to(device) + config.offset_noise_coefficient * torch.randn(
135
+ cxr.shape[0],
136
+ cxr.shape[1], 1,
137
+ 1).to(device)
138
+ else:
139
+ noise = torch.randn_like(cxr).to(device)
140
+
141
+ blank = torch.zeros_like(cxr).to(device)
142
+ noise = torch.cat((noise, blank), dim=-3)
143
+
144
+ # 为每张图片随机采样一个时间步
145
+ timesteps = torch.randint(0, config.num_train_timesteps, (cxr.shape[0],),
146
+ device=device).long()
147
+
148
+ # 根据每个时间步的噪声幅度,向清晰的图片中添加噪声
149
+ noisy_images = noise_scheduler.add_noise(cat, noise, timesteps)
150
+
151
+ # 获取模型的预测结果
152
+ noise_pred = model(noisy_images, timesteps)
153
+
154
+ # 计算损失
155
+ loss = F.mse_loss(noise_pred[:, :4].float(), noise[:, :4].float())
156
+
157
+ test_losses.append(loss.item())
158
+
159
+ test_loss_epoch = sum(test_losses[-len(myTestLoader):]) / len(myTestLoader)
160
+ print(time.strftime("%H:%M:%S", time.localtime()), f"Epoch:{epoch},test losses:{test_loss_epoch}",
161
+ file=file)
162
+ plt_test_loss_epoch.append(test_loss_epoch)
163
+ print(time.strftime("%H:%M:%S", time.localtime()), "----------End Validation----------", file=file)
164
+ print(time.strftime("%H:%M:%S", time.localtime()), "----------Continue to Train----------",
165
+ file=file)
166
+ print(time.strftime("%H:%M:%S", time.localtime()), "----------End Training Normally----------", file=file)
167
+ # 查看损失曲线
168
+ f, ([ax1, ax2]) = plt.subplots(1, 2)
169
+ ax1.plot(train_epoch_list, plt_train_loss_epoch, color="red") # 绘制曲线
170
+ ax1.set_title('Train loss') # 添加标题
171
+ ax2.plot(test_epoch_list, plt_test_loss_epoch, color="blue") # 绘制曲线
172
+ ax2.set_title('Test loss') # 添加标题
173
+ plt.savefig("./loss.png") # 保存损失曲线
174
+ if not config.use_server:
175
+ plt.show() # 展示损失曲线
176
+ torch.save(model, "masked_lcm-600JSRT-" + str(date.today()) + "-myModel.pth")
177
+
178
+
179
+ if __name__ == "__main__":
180
+ train()
codes/lungSegmentation.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
codes/metrics.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2 as cv
2
+ import lpips
3
+ import numpy as np
4
+ import os
5
+ from config import config
6
+ from openpyxl import Workbook
7
+
8
+ import torch
9
+ from math import log10, sqrt
10
+ import torchvision.transforms as transforms
11
+
12
+ def cal_BSR(cxr_path, gt_path, bs_path):
13
+ cxr = cv.imread(cxr_path, 0)
14
+ gt = cv.imread(gt_path, 0)
15
+ bs = cv.imread(bs_path, 0)
16
+
17
+ cxr = cxr / 255
18
+ gt = gt / 255
19
+ bs = bs / 255
20
+
21
+ bone = cv.subtract(cxr, gt)
22
+
23
+
24
+ gt = cv.resize(gt, (config.image_size, config.image_size))
25
+ bs = cv.resize(bs, (config.image_size, config.image_size))
26
+ bone = cv.resize(bone, (config.image_size, config.image_size))
27
+
28
+ bs += np.average(cv.subtract(gt, bs))
29
+
30
+ bias = cv.subtract(gt, bs)
31
+ bias[bias < 0] = 0
32
+
33
+ BSR = 1 - np.sum(bias ** 2) / np.sum(bone ** 2)
34
+ return BSR
35
+
36
+
37
+ def cal_MSE(gt_path, bs_path):
38
+ gt = cv.imread(gt_path, 0)
39
+ bs = cv.imread(bs_path, 0)
40
+
41
+
42
+
43
+ gt = cv.resize(gt, (config.image_size, config.image_size))
44
+ bs = cv.resize(bs, (config.image_size, config.image_size))
45
+
46
+ MSE = np.mean((gt - bs) ** 2)
47
+ MSE = 2*lpips.l2(gt,bs)
48
+ return MSE
49
+
50
+
51
+
52
+
53
+ def cal_PSNR(gt_path, bs_path):
54
+ mse = cal_MSE(gt_path,bs_path)
55
+ max_pixel = 1
56
+
57
+ PSNR = 20 * log10(max_pixel / sqrt(mse))
58
+ return PSNR
59
+
60
+
61
+
62
+
63
+ def cal_LPIPS(gt_path, bs_path):
64
+ lplps_model = lpips.LPIPS()
65
+
66
+ gt = cv.imread(gt_path, 0)
67
+ bs = cv.imread(bs_path, 0)
68
+
69
+
70
+ gt = cv.resize(gt, (config.image_size, config.image_size))
71
+ bs = cv.resize(bs, (config.image_size, config.image_size))
72
+
73
+ gt = transforms.ToTensor()(gt)
74
+ bs = transforms.ToTensor()(bs)
75
+
76
+ gt = torch.unsqueeze(gt, dim=0)
77
+ bs = torch.unsqueeze(bs, dim=0)
78
+
79
+ LPIPS = lplps_model(gt, bs).item()
80
+
81
+ return LPIPS
82
+
83
+
84
+ if __name__ == "__main__":
85
+ wb = Workbook()
86
+
87
+ ws = wb.active
88
+
89
+ CXR_path = "SZCH-X-Rays/CXR"
90
+ GT_path = "SZCH-X-Rays/BS"
91
+ BS_path = "lcm_output_bs/Fusion_BS"
92
+
93
+ BSR_list = []
94
+ MSE_list = []
95
+ PSNR_list = []
96
+ LPIPS_list = []
97
+ ws.append(["Filename", "BSR", "MSE", "PSNR", "LPIPS"])
98
+ txt = 'SZCH_testset.txt'
99
+ with open(txt, 'r', encoding='utf-8') as file:
100
+ lines = file.readlines()
101
+ file_names = [line.strip() for line in lines]
102
+
103
+ for filename in os.listdir(BS_path):
104
+ if filename in file_names:
105
+ pass
106
+ else:
107
+ continue
108
+ cxr_path = os.path.join(CXR_path, filename)
109
+ gt_path = os.path.join(GT_path, filename)
110
+ bs_path = os.path.join(BS_path, filename)
111
+
112
+ BSR = cal_BSR(cxr_path, gt_path, bs_path)
113
+ MSE = cal_MSE(gt_path, bs_path)
114
+ PSNR = cal_PSNR(gt_path, bs_path)
115
+ LPIPS = cal_LPIPS(gt_path, bs_path)
116
+
117
+ BSR_list.append(BSR)
118
+ MSE_list.append(MSE)
119
+ PSNR_list.append(PSNR)
120
+ LPIPS_list.append(LPIPS)
121
+ print(f"{filename} BSR: {BSR} MSE: {MSE} PSNR:{PSNR} LPIPS:{LPIPS}")
122
+ ws.append([filename, BSR, MSE, PSNR, LPIPS])
123
+
124
+ ws.append(["Mean",
125
+ np.mean(np.array(BSR_list)),
126
+ np.mean(np.array(MSE_list)),
127
+ np.mean(np.array(PSNR_list)),
128
+ np.mean(np.array(LPIPS_list))])
129
+ ws.append(["Std",
130
+ np.std(np.array(BSR_list)),
131
+ np.std(np.array(MSE_list)),
132
+ np.std(np.array(PSNR_list)),
133
+ np.std(np.array(LPIPS_list))])
134
+ print("Average BSR:", np.mean(np.array(BSR_list)), "Std:", np.std(np.array(BSR_list)))
135
+ print("Average MSE:", np.mean(np.array(MSE_list)), "Std:", np.std(np.array(MSE_list)))
136
+ print("Average PSNR:", np.mean(np.array(PSNR_list)), "Std:", np.std(np.array(PSNR_list)))
137
+ print("Average LPIPS:", np.mean(np.array(LPIPS_list)), "Std:", np.std(np.array(LPIPS_list)))
138
+
139
+ # 保存工作簿到文件
140
+ wb.save("sample.xlsx")
codes/model.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from modules.unet import UNetModel
3
+ from generative.networks.nets import VQVAE
4
+ from config import config
5
+
6
+ myUnet = UNetModel(
7
+ image_size=config.image_size / config.r,
8
+ model_channels=128,
9
+ in_channels=8,
10
+ out_channels=8,
11
+ num_res_blocks=8,
12
+ num_heads=8,
13
+ attention_resolutions=(64, 32, 16, 8),
14
+ num_heads_upsample=-1,
15
+ num_head_channels=-1,
16
+ resblock_updown=True,
17
+ channel_mult=(1, 1, 2, 2, 4, 4),
18
+ use_scale_shift_norm=True,
19
+ use_new_attention_order=True
20
+ )
21
+
22
+ myVQGANModel = VQVAE(
23
+ spatial_dims=2,
24
+ in_channels=1,
25
+ out_channels=1,
26
+ num_channels=(128, 256, 512),
27
+ num_res_channels=512,
28
+ num_res_layers=2,
29
+ downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1),),
30
+ upsample_parameters=((2, 4, 1, 1, 0), (2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),
31
+ num_embeddings=1024,
32
+ embedding_dim=4,
33
+ )
34
+
35
+ if __name__ == "__main__":
36
+ print("Number of model parameters:", sum([p.numel() for p in myUnet.parameters()]))
37
+ print("Number of model parameters:", sum([p.numel() for p in myVQGANModel.parameters()]))
codes/modules/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Codebase for "Bone Suppression via conditional diffusion model".
3
+ """
codes/modules/diffusion_model_unet.py ADDED
@@ -0,0 +1,2099 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+ #
12
+ # =========================================================================
13
+ # Adapted from https://github.com/huggingface/diffusers
14
+ # which has the following license:
15
+ # https://github.com/huggingface/diffusers/blob/main/LICENSE
16
+ #
17
+ # Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
18
+ #
19
+ # Licensed under the Apache License, Version 2.0 (the "License");
20
+ # you may not use this file except in compliance with the License.
21
+ # You may obtain a copy of the License at
22
+ #
23
+ # http://www.apache.org/licenses/LICENSE-2.0
24
+ #
25
+ # Unless required by applicable law or agreed to in writing, software
26
+ # distributed under the License is distributed on an "AS IS" BASIS,
27
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28
+ # See the License for the specific language governing permissions and
29
+ # limitations under the License.
30
+ # =========================================================================
31
+
32
+ from __future__ import annotations
33
+
34
+ import importlib.util
35
+ import math
36
+ from collections.abc import Sequence
37
+
38
+ import torch
39
+ import torch.nn.functional as F
40
+ from monai.networks.blocks import Convolution, MLPBlock
41
+ from monai.networks.layers.factories import Pool
42
+ from monai.utils import ensure_tuple_rep
43
+ from torch import nn
44
+
45
+ # To install xformers, use pip install xformers==0.0.16rc401
46
+ if importlib.util.find_spec("xformers") is not None:
47
+ import xformers
48
+ import xformers.ops
49
+
50
+ has_xformers = True
51
+ else:
52
+ xformers = None
53
+ has_xformers = False
54
+
55
+
56
+ # TODO: Use MONAI's optional_import
57
+ # from monai.utils import optional_import
58
+ # xformers, has_xformers = optional_import("xformers.ops", name="xformers")
59
+
60
+ __all__ = ["DiffusionModelUNet"]
61
+
62
+
63
+ def zero_module(module: nn.Module) -> nn.Module:
64
+ """
65
+ Zero out the parameters of a module and return it.
66
+ """
67
+ for p in module.parameters():
68
+ p.detach().zero_()
69
+ return module
70
+
71
+
72
+ class CrossAttention(nn.Module):
73
+ """
74
+ A cross attention layer.
75
+
76
+ Args:
77
+ query_dim: number of channels in the query.
78
+ cross_attention_dim: number of channels in the context.
79
+ num_attention_heads: number of heads to use for multi-head attention.
80
+ num_head_channels: number of channels in each head.
81
+ dropout: dropout probability to use.
82
+ upcast_attention: if True, upcast attention operations to full precision.
83
+ use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ query_dim: int,
89
+ cross_attention_dim: int | None = None,
90
+ num_attention_heads: int = 8,
91
+ num_head_channels: int = 64,
92
+ dropout: float = 0.0,
93
+ upcast_attention: bool = False,
94
+ use_flash_attention: bool = False,
95
+ ) -> None:
96
+ super().__init__()
97
+ self.use_flash_attention = use_flash_attention
98
+ inner_dim = num_head_channels * num_attention_heads
99
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
100
+
101
+ self.scale = 1 / math.sqrt(num_head_channels)
102
+ self.num_heads = num_attention_heads
103
+
104
+ self.upcast_attention = upcast_attention
105
+
106
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
107
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
108
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)
109
+
110
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
111
+
112
+ def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor:
113
+ """
114
+ Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch.
115
+ """
116
+ batch_size, seq_len, dim = x.shape
117
+ x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads)
118
+ x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads)
119
+ return x
120
+
121
+ def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor:
122
+ """Combine the output of the attention heads back into the hidden state dimension."""
123
+ batch_size, seq_len, dim = x.shape
124
+ x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim)
125
+ x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads)
126
+ return x
127
+
128
+ def _memory_efficient_attention_xformers(
129
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
130
+ ) -> torch.Tensor:
131
+ query = query.contiguous()
132
+ key = key.contiguous()
133
+ value = value.contiguous()
134
+ x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
135
+ return x
136
+
137
+ def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
138
+ dtype = query.dtype
139
+ if self.upcast_attention:
140
+ query = query.float()
141
+ key = key.float()
142
+
143
+ attention_scores = torch.baddbmm(
144
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
145
+ query,
146
+ key.transpose(-1, -2),
147
+ beta=0,
148
+ alpha=self.scale,
149
+ )
150
+ attention_probs = attention_scores.softmax(dim=-1)
151
+ attention_probs = attention_probs.to(dtype=dtype)
152
+
153
+ x = torch.bmm(attention_probs, value)
154
+ return x
155
+
156
+ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
157
+ query = self.to_q(x)
158
+ context = context if context is not None else x
159
+ key = self.to_k(context)
160
+ value = self.to_v(context)
161
+
162
+ # Multi-Head Attention
163
+ query = self.reshape_heads_to_batch_dim(query)
164
+ key = self.reshape_heads_to_batch_dim(key)
165
+ value = self.reshape_heads_to_batch_dim(value)
166
+
167
+ if self.use_flash_attention:
168
+ x = self._memory_efficient_attention_xformers(query, key, value)
169
+ else:
170
+ x = self._attention(query, key, value)
171
+
172
+ x = self.reshape_batch_dim_to_heads(x)
173
+ x = x.to(query.dtype)
174
+
175
+ return self.to_out(x)
176
+
177
+
178
+ class BasicTransformerBlock(nn.Module):
179
+ """
180
+ A basic Transformer block.
181
+
182
+ Args:
183
+ num_channels: number of channels in the input and output.
184
+ num_attention_heads: number of heads to use for multi-head attention.
185
+ num_head_channels: number of channels in each attention head.
186
+ dropout: dropout probability to use.
187
+ cross_attention_dim: size of the context vector for cross attention.
188
+ upcast_attention: if True, upcast attention operations to full precision.
189
+ use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
190
+ """
191
+
192
+ def __init__(
193
+ self,
194
+ num_channels: int,
195
+ num_attention_heads: int,
196
+ num_head_channels: int,
197
+ dropout: float = 0.0,
198
+ cross_attention_dim: int | None = None,
199
+ upcast_attention: bool = False,
200
+ use_flash_attention: bool = False,
201
+ ) -> None:
202
+ super().__init__()
203
+ self.attn1 = CrossAttention(
204
+ query_dim=num_channels,
205
+ num_attention_heads=num_attention_heads,
206
+ num_head_channels=num_head_channels,
207
+ dropout=dropout,
208
+ upcast_attention=upcast_attention,
209
+ use_flash_attention=use_flash_attention,
210
+ ) # is a self-attention
211
+ self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout)
212
+ self.attn2 = CrossAttention(
213
+ query_dim=num_channels,
214
+ cross_attention_dim=cross_attention_dim,
215
+ num_attention_heads=num_attention_heads,
216
+ num_head_channels=num_head_channels,
217
+ dropout=dropout,
218
+ upcast_attention=upcast_attention,
219
+ use_flash_attention=use_flash_attention,
220
+ ) # is a self-attention if context is None
221
+ self.norm1 = nn.LayerNorm(num_channels)
222
+ self.norm2 = nn.LayerNorm(num_channels)
223
+ self.norm3 = nn.LayerNorm(num_channels)
224
+
225
+ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
226
+ # 1. Self-Attention
227
+ x = self.attn1(self.norm1(x)) + x
228
+
229
+ # 2. Cross-Attention
230
+ x = self.attn2(self.norm2(x), context=context) + x
231
+
232
+ # 3. Feed-forward
233
+ x = self.ff(self.norm3(x)) + x
234
+ return x
235
+
236
+
237
+ class SpatialTransformer(nn.Module):
238
+ """
239
+ Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
240
+ standard transformer action. Finally, reshape to image.
241
+
242
+ Args:
243
+ spatial_dims: number of spatial dimensions.
244
+ in_channels: number of channels in the input and output.
245
+ num_attention_heads: number of heads to use for multi-head attention.
246
+ num_head_channels: number of channels in each attention head.
247
+ num_layers: number of layers of Transformer blocks to use.
248
+ dropout: dropout probability to use.
249
+ norm_num_groups: number of groups for the normalization.
250
+ norm_eps: epsilon for the normalization.
251
+ cross_attention_dim: number of context dimensions to use.
252
+ upcast_attention: if True, upcast attention operations to full precision.
253
+ use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
254
+ """
255
+
256
+ def __init__(
257
+ self,
258
+ spatial_dims: int,
259
+ in_channels: int,
260
+ num_attention_heads: int,
261
+ num_head_channels: int,
262
+ num_layers: int = 1,
263
+ dropout: float = 0.0,
264
+ norm_num_groups: int = 32,
265
+ norm_eps: float = 1e-6,
266
+ cross_attention_dim: int | None = None,
267
+ upcast_attention: bool = False,
268
+ use_flash_attention: bool = False,
269
+ ) -> None:
270
+ super().__init__()
271
+ self.spatial_dims = spatial_dims
272
+ self.in_channels = in_channels
273
+ inner_dim = num_attention_heads * num_head_channels
274
+
275
+ self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True)
276
+
277
+ self.proj_in = Convolution(
278
+ spatial_dims=spatial_dims,
279
+ in_channels=in_channels,
280
+ out_channels=inner_dim,
281
+ strides=1,
282
+ kernel_size=1,
283
+ padding=0,
284
+ conv_only=True,
285
+ )
286
+
287
+ self.transformer_blocks = nn.ModuleList(
288
+ [
289
+ BasicTransformerBlock(
290
+ num_channels=inner_dim,
291
+ num_attention_heads=num_attention_heads,
292
+ num_head_channels=num_head_channels,
293
+ dropout=dropout,
294
+ cross_attention_dim=cross_attention_dim,
295
+ upcast_attention=upcast_attention,
296
+ use_flash_attention=use_flash_attention,
297
+ )
298
+ for _ in range(num_layers)
299
+ ]
300
+ )
301
+
302
+ self.proj_out = zero_module(
303
+ Convolution(
304
+ spatial_dims=spatial_dims,
305
+ in_channels=inner_dim,
306
+ out_channels=in_channels,
307
+ strides=1,
308
+ kernel_size=1,
309
+ padding=0,
310
+ conv_only=True,
311
+ )
312
+ )
313
+
314
+ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
315
+ # note: if no context is given, cross-attention defaults to self-attention
316
+ batch = channel = height = width = depth = -1
317
+ if self.spatial_dims == 2:
318
+ batch, channel, height, width = x.shape
319
+ if self.spatial_dims == 3:
320
+ batch, channel, height, width, depth = x.shape
321
+
322
+ residual = x
323
+ x = self.norm(x)
324
+ x = self.proj_in(x)
325
+
326
+ inner_dim = x.shape[1]
327
+
328
+ if self.spatial_dims == 2:
329
+ x = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
330
+ if self.spatial_dims == 3:
331
+ x = x.permute(0, 2, 3, 4, 1).reshape(batch, height * width * depth, inner_dim)
332
+
333
+ for block in self.transformer_blocks:
334
+ x = block(x, context=context)
335
+
336
+ if self.spatial_dims == 2:
337
+ x = x.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
338
+ if self.spatial_dims == 3:
339
+ x = x.reshape(batch, height, width, depth, inner_dim).permute(0, 4, 1, 2, 3).contiguous()
340
+
341
+ x = self.proj_out(x)
342
+ return x + residual
343
+
344
+
345
+ class AttentionBlock(nn.Module):
346
+ """
347
+ An attention block that allows spatial positions to attend to each other. Uses three q, k, v linear layers to
348
+ compute attention.
349
+
350
+ Args:
351
+ spatial_dims: number of spatial dimensions.
352
+ num_channels: number of input channels.
353
+ num_head_channels: number of channels in each attention head.
354
+ norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of
355
+ channels is divisible by this number.
356
+ norm_eps: epsilon value to use for the normalisation.
357
+ use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
358
+ """
359
+
360
+ def __init__(
361
+ self,
362
+ spatial_dims: int,
363
+ num_channels: int,
364
+ num_head_channels: int | None = None,
365
+ norm_num_groups: int = 32,
366
+ norm_eps: float = 1e-6,
367
+ use_flash_attention: bool = False,
368
+ ) -> None:
369
+ super().__init__()
370
+ self.use_flash_attention = use_flash_attention
371
+ self.spatial_dims = spatial_dims
372
+ self.num_channels = num_channels
373
+
374
+ self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1
375
+ self.scale = 1 / math.sqrt(num_channels / self.num_heads)
376
+
377
+ self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True)
378
+
379
+ self.to_q = nn.Linear(num_channels, num_channels)
380
+ self.to_k = nn.Linear(num_channels, num_channels)
381
+ self.to_v = nn.Linear(num_channels, num_channels)
382
+
383
+ self.proj_attn = nn.Linear(num_channels, num_channels)
384
+
385
+ def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor:
386
+ batch_size, seq_len, dim = x.shape
387
+ x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads)
388
+ x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads)
389
+ return x
390
+
391
+ def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor:
392
+ batch_size, seq_len, dim = x.shape
393
+ x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim)
394
+ x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads)
395
+ return x
396
+
397
+ def _memory_efficient_attention_xformers(
398
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
399
+ ) -> torch.Tensor:
400
+ query = query.contiguous()
401
+ key = key.contiguous()
402
+ value = value.contiguous()
403
+ x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
404
+ return x
405
+
406
+ def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
407
+ attention_scores = torch.baddbmm(
408
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
409
+ query,
410
+ key.transpose(-1, -2),
411
+ beta=0,
412
+ alpha=self.scale,
413
+ )
414
+ attention_probs = attention_scores.softmax(dim=-1)
415
+ x = torch.bmm(attention_probs, value)
416
+ return x
417
+
418
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
419
+ residual = x
420
+
421
+ batch = channel = height = width = depth = -1
422
+ if self.spatial_dims == 2:
423
+ batch, channel, height, width = x.shape
424
+ if self.spatial_dims == 3:
425
+ batch, channel, height, width, depth = x.shape
426
+
427
+ # norm
428
+ x = self.norm(x)
429
+
430
+ if self.spatial_dims == 2:
431
+ x = x.view(batch, channel, height * width).transpose(1, 2)
432
+ if self.spatial_dims == 3:
433
+ x = x.view(batch, channel, height * width * depth).transpose(1, 2)
434
+
435
+ # proj to q, k, v
436
+ query = self.to_q(x)
437
+ key = self.to_k(x)
438
+ value = self.to_v(x)
439
+
440
+ # Multi-Head Attention
441
+ query = self.reshape_heads_to_batch_dim(query)
442
+ key = self.reshape_heads_to_batch_dim(key)
443
+ value = self.reshape_heads_to_batch_dim(value)
444
+
445
+ if self.use_flash_attention:
446
+ x = self._memory_efficient_attention_xformers(query, key, value)
447
+ else:
448
+ x = self._attention(query, key, value)
449
+
450
+ x = self.reshape_batch_dim_to_heads(x)
451
+ x = x.to(query.dtype)
452
+
453
+ if self.spatial_dims == 2:
454
+ x = x.transpose(-1, -2).reshape(batch, channel, height, width)
455
+ if self.spatial_dims == 3:
456
+ x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth)
457
+
458
+ return x + residual
459
+
460
+
461
+ def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000) -> torch.Tensor:
462
+ """
463
+ Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic
464
+ Models" https://arxiv.org/abs/2006.11239.
465
+
466
+ Args:
467
+ timesteps: a 1-D Tensor of N indices, one per batch element.
468
+ embedding_dim: the dimension of the output.
469
+ max_period: controls the minimum frequency of the embeddings.
470
+ """
471
+ if timesteps.ndim != 1:
472
+ raise ValueError("Timesteps should be a 1d-array")
473
+
474
+ half_dim = embedding_dim // 2
475
+ exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
476
+ freqs = torch.exp(exponent / half_dim)
477
+
478
+ args = timesteps[:, None].float() * freqs[None, :]
479
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
480
+
481
+ # zero pad
482
+ if embedding_dim % 2 == 1:
483
+ embedding = torch.nn.functional.pad(embedding, (0, 1, 0, 0))
484
+
485
+ return embedding
486
+
487
+
488
+ class Downsample(nn.Module):
489
+ """
490
+ Downsampling layer.
491
+
492
+ Args:
493
+ spatial_dims: number of spatial dimensions.
494
+ num_channels: number of input channels.
495
+ use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is
496
+ False, the number of output channels must be the same as the number of input channels.
497
+ out_channels: number of output channels.
498
+ padding: controls the amount of implicit zero-paddings on both sides for padding number of points
499
+ for each dimension.
500
+ """
501
+
502
+ def __init__(
503
+ self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1
504
+ ) -> None:
505
+ super().__init__()
506
+ self.num_channels = num_channels
507
+ self.out_channels = out_channels or num_channels
508
+ self.use_conv = use_conv
509
+ if use_conv:
510
+ self.op = Convolution(
511
+ spatial_dims=spatial_dims,
512
+ in_channels=self.num_channels,
513
+ out_channels=self.out_channels,
514
+ strides=2,
515
+ kernel_size=3,
516
+ padding=padding,
517
+ conv_only=True,
518
+ )
519
+ else:
520
+ if self.num_channels != self.out_channels:
521
+ raise ValueError("num_channels and out_channels must be equal when use_conv=False")
522
+ self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2)
523
+
524
+ def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
525
+ del emb
526
+ if x.shape[1] != self.num_channels:
527
+ raise ValueError(
528
+ f"Input number of channels ({x.shape[1]}) is not equal to expected number of channels "
529
+ f"({self.num_channels})"
530
+ )
531
+ return self.op(x)
532
+
533
+
534
+ class Upsample(nn.Module):
535
+ """
536
+ Upsampling layer with an optional convolution.
537
+
538
+ Args:
539
+ spatial_dims: number of spatial dimensions.
540
+ num_channels: number of input channels.
541
+ use_conv: if True uses Convolution instead of Pool average to perform downsampling.
542
+ out_channels: number of output channels.
543
+ padding: controls the amount of implicit zero-paddings on both sides for padding number of points for each
544
+ dimension.
545
+ """
546
+
547
+ def __init__(
548
+ self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1
549
+ ) -> None:
550
+ super().__init__()
551
+ self.num_channels = num_channels
552
+ self.out_channels = out_channels or num_channels
553
+ self.use_conv = use_conv
554
+ if use_conv:
555
+ self.conv = Convolution(
556
+ spatial_dims=spatial_dims,
557
+ in_channels=self.num_channels,
558
+ out_channels=self.out_channels,
559
+ strides=1,
560
+ kernel_size=3,
561
+ padding=padding,
562
+ conv_only=True,
563
+ )
564
+ else:
565
+ self.conv = None
566
+
567
+ def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
568
+ del emb
569
+ if x.shape[1] != self.num_channels:
570
+ raise ValueError("Input channels should be equal to num_channels")
571
+
572
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
573
+ # https://github.com/pytorch/pytorch/issues/86679
574
+ dtype = x.dtype
575
+ if dtype == torch.bfloat16:
576
+ x = x.to(torch.float32)
577
+
578
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
579
+
580
+ # If the input is bfloat16, we cast back to bfloat16
581
+ if dtype == torch.bfloat16:
582
+ x = x.to(dtype)
583
+
584
+ if self.use_conv:
585
+ x = self.conv(x)
586
+ return x
587
+
588
+
589
+ class ResnetBlock(nn.Module):
590
+ """
591
+ Residual block with timestep conditioning.
592
+
593
+ Args:
594
+ spatial_dims: The number of spatial dimensions.
595
+ in_channels: number of input channels.
596
+ temb_channels: number of timestep embedding channels.
597
+ out_channels: number of output channels.
598
+ up: if True, performs upsampling.
599
+ down: if True, performs downsampling.
600
+ norm_num_groups: number of groups for the group normalization.
601
+ norm_eps: epsilon for the group normalization.
602
+ """
603
+
604
+ def __init__(
605
+ self,
606
+ spatial_dims: int,
607
+ in_channels: int,
608
+ temb_channels: int,
609
+ out_channels: int | None = None,
610
+ up: bool = False,
611
+ down: bool = False,
612
+ norm_num_groups: int = 32,
613
+ norm_eps: float = 1e-6,
614
+ ) -> None:
615
+ super().__init__()
616
+ self.spatial_dims = spatial_dims
617
+ self.channels = in_channels
618
+ self.emb_channels = temb_channels
619
+ self.out_channels = out_channels or in_channels
620
+ self.up = up
621
+ self.down = down
622
+
623
+ self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True)
624
+ self.nonlinearity = nn.SiLU()
625
+ self.conv1 = Convolution(
626
+ spatial_dims=spatial_dims,
627
+ in_channels=in_channels,
628
+ out_channels=self.out_channels,
629
+ strides=1,
630
+ kernel_size=3,
631
+ padding=1,
632
+ conv_only=True,
633
+ dilation=3
634
+ )
635
+
636
+ self.upsample = self.downsample = None
637
+ if self.up:
638
+ self.upsample = Upsample(spatial_dims, in_channels, use_conv=False)
639
+ elif down:
640
+ self.downsample = Downsample(spatial_dims, in_channels, use_conv=False)
641
+
642
+ self.time_emb_proj = nn.Linear(temb_channels, self.out_channels)
643
+
644
+ self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=self.out_channels, eps=norm_eps, affine=True)
645
+ self.conv2 = zero_module(
646
+ Convolution(
647
+ spatial_dims=spatial_dims,
648
+ in_channels=self.out_channels,
649
+ out_channels=self.out_channels,
650
+ strides=1,
651
+ kernel_size=3,
652
+ padding=1,
653
+ conv_only=True,
654
+ dilation=2
655
+ )
656
+ )
657
+
658
+ if self.out_channels == in_channels:
659
+ self.skip_connection = nn.Identity()
660
+ else:
661
+ self.skip_connection = Convolution(
662
+ spatial_dims=spatial_dims,
663
+ in_channels=in_channels,
664
+ out_channels=self.out_channels,
665
+ strides=1,
666
+ kernel_size=1,
667
+ padding=0,
668
+ conv_only=True,
669
+ dilation=1
670
+ )
671
+
672
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
673
+ h = x
674
+ h = self.norm1(h)
675
+ h = self.nonlinearity(h)
676
+
677
+ if self.upsample is not None:
678
+ if h.shape[0] >= 64:
679
+ x = x.contiguous()
680
+ h = h.contiguous()
681
+ x = self.upsample(x)
682
+ h = self.upsample(h)
683
+ elif self.downsample is not None:
684
+ x = self.downsample(x)
685
+ h = self.downsample(h)
686
+
687
+ h = self.conv1(h)
688
+
689
+ if self.spatial_dims == 2:
690
+ temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None]
691
+ else:
692
+ temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None]
693
+ h = h + temb
694
+
695
+ h = self.norm2(h)
696
+ h = self.nonlinearity(h)
697
+ h = self.conv2(h)
698
+
699
+ return self.skip_connection(x) + h
700
+
701
+
702
+ class DownBlock(nn.Module):
703
+ """
704
+ Unet's down block containing resnet and downsamplers blocks.
705
+
706
+ Args:
707
+ spatial_dims: The number of spatial dimensions.
708
+ in_channels: number of input channels.
709
+ out_channels: number of output channels.
710
+ temb_channels: number of timestep embedding channels.
711
+ num_res_blocks: number of residual blocks.
712
+ norm_num_groups: number of groups for the group normalization.
713
+ norm_eps: epsilon for the group normalization.
714
+ add_downsample: if True add downsample block.
715
+ resblock_updown: if True use residual blocks for downsampling.
716
+ downsample_padding: padding used in the downsampling block.
717
+ """
718
+
719
+ def __init__(
720
+ self,
721
+ spatial_dims: int,
722
+ in_channels: int,
723
+ out_channels: int,
724
+ temb_channels: int,
725
+ num_res_blocks: int = 1,
726
+ norm_num_groups: int = 32,
727
+ norm_eps: float = 1e-6,
728
+ add_downsample: bool = True,
729
+ resblock_updown: bool = False,
730
+ downsample_padding: int = 1,
731
+ ) -> None:
732
+ super().__init__()
733
+ self.resblock_updown = resblock_updown
734
+
735
+ resnets = []
736
+
737
+ for i in range(num_res_blocks):
738
+ in_channels = in_channels if i == 0 else out_channels
739
+ resnets.append(
740
+ ResnetBlock(
741
+ spatial_dims=spatial_dims,
742
+ in_channels=in_channels,
743
+ out_channels=out_channels,
744
+ temb_channels=temb_channels,
745
+ norm_num_groups=norm_num_groups,
746
+ norm_eps=norm_eps,
747
+ )
748
+ )
749
+
750
+ self.resnets = nn.ModuleList(resnets)
751
+
752
+ if add_downsample:
753
+ if resblock_updown:
754
+ self.downsampler = ResnetBlock(
755
+ spatial_dims=spatial_dims,
756
+ in_channels=out_channels,
757
+ out_channels=out_channels,
758
+ temb_channels=temb_channels,
759
+ norm_num_groups=norm_num_groups,
760
+ norm_eps=norm_eps,
761
+ down=True,
762
+ )
763
+ else:
764
+ self.downsampler = Downsample(
765
+ spatial_dims=spatial_dims,
766
+ num_channels=out_channels,
767
+ use_conv=True,
768
+ out_channels=out_channels,
769
+ padding=downsample_padding,
770
+ )
771
+ else:
772
+ self.downsampler = None
773
+
774
+ def forward(
775
+ self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
776
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
777
+ del context
778
+ output_states = []
779
+
780
+ for resnet in self.resnets:
781
+ hidden_states = resnet(hidden_states, temb)
782
+ output_states.append(hidden_states)
783
+
784
+ if self.downsampler is not None:
785
+ hidden_states = self.downsampler(hidden_states, temb)
786
+ output_states.append(hidden_states)
787
+
788
+ return hidden_states, output_states
789
+
790
+
791
+ class AttnDownBlock(nn.Module):
792
+ """
793
+ Unet's down block containing resnet, downsamplers and self-attention blocks.
794
+
795
+ Args:
796
+ spatial_dims: The number of spatial dimensions.
797
+ in_channels: number of input channels.
798
+ out_channels: number of output channels.
799
+ temb_channels: number of timestep embedding channels.
800
+ num_res_blocks: number of residual blocks.
801
+ norm_num_groups: number of groups for the group normalization.
802
+ norm_eps: epsilon for the group normalization.
803
+ add_downsample: if True add downsample block.
804
+ resblock_updown: if True use residual blocks for downsampling.
805
+ downsample_padding: padding used in the downsampling block.
806
+ num_head_channels: number of channels in each attention head.
807
+ use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
808
+ """
809
+
810
+ def __init__(
811
+ self,
812
+ spatial_dims: int,
813
+ in_channels: int,
814
+ out_channels: int,
815
+ temb_channels: int,
816
+ num_res_blocks: int = 1,
817
+ norm_num_groups: int = 32,
818
+ norm_eps: float = 1e-6,
819
+ add_downsample: bool = True,
820
+ resblock_updown: bool = False,
821
+ downsample_padding: int = 1,
822
+ num_head_channels: int = 1,
823
+ use_flash_attention: bool = False,
824
+ ) -> None:
825
+ super().__init__()
826
+ self.resblock_updown = resblock_updown
827
+
828
+ resnets = []
829
+ attentions = []
830
+
831
+ for i in range(num_res_blocks):
832
+ in_channels = in_channels if i == 0 else out_channels
833
+ resnets.append(
834
+ ResnetBlock(
835
+ spatial_dims=spatial_dims,
836
+ in_channels=in_channels,
837
+ out_channels=out_channels,
838
+ temb_channels=temb_channels,
839
+ norm_num_groups=norm_num_groups,
840
+ norm_eps=norm_eps,
841
+ )
842
+ )
843
+ attentions.append(
844
+ AttentionBlock(
845
+ spatial_dims=spatial_dims,
846
+ num_channels=out_channels,
847
+ num_head_channels=num_head_channels,
848
+ norm_num_groups=norm_num_groups,
849
+ norm_eps=norm_eps,
850
+ use_flash_attention=use_flash_attention,
851
+ )
852
+ )
853
+
854
+ self.attentions = nn.ModuleList(attentions)
855
+ self.resnets = nn.ModuleList(resnets)
856
+
857
+ if add_downsample:
858
+ if resblock_updown:
859
+ self.downsampler = ResnetBlock(
860
+ spatial_dims=spatial_dims,
861
+ in_channels=out_channels,
862
+ out_channels=out_channels,
863
+ temb_channels=temb_channels,
864
+ norm_num_groups=norm_num_groups,
865
+ norm_eps=norm_eps,
866
+ down=True,
867
+ )
868
+ else:
869
+ self.downsampler = Downsample(
870
+ spatial_dims=spatial_dims,
871
+ num_channels=out_channels,
872
+ use_conv=True,
873
+ out_channels=out_channels,
874
+ padding=downsample_padding,
875
+ )
876
+ else:
877
+ self.downsampler = None
878
+
879
+ def forward(
880
+ self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
881
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
882
+ del context
883
+ output_states = []
884
+
885
+ for resnet, attn in zip(self.resnets, self.attentions):
886
+ hidden_states = resnet(hidden_states, temb)
887
+ hidden_states = attn(hidden_states)
888
+ output_states.append(hidden_states)
889
+
890
+ if self.downsampler is not None:
891
+ hidden_states = self.downsampler(hidden_states, temb)
892
+ output_states.append(hidden_states)
893
+
894
+ return hidden_states, output_states
895
+
896
+
897
+ class CrossAttnDownBlock(nn.Module):
898
+ """
899
+ Unet's down block containing resnet, downsamplers and cross-attention blocks.
900
+
901
+ Args:
902
+ spatial_dims: number of spatial dimensions.
903
+ in_channels: number of input channels.
904
+ out_channels: number of output channels.
905
+ temb_channels: number of timestep embedding channels.
906
+ num_res_blocks: number of residual blocks.
907
+ norm_num_groups: number of groups for the group normalization.
908
+ norm_eps: epsilon for the group normalization.
909
+ add_downsample: if True add downsample block.
910
+ resblock_updown: if True use residual blocks for downsampling.
911
+ downsample_padding: padding used in the downsampling block.
912
+ num_head_channels: number of channels in each attention head.
913
+ transformer_num_layers: number of layers of Transformer blocks to use.
914
+ cross_attention_dim: number of context dimensions to use.
915
+ upcast_attention: if True, upcast attention operations to full precision.
916
+ use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
917
+ """
918
+
919
+ def __init__(
920
+ self,
921
+ spatial_dims: int,
922
+ in_channels: int,
923
+ out_channels: int,
924
+ temb_channels: int,
925
+ num_res_blocks: int = 1,
926
+ norm_num_groups: int = 32,
927
+ norm_eps: float = 1e-6,
928
+ add_downsample: bool = True,
929
+ resblock_updown: bool = False,
930
+ downsample_padding: int = 1,
931
+ num_head_channels: int = 1,
932
+ transformer_num_layers: int = 1,
933
+ cross_attention_dim: int | None = None,
934
+ upcast_attention: bool = False,
935
+ use_flash_attention: bool = False,
936
+ ) -> None:
937
+ super().__init__()
938
+ self.resblock_updown = resblock_updown
939
+
940
+ resnets = []
941
+ attentions = []
942
+
943
+ for i in range(num_res_blocks):
944
+ in_channels = in_channels if i == 0 else out_channels
945
+ resnets.append(
946
+ ResnetBlock(
947
+ spatial_dims=spatial_dims,
948
+ in_channels=in_channels,
949
+ out_channels=out_channels,
950
+ temb_channels=temb_channels,
951
+ norm_num_groups=norm_num_groups,
952
+ norm_eps=norm_eps,
953
+ )
954
+ )
955
+
956
+ attentions.append(
957
+ SpatialTransformer(
958
+ spatial_dims=spatial_dims,
959
+ in_channels=out_channels,
960
+ num_attention_heads=out_channels // num_head_channels,
961
+ num_head_channels=num_head_channels,
962
+ num_layers=transformer_num_layers,
963
+ norm_num_groups=norm_num_groups,
964
+ norm_eps=norm_eps,
965
+ cross_attention_dim=cross_attention_dim,
966
+ upcast_attention=upcast_attention,
967
+ use_flash_attention=use_flash_attention,
968
+ )
969
+ )
970
+
971
+ self.attentions = nn.ModuleList(attentions)
972
+ self.resnets = nn.ModuleList(resnets)
973
+
974
+ if add_downsample:
975
+ if resblock_updown:
976
+ self.downsampler = ResnetBlock(
977
+ spatial_dims=spatial_dims,
978
+ in_channels=out_channels,
979
+ out_channels=out_channels,
980
+ temb_channels=temb_channels,
981
+ norm_num_groups=norm_num_groups,
982
+ norm_eps=norm_eps,
983
+ down=True,
984
+ )
985
+ else:
986
+ self.downsampler = Downsample(
987
+ spatial_dims=spatial_dims,
988
+ num_channels=out_channels,
989
+ use_conv=True,
990
+ out_channels=out_channels,
991
+ padding=downsample_padding,
992
+ )
993
+ else:
994
+ self.downsampler = None
995
+
996
+ def forward(
997
+ self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
998
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
999
+ output_states = []
1000
+
1001
+ for resnet, attn in zip(self.resnets, self.attentions):
1002
+ hidden_states = resnet(hidden_states, temb)
1003
+ hidden_states = attn(hidden_states, context=context)
1004
+ output_states.append(hidden_states)
1005
+
1006
+ if self.downsampler is not None:
1007
+ hidden_states = self.downsampler(hidden_states, temb)
1008
+ output_states.append(hidden_states)
1009
+
1010
+ return hidden_states, output_states
1011
+
1012
+
1013
+ class AttnMidBlock(nn.Module):
1014
+ """
1015
+ Unet's mid block containing resnet and self-attention blocks.
1016
+
1017
+ Args:
1018
+ spatial_dims: The number of spatial dimensions.
1019
+ in_channels: number of input channels.
1020
+ temb_channels: number of timestep embedding channels.
1021
+ norm_num_groups: number of groups for the group normalization.
1022
+ norm_eps: epsilon for the group normalization.
1023
+ num_head_channels: number of channels in each attention head.
1024
+ use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
1025
+ """
1026
+
1027
+ def __init__(
1028
+ self,
1029
+ spatial_dims: int,
1030
+ in_channels: int,
1031
+ temb_channels: int,
1032
+ norm_num_groups: int = 32,
1033
+ norm_eps: float = 1e-6,
1034
+ num_head_channels: int = 1,
1035
+ use_flash_attention: bool = False,
1036
+ ) -> None:
1037
+ super().__init__()
1038
+ self.attention = None
1039
+
1040
+ self.resnet_1 = ResnetBlock(
1041
+ spatial_dims=spatial_dims,
1042
+ in_channels=in_channels,
1043
+ out_channels=in_channels,
1044
+ temb_channels=temb_channels,
1045
+ norm_num_groups=norm_num_groups,
1046
+ norm_eps=norm_eps,
1047
+ )
1048
+ self.attention = AttentionBlock(
1049
+ spatial_dims=spatial_dims,
1050
+ num_channels=in_channels,
1051
+ num_head_channels=num_head_channels,
1052
+ norm_num_groups=norm_num_groups,
1053
+ norm_eps=norm_eps,
1054
+ use_flash_attention=use_flash_attention,
1055
+ )
1056
+
1057
+ self.resnet_2 = ResnetBlock(
1058
+ spatial_dims=spatial_dims,
1059
+ in_channels=in_channels,
1060
+ out_channels=in_channels,
1061
+ temb_channels=temb_channels,
1062
+ norm_num_groups=norm_num_groups,
1063
+ norm_eps=norm_eps,
1064
+ )
1065
+
1066
+ def forward(
1067
+ self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
1068
+ ) -> torch.Tensor:
1069
+ del context
1070
+ hidden_states = self.resnet_1(hidden_states, temb)
1071
+ hidden_states = self.attention(hidden_states)
1072
+ hidden_states = self.resnet_2(hidden_states, temb)
1073
+
1074
+ return hidden_states
1075
+
1076
+
1077
+ class CrossAttnMidBlock(nn.Module):
1078
+ """
1079
+ Unet's mid block containing resnet and cross-attention blocks.
1080
+
1081
+ Args:
1082
+ spatial_dims: The number of spatial dimensions.
1083
+ in_channels: number of input channels.
1084
+ temb_channels: number of timestep embedding channels
1085
+ norm_num_groups: number of groups for the group normalization.
1086
+ norm_eps: epsilon for the group normalization.
1087
+ num_head_channels: number of channels in each attention head.
1088
+ transformer_num_layers: number of layers of Transformer blocks to use.
1089
+ cross_attention_dim: number of context dimensions to use.
1090
+ upcast_attention: if True, upcast attention operations to full precision.
1091
+ use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
1092
+ """
1093
+
1094
+ def __init__(
1095
+ self,
1096
+ spatial_dims: int,
1097
+ in_channels: int,
1098
+ temb_channels: int,
1099
+ norm_num_groups: int = 32,
1100
+ norm_eps: float = 1e-6,
1101
+ num_head_channels: int = 1,
1102
+ transformer_num_layers: int = 1,
1103
+ cross_attention_dim: int | None = None,
1104
+ upcast_attention: bool = False,
1105
+ use_flash_attention: bool = False,
1106
+ ) -> None:
1107
+ super().__init__()
1108
+ self.attention = None
1109
+
1110
+ self.resnet_1 = ResnetBlock(
1111
+ spatial_dims=spatial_dims,
1112
+ in_channels=in_channels,
1113
+ out_channels=in_channels,
1114
+ temb_channels=temb_channels,
1115
+ norm_num_groups=norm_num_groups,
1116
+ norm_eps=norm_eps,
1117
+ )
1118
+ self.attention = SpatialTransformer(
1119
+ spatial_dims=spatial_dims,
1120
+ in_channels=in_channels,
1121
+ num_attention_heads=in_channels // num_head_channels,
1122
+ num_head_channels=num_head_channels,
1123
+ num_layers=transformer_num_layers,
1124
+ norm_num_groups=norm_num_groups,
1125
+ norm_eps=norm_eps,
1126
+ cross_attention_dim=cross_attention_dim,
1127
+ upcast_attention=upcast_attention,
1128
+ use_flash_attention=use_flash_attention,
1129
+ )
1130
+ self.resnet_2 = ResnetBlock(
1131
+ spatial_dims=spatial_dims,
1132
+ in_channels=in_channels,
1133
+ out_channels=in_channels,
1134
+ temb_channels=temb_channels,
1135
+ norm_num_groups=norm_num_groups,
1136
+ norm_eps=norm_eps,
1137
+ )
1138
+
1139
+ def forward(
1140
+ self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
1141
+ ) -> torch.Tensor:
1142
+ hidden_states = self.resnet_1(hidden_states, temb)
1143
+ hidden_states = self.attention(hidden_states, context=context)
1144
+ hidden_states = self.resnet_2(hidden_states, temb)
1145
+
1146
+ return hidden_states
1147
+
1148
+
1149
+ class UpBlock(nn.Module):
1150
+ """
1151
+ Unet's up block containing resnet and upsamplers blocks.
1152
+
1153
+ Args:
1154
+ spatial_dims: The number of spatial dimensions.
1155
+ in_channels: number of input channels.
1156
+ prev_output_channel: number of channels from residual connection.
1157
+ out_channels: number of output channels.
1158
+ temb_channels: number of timestep embedding channels.
1159
+ num_res_blocks: number of residual blocks.
1160
+ norm_num_groups: number of groups for the group normalization.
1161
+ norm_eps: epsilon for the group normalization.
1162
+ add_upsample: if True add downsample block.
1163
+ resblock_updown: if True use residual blocks for upsampling.
1164
+ """
1165
+
1166
+ def __init__(
1167
+ self,
1168
+ spatial_dims: int,
1169
+ in_channels: int,
1170
+ prev_output_channel: int,
1171
+ out_channels: int,
1172
+ temb_channels: int,
1173
+ num_res_blocks: int = 1,
1174
+ norm_num_groups: int = 32,
1175
+ norm_eps: float = 1e-6,
1176
+ add_upsample: bool = True,
1177
+ resblock_updown: bool = False,
1178
+ ) -> None:
1179
+ super().__init__()
1180
+ self.resblock_updown = resblock_updown
1181
+ resnets = []
1182
+
1183
+ for i in range(num_res_blocks):
1184
+ res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels
1185
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1186
+
1187
+ resnets.append(
1188
+ ResnetBlock(
1189
+ spatial_dims=spatial_dims,
1190
+ in_channels=resnet_in_channels + res_skip_channels,
1191
+ out_channels=out_channels,
1192
+ temb_channels=temb_channels,
1193
+ norm_num_groups=norm_num_groups,
1194
+ norm_eps=norm_eps,
1195
+ )
1196
+ )
1197
+
1198
+ self.resnets = nn.ModuleList(resnets)
1199
+
1200
+ if add_upsample:
1201
+ if resblock_updown:
1202
+ self.upsampler = ResnetBlock(
1203
+ spatial_dims=spatial_dims,
1204
+ in_channels=out_channels,
1205
+ out_channels=out_channels,
1206
+ temb_channels=temb_channels,
1207
+ norm_num_groups=norm_num_groups,
1208
+ norm_eps=norm_eps,
1209
+ up=True,
1210
+ )
1211
+ else:
1212
+ self.upsampler = Upsample(
1213
+ spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels
1214
+ )
1215
+ else:
1216
+ self.upsampler = None
1217
+
1218
+ def forward(
1219
+ self,
1220
+ hidden_states: torch.Tensor,
1221
+ res_hidden_states_list: list[torch.Tensor],
1222
+ temb: torch.Tensor,
1223
+ context: torch.Tensor | None = None,
1224
+ ) -> torch.Tensor:
1225
+ del context
1226
+ for resnet in self.resnets:
1227
+ # pop res hidden states
1228
+ res_hidden_states = res_hidden_states_list[-1]
1229
+ res_hidden_states_list = res_hidden_states_list[:-1]
1230
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1231
+
1232
+ hidden_states = resnet(hidden_states, temb)
1233
+
1234
+ if self.upsampler is not None:
1235
+ hidden_states = self.upsampler(hidden_states, temb)
1236
+
1237
+ return hidden_states
1238
+
1239
+
1240
+ class AttnUpBlock(nn.Module):
1241
+ """
1242
+ Unet's up block containing resnet, upsamplers, and self-attention blocks.
1243
+
1244
+ Args:
1245
+ spatial_dims: The number of spatial dimensions.
1246
+ in_channels: number of input channels.
1247
+ prev_output_channel: number of channels from residual connection.
1248
+ out_channels: number of output channels.
1249
+ temb_channels: number of timestep embedding channels.
1250
+ num_res_blocks: number of residual blocks.
1251
+ norm_num_groups: number of groups for the group normalization.
1252
+ norm_eps: epsilon for the group normalization.
1253
+ add_upsample: if True add downsample block.
1254
+ resblock_updown: if True use residual blocks for upsampling.
1255
+ num_head_channels: number of channels in each attention head.
1256
+ use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
1257
+ """
1258
+
1259
+ def __init__(
1260
+ self,
1261
+ spatial_dims: int,
1262
+ in_channels: int,
1263
+ prev_output_channel: int,
1264
+ out_channels: int,
1265
+ temb_channels: int,
1266
+ num_res_blocks: int = 1,
1267
+ norm_num_groups: int = 32,
1268
+ norm_eps: float = 1e-6,
1269
+ add_upsample: bool = True,
1270
+ resblock_updown: bool = False,
1271
+ num_head_channels: int = 1,
1272
+ use_flash_attention: bool = False,
1273
+ ) -> None:
1274
+ super().__init__()
1275
+ self.resblock_updown = resblock_updown
1276
+
1277
+ resnets = []
1278
+ attentions = []
1279
+
1280
+ for i in range(num_res_blocks):
1281
+ res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels
1282
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1283
+
1284
+ resnets.append(
1285
+ ResnetBlock(
1286
+ spatial_dims=spatial_dims,
1287
+ in_channels=resnet_in_channels + res_skip_channels,
1288
+ out_channels=out_channels,
1289
+ temb_channels=temb_channels,
1290
+ norm_num_groups=norm_num_groups,
1291
+ norm_eps=norm_eps,
1292
+ )
1293
+ )
1294
+ attentions.append(
1295
+ AttentionBlock(
1296
+ spatial_dims=spatial_dims,
1297
+ num_channels=out_channels,
1298
+ num_head_channels=num_head_channels,
1299
+ norm_num_groups=norm_num_groups,
1300
+ norm_eps=norm_eps,
1301
+ use_flash_attention=use_flash_attention,
1302
+ )
1303
+ )
1304
+
1305
+ self.resnets = nn.ModuleList(resnets)
1306
+ self.attentions = nn.ModuleList(attentions)
1307
+
1308
+ if add_upsample:
1309
+ if resblock_updown:
1310
+ self.upsampler = ResnetBlock(
1311
+ spatial_dims=spatial_dims,
1312
+ in_channels=out_channels,
1313
+ out_channels=out_channels,
1314
+ temb_channels=temb_channels,
1315
+ norm_num_groups=norm_num_groups,
1316
+ norm_eps=norm_eps,
1317
+ up=True,
1318
+ )
1319
+ else:
1320
+ self.upsampler = Upsample(
1321
+ spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels
1322
+ )
1323
+ else:
1324
+ self.upsampler = None
1325
+
1326
+ def forward(
1327
+ self,
1328
+ hidden_states: torch.Tensor,
1329
+ res_hidden_states_list: list[torch.Tensor],
1330
+ temb: torch.Tensor,
1331
+ context: torch.Tensor | None = None,
1332
+ ) -> torch.Tensor:
1333
+ del context
1334
+ for resnet, attn in zip(self.resnets, self.attentions):
1335
+ # pop res hidden states
1336
+ res_hidden_states = res_hidden_states_list[-1]
1337
+ res_hidden_states_list = res_hidden_states_list[:-1]
1338
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1339
+
1340
+ hidden_states = resnet(hidden_states, temb)
1341
+ hidden_states = attn(hidden_states)
1342
+
1343
+ if self.upsampler is not None:
1344
+ hidden_states = self.upsampler(hidden_states, temb)
1345
+
1346
+ return hidden_states
1347
+
1348
+
1349
+ class CrossAttnUpBlock(nn.Module):
1350
+ """
1351
+ Unet's up block containing resnet, upsamplers, and self-attention blocks.
1352
+
1353
+ Args:
1354
+ spatial_dims: The number of spatial dimensions.
1355
+ in_channels: number of input channels.
1356
+ prev_output_channel: number of channels from residual connection.
1357
+ out_channels: number of output channels.
1358
+ temb_channels: number of timestep embedding channels.
1359
+ num_res_blocks: number of residual blocks.
1360
+ norm_num_groups: number of groups for the group normalization.
1361
+ norm_eps: epsilon for the group normalization.
1362
+ add_upsample: if True add downsample block.
1363
+ resblock_updown: if True use residual blocks for upsampling.
1364
+ num_head_channels: number of channels in each attention head.
1365
+ transformer_num_layers: number of layers of Transformer blocks to use.
1366
+ cross_attention_dim: number of context dimensions to use.
1367
+ upcast_attention: if True, upcast attention operations to full precision.
1368
+ use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
1369
+ """
1370
+
1371
+ def __init__(
1372
+ self,
1373
+ spatial_dims: int,
1374
+ in_channels: int,
1375
+ prev_output_channel: int,
1376
+ out_channels: int,
1377
+ temb_channels: int,
1378
+ num_res_blocks: int = 1,
1379
+ norm_num_groups: int = 32,
1380
+ norm_eps: float = 1e-6,
1381
+ add_upsample: bool = True,
1382
+ resblock_updown: bool = False,
1383
+ num_head_channels: int = 1,
1384
+ transformer_num_layers: int = 1,
1385
+ cross_attention_dim: int | None = None,
1386
+ upcast_attention: bool = False,
1387
+ use_flash_attention: bool = False,
1388
+ ) -> None:
1389
+ super().__init__()
1390
+ self.resblock_updown = resblock_updown
1391
+
1392
+ resnets = []
1393
+ attentions = []
1394
+
1395
+ for i in range(num_res_blocks):
1396
+ res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels
1397
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1398
+
1399
+ resnets.append(
1400
+ ResnetBlock(
1401
+ spatial_dims=spatial_dims,
1402
+ in_channels=resnet_in_channels + res_skip_channels,
1403
+ out_channels=out_channels,
1404
+ temb_channels=temb_channels,
1405
+ norm_num_groups=norm_num_groups,
1406
+ norm_eps=norm_eps,
1407
+ )
1408
+ )
1409
+ attentions.append(
1410
+ SpatialTransformer(
1411
+ spatial_dims=spatial_dims,
1412
+ in_channels=out_channels,
1413
+ num_attention_heads=out_channels // num_head_channels,
1414
+ num_head_channels=num_head_channels,
1415
+ norm_num_groups=norm_num_groups,
1416
+ norm_eps=norm_eps,
1417
+ num_layers=transformer_num_layers,
1418
+ cross_attention_dim=cross_attention_dim,
1419
+ upcast_attention=upcast_attention,
1420
+ use_flash_attention=use_flash_attention,
1421
+ )
1422
+ )
1423
+
1424
+ self.attentions = nn.ModuleList(attentions)
1425
+ self.resnets = nn.ModuleList(resnets)
1426
+
1427
+ if add_upsample:
1428
+ if resblock_updown:
1429
+ self.upsampler = ResnetBlock(
1430
+ spatial_dims=spatial_dims,
1431
+ in_channels=out_channels,
1432
+ out_channels=out_channels,
1433
+ temb_channels=temb_channels,
1434
+ norm_num_groups=norm_num_groups,
1435
+ norm_eps=norm_eps,
1436
+ up=True,
1437
+ )
1438
+ else:
1439
+ self.upsampler = Upsample(
1440
+ spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels
1441
+ )
1442
+ else:
1443
+ self.upsampler = None
1444
+
1445
+ def forward(
1446
+ self,
1447
+ hidden_states: torch.Tensor,
1448
+ res_hidden_states_list: list[torch.Tensor],
1449
+ temb: torch.Tensor,
1450
+ context: torch.Tensor | None = None,
1451
+ ) -> torch.Tensor:
1452
+ for resnet, attn in zip(self.resnets, self.attentions):
1453
+ # pop res hidden states
1454
+ res_hidden_states = res_hidden_states_list[-1]
1455
+ res_hidden_states_list = res_hidden_states_list[:-1]
1456
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1457
+
1458
+ hidden_states = resnet(hidden_states, temb)
1459
+ hidden_states = attn(hidden_states, context=context)
1460
+
1461
+ if self.upsampler is not None:
1462
+ hidden_states = self.upsampler(hidden_states, temb)
1463
+
1464
+ return hidden_states
1465
+
1466
+
1467
+ def get_down_block(
1468
+ spatial_dims: int,
1469
+ in_channels: int,
1470
+ out_channels: int,
1471
+ temb_channels: int,
1472
+ num_res_blocks: int,
1473
+ norm_num_groups: int,
1474
+ norm_eps: float,
1475
+ add_downsample: bool,
1476
+ resblock_updown: bool,
1477
+ with_attn: bool,
1478
+ with_cross_attn: bool,
1479
+ num_head_channels: int,
1480
+ transformer_num_layers: int,
1481
+ cross_attention_dim: int | None,
1482
+ upcast_attention: bool = False,
1483
+ use_flash_attention: bool = False,
1484
+ ) -> nn.Module:
1485
+ if with_attn:
1486
+ return AttnDownBlock(
1487
+ spatial_dims=spatial_dims,
1488
+ in_channels=in_channels,
1489
+ out_channels=out_channels,
1490
+ temb_channels=temb_channels,
1491
+ num_res_blocks=num_res_blocks,
1492
+ norm_num_groups=norm_num_groups,
1493
+ norm_eps=norm_eps,
1494
+ add_downsample=add_downsample,
1495
+ resblock_updown=resblock_updown,
1496
+ num_head_channels=num_head_channels,
1497
+ use_flash_attention=use_flash_attention,
1498
+ )
1499
+ elif with_cross_attn:
1500
+ return CrossAttnDownBlock(
1501
+ spatial_dims=spatial_dims,
1502
+ in_channels=in_channels,
1503
+ out_channels=out_channels,
1504
+ temb_channels=temb_channels,
1505
+ num_res_blocks=num_res_blocks,
1506
+ norm_num_groups=norm_num_groups,
1507
+ norm_eps=norm_eps,
1508
+ add_downsample=add_downsample,
1509
+ resblock_updown=resblock_updown,
1510
+ num_head_channels=num_head_channels,
1511
+ transformer_num_layers=transformer_num_layers,
1512
+ cross_attention_dim=cross_attention_dim,
1513
+ upcast_attention=upcast_attention,
1514
+ use_flash_attention=use_flash_attention,
1515
+ )
1516
+ else:
1517
+ return DownBlock(
1518
+ spatial_dims=spatial_dims,
1519
+ in_channels=in_channels,
1520
+ out_channels=out_channels,
1521
+ temb_channels=temb_channels,
1522
+ num_res_blocks=num_res_blocks,
1523
+ norm_num_groups=norm_num_groups,
1524
+ norm_eps=norm_eps,
1525
+ add_downsample=add_downsample,
1526
+ resblock_updown=resblock_updown,
1527
+ )
1528
+
1529
+
1530
+ def get_mid_block(
1531
+ spatial_dims: int,
1532
+ in_channels: int,
1533
+ temb_channels: int,
1534
+ norm_num_groups: int,
1535
+ norm_eps: float,
1536
+ with_conditioning: bool,
1537
+ num_head_channels: int,
1538
+ transformer_num_layers: int,
1539
+ cross_attention_dim: int | None,
1540
+ upcast_attention: bool = False,
1541
+ use_flash_attention: bool = False,
1542
+ ) -> nn.Module:
1543
+ if with_conditioning:
1544
+ return CrossAttnMidBlock(
1545
+ spatial_dims=spatial_dims,
1546
+ in_channels=in_channels,
1547
+ temb_channels=temb_channels,
1548
+ norm_num_groups=norm_num_groups,
1549
+ norm_eps=norm_eps,
1550
+ num_head_channels=num_head_channels,
1551
+ transformer_num_layers=transformer_num_layers,
1552
+ cross_attention_dim=cross_attention_dim,
1553
+ upcast_attention=upcast_attention,
1554
+ use_flash_attention=use_flash_attention,
1555
+ )
1556
+ else:
1557
+ return AttnMidBlock(
1558
+ spatial_dims=spatial_dims,
1559
+ in_channels=in_channels,
1560
+ temb_channels=temb_channels,
1561
+ norm_num_groups=norm_num_groups,
1562
+ norm_eps=norm_eps,
1563
+ num_head_channels=num_head_channels,
1564
+ use_flash_attention=use_flash_attention,
1565
+ )
1566
+
1567
+
1568
+ def get_up_block(
1569
+ spatial_dims: int,
1570
+ in_channels: int,
1571
+ prev_output_channel: int,
1572
+ out_channels: int,
1573
+ temb_channels: int,
1574
+ num_res_blocks: int,
1575
+ norm_num_groups: int,
1576
+ norm_eps: float,
1577
+ add_upsample: bool,
1578
+ resblock_updown: bool,
1579
+ with_attn: bool,
1580
+ with_cross_attn: bool,
1581
+ num_head_channels: int,
1582
+ transformer_num_layers: int,
1583
+ cross_attention_dim: int | None,
1584
+ upcast_attention: bool = False,
1585
+ use_flash_attention: bool = False,
1586
+ ) -> nn.Module:
1587
+ if with_attn:
1588
+ return AttnUpBlock(
1589
+ spatial_dims=spatial_dims,
1590
+ in_channels=in_channels,
1591
+ prev_output_channel=prev_output_channel,
1592
+ out_channels=out_channels,
1593
+ temb_channels=temb_channels,
1594
+ num_res_blocks=num_res_blocks,
1595
+ norm_num_groups=norm_num_groups,
1596
+ norm_eps=norm_eps,
1597
+ add_upsample=add_upsample,
1598
+ resblock_updown=resblock_updown,
1599
+ num_head_channels=num_head_channels,
1600
+ use_flash_attention=use_flash_attention,
1601
+ )
1602
+ elif with_cross_attn:
1603
+ return CrossAttnUpBlock(
1604
+ spatial_dims=spatial_dims,
1605
+ in_channels=in_channels,
1606
+ prev_output_channel=prev_output_channel,
1607
+ out_channels=out_channels,
1608
+ temb_channels=temb_channels,
1609
+ num_res_blocks=num_res_blocks,
1610
+ norm_num_groups=norm_num_groups,
1611
+ norm_eps=norm_eps,
1612
+ add_upsample=add_upsample,
1613
+ resblock_updown=resblock_updown,
1614
+ num_head_channels=num_head_channels,
1615
+ transformer_num_layers=transformer_num_layers,
1616
+ cross_attention_dim=cross_attention_dim,
1617
+ upcast_attention=upcast_attention,
1618
+ use_flash_attention=use_flash_attention,
1619
+ )
1620
+ else:
1621
+ return UpBlock(
1622
+ spatial_dims=spatial_dims,
1623
+ in_channels=in_channels,
1624
+ prev_output_channel=prev_output_channel,
1625
+ out_channels=out_channels,
1626
+ temb_channels=temb_channels,
1627
+ num_res_blocks=num_res_blocks,
1628
+ norm_num_groups=norm_num_groups,
1629
+ norm_eps=norm_eps,
1630
+ add_upsample=add_upsample,
1631
+ resblock_updown=resblock_updown,
1632
+ )
1633
+
1634
+
1635
+ class DiffusionModelUNet(nn.Module):
1636
+ """
1637
+ Unet network with timestep embedding and attention mechanisms for conditioning based on
1638
+ Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752
1639
+ and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162
1640
+
1641
+ Args:
1642
+ spatial_dims: number of spatial dimensions.
1643
+ in_channels: number of input channels.
1644
+ out_channels: number of output channels.
1645
+ num_res_blocks: number of residual blocks (see ResnetBlock) per level.
1646
+ num_channels: tuple of block output channels.
1647
+ attention_levels: list of levels to add attention.
1648
+ norm_num_groups: number of groups for the normalization.
1649
+ norm_eps: epsilon for the normalization.
1650
+ resblock_updown: if True use residual blocks for up/downsampling.
1651
+ num_head_channels: number of channels in each attention head.
1652
+ with_conditioning: if True add spatial transformers to perform conditioning.
1653
+ transformer_num_layers: number of layers of Transformer blocks to use.
1654
+ cross_attention_dim: number of context dimensions to use.
1655
+ num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
1656
+ classes.
1657
+ upcast_attention: if True, upcast attention operations to full precision.
1658
+ use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
1659
+ """
1660
+
1661
+ def __init__(
1662
+ self,
1663
+ spatial_dims: int,
1664
+ in_channels: int,
1665
+ out_channels: int,
1666
+ num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
1667
+ num_channels: Sequence[int] = (32, 64, 64, 64),
1668
+ attention_levels: Sequence[bool] = (False, False, True, True),
1669
+ norm_num_groups: int = 32,
1670
+ norm_eps: float = 1e-6,
1671
+ resblock_updown: bool = False,
1672
+ num_head_channels: int | Sequence[int] = 8,
1673
+ with_conditioning: bool = False,
1674
+ transformer_num_layers: int = 1,
1675
+ cross_attention_dim: int | None = None,
1676
+ num_class_embeds: int | None = None,
1677
+ upcast_attention: bool = False,
1678
+ use_flash_attention: bool = False,
1679
+ ) -> None:
1680
+ super().__init__()
1681
+ if with_conditioning is True and cross_attention_dim is None:
1682
+ raise ValueError(
1683
+ "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) "
1684
+ "when using with_conditioning."
1685
+ )
1686
+ if cross_attention_dim is not None and with_conditioning is False:
1687
+ raise ValueError(
1688
+ "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim."
1689
+ )
1690
+
1691
+ # All number of channels should be multiple of num_groups
1692
+ if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels):
1693
+ raise ValueError("DiffusionModelUNet expects all num_channels being multiple of norm_num_groups")
1694
+
1695
+ if len(num_channels) != len(attention_levels):
1696
+ raise ValueError("DiffusionModelUNet expects num_channels being same size of attention_levels")
1697
+
1698
+ if isinstance(num_head_channels, int):
1699
+ num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels))
1700
+
1701
+ if len(num_head_channels) != len(attention_levels):
1702
+ raise ValueError(
1703
+ "num_head_channels should have the same length as attention_levels. For the i levels without attention,"
1704
+ " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored."
1705
+ )
1706
+
1707
+ if isinstance(num_res_blocks, int):
1708
+ num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels))
1709
+
1710
+ if len(num_res_blocks) != len(num_channels):
1711
+ raise ValueError(
1712
+ "`num_res_blocks` should be a single integer or a tuple of integers with the same length as "
1713
+ "`num_channels`."
1714
+ )
1715
+
1716
+ if use_flash_attention and not has_xformers:
1717
+ raise ValueError("use_flash_attention is True but xformers is not installed.")
1718
+
1719
+ if use_flash_attention is True and not torch.cuda.is_available():
1720
+ raise ValueError(
1721
+ "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU."
1722
+ )
1723
+
1724
+ self.in_channels = in_channels
1725
+ self.block_out_channels = num_channels
1726
+ self.out_channels = out_channels
1727
+ self.num_res_blocks = num_res_blocks
1728
+ self.attention_levels = attention_levels
1729
+ self.num_head_channels = num_head_channels
1730
+ self.with_conditioning = with_conditioning
1731
+
1732
+ # input
1733
+ self.conv_in = Convolution(
1734
+ spatial_dims=spatial_dims,
1735
+ in_channels=in_channels,
1736
+ out_channels=num_channels[0],
1737
+ strides=1,
1738
+ kernel_size=3,
1739
+ padding=1,
1740
+ conv_only=True,
1741
+ )
1742
+
1743
+ # time
1744
+ time_embed_dim = num_channels[0] * 4
1745
+ self.time_embed = nn.Sequential(
1746
+ nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim)
1747
+ )
1748
+
1749
+ # class embedding
1750
+ self.num_class_embeds = num_class_embeds
1751
+ if num_class_embeds is not None:
1752
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
1753
+
1754
+ # down
1755
+ self.down_blocks = nn.ModuleList([])
1756
+ output_channel = num_channels[0]
1757
+ for i in range(len(num_channels)):
1758
+ input_channel = output_channel
1759
+ output_channel = num_channels[i]
1760
+ is_final_block = i == len(num_channels) - 1
1761
+
1762
+ down_block = get_down_block(
1763
+ spatial_dims=spatial_dims,
1764
+ in_channels=input_channel,
1765
+ out_channels=output_channel,
1766
+ temb_channels=time_embed_dim,
1767
+ num_res_blocks=num_res_blocks[i],
1768
+ norm_num_groups=norm_num_groups,
1769
+ norm_eps=norm_eps,
1770
+ add_downsample=not is_final_block,
1771
+ resblock_updown=resblock_updown,
1772
+ with_attn=(attention_levels[i] and not with_conditioning),
1773
+ with_cross_attn=(attention_levels[i] and with_conditioning),
1774
+ num_head_channels=num_head_channels[i],
1775
+ transformer_num_layers=transformer_num_layers,
1776
+ cross_attention_dim=cross_attention_dim,
1777
+ upcast_attention=upcast_attention,
1778
+ use_flash_attention=use_flash_attention,
1779
+ )
1780
+
1781
+ self.down_blocks.append(down_block)
1782
+
1783
+ # mid
1784
+ self.middle_block = get_mid_block(
1785
+ spatial_dims=spatial_dims,
1786
+ in_channels=num_channels[-1],
1787
+ temb_channels=time_embed_dim,
1788
+ norm_num_groups=norm_num_groups,
1789
+ norm_eps=norm_eps,
1790
+ with_conditioning=with_conditioning,
1791
+ num_head_channels=num_head_channels[-1],
1792
+ transformer_num_layers=transformer_num_layers,
1793
+ cross_attention_dim=cross_attention_dim,
1794
+ upcast_attention=upcast_attention,
1795
+ use_flash_attention=use_flash_attention,
1796
+ )
1797
+
1798
+ # up
1799
+ self.up_blocks = nn.ModuleList([])
1800
+ reversed_block_out_channels = list(reversed(num_channels))
1801
+ reversed_num_res_blocks = list(reversed(num_res_blocks))
1802
+ reversed_attention_levels = list(reversed(attention_levels))
1803
+ reversed_num_head_channels = list(reversed(num_head_channels))
1804
+ output_channel = reversed_block_out_channels[0]
1805
+ for i in range(len(reversed_block_out_channels)):
1806
+ prev_output_channel = output_channel
1807
+ output_channel = reversed_block_out_channels[i]
1808
+ input_channel = reversed_block_out_channels[min(i + 1, len(num_channels) - 1)]
1809
+
1810
+ is_final_block = i == len(num_channels) - 1
1811
+
1812
+ up_block = get_up_block(
1813
+ spatial_dims=spatial_dims,
1814
+ in_channels=input_channel,
1815
+ prev_output_channel=prev_output_channel,
1816
+ out_channels=output_channel,
1817
+ temb_channels=time_embed_dim,
1818
+ num_res_blocks=reversed_num_res_blocks[i] + 1,
1819
+ norm_num_groups=norm_num_groups,
1820
+ norm_eps=norm_eps,
1821
+ add_upsample=not is_final_block,
1822
+ resblock_updown=resblock_updown,
1823
+ with_attn=(reversed_attention_levels[i] and not with_conditioning),
1824
+ with_cross_attn=(reversed_attention_levels[i] and with_conditioning),
1825
+ num_head_channels=reversed_num_head_channels[i],
1826
+ transformer_num_layers=transformer_num_layers,
1827
+ cross_attention_dim=cross_attention_dim,
1828
+ upcast_attention=upcast_attention,
1829
+ use_flash_attention=use_flash_attention,
1830
+ )
1831
+
1832
+ self.up_blocks.append(up_block)
1833
+
1834
+ # out
1835
+ self.out = nn.Sequential(
1836
+ nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[0], eps=norm_eps, affine=True),
1837
+ nn.SiLU(),
1838
+ zero_module(
1839
+ Convolution(
1840
+ spatial_dims=spatial_dims,
1841
+ in_channels=num_channels[0],
1842
+ out_channels=out_channels,
1843
+ strides=1,
1844
+ kernel_size=3,
1845
+ padding=1,
1846
+ conv_only=True,
1847
+ dilation=2
1848
+ )
1849
+ ),
1850
+ )
1851
+
1852
+ def forward(
1853
+ self,
1854
+ x: torch.Tensor,
1855
+ timesteps: torch.Tensor,
1856
+ context: torch.Tensor | None = None,
1857
+ class_labels: torch.Tensor | None = None,
1858
+ down_block_additional_residuals: tuple[torch.Tensor] | None = None,
1859
+ mid_block_additional_residual: torch.Tensor | None = None,
1860
+ ) -> torch.Tensor:
1861
+ """
1862
+ Args:
1863
+ x: input tensor (N, C, SpatialDims).
1864
+ timesteps: timestep tensor (N,).
1865
+ context: context tensor (N, 1, ContextDim).
1866
+ class_labels: context tensor (N, ).
1867
+ down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims).
1868
+ mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims).
1869
+ """
1870
+ # 1. time
1871
+ t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])
1872
+
1873
+ # timesteps does not contain any weights and will always return f32 tensors
1874
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1875
+ # there might be better ways to encapsulate this.
1876
+ t_emb = t_emb.to(dtype=x.dtype)
1877
+ emb = self.time_embed(t_emb)
1878
+
1879
+ # 2. class
1880
+ if self.num_class_embeds is not None:
1881
+ if class_labels is None:
1882
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
1883
+ class_emb = self.class_embedding(class_labels)
1884
+ class_emb = class_emb.to(dtype=x.dtype)
1885
+ emb = emb + class_emb
1886
+
1887
+ # 3. initial convolution
1888
+ h = self.conv_in(x)
1889
+
1890
+ # 4. down
1891
+ if context is not None and self.with_conditioning is False:
1892
+ raise ValueError("model should have with_conditioning = True if context is provided")
1893
+ down_block_res_samples: list[torch.Tensor] = [h]
1894
+ for downsample_block in self.down_blocks:
1895
+ h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context)
1896
+ for residual in res_samples:
1897
+ down_block_res_samples.append(residual)
1898
+
1899
+ # Additional residual conections for Controlnets
1900
+ if down_block_additional_residuals is not None:
1901
+ new_down_block_res_samples = ()
1902
+ for down_block_res_sample, down_block_additional_residual in zip(
1903
+ down_block_res_samples, down_block_additional_residuals
1904
+ ):
1905
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1906
+ new_down_block_res_samples += (down_block_res_sample,)
1907
+
1908
+ down_block_res_samples = new_down_block_res_samples
1909
+
1910
+ # 5. mid
1911
+ h = self.middle_block(hidden_states=h, temb=emb, context=context)
1912
+
1913
+ # Additional residual conections for Controlnets
1914
+ if mid_block_additional_residual is not None:
1915
+ h = h + mid_block_additional_residual
1916
+
1917
+ # 6. up
1918
+ for upsample_block in self.up_blocks:
1919
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1920
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1921
+ h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context)
1922
+
1923
+ # 7. output block
1924
+ h = self.out(h)
1925
+
1926
+ return h
1927
+
1928
+
1929
+ class DiffusionModelEncoder(nn.Module):
1930
+ """
1931
+ Classification Network based on the Encoder of the Diffusion Model, followed by fully connected layers. This network is based on
1932
+ Wolleb et al. "Diffusion Models for Medical Anomaly Detection" (https://arxiv.org/abs/2203.04306).
1933
+
1934
+ Args:
1935
+ spatial_dims: number of spatial dimensions.
1936
+ in_channels: number of input channels.
1937
+ out_channels: number of output channels.
1938
+ num_res_blocks: number of residual blocks (see ResnetBlock) per level.
1939
+ num_channels: tuple of block output channels.
1940
+ attention_levels: list of levels to add attention.
1941
+ norm_num_groups: number of groups for the normalization.
1942
+ norm_eps: epsilon for the normalization.
1943
+ resblock_updown: if True use residual blocks for downsampling.
1944
+ num_head_channels: number of channels in each attention head.
1945
+ with_conditioning: if True add spatial transformers to perform conditioning.
1946
+ transformer_num_layers: number of layers of Transformer blocks to use.
1947
+ cross_attention_dim: number of context dimensions to use.
1948
+ num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes.
1949
+ upcast_attention: if True, upcast attention operations to full precision.
1950
+ """
1951
+
1952
+ def __init__(
1953
+ self,
1954
+ spatial_dims: int,
1955
+ in_channels: int,
1956
+ out_channels: int,
1957
+ num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
1958
+ num_channels: Sequence[int] = (32, 64, 64, 64),
1959
+ attention_levels: Sequence[bool] = (False, False, True, True),
1960
+ norm_num_groups: int = 32,
1961
+ norm_eps: float = 1e-6,
1962
+ resblock_updown: bool = False,
1963
+ num_head_channels: int | Sequence[int] = 8,
1964
+ with_conditioning: bool = False,
1965
+ transformer_num_layers: int = 1,
1966
+ cross_attention_dim: int | None = None,
1967
+ num_class_embeds: int | None = None,
1968
+ upcast_attention: bool = False,
1969
+ ) -> None:
1970
+ super().__init__()
1971
+ if with_conditioning is True and cross_attention_dim is None:
1972
+ raise ValueError(
1973
+ "DiffusionModelEncoder expects dimension of the cross-attention conditioning (cross_attention_dim) "
1974
+ "when using with_conditioning."
1975
+ )
1976
+ if cross_attention_dim is not None and with_conditioning is False:
1977
+ raise ValueError(
1978
+ "DiffusionModelEncoder expects with_conditioning=True when specifying the cross_attention_dim."
1979
+ )
1980
+
1981
+ # All number of channels should be multiple of num_groups
1982
+ if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels):
1983
+ raise ValueError("DiffusionModelEncoder expects all num_channels being multiple of norm_num_groups")
1984
+ if len(num_channels) != len(attention_levels):
1985
+ raise ValueError("DiffusionModelEncoder expects num_channels being same size of attention_levels")
1986
+
1987
+ if isinstance(num_head_channels, int):
1988
+ num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels))
1989
+
1990
+ if len(num_head_channels) != len(attention_levels):
1991
+ raise ValueError(
1992
+ "num_head_channels should have the same length as attention_levels. For the i levels without attention,"
1993
+ " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored."
1994
+ )
1995
+
1996
+ self.in_channels = in_channels
1997
+ self.block_out_channels = num_channels
1998
+ self.out_channels = out_channels
1999
+ self.num_res_blocks = num_res_blocks
2000
+ self.attention_levels = attention_levels
2001
+ self.num_head_channels = num_head_channels
2002
+ self.with_conditioning = with_conditioning
2003
+
2004
+ # input
2005
+ self.conv_in = Convolution(
2006
+ spatial_dims=spatial_dims,
2007
+ in_channels=in_channels,
2008
+ out_channels=num_channels[0],
2009
+ strides=1,
2010
+ kernel_size=3,
2011
+ padding=1,
2012
+ conv_only=True,
2013
+ )
2014
+
2015
+ # time
2016
+ time_embed_dim = num_channels[0] * 4
2017
+ self.time_embed = nn.Sequential(
2018
+ nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim)
2019
+ )
2020
+
2021
+ # class embedding
2022
+ self.num_class_embeds = num_class_embeds
2023
+ if num_class_embeds is not None:
2024
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
2025
+
2026
+ # down
2027
+ self.down_blocks = nn.ModuleList([])
2028
+ output_channel = num_channels[0]
2029
+ for i in range(len(num_channels)):
2030
+ input_channel = output_channel
2031
+ output_channel = num_channels[i]
2032
+ is_final_block = i == len(num_channels) # - 1
2033
+
2034
+ down_block = get_down_block(
2035
+ spatial_dims=spatial_dims,
2036
+ in_channels=input_channel,
2037
+ out_channels=output_channel,
2038
+ temb_channels=time_embed_dim,
2039
+ num_res_blocks=num_res_blocks[i],
2040
+ norm_num_groups=norm_num_groups,
2041
+ norm_eps=norm_eps,
2042
+ add_downsample=not is_final_block,
2043
+ resblock_updown=resblock_updown,
2044
+ with_attn=(attention_levels[i] and not with_conditioning),
2045
+ with_cross_attn=(attention_levels[i] and with_conditioning),
2046
+ num_head_channels=num_head_channels[i],
2047
+ transformer_num_layers=transformer_num_layers,
2048
+ cross_attention_dim=cross_attention_dim,
2049
+ upcast_attention=upcast_attention,
2050
+ )
2051
+
2052
+ self.down_blocks.append(down_block)
2053
+
2054
+ self.out = nn.Sequential(nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels))
2055
+
2056
+ def forward(
2057
+ self,
2058
+ x: torch.Tensor,
2059
+ timesteps: torch.Tensor,
2060
+ context: torch.Tensor | None = None,
2061
+ class_labels: torch.Tensor | None = None,
2062
+ ) -> torch.Tensor:
2063
+ """
2064
+ Args:
2065
+ x: input tensor (N, C, SpatialDims).
2066
+ timesteps: timestep tensor (N,).
2067
+ context: context tensor (N, 1, ContextDim).
2068
+ class_labels: context tensor (N, ).
2069
+ """
2070
+ # 1. time
2071
+ t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])
2072
+
2073
+ # timesteps does not contain any weights and will always return f32 tensors
2074
+ # but time_embedding might actually be running in fp16. so we need to cast here.
2075
+ # there might be better ways to encapsulate this.
2076
+ t_emb = t_emb.to(dtype=x.dtype)
2077
+ emb = self.time_embed(t_emb)
2078
+
2079
+ # 2. class
2080
+ if self.num_class_embeds is not None:
2081
+ if class_labels is None:
2082
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
2083
+ class_emb = self.class_embedding(class_labels)
2084
+ class_emb = class_emb.to(dtype=x.dtype)
2085
+ emb = emb + class_emb
2086
+
2087
+ # 3. initial convolution
2088
+ h = self.conv_in(x)
2089
+
2090
+ # 4. down
2091
+ if context is not None and self.with_conditioning is False:
2092
+ raise ValueError("model should have with_conditioning = True if context is provided")
2093
+ for downsample_block in self.down_blocks:
2094
+ h, _ = downsample_block(hidden_states=h, temb=emb, context=context)
2095
+
2096
+ h = h.reshape(h.shape[0], -1)
2097
+ output = self.out(h)
2098
+
2099
+ return output
codes/modules/fp16_util.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers to train with 16-bit precision.
3
+ """
4
+
5
+ import numpy as np
6
+ import torch as th
7
+ import torch.nn as nn
8
+ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
9
+
10
+ from . import logger
11
+
12
+ INITIAL_LOG_LOSS_SCALE = 20.0
13
+
14
+
15
+ def convert_module_to_f16(l):
16
+ """
17
+ Convert primitive modules to float16.
18
+ """
19
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
20
+ l.weight.data = l.weight.data.half()
21
+ if l.bias is not None:
22
+ l.bias.data = l.bias.data.half()
23
+
24
+
25
+ def convert_module_to_f32(l):
26
+ """
27
+ Convert primitive modules to float32, undoing convert_module_to_f16().
28
+ """
29
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
30
+ l.weight.data = l.weight.data.float()
31
+ if l.bias is not None:
32
+ l.bias.data = l.bias.data.float()
33
+
34
+
35
+ def make_master_params(param_groups_and_shapes):
36
+ """
37
+ Copy model parameters into a (differently-shaped) list of full-precision
38
+ parameters.
39
+ """
40
+ master_params = []
41
+ for param_group, shape in param_groups_and_shapes:
42
+ master_param = nn.Parameter(
43
+ _flatten_dense_tensors(
44
+ [param.detach().float() for (_, param) in param_group]
45
+ ).view(shape)
46
+ )
47
+ master_param.requires_grad = True
48
+ master_params.append(master_param)
49
+ return master_params
50
+
51
+
52
+ def model_grads_to_master_grads(param_groups_and_shapes, master_params):
53
+ """
54
+ Copy the gradients from the model parameters into the master parameters
55
+ from make_master_params().
56
+ """
57
+ for master_param, (param_group, shape) in zip(
58
+ master_params, param_groups_and_shapes
59
+ ):
60
+ master_param.grad = _flatten_dense_tensors(
61
+ [param_grad_or_zeros(param) for (_, param) in param_group]
62
+ ).view(shape)
63
+
64
+
65
+ def master_params_to_model_params(param_groups_and_shapes, master_params):
66
+ """
67
+ Copy the master parameter data back into the model parameters.
68
+ """
69
+ # Without copying to a list, if a generator is passed, this will
70
+ # silently not copy any parameters.
71
+ for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
72
+ for (_, param), unflat_master_param in zip(
73
+ param_group, unflatten_master_params(param_group, master_param.view(-1))
74
+ ):
75
+ param.detach().copy_(unflat_master_param)
76
+
77
+
78
+ def unflatten_master_params(param_group, master_param):
79
+ return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
80
+
81
+
82
+ def get_param_groups_and_shapes(named_model_params):
83
+ named_model_params = list(named_model_params)
84
+ scalar_vector_named_params = (
85
+ [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
86
+ (-1),
87
+ )
88
+ matrix_named_params = (
89
+ [(n, p) for (n, p) in named_model_params if p.ndim > 1],
90
+ (1, -1),
91
+ )
92
+ return [scalar_vector_named_params, matrix_named_params]
93
+
94
+
95
+ def master_params_to_state_dict(
96
+ model, param_groups_and_shapes, master_params, use_fp16
97
+ ):
98
+ if use_fp16:
99
+ state_dict = model.state_dict()
100
+ for master_param, (param_group, _) in zip(
101
+ master_params, param_groups_and_shapes
102
+ ):
103
+ for (name, _), unflat_master_param in zip(
104
+ param_group, unflatten_master_params(param_group, master_param.view(-1))
105
+ ):
106
+ assert name in state_dict
107
+ state_dict[name] = unflat_master_param
108
+ else:
109
+ state_dict = model.state_dict()
110
+ for i, (name, _value) in enumerate(model.named_parameters()):
111
+ assert name in state_dict
112
+ state_dict[name] = master_params[i]
113
+ return state_dict
114
+
115
+
116
+ def state_dict_to_master_params(model, state_dict, use_fp16):
117
+ if use_fp16:
118
+ named_model_params = [
119
+ (name, state_dict[name]) for name, _ in model.named_parameters()
120
+ ]
121
+ param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
122
+ master_params = make_master_params(param_groups_and_shapes)
123
+ else:
124
+ master_params = [state_dict[name] for name, _ in model.named_parameters()]
125
+ return master_params
126
+
127
+
128
+ def zero_master_grads(master_params):
129
+ for param in master_params:
130
+ param.grad = None
131
+
132
+
133
+ def zero_grad(model_params):
134
+ for param in model_params:
135
+ # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
136
+ if param.grad is not None:
137
+ param.grad.detach_()
138
+ param.grad.zero_()
139
+
140
+
141
+ def param_grad_or_zeros(param):
142
+ if param.grad is not None:
143
+ return param.grad.data.detach()
144
+ else:
145
+ return th.zeros_like(param)
146
+
147
+
148
+ class MixedPrecisionTrainer:
149
+ def __init__(
150
+ self,
151
+ *,
152
+ model,
153
+ use_fp16=False,
154
+ fp16_scale_growth=1e-3,
155
+ initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
156
+ ):
157
+ self.model = model
158
+ self.use_fp16 = use_fp16
159
+ self.fp16_scale_growth = fp16_scale_growth
160
+
161
+ self.model_params = list(self.model.parameters())
162
+ self.master_params = self.model_params
163
+ self.param_groups_and_shapes = None
164
+ self.lg_loss_scale = initial_lg_loss_scale
165
+
166
+ if self.use_fp16:
167
+ self.param_groups_and_shapes = get_param_groups_and_shapes(
168
+ self.model.named_parameters()
169
+ )
170
+ self.master_params = make_master_params(self.param_groups_and_shapes)
171
+ self.model.convert_to_fp16()
172
+
173
+ def zero_grad(self):
174
+ zero_grad(self.model_params)
175
+
176
+ def backward(self, loss: th.Tensor):
177
+ if self.use_fp16:
178
+ loss_scale = 2 ** self.lg_loss_scale
179
+ (loss * loss_scale).backward()
180
+ else:
181
+ loss.backward()
182
+
183
+ def optimize(self, opt: th.optim.Optimizer):
184
+ if self.use_fp16:
185
+ return self._optimize_fp16(opt)
186
+ else:
187
+ return self._optimize_normal(opt)
188
+
189
+ def _optimize_fp16(self, opt: th.optim.Optimizer):
190
+ logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
191
+ model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
192
+ grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
193
+ if check_overflow(grad_norm):
194
+ self.lg_loss_scale -= 1
195
+ logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
196
+ zero_master_grads(self.master_params)
197
+ return False
198
+
199
+ logger.logkv_mean("grad_norm", grad_norm)
200
+ logger.logkv_mean("param_norm", param_norm)
201
+
202
+ for p in self.master_params:
203
+ p.grad.mul_(1.0 / (2 ** self.lg_loss_scale))
204
+ opt.step()
205
+ zero_master_grads(self.master_params)
206
+ master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
207
+ self.lg_loss_scale += self.fp16_scale_growth
208
+ return True
209
+
210
+ def _optimize_normal(self, opt: th.optim.Optimizer):
211
+ grad_norm, param_norm = self._compute_norms()
212
+ logger.logkv_mean("grad_norm", grad_norm)
213
+ logger.logkv_mean("param_norm", param_norm)
214
+ opt.step()
215
+ return True
216
+
217
+ def _compute_norms(self, grad_scale=1.0):
218
+ grad_norm = 0.0
219
+ param_norm = 0.0
220
+ for p in self.master_params:
221
+ with th.no_grad():
222
+ param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
223
+ if p.grad is not None:
224
+ grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
225
+ return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
226
+
227
+ def master_params_to_state_dict(self, master_params):
228
+ return master_params_to_state_dict(
229
+ self.model, self.param_groups_and_shapes, master_params, self.use_fp16
230
+ )
231
+
232
+ def state_dict_to_master_params(self, state_dict):
233
+ return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
234
+
235
+
236
+ def check_overflow(value):
237
+ return (value == float("inf")) or (value == -float("inf")) or (value != value)
codes/modules/logger.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
3
+ https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import shutil
9
+ import os.path as osp
10
+ import json
11
+ import time
12
+ import datetime
13
+ import tempfile
14
+ import warnings
15
+ from collections import defaultdict
16
+ from contextlib import contextmanager
17
+
18
+ DEBUG = 10
19
+ INFO = 20
20
+ WARN = 30
21
+ ERROR = 40
22
+
23
+ DISABLED = 50
24
+
25
+
26
+ class KVWriter(object):
27
+ def writekvs(self, kvs):
28
+ raise NotImplementedError
29
+
30
+
31
+ class SeqWriter(object):
32
+ def writeseq(self, seq):
33
+ raise NotImplementedError
34
+
35
+
36
+ class HumanOutputFormat(KVWriter, SeqWriter):
37
+ def __init__(self, filename_or_file):
38
+ if isinstance(filename_or_file, str):
39
+ self.file = open(filename_or_file, "wt")
40
+ self.own_file = True
41
+ else:
42
+ assert hasattr(filename_or_file, "read"), (
43
+ "expected file or str, got %s" % filename_or_file
44
+ )
45
+ self.file = filename_or_file
46
+ self.own_file = False
47
+
48
+ def writekvs(self, kvs):
49
+ # Create strings for printing
50
+ key2str = {}
51
+ for (key, val) in sorted(kvs.items()):
52
+ if hasattr(val, "__float__"):
53
+ valstr = "%-8.3g" % val
54
+ else:
55
+ valstr = str(val)
56
+ key2str[self._truncate(key)] = self._truncate(valstr)
57
+
58
+ # Find max widths
59
+ if len(key2str) == 0:
60
+ print("WARNING: tried to write empty key-value dict")
61
+ return
62
+ else:
63
+ keywidth = max(map(len, key2str.keys()))
64
+ valwidth = max(map(len, key2str.values()))
65
+
66
+ # Write out the data
67
+ dashes = "-" * (keywidth + valwidth + 7)
68
+ lines = [dashes]
69
+ for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
70
+ lines.append(
71
+ "| %s%s | %s%s |"
72
+ % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
73
+ )
74
+ lines.append(dashes)
75
+ self.file.write("\n".join(lines) + "\n")
76
+
77
+ # Flush the output to the file
78
+ self.file.flush()
79
+
80
+ def _truncate(self, s):
81
+ maxlen = 30
82
+ return s[: maxlen - 3] + "..." if len(s) > maxlen else s
83
+
84
+ def writeseq(self, seq):
85
+ seq = list(seq)
86
+ for (i, elem) in enumerate(seq):
87
+ self.file.write(elem)
88
+ if i < len(seq) - 1: # add space unless this is the last one
89
+ self.file.write(" ")
90
+ self.file.write("\n")
91
+ self.file.flush()
92
+
93
+ def close(self):
94
+ if self.own_file:
95
+ self.file.close()
96
+
97
+
98
+ class JSONOutputFormat(KVWriter):
99
+ def __init__(self, filename):
100
+ self.file = open(filename, "wt")
101
+
102
+ def writekvs(self, kvs):
103
+ for k, v in sorted(kvs.items()):
104
+ if hasattr(v, "dtype"):
105
+ kvs[k] = float(v)
106
+ self.file.write(json.dumps(kvs) + "\n")
107
+ self.file.flush()
108
+
109
+ def close(self):
110
+ self.file.close()
111
+
112
+
113
+ class CSVOutputFormat(KVWriter):
114
+ def __init__(self, filename):
115
+ self.file = open(filename, "w+t")
116
+ self.keys = []
117
+ self.sep = ","
118
+
119
+ def writekvs(self, kvs):
120
+ # Add our current row to the history
121
+ extra_keys = list(kvs.keys() - self.keys)
122
+ extra_keys.sort()
123
+ if extra_keys:
124
+ self.keys.extend(extra_keys)
125
+ self.file.seek(0)
126
+ lines = self.file.readlines()
127
+ self.file.seek(0)
128
+ for (i, k) in enumerate(self.keys):
129
+ if i > 0:
130
+ self.file.write(",")
131
+ self.file.write(k)
132
+ self.file.write("\n")
133
+ for line in lines[1:]:
134
+ self.file.write(line[:-1])
135
+ self.file.write(self.sep * len(extra_keys))
136
+ self.file.write("\n")
137
+ for (i, k) in enumerate(self.keys):
138
+ if i > 0:
139
+ self.file.write(",")
140
+ v = kvs.get(k)
141
+ if v is not None:
142
+ self.file.write(str(v))
143
+ self.file.write("\n")
144
+ self.file.flush()
145
+
146
+ def close(self):
147
+ self.file.close()
148
+
149
+
150
+ class TensorBoardOutputFormat(KVWriter):
151
+ """
152
+ Dumps key/value pairs into TensorBoard's numeric format.
153
+ """
154
+
155
+ def __init__(self, dir):
156
+ os.makedirs(dir, exist_ok=True)
157
+ self.dir = dir
158
+ self.step = 1
159
+ prefix = "events"
160
+ path = osp.join(osp.abspath(dir), prefix)
161
+ import tensorflow as tf
162
+ from tensorflow.python import pywrap_tensorflow
163
+ from tensorflow.core.util import event_pb2
164
+ from tensorflow.python.util import compat
165
+
166
+ self.tf = tf
167
+ self.event_pb2 = event_pb2
168
+ self.pywrap_tensorflow = pywrap_tensorflow
169
+ self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
170
+
171
+ def writekvs(self, kvs):
172
+ def summary_val(k, v):
173
+ kwargs = {"tag": k, "simple_value": float(v)}
174
+ return self.tf.Summary.Value(**kwargs)
175
+
176
+ summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
177
+ event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
178
+ event.step = (
179
+ self.step
180
+ ) # is there any reason why you'd want to specify the step?
181
+ self.writer.WriteEvent(event)
182
+ self.writer.Flush()
183
+ self.step += 1
184
+
185
+ def close(self):
186
+ if self.writer:
187
+ self.writer.Close()
188
+ self.writer = None
189
+
190
+
191
+ def make_output_format(format, ev_dir, log_suffix=""):
192
+ os.makedirs(ev_dir, exist_ok=True)
193
+ if format == "stdout":
194
+ return HumanOutputFormat(sys.stdout)
195
+ elif format == "log":
196
+ return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
197
+ elif format == "json":
198
+ return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
199
+ elif format == "csv":
200
+ return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
201
+ elif format == "tensorboard":
202
+ return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
203
+ else:
204
+ raise ValueError("Unknown format specified: %s" % (format,))
205
+
206
+
207
+ # ================================================================
208
+ # API
209
+ # ================================================================
210
+
211
+
212
+ def logkv(key, val):
213
+ """
214
+ Log a value of some diagnostic
215
+ Call this once for each diagnostic quantity, each iteration
216
+ If called many times, last value will be used.
217
+ """
218
+ get_current().logkv(key, val)
219
+
220
+
221
+ def logkv_mean(key, val):
222
+ """
223
+ The same as logkv(), but if called many times, values averaged.
224
+ """
225
+ get_current().logkv_mean(key, val)
226
+
227
+
228
+ def logkvs(d):
229
+ """
230
+ Log a dictionary of key-value pairs
231
+ """
232
+ for (k, v) in d.items():
233
+ logkv(k, v)
234
+
235
+
236
+ def dumpkvs():
237
+ """
238
+ Write all of the diagnostics from the current iteration
239
+ """
240
+ return get_current().dumpkvs()
241
+
242
+
243
+ def getkvs():
244
+ return get_current().name2val
245
+
246
+
247
+ def log(*args, level=INFO):
248
+ """
249
+ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
250
+ """
251
+ get_current().log(*args, level=level)
252
+
253
+
254
+ def debug(*args):
255
+ log(*args, level=DEBUG)
256
+
257
+
258
+ def info(*args):
259
+ log(*args, level=INFO)
260
+
261
+
262
+ def warn(*args):
263
+ log(*args, level=WARN)
264
+
265
+
266
+ def error(*args):
267
+ log(*args, level=ERROR)
268
+
269
+
270
+ def set_level(level):
271
+ """
272
+ Set logging threshold on current logger.
273
+ """
274
+ get_current().set_level(level)
275
+
276
+
277
+ def set_comm(comm):
278
+ get_current().set_comm(comm)
279
+
280
+
281
+ def get_dir():
282
+ """
283
+ Get directory that log files are being written to.
284
+ will be None if there is no output directory (i.e., if you didn't call start)
285
+ """
286
+ return get_current().get_dir()
287
+
288
+
289
+ record_tabular = logkv
290
+ dump_tabular = dumpkvs
291
+
292
+
293
+ @contextmanager
294
+ def profile_kv(scopename):
295
+ logkey = "wait_" + scopename
296
+ tstart = time.time()
297
+ try:
298
+ yield
299
+ finally:
300
+ get_current().name2val[logkey] += time.time() - tstart
301
+
302
+
303
+ def profile(n):
304
+ """
305
+ Usage:
306
+ @profile("my_func")
307
+ def my_func(): code
308
+ """
309
+
310
+ def decorator_with_name(func):
311
+ def func_wrapper(*args, **kwargs):
312
+ with profile_kv(n):
313
+ return func(*args, **kwargs)
314
+
315
+ return func_wrapper
316
+
317
+ return decorator_with_name
318
+
319
+
320
+ # ================================================================
321
+ # Backend
322
+ # ================================================================
323
+
324
+
325
+ def get_current():
326
+ if Logger.CURRENT is None:
327
+ _configure_default_logger()
328
+
329
+ return Logger.CURRENT
330
+
331
+
332
+ class Logger(object):
333
+ DEFAULT = None # A logger with no output files. (See right below class definition)
334
+ # So that you can still log to the terminal without setting up any output files
335
+ CURRENT = None # Current logger being used by the free functions above
336
+
337
+ def __init__(self, dir, output_formats, comm=None):
338
+ self.name2val = defaultdict(float) # values this iteration
339
+ self.name2cnt = defaultdict(int)
340
+ self.level = INFO
341
+ self.dir = dir
342
+ self.output_formats = output_formats
343
+ self.comm = comm
344
+
345
+ # Logging API, forwarded
346
+ # ----------------------------------------
347
+ def logkv(self, key, val):
348
+ self.name2val[key] = val
349
+
350
+ def logkv_mean(self, key, val):
351
+ oldval, cnt = self.name2val[key], self.name2cnt[key]
352
+ self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
353
+ self.name2cnt[key] = cnt + 1
354
+
355
+ def dumpkvs(self):
356
+ if self.comm is None:
357
+ d = self.name2val
358
+ else:
359
+ d = mpi_weighted_mean(
360
+ self.comm,
361
+ {
362
+ name: (val, self.name2cnt.get(name, 1))
363
+ for (name, val) in self.name2val.items()
364
+ },
365
+ )
366
+ if self.comm.rank != 0:
367
+ d["dummy"] = 1 # so we don't get a warning about empty dict
368
+ out = d.copy() # Return the dict for unit testing purposes
369
+ for fmt in self.output_formats:
370
+ if isinstance(fmt, KVWriter):
371
+ fmt.writekvs(d)
372
+ self.name2val.clear()
373
+ self.name2cnt.clear()
374
+ return out
375
+
376
+ def log(self, *args, level=INFO):
377
+ if self.level <= level:
378
+ self._do_log(args)
379
+
380
+ # Configuration
381
+ # ----------------------------------------
382
+ def set_level(self, level):
383
+ self.level = level
384
+
385
+ def set_comm(self, comm):
386
+ self.comm = comm
387
+
388
+ def get_dir(self):
389
+ return self.dir
390
+
391
+ def close(self):
392
+ for fmt in self.output_formats:
393
+ fmt.close()
394
+
395
+ # Misc
396
+ # ----------------------------------------
397
+ def _do_log(self, args):
398
+ for fmt in self.output_formats:
399
+ if isinstance(fmt, SeqWriter):
400
+ fmt.writeseq(map(str, args))
401
+
402
+
403
+ def get_rank_without_mpi_import():
404
+ # check environment variables here instead of importing mpi4py
405
+ # to avoid calling MPI_Init() when this module is imported
406
+ for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
407
+ if varname in os.environ:
408
+ return int(os.environ[varname])
409
+ return 0
410
+
411
+
412
+ def mpi_weighted_mean(comm, local_name2valcount):
413
+ """
414
+ Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
415
+ Perform a weighted average over dicts that are each on a different node
416
+ Input: local_name2valcount: dict mapping key -> (value, count)
417
+ Returns: key -> mean
418
+ """
419
+ all_name2valcount = comm.gather(local_name2valcount)
420
+ if comm.rank == 0:
421
+ name2sum = defaultdict(float)
422
+ name2count = defaultdict(float)
423
+ for n2vc in all_name2valcount:
424
+ for (name, (val, count)) in n2vc.items():
425
+ try:
426
+ val = float(val)
427
+ except ValueError:
428
+ if comm.rank == 0:
429
+ warnings.warn(
430
+ "WARNING: tried to compute mean on non-float {}={}".format(
431
+ name, val
432
+ )
433
+ )
434
+ else:
435
+ name2sum[name] += val * count
436
+ name2count[name] += count
437
+ return {name: name2sum[name] / name2count[name] for name in name2sum}
438
+ else:
439
+ return {}
440
+
441
+
442
+ def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
443
+ """
444
+ If comm is provided, average all numerical stats across that comm
445
+ """
446
+ if dir is None:
447
+ dir = os.getenv("OPENAI_LOGDIR")
448
+ if dir is None:
449
+ dir = osp.join(
450
+ tempfile.gettempdir(),
451
+ datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
452
+ )
453
+ assert isinstance(dir, str)
454
+ dir = os.path.expanduser(dir)
455
+ os.makedirs(os.path.expanduser(dir), exist_ok=True)
456
+
457
+ rank = get_rank_without_mpi_import()
458
+ if rank > 0:
459
+ log_suffix = log_suffix + "-rank%03i" % rank
460
+
461
+ if format_strs is None:
462
+ if rank == 0:
463
+ format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
464
+ else:
465
+ format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
466
+ format_strs = filter(None, format_strs)
467
+ output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
468
+
469
+ Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
470
+ if output_formats:
471
+ log("Logging to %s" % dir)
472
+
473
+
474
+ def _configure_default_logger():
475
+ configure()
476
+ Logger.DEFAULT = Logger.CURRENT
477
+
478
+
479
+ def reset():
480
+ if Logger.CURRENT is not Logger.DEFAULT:
481
+ Logger.CURRENT.close()
482
+ Logger.CURRENT = Logger.DEFAULT
483
+ log("Reset logger")
484
+
485
+
486
+ @contextmanager
487
+ def scoped_configure(dir=None, format_strs=None, comm=None):
488
+ prevlogger = Logger.CURRENT
489
+ configure(dir=dir, format_strs=format_strs, comm=comm)
490
+ try:
491
+ yield
492
+ finally:
493
+ Logger.CURRENT.close()
494
+ Logger.CURRENT = prevlogger
495
+
codes/modules/nn.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Various utilities for neural networks.
3
+ """
4
+
5
+ import math
6
+
7
+ import torch as th
8
+ import torch.nn as nn
9
+
10
+
11
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12
+ class SiLU(nn.Module):
13
+ def forward(self, x):
14
+ return x * th.sigmoid(x)
15
+
16
+
17
+ class GroupNorm32(nn.GroupNorm):
18
+ def forward(self, x):
19
+ return super().forward(x.float()).type(x.dtype)
20
+
21
+
22
+ def conv_nd(dims, *args, **kwargs):
23
+ """
24
+ Create a 1D, 2D, or 3D convolution module.
25
+ """
26
+ if dims == 1:
27
+ return nn.Conv1d(*args, **kwargs)
28
+ elif dims == 2:
29
+ return nn.Conv2d(*args, **kwargs)
30
+ elif dims == 3:
31
+ return nn.Conv3d(*args, **kwargs)
32
+ raise ValueError(f"unsupported dimensions: {dims}")
33
+
34
+
35
+ def linear(*args, **kwargs):
36
+ """
37
+ Create a linear module.
38
+ """
39
+ return nn.Linear(*args, **kwargs)
40
+
41
+
42
+ def avg_pool_nd(dims, *args, **kwargs):
43
+ """
44
+ Create a 1D, 2D, or 3D average pooling module.
45
+ """
46
+ if dims == 1:
47
+ return nn.AvgPool1d(*args, **kwargs)
48
+ elif dims == 2:
49
+ return nn.AvgPool2d(*args, **kwargs)
50
+ elif dims == 3:
51
+ return nn.AvgPool3d(*args, **kwargs)
52
+ raise ValueError(f"unsupported dimensions: {dims}")
53
+
54
+
55
+ def update_ema(target_params, source_params, rate=0.99):
56
+ """
57
+ Update target parameters to be closer to those of source parameters using
58
+ an exponential moving average.
59
+
60
+ :param target_params: the target parameter sequence.
61
+ :param source_params: the source parameter sequence.
62
+ :param rate: the EMA rate (closer to 1 means slower).
63
+ """
64
+ for targ, src in zip(target_params, source_params):
65
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66
+
67
+
68
+ def zero_module(module):
69
+ """
70
+ Zero out the parameters of a module and return it.
71
+ """
72
+ for p in module.parameters():
73
+ p.detach().zero_()
74
+ return module
75
+
76
+
77
+ def scale_module(module, scale):
78
+ """
79
+ Scale the parameters of a module and return it.
80
+ """
81
+ for p in module.parameters():
82
+ p.detach().mul_(scale)
83
+ return module
84
+
85
+
86
+ def mean_flat(tensor):
87
+ """
88
+ Take the mean over all non-batch dimensions.
89
+ """
90
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
91
+
92
+
93
+ def normalization(channels):
94
+ """
95
+ Make a standard normalization layer.
96
+
97
+ :param channels: number of input channels.
98
+ :return: an nn.Module for normalization.
99
+ """
100
+ return GroupNorm32(32, channels)
101
+
102
+
103
+ def timestep_embedding(timesteps, dim, max_period=10000):
104
+ """
105
+ Create sinusoidal timestep embeddings.
106
+
107
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
108
+ These may be fractional.
109
+ :param dim: the dimension of the output.
110
+ :param max_period: controls the minimum frequency of the embeddings.
111
+ :return: an [N x dim] Tensor of positional embeddings.
112
+ """
113
+ half = dim // 2
114
+ freqs = th.exp(
115
+ -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
116
+ ).to(device=timesteps.device)
117
+ args = timesteps[:, None].float() * freqs[None]
118
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
119
+ if dim % 2:
120
+ embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
121
+ return embedding
122
+
123
+
124
+ def checkpoint(func, inputs, params, flag):
125
+ """
126
+ Evaluate a function without caching intermediate activations, allowing for
127
+ reduced memory at the expense of extra compute in the backward pass.
128
+
129
+ :param func: the function to evaluate.
130
+ :param inputs: the argument sequence to pass to `func`.
131
+ :param params: a sequence of parameters `func` depends on but does not
132
+ explicitly take as arguments.
133
+ :param flag: if False, disable gradient checkpointing.
134
+ """
135
+ if flag:
136
+ args = tuple(inputs) + tuple(params)
137
+ return CheckpointFunction.apply(func, len(inputs), *args)
138
+ else:
139
+ return func(*inputs)
140
+
141
+
142
+ class CheckpointFunction(th.autograd.Function):
143
+ @staticmethod
144
+ def forward(ctx, run_function, length, *args):
145
+ ctx.run_function = run_function
146
+ ctx.input_tensors = list(args[:length])
147
+ ctx.input_params = list(args[length:])
148
+ with th.no_grad():
149
+ output_tensors = ctx.run_function(*ctx.input_tensors)
150
+ return output_tensors
151
+
152
+ @staticmethod
153
+ def backward(ctx, *output_grads):
154
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155
+ with th.enable_grad():
156
+ # Fixes a bug where the first op in run_function modifies the
157
+ # Tensor storage in place, which is not allowed for detach()'d
158
+ # Tensors.
159
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160
+ output_tensors = ctx.run_function(*shallow_copies)
161
+ input_grads = th.autograd.grad(
162
+ output_tensors,
163
+ ctx.input_tensors + ctx.input_params,
164
+ output_grads,
165
+ allow_unused=True,
166
+ )
167
+ del ctx.input_tensors
168
+ del ctx.input_params
169
+ del output_tensors
170
+ return (None, None) + input_grads
codes/modules/unet.py ADDED
@@ -0,0 +1,894 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+
3
+ import math
4
+
5
+ import numpy as np
6
+ import torch as th
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .fp16_util import convert_module_to_f16, convert_module_to_f32
11
+ from .nn import (
12
+ checkpoint,
13
+ conv_nd,
14
+ linear,
15
+ avg_pool_nd,
16
+ zero_module,
17
+ normalization,
18
+ timestep_embedding,
19
+ )
20
+
21
+
22
+ class AttentionPool2d(nn.Module):
23
+ """
24
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ spacial_dim: int,
30
+ embed_dim: int,
31
+ num_heads_channels: int,
32
+ output_dim: int = None,
33
+ ):
34
+ super().__init__()
35
+ self.positional_embedding = nn.Parameter(
36
+ th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
37
+ )
38
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
39
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
40
+ self.num_heads = embed_dim // num_heads_channels
41
+ self.attention = QKVAttention(self.num_heads)
42
+
43
+ def forward(self, x):
44
+ b, c, *_spatial = x.shape
45
+ x = x.reshape(b, c, -1) # NC(HW)
46
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
47
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
48
+ x = self.qkv_proj(x)
49
+ x = self.attention(x)
50
+ x = self.c_proj(x)
51
+ return x[:, :, 0]
52
+
53
+
54
+ class TimestepBlock(nn.Module):
55
+ """
56
+ Any module where forward() takes timestep embeddings as a second argument.
57
+ """
58
+
59
+ @abstractmethod
60
+ def forward(self, x, emb):
61
+ """
62
+ Apply the module to `x` given `emb` timestep embeddings.
63
+ """
64
+
65
+
66
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
67
+ """
68
+ A sequential module that passes timestep embeddings to the children that
69
+ support it as an extra input.
70
+ """
71
+
72
+ def forward(self, x, emb):
73
+ for layer in self:
74
+ if isinstance(layer, TimestepBlock):
75
+ x = layer(x, emb)
76
+ else:
77
+ x = layer(x)
78
+ return x
79
+
80
+
81
+ class Upsample(nn.Module):
82
+ """
83
+ An upsampling layer with an optional convolution.
84
+
85
+ :param channels: channels in the inputs and outputs.
86
+ :param use_conv: a bool determining if a convolution is applied.
87
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
88
+ upsampling occurs in the inner-two dimensions.
89
+ """
90
+
91
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
92
+ super().__init__()
93
+ self.channels = channels
94
+ self.out_channels = out_channels or channels
95
+ self.use_conv = use_conv
96
+ self.dims = dims
97
+ if use_conv:
98
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
99
+
100
+ def forward(self, x):
101
+ assert x.shape[1] == self.channels
102
+ if self.dims == 3:
103
+ x = F.interpolate(
104
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
105
+ )
106
+ else:
107
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
108
+ if self.use_conv:
109
+ x = self.conv(x)
110
+ return x
111
+
112
+
113
+ class Downsample(nn.Module):
114
+ """
115
+ A downsampling layer with an optional convolution.
116
+
117
+ :param channels: channels in the inputs and outputs.
118
+ :param use_conv: a bool determining if a convolution is applied.
119
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
120
+ downsampling occurs in the inner-two dimensions.
121
+ """
122
+
123
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
124
+ super().__init__()
125
+ self.channels = channels
126
+ self.out_channels = out_channels or channels
127
+ self.use_conv = use_conv
128
+ self.dims = dims
129
+ stride = 2 if dims != 3 else (1, 2, 2)
130
+ if use_conv:
131
+ self.op = conv_nd(
132
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=1
133
+ )
134
+ else:
135
+ assert self.channels == self.out_channels
136
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
137
+
138
+ def forward(self, x):
139
+ assert x.shape[1] == self.channels
140
+ return self.op(x)
141
+
142
+
143
+ class ResBlock(TimestepBlock):
144
+ """
145
+ A residual block that can optionally change the number of channels.
146
+
147
+ :param channels: the number of input channels.
148
+ :param emb_channels: the number of timestep embedding channels.
149
+ :param dropout: the rate of dropout.
150
+ :param out_channels: if specified, the number of out channels.
151
+ :param use_conv: if True and out_channels is specified, use a spatial
152
+ convolution instead of a smaller 1x1 convolution to change the
153
+ channels in the skip connection.
154
+ :param dims: determines if the signal is 1D, 2D, or 3D.
155
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
156
+ :param up: if True, use this block for upsampling.
157
+ :param down: if True, use this block for downsampling.
158
+ """
159
+
160
+ def __init__(
161
+ self,
162
+ channels,
163
+ emb_channels,
164
+ dropout,
165
+ out_channels=None,
166
+ use_conv=False,
167
+ use_scale_shift_norm=False,
168
+ dims=2,
169
+ use_checkpoint=False,
170
+ up=False,
171
+ down=False,
172
+ ):
173
+ super().__init__()
174
+ self.channels = channels
175
+ self.emb_channels = emb_channels
176
+ self.dropout = dropout
177
+ self.out_channels = out_channels or channels
178
+ self.use_conv = use_conv
179
+ self.use_checkpoint = use_checkpoint
180
+ self.use_scale_shift_norm = use_scale_shift_norm
181
+
182
+ self.in_layers = nn.Sequential(
183
+ normalization(channels),
184
+ nn.SiLU(),
185
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
186
+ )
187
+
188
+ self.updown = up or down
189
+
190
+ if up:
191
+ self.h_upd = Upsample(channels, False, dims)
192
+ self.x_upd = Upsample(channels, False, dims)
193
+ elif down:
194
+ self.h_upd = Downsample(channels, False, dims)
195
+ self.x_upd = Downsample(channels, False, dims)
196
+ else:
197
+ self.h_upd = self.x_upd = nn.Identity()
198
+
199
+ self.emb_layers = nn.Sequential(
200
+ nn.SiLU(),
201
+ linear(
202
+ emb_channels,
203
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
204
+ ),
205
+ )
206
+ self.out_layers = nn.Sequential(
207
+ normalization(self.out_channels),
208
+ nn.SiLU(),
209
+ nn.Dropout(p=dropout),
210
+ zero_module(
211
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
212
+ ),
213
+ )
214
+
215
+ if self.out_channels == channels:
216
+ self.skip_connection = nn.Identity()
217
+ elif use_conv:
218
+ self.skip_connection = conv_nd(
219
+ dims, channels, self.out_channels, 3, padding=1
220
+ )
221
+ else:
222
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
223
+
224
+ def forward(self, x, emb):
225
+ """
226
+ Apply the block to a Tensor, conditioned on a timestep embedding.
227
+
228
+ :param x: an [N x C x ...] Tensor of features.
229
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
230
+ :return: an [N x C x ...] Tensor of outputs.
231
+ """
232
+ return checkpoint(
233
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
234
+ )
235
+
236
+ def _forward(self, x, emb):
237
+ if self.updown:
238
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
239
+ h = in_rest(x)
240
+ h = self.h_upd(h)
241
+ x = self.x_upd(x)
242
+ h = in_conv(h)
243
+ else:
244
+ h = self.in_layers(x)
245
+ emb_out = self.emb_layers(emb).type(h.dtype)
246
+ while len(emb_out.shape) < len(h.shape):
247
+ emb_out = emb_out[..., None]
248
+ if self.use_scale_shift_norm:
249
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
250
+ scale, shift = th.chunk(emb_out, 2, dim=1)
251
+ h = out_norm(h) * (1 + scale) + shift
252
+ h = out_rest(h)
253
+ else:
254
+ h = h + emb_out
255
+ h = self.out_layers(h)
256
+ return self.skip_connection(x) + h
257
+
258
+
259
+ class AttentionBlock(nn.Module):
260
+ """
261
+ An attention block that allows spatial positions to attend to each other.
262
+
263
+ Originally ported from here, but adapted to the N-d case.
264
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
265
+ """
266
+
267
+ def __init__(
268
+ self,
269
+ channels,
270
+ num_heads=1,
271
+ num_head_channels=-1,
272
+ use_checkpoint=False,
273
+ use_new_attention_order=False,
274
+ ):
275
+ super().__init__()
276
+ self.channels = channels
277
+ if num_head_channels == -1:
278
+ self.num_heads = num_heads
279
+ else:
280
+ assert (
281
+ channels % num_head_channels == 0
282
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
283
+ self.num_heads = channels // num_head_channels
284
+ self.use_checkpoint = use_checkpoint
285
+ self.norm = normalization(channels)
286
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
287
+ if use_new_attention_order:
288
+ # split qkv before split heads
289
+ self.attention = QKVAttention(self.num_heads)
290
+ else:
291
+ # split heads before split qkv
292
+ self.attention = QKVAttentionLegacy(self.num_heads)
293
+
294
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
295
+
296
+ def forward(self, x):
297
+ return checkpoint(self._forward, (x,), self.parameters(), True)
298
+
299
+ def _forward(self, x):
300
+ b, c, *spatial = x.shape
301
+ x = x.reshape(b, c, -1)
302
+ qkv = self.qkv(self.norm(x))
303
+ h = self.attention(qkv)
304
+ h = self.proj_out(h)
305
+ return (x + h).reshape(b, c, *spatial)
306
+
307
+
308
+ def count_flops_attn(model, _x, y):
309
+ """
310
+ A counter for the `thop` package to count the operations in an
311
+ attention operation.
312
+ Meant to be used like:
313
+ macs, params = thop.profile(
314
+ model,
315
+ inputs=(inputs, timestamps),
316
+ custom_ops={QKVAttention: QKVAttention.count_flops},
317
+ )
318
+ """
319
+ b, c, *spatial = y[0].shape
320
+ num_spatial = int(np.prod(spatial))
321
+ # We perform two matmuls with the same number of ops.
322
+ # The first computes the weight matrix, the second computes
323
+ # the combination of the value vectors.
324
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
325
+ model.total_ops += th.DoubleTensor([matmul_ops])
326
+
327
+
328
+ class QKVAttentionLegacy(nn.Module):
329
+ """
330
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
331
+ """
332
+
333
+ def __init__(self, n_heads):
334
+ super().__init__()
335
+ self.n_heads = n_heads
336
+
337
+ def forward(self, qkv):
338
+ """
339
+ Apply QKV attention.
340
+
341
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
342
+ :return: an [N x (H * C) x T] tensor after attention.
343
+ """
344
+ bs, width, length = qkv.shape
345
+ assert width % (3 * self.n_heads) == 0
346
+ ch = width // (3 * self.n_heads)
347
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
348
+ scale = 1 / math.sqrt(math.sqrt(ch))
349
+ weight = th.einsum(
350
+ "bct,bcs->bts", q * scale, k * scale
351
+ ) # More stable with f16 than dividing afterwards
352
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
353
+ a = th.einsum("bts,bcs->bct", weight, v)
354
+ return a.reshape(bs, -1, length)
355
+
356
+ @staticmethod
357
+ def count_flops(model, _x, y):
358
+ return count_flops_attn(model, _x, y)
359
+
360
+
361
+ class QKVAttention(nn.Module):
362
+ """
363
+ A module which performs QKV attention and splits in a different order.
364
+ """
365
+
366
+ def __init__(self, n_heads):
367
+ super().__init__()
368
+ self.n_heads = n_heads
369
+
370
+ def forward(self, qkv):
371
+ """
372
+ Apply QKV attention.
373
+
374
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
375
+ :return: an [N x (H * C) x T] tensor after attention.
376
+ """
377
+ bs, width, length = qkv.shape
378
+ assert width % (3 * self.n_heads) == 0
379
+ ch = width // (3 * self.n_heads)
380
+ q, k, v = qkv.chunk(3, dim=1)
381
+ scale = 1 / math.sqrt(math.sqrt(ch))
382
+ weight = th.einsum(
383
+ "bct,bcs->bts",
384
+ (q * scale).view(bs * self.n_heads, ch, length),
385
+ (k * scale).view(bs * self.n_heads, ch, length),
386
+ ) # More stable with f16 than dividing afterwards
387
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
388
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
389
+ return a.reshape(bs, -1, length)
390
+
391
+ @staticmethod
392
+ def count_flops(model, _x, y):
393
+ return count_flops_attn(model, _x, y)
394
+
395
+
396
+ class UNetModel(nn.Module):
397
+ """
398
+ The full UNet model with attention and timestep embedding.
399
+
400
+ :param in_channels: channels in the input Tensor.
401
+ :param model_channels: base channel count for the model.
402
+ :param out_channels: channels in the output Tensor.
403
+ :param num_res_blocks: number of residual blocks per downsample.
404
+ :param attention_resolutions: a collection of downsample rates at which
405
+ attention will take place. May be a set, list, or tuple.
406
+ For example, if this contains 4, then at 4x downsampling, attention
407
+ will be used.
408
+ :param dropout: the dropout probability.
409
+ :param channel_mult: channel multiplier for each level of the UNet.
410
+ :param conv_resample: if True, use learned convolutions for upsampling and
411
+ downsampling.
412
+ :param dims: determines if the signal is 1D, 2D, or 3D.
413
+ :param num_classes: if specified (as an int), then this model will be
414
+ class-conditional with `num_classes` classes.
415
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
416
+ :param num_heads: the number of attention heads in each attention layer.
417
+ :param num_heads_channels: if specified, ignore num_heads and instead use
418
+ a fixed channel width per attention head.
419
+ :param num_heads_upsample: works with num_heads to set a different number
420
+ of heads for upsampling. Deprecated.
421
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
422
+ :param resblock_updown: use residual blocks for up/downsampling.
423
+ :param use_new_attention_order: use a different attention pattern for potentially
424
+ increased efficiency.
425
+ """
426
+
427
+ def __init__(
428
+ self,
429
+ image_size,
430
+ in_channels,
431
+ model_channels,
432
+ out_channels,
433
+ num_res_blocks,
434
+ attention_resolutions,
435
+ dropout=0,
436
+ channel_mult=(1, 2, 4, 8),
437
+ conv_resample=True,
438
+ dims=2,
439
+ num_classes=None,
440
+ use_checkpoint=False,
441
+ use_fp16=False,
442
+ num_heads=1,
443
+ num_head_channels=-1,
444
+ num_heads_upsample=-1,
445
+ use_scale_shift_norm=False,
446
+ resblock_updown=False,
447
+ use_new_attention_order=False,
448
+ ):
449
+ super().__init__()
450
+
451
+ if num_heads_upsample == -1:
452
+ num_heads_upsample = num_heads
453
+
454
+ self.image_size = image_size
455
+ self.in_channels = in_channels
456
+ self.model_channels = model_channels
457
+ self.out_channels = out_channels
458
+ self.num_res_blocks = num_res_blocks
459
+ self.attention_resolutions = attention_resolutions
460
+ self.dropout = dropout
461
+ self.channel_mult = channel_mult
462
+ self.conv_resample = conv_resample
463
+ self.num_classes = num_classes
464
+ self.use_checkpoint = use_checkpoint
465
+ self.dtype = th.float16 if use_fp16 else th.float32
466
+ self.num_heads = num_heads
467
+ self.num_head_channels = num_head_channels
468
+ self.num_heads_upsample = num_heads_upsample
469
+
470
+ time_embed_dim = model_channels * 4
471
+ self.time_embed = nn.Sequential(
472
+ linear(model_channels, time_embed_dim),
473
+ nn.SiLU(),
474
+ linear(time_embed_dim, time_embed_dim),
475
+ )
476
+
477
+ if self.num_classes is not None:
478
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
479
+
480
+ ch = input_ch = int(channel_mult[0] * model_channels)
481
+ self.input_blocks = nn.ModuleList(
482
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
483
+ )
484
+ self._feature_size = ch
485
+ input_block_chans = [ch]
486
+ ds = 1
487
+ for level, mult in enumerate(channel_mult):
488
+ for _ in range(num_res_blocks):
489
+ layers = [
490
+ ResBlock(
491
+ ch,
492
+ time_embed_dim,
493
+ dropout,
494
+ out_channels=int(mult * model_channels),
495
+ dims=dims,
496
+ use_checkpoint=use_checkpoint,
497
+ use_scale_shift_norm=use_scale_shift_norm,
498
+ )
499
+ ]
500
+ ch = int(mult * model_channels)
501
+ if ds in attention_resolutions:
502
+ layers.append(
503
+ AttentionBlock(
504
+ ch,
505
+ use_checkpoint=use_checkpoint,
506
+ num_heads=num_heads,
507
+ num_head_channels=num_head_channels,
508
+ use_new_attention_order=use_new_attention_order,
509
+ )
510
+ )
511
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
512
+ self._feature_size += ch
513
+ input_block_chans.append(ch)
514
+ if level != len(channel_mult) - 1:
515
+ out_ch = ch
516
+ self.input_blocks.append(
517
+ TimestepEmbedSequential(
518
+ ResBlock(
519
+ ch,
520
+ time_embed_dim,
521
+ dropout,
522
+ out_channels=out_ch,
523
+ dims=dims,
524
+ use_checkpoint=use_checkpoint,
525
+ use_scale_shift_norm=use_scale_shift_norm,
526
+ down=True,
527
+ )
528
+ if resblock_updown
529
+ else Downsample(
530
+ ch, conv_resample, dims=dims, out_channels=out_ch
531
+ )
532
+ )
533
+ )
534
+ ch = out_ch
535
+ input_block_chans.append(ch)
536
+ ds *= 2
537
+ self._feature_size += ch
538
+
539
+ self.middle_block = TimestepEmbedSequential(
540
+ ResBlock(
541
+ ch,
542
+ time_embed_dim,
543
+ dropout,
544
+ dims=dims,
545
+ use_checkpoint=use_checkpoint,
546
+ use_scale_shift_norm=use_scale_shift_norm,
547
+ ),
548
+ AttentionBlock(
549
+ ch,
550
+ use_checkpoint=use_checkpoint,
551
+ num_heads=num_heads,
552
+ num_head_channels=num_head_channels,
553
+ use_new_attention_order=use_new_attention_order,
554
+ ),
555
+ ResBlock(
556
+ ch,
557
+ time_embed_dim,
558
+ dropout,
559
+ dims=dims,
560
+ use_checkpoint=use_checkpoint,
561
+ use_scale_shift_norm=use_scale_shift_norm,
562
+ ),
563
+ )
564
+ self._feature_size += ch
565
+
566
+ self.output_blocks = nn.ModuleList([])
567
+ for level, mult in list(enumerate(channel_mult))[::-1]:
568
+ for i in range(num_res_blocks + 1):
569
+ ich = input_block_chans.pop()
570
+ layers = [
571
+ ResBlock(
572
+ ch + ich,
573
+ time_embed_dim,
574
+ dropout,
575
+ out_channels=int(model_channels * mult),
576
+ dims=dims,
577
+ use_checkpoint=use_checkpoint,
578
+ use_scale_shift_norm=use_scale_shift_norm,
579
+ )
580
+ ]
581
+ ch = int(model_channels * mult)
582
+ if ds in attention_resolutions:
583
+ layers.append(
584
+ AttentionBlock(
585
+ ch,
586
+ use_checkpoint=use_checkpoint,
587
+ num_heads=num_heads_upsample,
588
+ num_head_channels=num_head_channels,
589
+ use_new_attention_order=use_new_attention_order,
590
+ )
591
+ )
592
+ if level and i == num_res_blocks:
593
+ out_ch = ch
594
+ layers.append(
595
+ ResBlock(
596
+ ch,
597
+ time_embed_dim,
598
+ dropout,
599
+ out_channels=out_ch,
600
+ dims=dims,
601
+ use_checkpoint=use_checkpoint,
602
+ use_scale_shift_norm=use_scale_shift_norm,
603
+ up=True,
604
+ )
605
+ if resblock_updown
606
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
607
+ )
608
+ ds //= 2
609
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
610
+ self._feature_size += ch
611
+
612
+ self.out = nn.Sequential(
613
+ normalization(ch),
614
+ nn.SiLU(),
615
+ zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
616
+ )
617
+
618
+ def convert_to_fp16(self):
619
+ """
620
+ Convert the torso of the model to float16.
621
+ """
622
+ self.input_blocks.apply(convert_module_to_f16)
623
+ self.middle_block.apply(convert_module_to_f16)
624
+ self.output_blocks.apply(convert_module_to_f16)
625
+
626
+ def convert_to_fp32(self):
627
+ """
628
+ Convert the torso of the model to float32.
629
+ """
630
+ self.input_blocks.apply(convert_module_to_f32)
631
+ self.middle_block.apply(convert_module_to_f32)
632
+ self.output_blocks.apply(convert_module_to_f32)
633
+
634
+ def forward(self, x, timesteps, y=None):
635
+ """
636
+ Apply the model to an input batch.
637
+
638
+ :param x: an [N x C x ...] Tensor of inputs.
639
+ :param timesteps: a 1-D batch of timesteps.
640
+ :param y: an [N] Tensor of labels, if class-conditional.
641
+ :return: an [N x C x ...] Tensor of outputs.
642
+ """
643
+ assert (y is not None) == (
644
+ self.num_classes is not None
645
+ ), "must specify y if and only if the model is class-conditional"
646
+
647
+ hs = []
648
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
649
+
650
+ if self.num_classes is not None:
651
+ assert y.shape == (x.shape[0],)
652
+ emb = emb + self.label_emb(y)
653
+
654
+ h = x.type(self.dtype)
655
+ for module in self.input_blocks:
656
+ h = module(h, emb)
657
+ hs.append(h)
658
+ h = self.middle_block(h, emb)
659
+ for module in self.output_blocks:
660
+ h = th.cat([h, hs.pop()], dim=1)
661
+ h = module(h, emb)
662
+ h = h.type(x.dtype)
663
+ return self.out(h)
664
+
665
+
666
+ class SuperResModel(UNetModel):
667
+ """
668
+ A UNetModel that performs super-resolution.
669
+
670
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
671
+ """
672
+
673
+ def __init__(self, image_size, in_channels, *args, **kwargs):
674
+ super().__init__(image_size, in_channels * 2, *args, **kwargs)
675
+
676
+ def forward(self, x, timesteps, low_res=None, **kwargs):
677
+ _, _, new_height, new_width = x.shape
678
+ upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
679
+ x = th.cat([x, upsampled], dim=1)
680
+ return super().forward(x, timesteps, **kwargs)
681
+
682
+
683
+ class EncoderUNetModel(nn.Module):
684
+ """
685
+ The half UNet model with attention and timestep embedding.
686
+
687
+ For usage, see UNet.
688
+ """
689
+
690
+ def __init__(
691
+ self,
692
+ image_size,
693
+ in_channels,
694
+ model_channels,
695
+ out_channels,
696
+ num_res_blocks,
697
+ attention_resolutions,
698
+ dropout=0,
699
+ channel_mult=(1, 2, 4, 8),
700
+ conv_resample=True,
701
+ dims=2,
702
+ use_checkpoint=False,
703
+ use_fp16=False,
704
+ num_heads=1,
705
+ num_head_channels=-1,
706
+ num_heads_upsample=-1,
707
+ use_scale_shift_norm=False,
708
+ resblock_updown=False,
709
+ use_new_attention_order=False,
710
+ pool="adaptive",
711
+ ):
712
+ super().__init__()
713
+
714
+ if num_heads_upsample == -1:
715
+ num_heads_upsample = num_heads
716
+
717
+ self.in_channels = in_channels
718
+ self.model_channels = model_channels
719
+ self.out_channels = out_channels
720
+ self.num_res_blocks = num_res_blocks
721
+ self.attention_resolutions = attention_resolutions
722
+ self.dropout = dropout
723
+ self.channel_mult = channel_mult
724
+ self.conv_resample = conv_resample
725
+ self.use_checkpoint = use_checkpoint
726
+ self.dtype = th.float16 if use_fp16 else th.float32
727
+ self.num_heads = num_heads
728
+ self.num_head_channels = num_head_channels
729
+ self.num_heads_upsample = num_heads_upsample
730
+
731
+ time_embed_dim = model_channels * 4
732
+ self.time_embed = nn.Sequential(
733
+ linear(model_channels, time_embed_dim),
734
+ nn.SiLU(),
735
+ linear(time_embed_dim, time_embed_dim),
736
+ )
737
+
738
+ ch = int(channel_mult[0] * model_channels)
739
+ self.input_blocks = nn.ModuleList(
740
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
741
+ )
742
+ self._feature_size = ch
743
+ input_block_chans = [ch]
744
+ ds = 1
745
+ for level, mult in enumerate(channel_mult):
746
+ for _ in range(num_res_blocks):
747
+ layers = [
748
+ ResBlock(
749
+ ch,
750
+ time_embed_dim,
751
+ dropout,
752
+ out_channels=int(mult * model_channels),
753
+ dims=dims,
754
+ use_checkpoint=use_checkpoint,
755
+ use_scale_shift_norm=use_scale_shift_norm,
756
+ )
757
+ ]
758
+ ch = int(mult * model_channels)
759
+ if ds in attention_resolutions:
760
+ layers.append(
761
+ AttentionBlock(
762
+ ch,
763
+ use_checkpoint=use_checkpoint,
764
+ num_heads=num_heads,
765
+ num_head_channels=num_head_channels,
766
+ use_new_attention_order=use_new_attention_order,
767
+ )
768
+ )
769
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
770
+ self._feature_size += ch
771
+ input_block_chans.append(ch)
772
+ if level != len(channel_mult) - 1:
773
+ out_ch = ch
774
+ self.input_blocks.append(
775
+ TimestepEmbedSequential(
776
+ ResBlock(
777
+ ch,
778
+ time_embed_dim,
779
+ dropout,
780
+ out_channels=out_ch,
781
+ dims=dims,
782
+ use_checkpoint=use_checkpoint,
783
+ use_scale_shift_norm=use_scale_shift_norm,
784
+ down=True,
785
+ )
786
+ if resblock_updown
787
+ else Downsample(
788
+ ch, conv_resample, dims=dims, out_channels=out_ch
789
+ )
790
+ )
791
+ )
792
+ ch = out_ch
793
+ input_block_chans.append(ch)
794
+ ds *= 2
795
+ self._feature_size += ch
796
+
797
+ self.middle_block = TimestepEmbedSequential(
798
+ ResBlock(
799
+ ch,
800
+ time_embed_dim,
801
+ dropout,
802
+ dims=dims,
803
+ use_checkpoint=use_checkpoint,
804
+ use_scale_shift_norm=use_scale_shift_norm,
805
+ ),
806
+ AttentionBlock(
807
+ ch,
808
+ use_checkpoint=use_checkpoint,
809
+ num_heads=num_heads,
810
+ num_head_channels=num_head_channels,
811
+ use_new_attention_order=use_new_attention_order,
812
+ ),
813
+ ResBlock(
814
+ ch,
815
+ time_embed_dim,
816
+ dropout,
817
+ dims=dims,
818
+ use_checkpoint=use_checkpoint,
819
+ use_scale_shift_norm=use_scale_shift_norm,
820
+ ),
821
+ )
822
+ self._feature_size += ch
823
+ self.pool = pool
824
+ if pool == "adaptive":
825
+ self.out = nn.Sequential(
826
+ normalization(ch),
827
+ nn.SiLU(),
828
+ nn.AdaptiveAvgPool2d((1, 1)),
829
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
830
+ nn.Flatten(),
831
+ )
832
+ elif pool == "attention":
833
+ assert num_head_channels != -1
834
+ self.out = nn.Sequential(
835
+ normalization(ch),
836
+ nn.SiLU(),
837
+ AttentionPool2d(
838
+ (image_size // ds), ch, num_head_channels, out_channels
839
+ ),
840
+ )
841
+ elif pool == "spatial":
842
+ self.out = nn.Sequential(
843
+ nn.Linear(self._feature_size, 2048),
844
+ nn.ReLU(),
845
+ nn.Linear(2048, self.out_channels),
846
+ )
847
+ elif pool == "spatial_v2":
848
+ self.out = nn.Sequential(
849
+ nn.Linear(self._feature_size, 2048),
850
+ normalization(2048),
851
+ nn.SiLU(),
852
+ nn.Linear(2048, self.out_channels),
853
+ )
854
+ else:
855
+ raise NotImplementedError(f"Unexpected {pool} pooling")
856
+
857
+ def convert_to_fp16(self):
858
+ """
859
+ Convert the torso of the model to float16.
860
+ """
861
+ self.input_blocks.apply(convert_module_to_f16)
862
+ self.middle_block.apply(convert_module_to_f16)
863
+
864
+ def convert_to_fp32(self):
865
+ """
866
+ Convert the torso of the model to float32.
867
+ """
868
+ self.input_blocks.apply(convert_module_to_f32)
869
+ self.middle_block.apply(convert_module_to_f32)
870
+
871
+ def forward(self, x, timesteps):
872
+ """
873
+ Apply the model to an input batch.
874
+
875
+ :param x: an [N x C x ...] Tensor of inputs.
876
+ :param timesteps: a 1-D batch of timesteps.
877
+ :return: an [N x K] Tensor of outputs.
878
+ """
879
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
880
+
881
+ results = []
882
+ h = x.type(self.dtype)
883
+ for module in self.input_blocks:
884
+ h = module(h, emb)
885
+ if self.pool.startswith("spatial"):
886
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
887
+ h = self.middle_block(h, emb)
888
+ if self.pool.startswith("spatial"):
889
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
890
+ h = th.cat(results, axis=-1)
891
+ return self.out(h)
892
+ else:
893
+ h = h.type(x.dtype)
894
+ return self.out(h)
codes/modules/unet_2d.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from ..configuration_utils import ConfigMixin, register_to_config
21
+ from ..utils import BaseOutput
22
+ from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
23
+ from .modeling_utils import ModelMixin
24
+ from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
25
+
26
+
27
+ @dataclass
28
+ class UNet2DOutput(BaseOutput):
29
+ """
30
+ The output of [`UNet2DModel`].
31
+
32
+ Args:
33
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
34
+ The hidden states output from the last layer of the model.
35
+ """
36
+
37
+ sample: torch.FloatTensor
38
+
39
+
40
+ class UNet2DModel(ModelMixin, ConfigMixin):
41
+ r"""
42
+ A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
43
+
44
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
45
+ for all models (such as downloading or saving).
46
+
47
+ Parameters:
48
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
49
+ Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
50
+ 1)`.
51
+ in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
52
+ out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
53
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
54
+ time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
55
+ freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
56
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
57
+ Whether to flip sin to cos for Fourier time embedding.
58
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
59
+ Tuple of downsample block types.
60
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
61
+ Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
62
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
63
+ Tuple of upsample block types.
64
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
65
+ Tuple of block output channels.
66
+ layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
67
+ mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
68
+ downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
69
+ downsample_type (`str`, *optional*, defaults to `conv`):
70
+ The downsample type for downsampling layers. Choose between "conv" and "resnet"
71
+ upsample_type (`str`, *optional*, defaults to `conv`):
72
+ The upsample type for upsampling layers. Choose between "conv" and "resnet"
73
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
74
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
75
+ attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
76
+ norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
77
+ norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
78
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
79
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
80
+ class_embed_type (`str`, *optional*, defaults to `None`):
81
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
82
+ `"timestep"`, or `"identity"`.
83
+ num_class_embeds (`int`, *optional*, defaults to `None`):
84
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
85
+ conditioning with `class_embed_type` equal to `None`.
86
+ """
87
+
88
+ @register_to_config
89
+ def __init__(
90
+ self,
91
+ sample_size: Optional[Union[int, Tuple[int, int]]] = None,
92
+ in_channels: int = 3,
93
+ out_channels: int = 3,
94
+ center_input_sample: bool = False,
95
+ time_embedding_type: str = "positional",
96
+ freq_shift: int = 0,
97
+ flip_sin_to_cos: bool = True,
98
+ down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
99
+ up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
100
+ block_out_channels: Tuple[int] = (224, 448, 672, 896),
101
+ layers_per_block: int = 2,
102
+ mid_block_scale_factor: float = 1,
103
+ downsample_padding: int = 1,
104
+ downsample_type: str = "conv",
105
+ upsample_type: str = "conv",
106
+ dropout: float = 0.0,
107
+ act_fn: str = "silu",
108
+ attention_head_dim: Optional[int] = 8,
109
+ norm_num_groups: int = 32,
110
+ norm_eps: float = 1e-5,
111
+ resnet_time_scale_shift: str = "default",
112
+ add_attention: bool = True,
113
+ class_embed_type: Optional[str] = None,
114
+ num_class_embeds: Optional[int] = None,
115
+ ):
116
+ super().__init__()
117
+
118
+ self.sample_size = sample_size
119
+ time_embed_dim = block_out_channels[0] * 4
120
+
121
+ # Check inputs
122
+ if len(down_block_types) != len(up_block_types):
123
+ raise ValueError(
124
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
125
+ )
126
+
127
+ if len(block_out_channels) != len(down_block_types):
128
+ raise ValueError(
129
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
130
+ )
131
+
132
+ # input
133
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
134
+
135
+ # time
136
+ if time_embedding_type == "fourier":
137
+ self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
138
+ timestep_input_dim = 2 * block_out_channels[0]
139
+ elif time_embedding_type == "positional":
140
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
141
+ timestep_input_dim = block_out_channels[0]
142
+
143
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
144
+
145
+ # class embedding
146
+ if class_embed_type is None and num_class_embeds is not None:
147
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
148
+ elif class_embed_type == "timestep":
149
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
150
+ elif class_embed_type == "identity":
151
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
152
+ else:
153
+ self.class_embedding = None
154
+
155
+ self.down_blocks = nn.ModuleList([])
156
+ self.mid_block = None
157
+ self.up_blocks = nn.ModuleList([])
158
+
159
+ # down
160
+ output_channel = block_out_channels[0]
161
+ for i, down_block_type in enumerate(down_block_types):
162
+ input_channel = output_channel
163
+ output_channel = block_out_channels[i]
164
+ is_final_block = i == len(block_out_channels) - 1
165
+
166
+ down_block = get_down_block(
167
+ down_block_type,
168
+ num_layers=layers_per_block,
169
+ in_channels=input_channel,
170
+ out_channels=output_channel,
171
+ temb_channels=time_embed_dim,
172
+ add_downsample=not is_final_block,
173
+ resnet_eps=norm_eps,
174
+ resnet_act_fn=act_fn,
175
+ resnet_groups=norm_num_groups,
176
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
177
+ downsample_padding=downsample_padding,
178
+ resnet_time_scale_shift=resnet_time_scale_shift,
179
+ downsample_type=downsample_type,
180
+ dropout=dropout,
181
+ )
182
+ self.down_blocks.append(down_block)
183
+
184
+ # mid
185
+ self.mid_block = UNetMidBlock2D(
186
+ in_channels=block_out_channels[-1],
187
+ temb_channels=time_embed_dim,
188
+ dropout=dropout,
189
+ resnet_eps=norm_eps,
190
+ resnet_act_fn=act_fn,
191
+ output_scale_factor=mid_block_scale_factor,
192
+ resnet_time_scale_shift=resnet_time_scale_shift,
193
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
194
+ resnet_groups=norm_num_groups,
195
+ add_attention=add_attention,
196
+ )
197
+
198
+ # up
199
+ reversed_block_out_channels = list(reversed(block_out_channels))
200
+ output_channel = reversed_block_out_channels[0]
201
+ for i, up_block_type in enumerate(up_block_types):
202
+ prev_output_channel = output_channel
203
+ output_channel = reversed_block_out_channels[i]
204
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
205
+
206
+ is_final_block = i == len(block_out_channels) - 1
207
+
208
+ up_block = get_up_block(
209
+ up_block_type,
210
+ num_layers=layers_per_block + 1,
211
+ in_channels=input_channel,
212
+ out_channels=output_channel,
213
+ prev_output_channel=prev_output_channel,
214
+ temb_channels=time_embed_dim,
215
+ add_upsample=not is_final_block,
216
+ resnet_eps=norm_eps,
217
+ resnet_act_fn=act_fn,
218
+ resnet_groups=norm_num_groups,
219
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
220
+ resnet_time_scale_shift=resnet_time_scale_shift,
221
+ upsample_type=upsample_type,
222
+ dropout=dropout,
223
+ )
224
+ self.up_blocks.append(up_block)
225
+ prev_output_channel = output_channel
226
+
227
+ # out
228
+ num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
229
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
230
+ self.conv_act = nn.SiLU()
231
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
232
+
233
+ def forward(
234
+ self,
235
+ sample: torch.FloatTensor,
236
+ timestep: Union[torch.Tensor, float, int],
237
+ class_labels: Optional[torch.Tensor] = None,
238
+ return_dict: bool = True,
239
+ ) -> Union[UNet2DOutput, Tuple]:
240
+ r"""
241
+ The [`UNet2DModel`] forward method.
242
+
243
+ Args:
244
+ sample (`torch.FloatTensor`):
245
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
246
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
247
+ class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
248
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
249
+ return_dict (`bool`, *optional*, defaults to `True`):
250
+ Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
251
+
252
+ Returns:
253
+ [`~models.unet_2d.UNet2DOutput`] or `tuple`:
254
+ If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
255
+ returned where the first element is the sample tensor.
256
+ """
257
+ # 0. center input if necessary
258
+ if self.config.center_input_sample:
259
+ sample = 2 * sample - 1.0
260
+
261
+ # 1. time
262
+ timesteps = timestep
263
+ if not torch.is_tensor(timesteps):
264
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
265
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
266
+ timesteps = timesteps[None].to(sample.device)
267
+
268
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
269
+ timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
270
+
271
+ t_emb = self.time_proj(timesteps)
272
+
273
+ # timesteps does not contain any weights and will always return f32 tensors
274
+ # but time_embedding might actually be running in fp16. so we need to cast here.
275
+ # there might be better ways to encapsulate this.
276
+ t_emb = t_emb.to(dtype=self.dtype)
277
+ emb = self.time_embedding(t_emb)
278
+
279
+ if self.class_embedding is not None:
280
+ if class_labels is None:
281
+ raise ValueError("class_labels should be provided when doing class conditioning")
282
+
283
+ if self.config.class_embed_type == "timestep":
284
+ class_labels = self.time_proj(class_labels)
285
+
286
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
287
+ emb = emb + class_emb
288
+
289
+ # 2. pre-process
290
+ skip_sample = sample
291
+ sample = self.conv_in(sample)
292
+
293
+ # 3. down
294
+ down_block_res_samples = (sample,)
295
+ for downsample_block in self.down_blocks:
296
+ if hasattr(downsample_block, "skip_conv"):
297
+ sample, res_samples, skip_sample = downsample_block(
298
+ hidden_states=sample, temb=emb, skip_sample=skip_sample
299
+ )
300
+ else:
301
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
302
+
303
+ down_block_res_samples += res_samples
304
+
305
+ # 4. mid
306
+ sample = self.mid_block(sample, emb)
307
+
308
+ # 5. up
309
+ skip_sample = None
310
+ for upsample_block in self.up_blocks:
311
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
312
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
313
+
314
+ if hasattr(upsample_block, "skip_conv"):
315
+ sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
316
+ else:
317
+ sample = upsample_block(sample, res_samples, emb)
318
+
319
+ # 6. post-process
320
+ sample = self.conv_norm_out(sample)
321
+ sample = self.conv_act(sample)
322
+ sample = self.conv_out(sample)
323
+
324
+ if skip_sample is not None:
325
+ sample += skip_sample
326
+
327
+ if self.config.time_embedding_type == "fourier":
328
+ timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
329
+ sample = sample / timesteps
330
+
331
+ if not return_dict:
332
+ return (sample,)
333
+
334
+ return UNet2DOutput(sample=sample)
codes/pytorch_msssim.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from math import exp
4
+ import numpy as np
5
+
6
+ """
7
+ CODE SOURCE:
8
+ https://github.com/jorge-pessoa/pytorch-msssim/blob/master/pytorch_msssim/__init__.py
9
+ License:
10
+ MIT
11
+
12
+ Original Paper:
13
+ DOI: 10.1109/ACSSC.2003.1292216
14
+ """
15
+
16
+
17
+ def gaussian(window_size, sigma):
18
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
19
+ return gauss / gauss.sum()
20
+
21
+
22
+ def create_window(window_size, channel=1):
23
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
24
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
25
+ window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
26
+ return window
27
+
28
+
29
+ def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
30
+ # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
31
+ if val_range is None:
32
+ if torch.max(img1) > 128:
33
+ max_val = 255
34
+ else:
35
+ max_val = 1
36
+
37
+ if torch.min(img1) < -0.5:
38
+ min_val = -1
39
+ else:
40
+ min_val = 0
41
+ L = max_val - min_val
42
+ else:
43
+ L = val_range
44
+
45
+ padd = 0
46
+ (_, channel, height, width) = img1.size()
47
+ if window is None:
48
+ real_size = min(window_size, height, width)
49
+ window = create_window(real_size, channel=channel).to(img1.device)
50
+
51
+ mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
52
+ mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
53
+
54
+ mu1_sq = mu1.pow(2)
55
+ mu2_sq = mu2.pow(2)
56
+ mu1_mu2 = mu1 * mu2
57
+
58
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
59
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
60
+ sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
61
+
62
+ C1 = (0.01 * L) ** 2
63
+ C2 = (0.03 * L) ** 2
64
+
65
+ v1 = 2.0 * sigma12 + C2
66
+ v2 = sigma1_sq + sigma2_sq + C2
67
+ cs = v1 / v2 # contrast sensitivity
68
+
69
+ ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
70
+
71
+ if size_average:
72
+ cs = cs.mean()
73
+ ret = ssim_map.mean()
74
+ else:
75
+ cs = cs.mean(1).mean(1).mean(1)
76
+ ret = ssim_map.mean(1).mean(1).mean(1)
77
+
78
+ if full:
79
+ return ret, cs
80
+ return ret
81
+
82
+
83
+ def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=None):
84
+ device = img1.device
85
+ weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
86
+ levels = weights.size()[0]
87
+ ssims = []
88
+ mcs = []
89
+ for _ in range(levels):
90
+ sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
91
+
92
+ # Relu normalize (not compliant with original definition)
93
+ if normalize == "relu":
94
+ ssims.append(torch.relu(sim))
95
+ mcs.append(torch.relu(cs))
96
+ else:
97
+ ssims.append(sim)
98
+ mcs.append(cs)
99
+
100
+ img1 = F.avg_pool2d(img1, (2, 2))
101
+ img2 = F.avg_pool2d(img2, (2, 2))
102
+
103
+ ssims = torch.stack(ssims)
104
+ mcs = torch.stack(mcs)
105
+
106
+ # Simple normalize (not compliant with original definition)
107
+ # TODO: remove support for normalize == True (kept for backward support)
108
+ if normalize == "simple" or normalize == True:
109
+ ssims = (ssims + 1) / 2
110
+ mcs = (mcs + 1) / 2
111
+
112
+ pow1 = mcs ** weights
113
+ pow2 = ssims ** weights
114
+
115
+ # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
116
+ output = torch.prod(pow1[:-1]) * pow2[-1]
117
+ return output
118
+
119
+
120
+ # Classes to re-use window
121
+ class SSIM(torch.nn.Module):
122
+ def __init__(self, window_size=11, size_average=True, val_range=None):
123
+ super(SSIM, self).__init__()
124
+ self.window_size = window_size
125
+ self.size_average = size_average
126
+ self.val_range = val_range
127
+
128
+ # Assume 1 channel for SSIM
129
+ self.channel = 1
130
+ self.window = create_window(window_size)
131
+
132
+ def forward(self, img1, img2):
133
+ (_, channel, _, _) = img1.size()
134
+
135
+ if channel == self.channel and self.window.dtype == img1.dtype:
136
+ window = self.window
137
+ else:
138
+ window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
139
+ self.window = window
140
+ self.channel = channel
141
+
142
+ return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
143
+
144
+
145
+ class MSSSIM(torch.nn.Module):
146
+ def __init__(self, window_size=11, size_average=True, channel=3, normalize=None):
147
+ super(MSSSIM, self).__init__()
148
+ self.window_size = window_size
149
+ self.size_average = size_average
150
+ self.channel = channel
151
+ self.normalize = normalize
152
+
153
+ def forward(self, img1, img2):
154
+ # TODO: store window between calls if possible
155
+ return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average,
156
+ normalize=self.normalize)
codes/transform.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from config import config
2
+ from torchvision import transforms
3
+ import cv2 as cv
4
+
5
+
6
+
7
+ class myTransformMethod():
8
+ def __call__(self, img):
9
+
10
+ img = cv.resize(img, (config.image_size, config.image_size))
11
+ if img.shape[-1] == 3: # HWC
12
+ img = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
13
+ return img
14
+
15
+
16
+ myTransform = {
17
+ 'trainTransform': transforms.Compose([
18
+ myTransformMethod(),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize([0.5], [0.5])
21
+ ]),
22
+ 'testTransform': transforms.Compose([
23
+ myTransformMethod(),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize([0.5], [0.5])
26
+ ]),
27
+
28
+ }
codes/vq-gan_eval.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from config import config
2
+ from transform import myTransform
3
+
4
+ import torch
5
+ import os
6
+ import cv2 as cv
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+
10
+
11
+ def eval():
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 设置运行环境
13
+
14
+ source_path = "./test_eval" # 图像文件夹路径
15
+ recon_output_path = "./vq-gan_recon"
16
+ compress_output_path = "./vq-gan_compress"
17
+
18
+
19
+ model = torch.load("2025-02-04-Mask-JSRT-VQGAN.pth").to(device).eval()
20
+
21
+ with torch.no_grad():
22
+ for filename in tqdm(os.listdir(source_path)):
23
+ img_path = os.path.join(source_path, filename)
24
+
25
+
26
+ img = cv.imread(img_path, 0)
27
+
28
+ img = myTransform["testTransform"](img).to(device) # CHW
29
+ img = torch.unsqueeze(img, dim=0).to(device) # BCHW
30
+
31
+ recon, _ = model(img)
32
+
33
+ recon = np.array(recon.detach().to("cpu")) # BCHW
34
+ recon = np.squeeze(recon) # HW
35
+ recon = recon * 0.5 + 0.5
36
+ recon = np.clip(recon, 0, 1)
37
+
38
+ if not config.use_server:
39
+ cv.imshow("win", recon)
40
+ cv.waitKey(0)
41
+
42
+ recon *= 255
43
+ cv.imwrite(os.path.join(recon_output_path, filename), recon)
44
+
45
+ if config.output_feature_map:
46
+ compress = model.encode_stage_2_inputs(img).cpu().detach().numpy()
47
+ compress = np.transpose(np.squeeze(compress)[1:], (1, 2, 0))
48
+ compress = compress * 0.5 + 0.5
49
+ compress = np.clip(compress, 0, 1)
50
+ if not config.use_server:
51
+ cv.imshow("win", compress)
52
+ cv.waitKey(0)
53
+
54
+ compress *= 255
55
+ cv.imwrite(os.path.join(compress_output_path, filename), compress)
56
+
57
+
58
+
59
+ if __name__ == "__main__":
60
+ eval()
data/BS/0.png ADDED

Git LFS Details

  • SHA256: 596d2da0059a04d0d22248c4f4c14bcde03b6898a229343315a026fce5fda31c
  • Pointer size: 131 Bytes
  • Size of remote file: 455 kB
data/BS/1.png ADDED

Git LFS Details

  • SHA256: 743a74220b96fae88c0a3c1bd029fd81b7bb2a007c2f0b871940521bbbbaf7de
  • Pointer size: 131 Bytes
  • Size of remote file: 487 kB
data/BS/2.png ADDED

Git LFS Details

  • SHA256: 7ea0c94ae904f9f28aa85a413e54b2cb65f386529e9df0ebe23b7c56ff837114
  • Pointer size: 131 Bytes
  • Size of remote file: 424 kB
data/CXR/0.png ADDED

Git LFS Details

  • SHA256: 2ff0dbefdee19fcd6a63f5d8e21b25a1fa085b5ef3a96fce3240faa90696d9b7
  • Pointer size: 131 Bytes
  • Size of remote file: 490 kB
data/CXR/1.png ADDED

Git LFS Details

  • SHA256: b0f49c3985a2a610fbee7f27d372c088fbcf55b73c8198bd5f154dd713484ea3
  • Pointer size: 131 Bytes
  • Size of remote file: 526 kB
data/CXR/2.png ADDED

Git LFS Details

  • SHA256: 38d2a3ce3e0f4afe21047e5f1e7ea8a3c2f6b55b07c15621d8da14044671b674
  • Pointer size: 131 Bytes
  • Size of remote file: 511 kB
images/GL-LCM_gif.gif ADDED

Git LFS Details

  • SHA256: 75a83d8e3dfd78addc45f695e73846ae35661138f56085f483c454a7d9022585
  • Pointer size: 132 Bytes
  • Size of remote file: 8.76 MB
images/ablation.png ADDED

Git LFS Details

  • SHA256: e9457e6032368320af09bc2632c169abb43cec6e9cf6281c29732c0369f381ae
  • Pointer size: 132 Bytes
  • Size of remote file: 1.02 MB
images/comparison.png ADDED

Git LFS Details

  • SHA256: 397bb20751ca9d2a81df6694dd0569c5ff476bd47947f9b0af1b8934b04dc855
  • Pointer size: 133 Bytes
  • Size of remote file: 13.6 MB
images/framework.png ADDED

Git LFS Details

  • SHA256: 0bee5e8f15b3fdee26bb24abe444c4982025e6a88c67a87241d79584bc7f5f09
  • Pointer size: 132 Bytes
  • Size of remote file: 2.72 MB
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers~=0.27.2
2
+ lpips~=0.1.4
3
+ matplotlib~=3.7.2
4
+ matplotlib-inline~=0.1.6
5
+ monai~=1.2.0
6
+ monai-generative~=0.2.2
7
+ numpy~=1.26.4
8
+ opencv-python~=4.8.1.78
9
+ pandas~=2.0.3
10
+ Pillow~=9.3.0
11
+ scikit-image~=0.22.0
12
+ scikit-learn~=1.3.1
13
+ scipy~=1.11.3
14
+ seaborn~=0.13.0
15
+ torch~=2.0.1+cu117
16
+ torch-ema=~0.3
17
+ torchaudio~=2.0.2+cu117
18
+ torchsummary~=1.5.1
19
+ torchvision~=0.15.2+cu117
20
+ tqdm~=4.66.1