Bio_ClinicalBERT_MIMIC_IV_in_hospital_mortality_prediction_IA3_ti

This model is designed to predict in-hospital mortality (i.e., likely face death in their upcoming / current visit) from prior hospital records. It is trained on clinical notes from prior hospitalizations on MIMIC-IV. Model was trained on a novel tabular-infused IA3, whereby the pre-operative tabular features (e.g., patient demographics and insurance information) were used to initialize the newly introduced IA3 parameters.

Model Details

How to use model

from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained("cja5553/Bio_ClinicalBERT_MIMIC_IV_in_hospital_mortality_prediction_IA3_ti")
model = AutoModelForSequenceClassification.from_pretrained("cja5553/Bio_ClinicalBERT_MIMIC_IV_in_hospital_mortality_prediction_IA3_ti")

Then you can use this function below to get one test point

import torch

def get_outcome(tokenizer, model, text, device="cuda:0", max_length=512):

    device = torch.device(device)
    model = model.to(device)
    model.eval()

    inputs = tokenizer(
        text,
        return_tensors="pt",
        max_length=max_length,
        truncation=True,
        padding="max_length"
    ).to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.softmax(outputs.logits, dim=-1)[0]  # (2,)

    probs = probs.detach().cpu().numpy()
    result = {
        "False": float(probs[0]),
        "True": float(probs[1])
    }

    return result

Questions?

Contact me at alba@wustl.edu

Downloads last month
1
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for cja5553/Bio_ClinicalBERT_MIMIC_IV_in_hospital_mortality_prediction_IA3_ti

Adapter
(18)
this model

Collection including cja5553/Bio_ClinicalBERT_MIMIC_IV_in_hospital_mortality_prediction_IA3_ti