Image Classification
Transformers
PyTorch
English
sybil
medical
cancer
ct-scan
risk-prediction
healthcare
vision
Instructions to use Lab-Rasool/sybil with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Lab-Rasool/sybil with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-classification", model="Lab-Rasool/sybil") pipe("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/hub/parrots.png")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Lab-Rasool/sybil", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """PyTorch Sybil model for lung cancer risk prediction""" | |
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| from transformers import PreTrainedModel | |
| from transformers.modeling_outputs import BaseModelOutput | |
| from typing import Optional, Dict, List, Tuple | |
| import numpy as np | |
| from dataclasses import dataclass | |
| try: | |
| from .configuration_sybil import SybilConfig | |
| except ImportError: | |
| from configuration_sybil import SybilConfig | |
| class SybilOutput(BaseModelOutput): | |
| """ | |
| Base class for Sybil model outputs. | |
| Args: | |
| risk_scores: (`torch.FloatTensor` of shape `(batch_size, max_followup)`): | |
| Predicted risk scores for each year up to max_followup. | |
| image_attention: (`torch.FloatTensor` of shape `(batch_size, num_slices, height, width)`, *optional*): | |
| Attention weights over image pixels. | |
| volume_attention: (`torch.FloatTensor` of shape `(batch_size, num_slices)`, *optional*): | |
| Attention weights over CT scan slices. | |
| hidden_states: (`torch.FloatTensor` of shape `(batch_size, hidden_dim)`, *optional*): | |
| Hidden states from the pooling layer. | |
| """ | |
| risk_scores: torch.FloatTensor = None | |
| image_attention: Optional[torch.FloatTensor] = None | |
| volume_attention: Optional[torch.FloatTensor] = None | |
| hidden_states: Optional[torch.FloatTensor] = None | |
| class CumulativeProbabilityLayer(nn.Module): | |
| """ | |
| Cumulative probability layer for survival prediction. | |
| Matches the original Sybil implementation exactly with: | |
| - hazard_fc: Year-specific hazards (can be zero after ReLU) | |
| - base_hazard_fc: Base hazard shared across all years | |
| - Triangular masking for cumulative hazard computation | |
| """ | |
| def __init__(self, hidden_dim: int, max_followup: int = 6): | |
| super().__init__() | |
| self.max_followup = max_followup | |
| # Year-specific hazards | |
| self.hazard_fc = nn.Linear(hidden_dim, max_followup) | |
| # Base hazard (shared across years) | |
| self.base_hazard_fc = nn.Linear(hidden_dim, 1) | |
| self.relu = nn.ReLU(inplace=True) | |
| # Upper triangular mask for cumulative computation | |
| mask = torch.ones([max_followup, max_followup]) | |
| mask = torch.tril(mask, diagonal=0) | |
| mask = torch.nn.Parameter(torch.t(mask), requires_grad=False) | |
| self.register_parameter("upper_triangular_mask", mask) | |
| def hazards(self, x): | |
| """Compute positive hazards using ReLU""" | |
| raw_hazard = self.hazard_fc(x) | |
| pos_hazard = self.relu(raw_hazard) | |
| return pos_hazard | |
| def forward(self, x): | |
| """ | |
| Compute cumulative probabilities matching original Sybil. | |
| Args: | |
| x: Hidden features [B, hidden_dim] | |
| Returns: | |
| Cumulative probabilities [B, max_followup] | |
| """ | |
| hazards = self.hazards(x) | |
| B, T = hazards.size() | |
| # Expand for masking: [B, T] -> [B, T, T] | |
| expanded_hazards = hazards.unsqueeze(-1).expand(B, T, T) | |
| # Apply triangular mask for cumulative sum | |
| masked_hazards = expanded_hazards * self.upper_triangular_mask | |
| # Base hazard (shared across years) | |
| base_hazard = self.base_hazard_fc(x) | |
| # Sum masked hazards and add base | |
| cum_prob = torch.sum(masked_hazards, dim=1) + base_hazard | |
| return cum_prob | |
| class GlobalMaxPool(nn.Module): | |
| """Pool to obtain the maximum value for each channel""" | |
| def __init__(self): | |
| super(GlobalMaxPool, self).__init__() | |
| def forward(self, x): | |
| """ | |
| Args: | |
| - x: tensor of shape (B, C, T, W, H) | |
| Returns: | |
| - output: dict. output['hidden'] is (B, C) | |
| """ | |
| spatially_flat_size = (*x.size()[:2], -1) | |
| x = x.view(spatially_flat_size) | |
| hidden, _ = torch.max(x, dim=-1) | |
| return {'hidden': hidden} | |
| class PerFrameMaxPool(nn.Module): | |
| """Pool to obtain the maximum value for each slice in 3D input""" | |
| def __init__(self): | |
| super(PerFrameMaxPool, self).__init__() | |
| def forward(self, x): | |
| """ | |
| Args: | |
| - x: tensor of shape (B, C, T, W, H) | |
| Returns: | |
| - output: dict. | |
| + output['multi_image_hidden'] is (B, C, T) | |
| """ | |
| assert len(x.shape) == 5 | |
| output = {} | |
| spatially_flat_size = (*x.size()[:3], -1) | |
| x = x.view(spatially_flat_size) | |
| output['multi_image_hidden'], _ = torch.max(x, dim=-1) | |
| return output | |
| class Simple_AttentionPool(nn.Module): | |
| """Pool to learn an attention over the slices""" | |
| def __init__(self, **kwargs): | |
| super(Simple_AttentionPool, self).__init__() | |
| self.attention_fc = nn.Linear(kwargs['num_chan'], 1) | |
| self.softmax = nn.Softmax(dim=-1) | |
| self.logsoftmax = nn.LogSoftmax(dim=-1) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| - x: tensor of shape (B, C, N) | |
| Returns: | |
| - output: dict | |
| + output['volume_attention']: tensor (B, N) | |
| + output['hidden']: tensor (B, C) | |
| """ | |
| output = {} | |
| B = x.shape[0] | |
| spatially_flat_size = (*x.size()[:2], -1) # B, C, N | |
| x = x.view(spatially_flat_size) | |
| attention_scores = self.attention_fc(x.transpose(1, 2)) # B, N, 1 | |
| output['volume_attention'] = self.logsoftmax(attention_scores.transpose(1, 2)).view(B, -1) | |
| attention_scores = self.softmax(attention_scores.transpose(1, 2)) # B, 1, N | |
| x = x * attention_scores # B, C, N | |
| output['hidden'] = torch.sum(x, dim=-1) | |
| return output | |
| class Simple_AttentionPool_MultiImg(nn.Module): | |
| """Pool to learn an attention over the slices and the volume""" | |
| def __init__(self, **kwargs): | |
| super(Simple_AttentionPool_MultiImg, self).__init__() | |
| self.attention_fc = nn.Linear(kwargs['num_chan'], 1) | |
| self.softmax = nn.Softmax(dim=-1) | |
| self.logsoftmax = nn.LogSoftmax(dim=-1) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| - x: tensor of shape (B, C, T, W, H) | |
| Returns: | |
| - output: dict | |
| + output['image_attention']: tensor (B, T, W*H) | |
| + output['multi_image_hidden']: tensor (B, C, T) | |
| + output['hidden']: tensor (B, T*C) | |
| """ | |
| output = {} | |
| B, C, T, W, H = x.size() | |
| x = x.permute([0, 2, 1, 3, 4]) | |
| x = x.contiguous().view(B*T, C, W*H) | |
| attention_scores = self.attention_fc(x.transpose(1, 2)) # BT, WH, 1 | |
| output['image_attention'] = self.logsoftmax(attention_scores.transpose(1, 2)).view(B, T, -1) | |
| attention_scores = self.softmax(attention_scores.transpose(1, 2)) # BT, 1, WH | |
| x = x * attention_scores # BT, C, WH | |
| x = torch.sum(x, dim=-1) | |
| output['multi_image_hidden'] = x.view(B, T, C).permute([0, 2, 1]).contiguous() | |
| output['hidden'] = x.view(B, T * C) | |
| return output | |
| class Conv1d_AttnPool(nn.Module): | |
| """Pool to learn an attention over the slices after convolution""" | |
| def __init__(self, **kwargs): | |
| super(Conv1d_AttnPool, self).__init__() | |
| self.conv1d = nn.Conv1d( | |
| kwargs['num_chan'], | |
| kwargs['num_chan'], | |
| kernel_size=kwargs['conv_pool_kernel_size'], | |
| stride=kwargs['stride'], | |
| padding=kwargs['conv_pool_kernel_size']//2, | |
| bias=False | |
| ) | |
| self.aggregate = Simple_AttentionPool(**kwargs) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| - x: tensor of shape (B, C, T) | |
| Returns: | |
| - output: dict | |
| + output['attention_scores']: tensor (B, C) | |
| + output['hidden']: tensor (B, C) | |
| """ | |
| # X: B, C, N | |
| x = self.conv1d(x) # B, C, N' | |
| return self.aggregate(x) | |
| class MultiAttentionPool(nn.Module): | |
| """Multi-attention pooling layer for CT scan aggregation - matches original Sybil architecture""" | |
| def __init__(self, channels: int = 512): | |
| super().__init__() | |
| params = { | |
| 'num_chan': 512, | |
| 'conv_pool_kernel_size': 11, | |
| 'stride': 1 | |
| } | |
| # Define all pooling sub-modules matching original Sybil | |
| self.image_pool1 = Simple_AttentionPool_MultiImg(**params) | |
| self.volume_pool1 = Simple_AttentionPool(**params) | |
| self.image_pool2 = PerFrameMaxPool() | |
| self.volume_pool2 = Conv1d_AttnPool(**params) | |
| self.global_max_pool = GlobalMaxPool() | |
| # Final linear layers to combine features | |
| self.multi_img_hidden_fc = nn.Linear(2 * 512, 512) | |
| self.hidden_fc = nn.Linear(3 * 512, 512) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: tensor of shape (B, C, T, W, H) where | |
| - B: batch size | |
| - C: channels (512) | |
| - T: temporal/depth dimension (slices) | |
| - W, H: spatial dimensions | |
| Returns: | |
| output: dict with keys: | |
| - 'hidden': (B, 512) - final aggregated features | |
| - 'image_attention_1': (B, T, W*H) - image attention scores | |
| - 'volume_attention_1': (B, T) - volume attention scores | |
| - 'image_attention_2': None (no attention for max pool) | |
| - 'volume_attention_2': (B, T) - volume attention scores | |
| - 'multi_image_hidden': (B, 512, T) - intermediate features | |
| - 'maxpool_hidden': (B, 512) - max pooled features | |
| """ | |
| output = {} | |
| # First attention pooling pathway | |
| image_pool_out1 = self.image_pool1(x) | |
| # Keys: "multi_image_hidden" (B, C, T), "image_attention" (B, T, W*H), "hidden" (B, T*C) | |
| volume_pool_out1 = self.volume_pool1(image_pool_out1['multi_image_hidden']) | |
| # Keys: "hidden" (B, C), "volume_attention" (B, T) | |
| # Second max pooling pathway | |
| image_pool_out2 = self.image_pool2(x) | |
| # Keys: "multi_image_hidden" (B, C, T) | |
| volume_pool_out2 = self.volume_pool2(image_pool_out2['multi_image_hidden']) | |
| # Keys: "hidden" (B, C), "volume_attention" (B, T) | |
| # Collect all pooling outputs with numbered suffixes | |
| for pool_out, num in [(image_pool_out1, 1), (volume_pool_out1, 1), | |
| (image_pool_out2, 2), (volume_pool_out2, 2)]: | |
| for key, val in pool_out.items(): | |
| output['{}_{}'.format(key, num)] = val | |
| # Global max pooling | |
| maxpool_out = self.global_max_pool(x) | |
| output['maxpool_hidden'] = maxpool_out['hidden'] | |
| # Combine multi-image features from both pathways | |
| multi_image_hidden = torch.cat( | |
| [image_pool_out1['multi_image_hidden'], image_pool_out2['multi_image_hidden']], | |
| dim=-2 | |
| ) # (B, C, 2*T) | |
| output['multi_image_hidden'] = self.multi_img_hidden_fc( | |
| multi_image_hidden.permute([0, 2, 1]).contiguous() | |
| ).permute([0, 2, 1]).contiguous() # (B, 512, T) | |
| # Combine all volume-level features | |
| hidden = torch.cat( | |
| [volume_pool_out1['hidden'], volume_pool_out2['hidden'], output['maxpool_hidden']], | |
| dim=-1 | |
| ) # (B, 3*512) | |
| output['hidden'] = self.hidden_fc(hidden) # (B, 512) | |
| return output | |
| class SybilPreTrainedModel(PreTrainedModel): | |
| """ | |
| An abstract class to handle weights initialization and a simple interface | |
| for downloading and loading pretrained models. | |
| """ | |
| config_class = SybilConfig | |
| base_model_prefix = "sybil" | |
| supports_gradient_checkpointing = False | |
| def _init_weights(self, module): | |
| """Initialize the weights""" | |
| if isinstance(module, nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Conv3d): | |
| nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| class SybilForRiskPrediction(SybilPreTrainedModel): | |
| """ | |
| Sybil model for lung cancer risk prediction from CT scans. | |
| This model takes 3D CT scan volumes as input and predicts cancer risk scores | |
| for multiple future time points (typically 1-6 years). | |
| """ | |
| def __init__(self, config: SybilConfig): | |
| super().__init__(config) | |
| self.config = config | |
| # Use pretrained R3D-18 as backbone | |
| encoder = torchvision.models.video.r3d_18(pretrained=True) | |
| self.image_encoder = nn.Sequential(*list(encoder.children())[:-2]) | |
| # Multi-attention pooling | |
| self.pool = MultiAttentionPool(channels=512) | |
| # Classification layers | |
| self.relu = nn.ReLU(inplace=False) | |
| self.dropout = nn.Dropout(p=config.dropout) | |
| # Risk prediction layer | |
| self.prob_of_failure_layer = CumulativeProbabilityLayer( | |
| config.hidden_dim, | |
| max_followup=config.max_followup | |
| ) | |
| # Calibrator for ensemble predictions | |
| self.calibrator = None | |
| if config.calibrator_data: | |
| self.set_calibrator(config.calibrator_data) | |
| # Initialize weights | |
| self.post_init() | |
| def set_calibrator(self, calibrator_data: Dict): | |
| """Set calibration data for risk score adjustment""" | |
| self.calibrator = calibrator_data | |
| def _calibrate_scores(self, scores: torch.Tensor) -> torch.Tensor: | |
| """Apply calibration to raw risk scores""" | |
| if self.calibrator is None: | |
| return scores | |
| # Convert to numpy for calibration | |
| scores_np = scores.detach().cpu().numpy() | |
| calibrated = np.zeros_like(scores_np) | |
| # Apply calibration for each year | |
| for year in range(scores_np.shape[1]): | |
| year_key = f"Year{year + 1}" | |
| if year_key in self.calibrator: | |
| # Apply calibration transformation | |
| calibrated[:, year] = self._apply_calibration( | |
| scores_np[:, year], | |
| self.calibrator[year_key] | |
| ) | |
| else: | |
| calibrated[:, year] = scores_np[:, year] | |
| return torch.from_numpy(calibrated).to(scores.device) | |
| def _apply_calibration(self, scores: np.ndarray, calibrator_params: Dict) -> np.ndarray: | |
| """Apply specific calibration transformation""" | |
| # Simplified calibration - in practice, this would use the full calibration model | |
| # from the original Sybil implementation | |
| return scores # Placeholder for now | |
| def forward( | |
| self, | |
| pixel_values: torch.FloatTensor, | |
| return_attentions: bool = False, | |
| return_dict: bool = True, | |
| ) -> SybilOutput: | |
| """ | |
| Forward pass of the Sybil model. | |
| Args: | |
| pixel_values: (`torch.FloatTensor` of shape `(batch_size, channels, depth, height, width)`): | |
| Pixel values of CT scan volumes. | |
| return_attentions: (`bool`, *optional*, defaults to `False`): | |
| Whether to return attention weights. | |
| return_dict: (`bool`, *optional*, defaults to `True`): | |
| Whether to return a `SybilOutput` instead of a plain tuple. | |
| Returns: | |
| `SybilOutput` or tuple | |
| """ | |
| # Extract features using 3D CNN backbone | |
| features = self.image_encoder(pixel_values) | |
| # Apply multi-attention pooling | |
| pool_output = self.pool(features) | |
| # Apply ReLU and dropout | |
| hidden = self.relu(pool_output['hidden']) | |
| hidden = self.dropout(hidden) | |
| # Predict risk scores | |
| risk_logits = self.prob_of_failure_layer(hidden) | |
| risk_scores = torch.sigmoid(risk_logits) | |
| # Apply calibration if available | |
| risk_scores = self._calibrate_scores(risk_scores) | |
| if not return_dict: | |
| outputs = (risk_scores,) | |
| if return_attentions: | |
| outputs = outputs + (pool_output.get('image_attention_1'), | |
| pool_output.get('volume_attention_1')) | |
| return outputs | |
| return SybilOutput( | |
| risk_scores=risk_scores, | |
| image_attention=pool_output.get('image_attention_1') if return_attentions else None, | |
| volume_attention=pool_output.get('volume_attention_1') if return_attentions else None, | |
| hidden_states=hidden if return_attentions else None | |
| ) | |
| def from_pretrained_ensemble( | |
| cls, | |
| pretrained_model_name_or_path, | |
| checkpoint_paths: List[str], | |
| calibrator_path: Optional[str] = None, | |
| **kwargs | |
| ): | |
| """ | |
| Load an ensemble of Sybil models from checkpoints. | |
| Args: | |
| pretrained_model_name_or_path: Path to the pretrained model or model identifier. | |
| checkpoint_paths: List of paths to individual model checkpoints. | |
| calibrator_path: Path to calibration data. | |
| **kwargs: Additional keyword arguments for model initialization. | |
| Returns: | |
| SybilEnsemble: An ensemble of Sybil models. | |
| """ | |
| config = kwargs.pop("config", None) | |
| if config is None: | |
| config = SybilConfig.from_pretrained(pretrained_model_name_or_path) | |
| # Load calibrator if provided | |
| calibrator_data = None | |
| if calibrator_path: | |
| import json | |
| with open(calibrator_path, 'r') as f: | |
| calibrator_data = json.load(f) | |
| config.calibrator_data = calibrator_data | |
| # Create ensemble | |
| models = [] | |
| for checkpoint_path in checkpoint_paths: | |
| model = cls(config) | |
| # Load checkpoint weights | |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
| # Remove 'model.' prefix from state dict keys if present | |
| state_dict = {} | |
| for k, v in checkpoint['state_dict'].items(): | |
| if k.startswith('model.'): | |
| state_dict[k[6:]] = v | |
| else: | |
| state_dict[k] = v | |
| # Map to new model structure | |
| mapped_state_dict = model._map_checkpoint_weights(state_dict) | |
| model.load_state_dict(mapped_state_dict, strict=False) | |
| models.append(model) | |
| return SybilEnsemble(models, config) | |
| def _map_checkpoint_weights(self, state_dict: Dict) -> Dict: | |
| """Map original Sybil checkpoint weights to new structure""" | |
| mapped = {} | |
| # Map encoder weights | |
| for k, v in state_dict.items(): | |
| if k.startswith('image_encoder'): | |
| mapped[k] = v | |
| elif k.startswith('pool'): | |
| # Map pooling layer weights | |
| mapped[k] = v | |
| elif k.startswith('prob_of_failure_layer'): | |
| # Map final prediction layer | |
| mapped[k] = v | |
| return mapped | |
| class SybilEnsemble: | |
| """Ensemble of Sybil models for improved predictions""" | |
| def __init__(self, models: List[SybilForRiskPrediction], config: SybilConfig): | |
| self.models = models | |
| self.config = config | |
| self.device = None | |
| def to(self, device): | |
| """Move all models to device""" | |
| self.device = device | |
| for model in self.models: | |
| model.to(device) | |
| return self | |
| def eval(self): | |
| """Set all models to evaluation mode""" | |
| for model in self.models: | |
| model.eval() | |
| def __call__( | |
| self, | |
| pixel_values: torch.FloatTensor, | |
| return_attentions: bool = False, | |
| ) -> SybilOutput: | |
| """ | |
| Run inference with ensemble voting. | |
| Args: | |
| pixel_values: Input CT scan volumes. | |
| return_attentions: Whether to return attention maps. | |
| Returns: | |
| SybilOutput with averaged predictions from all models. | |
| """ | |
| all_risk_scores = [] | |
| all_image_attentions = [] | |
| all_volume_attentions = [] | |
| with torch.no_grad(): | |
| for model in self.models: | |
| output = model( | |
| pixel_values=pixel_values, | |
| return_attentions=return_attentions | |
| ) | |
| all_risk_scores.append(output.risk_scores) | |
| if return_attentions: | |
| all_image_attentions.append(output.image_attention) | |
| all_volume_attentions.append(output.volume_attention) | |
| # Average predictions | |
| risk_scores = torch.stack(all_risk_scores).mean(dim=0) | |
| # Average attentions if requested | |
| image_attention = None | |
| volume_attention = None | |
| if return_attentions: | |
| image_attention = torch.stack(all_image_attentions).mean(dim=0) | |
| volume_attention = torch.stack(all_volume_attentions).mean(dim=0) | |
| return SybilOutput( | |
| risk_scores=risk_scores, | |
| image_attention=image_attention, | |
| volume_attention=volume_attention | |
| ) |