| import argparse |
| import os |
| import yaml |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import torchdiffeq |
| import utils |
| from diff2flow import VPDiffusionFlow, dict2namespace |
| import datasets |
| from tqdm import tqdm |
|
|
|
|
| def ode_inverse_solve( |
| flow_model, |
| x_data, |
| x_cond, |
| steps=100, |
| method="dopri5", |
| patch_size=64, |
| atol=1e-5, |
| rtol=1e-5, |
| ): |
| """ |
| Solves the ODE from t=0 (data) to t=1 (noise). |
| Returns x_1 (noise latent). |
| """ |
| |
| |
| |
|
|
| def drift_func(t, x): |
| |
| |
| return flow_model.get_velocity(x, t, x_cond, patch_size=patch_size) |
|
|
| |
| t_eval = torch.linspace(0.0, 1.0, steps + 1, device=x_data.device) |
|
|
| |
| out = torchdiffeq.odeint( |
| drift_func, x_data, t_eval, method=method, atol=atol, rtol=rtol |
| ) |
| |
| return out[-1] |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", type=str, required=True) |
| parser.add_argument("--resume", type=str, required=True) |
| parser.add_argument("--data_dir", type=str, default=None) |
| parser.add_argument("--dataset", type=str, default=None) |
| parser.add_argument("--steps", type=int, default=100) |
| parser.add_argument("--output_dir", type=str, default="reflow_data") |
| parser.add_argument("--seed", type=int, default=61) |
| parser.add_argument("--patch_size", type=int, default=64) |
| parser.add_argument("--method", type=str, default="dopri5") |
| parser.add_argument("--atol", type=float, default=1e-5) |
| parser.add_argument("--rtol", type=float, default=1e-5) |
| parser.add_argument( |
| "--max_images", |
| type=int, |
| default=None, |
| help="Max images to generate (for testing)", |
| ) |
| args = parser.parse_args() |
|
|
| |
| with open(os.path.join("configs", args.config), "r") as f: |
| config_dict = yaml.safe_load(f) |
| config = dict2namespace(config_dict) |
|
|
| if args.data_dir: |
| config.data.data_dir = args.data_dir |
| if args.dataset: |
| config.data.dataset = args.dataset |
|
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| config.device = device |
|
|
| |
| torch.manual_seed(args.seed) |
| np.random.seed(args.seed) |
|
|
| |
| print("Initializing VPDiffusionFlow...") |
| flow = VPDiffusionFlow(args, config) |
| flow.load_ckpt(args.resume) |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| |
| print(f"Loading dataset {config.data.dataset}...") |
| DATASET = datasets.__dict__[config.data.dataset](config) |
|
|
| |
| train_loader, _ = DATASET.get_loaders( |
| parse_patches=False, |
| validation=config.data.dataset if args.dataset else "raindrop", |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| train_loader, _ = DATASET.get_loaders(parse_patches=True) |
|
|
| print(f"Starting generation of reflow pairs...") |
|
|
| count = 0 |
|
|
| |
| for i, (x_batch, img_id) in enumerate( |
| tqdm(train_loader, desc="Generating Reflow Pairs") |
| ): |
| |
| |
| if x_batch.ndim == 5: |
| x_batch = x_batch.flatten(start_dim=0, end_dim=1) |
|
|
| input_img = x_batch[:, :3, :, :].to(device) |
| gt_img = x_batch[:, 3:, :, :].to(device) |
|
|
| |
| x_cond = utils.sampling.data_transform(input_img) |
| x_data = utils.sampling.data_transform(gt_img) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| with torch.no_grad(): |
| x_noise = ode_inverse_solve( |
| flow, |
| x_data, |
| x_cond, |
| steps=args.steps, |
| method=args.method, |
| patch_size=args.patch_size, |
| atol=args.atol, |
| rtol=args.rtol, |
| ) |
|
|
| |
| |
| |
| |
|
|
| |
| batch_data = { |
| "x_noise": x_noise.cpu(), |
| "x_data": x_data.cpu(), |
| "x_cond": x_cond.cpu(), |
| } |
|
|
| save_path = os.path.join(args.output_dir, f"batch_{i}.pth") |
| torch.save(batch_data, save_path) |
|
|
| print(f"Saved batch {i} to {save_path}") |
|
|
| count += input_img.shape[0] |
| if args.max_images and count >= args.max_images: |
| print(f"Reached max images {args.max_images}") |
| break |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|