Image-to-Image
BiRefNet
Safetensors
background-removal
mask-generation
Dichotomous Image Segmentation
Camouflaged Object Detection
Salient Object Detection
pytorch_model_hub_mixin
model_hub_mixin
custom_code
Instructions to use not-lain/BiRefNet with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- BiRefNet
How to use not-lain/BiRefNet with BiRefNet:
# Option 1: use with transformers from transformers import AutoModelForImageSegmentation birefnet = AutoModelForImageSegmentation.from_pretrained("not-lain/BiRefNet", trust_remote_code=True)# Option 2: use with BiRefNet # Install from https://github.com/ZhengPeng7/BiRefNet from models.birefnet import BiRefNet model = BiRefNet.from_pretrained("not-lain/BiRefNet") - Notebooks
- Google Colab
- Kaggle
| from typing import Dict, List, Any | |
| import base64 | |
| from io import BytesIO | |
| import torch | |
| from loadimg import load_img | |
| from torchvision import transforms | |
| from transformers import AutoModelForImageSegmentation | |
| torch.set_float32_matmul_precision(["high", "highest"][0]) | |
| birefnet = AutoModelForImageSegmentation.from_pretrained( | |
| "ZhengPeng7/BiRefNet", trust_remote_code=True | |
| ) | |
| birefnet.to("cuda") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| transform_image = transforms.Compose( | |
| [ | |
| transforms.Resize((1024, 1024)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| class EndpointHandler(): | |
| def __init__(self, path=""): | |
| self.birefnet = AutoModelForImageSegmentation.from_pretrained( | |
| "ZhengPeng7/BiRefNet", trust_remote_code=True | |
| ) | |
| self.birefnet.to(device) | |
| def __call__(self, data: Dict[str, Any]): | |
| """ | |
| data args: | |
| inputs (:obj: `str`) | |
| date (:obj: `str`) | |
| Return: | |
| A :obj:`list` | `dict`: will be serialized and returned | |
| """ | |
| print('data["inputs"] = ',data["inputs"]) | |
| image = load_img(data["inputs"]).convert("RGB") | |
| image_size = image.size | |
| input_images = transform_image(image).unsqueeze(0).to("cuda") | |
| # Prediction | |
| with torch.no_grad(): | |
| preds = birefnet(input_images)[-1].sigmoid().cpu() | |
| pred = preds[0].squeeze() | |
| pred_pil = transforms.ToPILImage()(pred) | |
| mask = pred_pil.resize(image_size) | |
| image.putalpha(mask) | |
| # buffered = BytesIO() | |
| # image.save(buffered, format="JPEG") | |
| # img_str = base64.b64encode(buffered.getvalue()) | |
| return image |