StateMINT
StateMINT is a neural emulator for malariasimulation outputs. This model repository contains two exported inference artifacts:
prevalence/: predicts malaria prevalence over time.cases/: predicts malaria cases over time.
Both artifacts use the same Mamba2Regressor architecture but have separate weights and preprocessing metadata. Users should load the folder that matches the target they want to predict.
Repository Layout
.
βββ prevalence/
β βββ checkpoint/
β βββ model_config.json
β βββ preprocessing_config.json
βββ cases/
β βββ checkpoint/
β βββ model_config.json
β βββ preprocessing_config.json
βββ README.md
Each target folder is self-contained:
checkpoint/contains model-only Orbax checkpoint data.model_config.jsoncontains the model architecture settings needed to instantiateMamba2Regressor.preprocessing_config.jsoncontains feature ordering, intervention timing, target transform settings, and the fitted static covariate scaler.
Intended Use
These models are intended for emulating trajectories generated by malariasimulation-style simulation inputs. They are designed for research and analysis workflows where fast approximate prediction of simulated prevalence or cases is useful.
They are not intended for direct clinical decision-making or for use on real-world surveillance data without additional validation.
Installation
Install from PyPI and the Hugging Face Hub client:
pip install mintstate
# For GPU support, install with the `[gpu]` extra:
pip install mintstate[gpu]
If installing from source:
git clone https://github.com/mrc-ide/stateMINT.git
cd stateMINT
pip install -e .
# For GPU support, install with the `[gpu]` extra:
pip install -e .[gpu]
Loading A Model
Recommended high-level API:
from stateMINT.model import Mamba2Regressor
artifact = Mamba2Regressor.from_pretrained(
"dide-ic/stateMINT",
predictor="prevalence",
revision="v1.2.0",
)
model = artifact.model
To load the cases model:
from stateMINT.model import Mamba2Regressor
artifact = Mamba2Regressor.from_pretrained(
"dide-ic/stateMINT",
predictor="cases",
revision="v1.2.0",
)
model = artifact.model
from_pretrained returns a ModelArtifact containing:
artifact.model: the restoredMamba2Regressor.artifact.preprocessing_config: the exported preprocessing metadata.artifact.scaler: the fitted static covariate scaler.artifact.prepare_inputs(...): converts raw static covariate dictionaries into model inputs.artifact.predict(...): predicts directly from raw static covariate dictionaries.
Example Use Case
Use the prevalence model to emulate malaria prevalence trajectories for two intervention scenarios. Provide one raw static covariate dictionary per trajectory:
from stateMINT.model import Mamba2Regressor
artifact = Mamba2Regressor.from_pretrained(
"dide-ic/stateMINT",
predictor="prevalence",
revision="v1.2.0",
)
static_covars = [
{
"eir": 50.0,
"dn0_use": 0.3,
"dn0_future": 0.4,
"Q0": 0.8,
"phi_bednets": 0.7,
"seasonal": 1.0,
"routine": 0.5,
"itn_use": 0.2,
"irs_use": 0.1,
"itn_future": 0.3,
"irs_future": 0.2,
"lsm": 0.0,
},
{
"eir": 120.0,
"dn0_use": 0.2,
"dn0_future": 0.2,
"Q0": 0.9,
"phi_bednets": 0.6,
"seasonal": 1.0,
"routine": 0.4,
"itn_use": 0.3,
"irs_use": 0.0,
"itn_future": 0.5,
"irs_future": 0.1,
"lsm": 0.2,
},
]
predicted_prevalence = artifact.predict(static_covars)
print(predicted_prevalence.shape) # (2, n_steps)
print(predicted_prevalence[0]) # prevalence trajectory for the first scenario
For case-count predictions, load predictor="cases" and call the same .predict(...) method:
artifact = Mamba2Regressor.from_pretrained(
"dide-ic/stateMINT",
predictor="cases",
revision="v1.2.0",
)
predicted_cases = artifact.predict(static_covars)
By default, .predict(...) returns predictions on the original target scale:
- prevalence predictions are probabilities in
[0, 1]; - cases predictions are case counts.
To get predictions in the model's transformed training space, pass transformed=True:
raw_predictions = artifact.predict(static_covars, transformed=True)
If you use the low-level model directly, it expects already-prepared arrays with shape:
(batch, time, input_size)
import jax.numpy as jnp
X = artifact.prepare_inputs(static_covars)
raw_predictions = artifact.model(jnp.asarray(X))
Preprocessing Contract
The model was trained on transformed inputs, not raw covariates. To reproduce training-time behavior, users must apply the same preprocessing described in preprocessing_config.json.
The static covariates are expected in this order:
eir
dn0_use
dn0_future
Q0
phi_bednets
seasonal
routine
itn_use
irs_use
itn_future
irs_future
lsm
The following covariates are zeroed before the intervention day:
dn0_future
itn_future
irs_future
lsm
routine
Each timestep uses:
time_normalized / cyclized , scaled_static_covariates, post_intervention_flag, years_since_intervention
The static covariates must be standardized using the fitted scaler stored in preprocessing_config.json:
scaled_static = (raw_static - scaler_mean) / scaler_scale
Do not refit the scaler for inference. The exported scaler is part of the trained model.
Prediction Scale
The model predicts in the transformed target space used during training.
For prevalence:
prevalence = sigmoid(raw_prediction)
For cases:
cases = expm1(raw_prediction)
StateMINT utilities may perform this inverse transform for you depending on the prediction helper being used.
General Notes
- The
prevalenceandcasesfolders have separate checkpoints and separate fitted scalers. Always load the folder corresponding to the target being predicted.