| |
| """ |
| bin_to_safetensors.py |
| |
| Convert a PyTorch checkpoint (e.g., pytorch_model.bin / .pt / .ckpt) to a .safetensors file. |
| - Safe tensors only: tensors are saved; non-tensor Python objects (optimizer, schedulers, etc.) are ignored. |
| - Heuristics try to locate a model state_dict within common training checkpoints. |
| |
| USAGE: |
| python bin_to_safetensors.py --in pytorch_model.bin --out model.safetensors |
| python bin_to_safetensors.py --in trainer.ckpt --out model.safetensors |
| |
| NOTE: |
| Loading with torch.load uses pickle and can execute code from untrusted sources. |
| Only run this on checkpoints from sources you trust. |
| """ |
|
|
| import argparse |
| import sys |
| from typing import Dict, Any |
|
|
| import torch |
| from safetensors.torch import save_file, is_safe_tensor |
|
|
|
|
| def _is_tensor_dict(d: Any) -> bool: |
| if not isinstance(d, dict) or not d: |
| return False |
| |
| for v in d.values(): |
| if not (torch.is_tensor(v) or (hasattr(v, "tensor") and torch.is_tensor(getattr(v, "tensor")))): |
| return False |
| return True |
|
|
|
|
| def _extract_state_dict(obj: Any) -> Dict[str, torch.Tensor]: |
| """ |
| Try to extract a {name: tensor} dict from various checkpoint formats. |
| """ |
| |
| if _is_tensor_dict(obj): |
| |
| return {k: (v.detach().cpu().contiguous() if torch.is_tensor(v) else v.tensor.detach().cpu().contiguous()) |
| for k, v in obj.items()} |
|
|
| if isinstance(obj, dict): |
| |
| candidate_keys = [ |
| "state_dict", |
| "model_state_dict", |
| "model", |
| "module", |
| "network", |
| "net", |
| "weights", |
| ] |
|
|
| for ck in candidate_keys: |
| if ck in obj and _is_tensor_dict(obj[ck]): |
| d = obj[ck] |
| return {k: (v.detach().cpu().contiguous() if torch.is_tensor(v) else v.tensor.detach().cpu().contiguous()) |
| for k, v in d.items()} |
|
|
| |
| for k, v in obj.items(): |
| if _is_tensor_dict(v): |
| d = v |
| return {kk: (vv.detach().cpu().contiguous() if torch.is_tensor(vv) else vv.tensor.detach().cpu().contiguous()) |
| for kk, vv in d.items()} |
|
|
| raise ValueError( |
| "Could not find a model state_dict (a dict of tensors). " |
| "If this is a full training checkpoint, load it in Python, extract model.state_dict(), " |
| "and save that mapping instead." |
| ) |
|
|
|
|
| def convert_bin_to_safetensors(in_path: str, out_path: str, metadata: Dict[str, str] = None) -> None: |
| |
| obj = torch.load(in_path, map_location="cpu") |
|
|
| |
| if isinstance(obj, (bytes, bytearray)) and is_safe_tensor(obj): |
| print(f"Input appears to already be a safetensors file: {in_path}") |
| return |
|
|
| state = _extract_state_dict(obj) |
|
|
| |
| meta = {"format": "converted-from-pytorch-bin"} |
| if metadata: |
| meta.update({str(k): str(v) for k, v in metadata.items()}) |
|
|
| |
| save_file(state, out_path, metadata=meta) |
| print(f"✅ Wrote {out_path} with {len(state)} tensors.") |
|
|
|
|
| def main(argv=None): |
| parser = argparse.ArgumentParser(description="Convert PyTorch .bin/.pt/.ckpt to .safetensors") |
| parser.add_argument("--in", dest="in_path", required=True, help="Input .bin/.pt/.ckpt file path") |
| parser.add_argument("--out", dest="out_path", required=True, help="Output .safetensors file path") |
| parser.add_argument("--meta", nargs="*", default=[], help='Optional metadata entries like key=value (repeatable)') |
| args = parser.parse_args(argv) |
|
|
| metadata = {} |
| for item in args.meta: |
| if "=" in item: |
| k, v = item.split("=", 1) |
| metadata[k] = v |
| else: |
| print(f"Warning: ignoring malformed --meta entry (expected key=value): {item}", file=sys.stderr) |
|
|
| convert_bin_to_safetensors(args.in_path, args.out_path, metadata) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|