Portfolio
Research / Keck USC

Multi-Modal Missing-Modality VLM + Retrieval-Augmented VQA for Alzheimer's Detection

2,363

ADNI subjects

~70M

Model params

0.707

DX 3-class Bal. Acc.

0.933

CN vs Dem Bal. Acc.

A custom vision-language model that ingests T1 MRI, DTI FA maps, and clinical scores to predict six tasks, including preclinical amyloid detection. Extended with a FAISS-based retrieval pipeline and a three-way LLM comparison (Mistral 7B, Gemma 4 26B MoE, MedGemma 1.5 4B) for visual question answering over brain scans.

PyTorch3D ResNet-18Cross-AttentionCLIPFAISSRAGMistral 7BGemma 4 26BMedGemmaADNI
01

The Clinical Ask

Under the NIA-AA A/T/N framework, cognitively normal (CN) individuals who test positive for amyloid are already on the Alzheimer's disease biological continuum. This is preclinical AD: no symptoms yet, but a 3 to 5x increased risk of progressing to MCI or dementia within 3 to 5 years. Anti-amyloid therapies (Lecanemab, Donanemab) work best at this stage.

The problem is finding these people. Amyloid PET scans cost $3,000 to $6,000 each. CSF draws are invasive. Neither works at screening scale. This project asks whether structural MRI + DTI + routine clinical scores can do that job. No PET infrastructure needed. Just the imaging and labs most patients already have.

Primary target

Detect amyloid positivity in CN subjects well enough to work as a rule-out triage tool. High specificity matters here: if the model predicts CN amyloid-negative, those subjects can skip the expensive PET scan.

02

Dataset

All data comes from the Alzheimer's Disease Neuroimaging Initiative (ADNI). After filtering to subjects with valid DX labels and 9DOF T1 paths, deduplicating to one scan per subject, and recovering DTI paths from an earlier dataset cut, the combined v3 cohort is 2,363 subjects with DTI coverage of 39.4% (nearly doubled from the earlier 19.8%).

80 / 20 stratified split by diagnosis

ClassTrainTestTotal
CN - Cognitively Normal669168837
MCI - Mild Cognitive Impairment650163813
Dementia570143713
Total1,8894742,363

100%

T1 MRI (9DOF 2mm)

39.4%

DTI FA coverage

~100%

Clinical scores

03

Model Architecture

The model is a multi-modal vision-language model with missing-modality support. Three modality-specific encoders produce ℓ2-normalized 512-d embeddings, each gated by a per-modality masking probability during training. The masked embeddings are fused via 8-head cross-attention with a learnable pool query, producing a fused representation z_f ∈ ℝ⁵¹² that feeds six MLP task heads.

T1 MRI91×109×91DTI FA91×109×91Clinical6 feats + APOE3D ResNet-18→ 512-d3D ResNet-18→ 512-dClinical MLP→ 512-dz_T1ℓ2, ℝ⁵¹²z_DTIℓ2, ℝ⁵¹²z_Clinℓ2, ℝ⁵¹²×drop 10%×drop 30%×drop 5%Cross-Attention8 headspool query+ modality emb.Fusionz_fℝ⁵¹²DX 3-class→ 3DX Binary→ 2Sex→ 2Age→ 1CDR-SB→ 1Amyloid→ 2

Three modality-specific encoders produce ℓ2-normalized 512-d embeddings. Each passes through a masking gate that randomly drops modalities during training (T1: 10%, DTI: 30%, Clinical: 5%). Masked embeddings are fused via 8-head cross-attention with a learnable pool query, producing z_f ∈ ℝ⁵¹² that feeds six MLP task heads. Amyloid head upweighted as the primary clinical target.

Imaging encoders

Two independent 3D ResNet-18 networks for T1 MRI and DTI FA maps. Input volumes are 91×109×91, output is 512-d followed by a linear projection + LayerNorm.

Clinical encoder

MLP over 6 continuous features (CDR-SB, ADAS-11/13, MMSE, MoCA, AV45 SUVR) plus an APOE genotype embedding. Output: 512-d, ℓ2-normalized.

Design decision: no label leakage

An earlier iteration fed DX_code, SEX_code, and Amyloid_code into the clinical encoder. Those are the same variables the task heads are trying to predict, so the accuracy numbers (DX3 at 97.6%, sex at 98.3%) were meaningless. The v2/v3 model removes these inputs entirely. DX, sex, age, and amyloid are prediction targets only; the clinical encoder receives 7 features. Every number on this page is leakage-free.

04

Training Procedure

Training runs in two stages. Stage 1 is contrastive pre-training: a pairwise CLIP/InfoNCE loss between all three modality pairs (T1–DTI, T1–Clinical, DTI–Clinical), computed only on pairs where both modalities are present. Stage 2B is multi-task fine-tuning across six heads with differential learning rates (backbone 10⁻⁵, heads 5×10⁻⁴).

During training, modalities are randomly dropped so the model learns to work with any subset at inference time. T1 is dropped 10% of the time, DTI 30%, Clinical 5%. DTI gets the highest drop rate because only 39.4% of subjects have it, so the model needs to handle "missing DTI" as the normal case, not the exception.

Stage 1 · Contrastive

30 epochs · pairwise InfoNCE

lr 1×10⁻⁴ · AdamW · cosine anneal

Stage 2B · Multi-task

30 epochs · 6 joint heads

focal (γ=2) + smooth L1 · AMP FP16

0.300.410.520.630.740.85151015202530EpochBest (Ep.5)
DX3 Bal. Acc.
Amyloid Bal. Acc.
CN Amyloid Bal. Acc.
Sex Acc.

Stage 2B training history. Best composite score at epoch 5; later epochs show overfitting, particularly on amyloid. The model used downstream is the epoch-5 checkpoint.

05

Results

Evaluated on the held-out 474-subject test set using all available modalities. Headline metrics on the best model (Stage 2B, epoch 5):

TaskBal. Acc.Macro F1AUC
DX 3-class (CN / MCI / Dem)primary0.7070.7030.865
DX Binary (CN vs Dem)0.9330.9320.981
Sex0.5750.5630.597
Amyloid (A𝛽+ / A𝛽−)0.7330.7330.806
CN Amyloid - preclinicalprimary0.6880.6950.685
MCI Amyloid0.6040.5950.723
Dementia Amyloid0.4840.4560.906
Age (years, MAE ↓)6.31
CDR-SB (MAE ↓)0.97

DX 3-class confusion matrix · full test set (n=474)

CN
MCI
Dem
CN
11971%
4829%
11%
168
MCI
4025%
9055%
3320%
163
Dem
21%
1813%
12386%
143

Rows = true class, columns = predicted. Each cell shows count and row-normalized %. MCI is the hardest class at 55% recall. It sits between CN and Dementia so the model hedges in both directions. Dementia recall is the strongest at 86%, with only 2 misclassified as CN.

CN amyloid (preclinical screen)

0.846 spec · 0.529 sens

High specificity, moderate sensitivity. That tradeoff is intentional for a triage screen. When the model predicts CN amyloid-negative, it is right 85% of the time.

Dementia amyloid

0.906 AUC

By the time disease is established, the signal is clear. Amyloid-positivity in dementia is correctly identified almost 97% of the time (sens = 0.969).

06

Modality Ablation

Every one of the seven possible modality subsets was evaluated using the same trained model, with non-selected modalities masked at inference time. This shows where the signal actually comes from and which combination works best for each task.

DX 3-class (Bal. Acc.)

CN / MCI / Dementia

T1 + DTI + Clin
0.707
T1 + Clin
0.692
DTI + Clin
0.701
Clin only
0.700
T1 + DTI
0.598
T1 only
0.587
DTI only
0.388

DX Binary (Bal. Acc.)

CN vs Dementia

T1 + DTI + Clin
0.933
T1 + Clin
0.932
DTI + Clin
0.938
Clin only
0.938
T1 + DTI
0.848
T1 only
0.833
DTI only
0.528

Amyloid (Bal. Acc.)

Overall positivity

T1 + DTI + Clin
0.733
T1 + Clin
0.699
DTI + Clin
0.746
Clin only
0.713
T1 + DTI
0.621
T1 only
0.587
DTI only
0.575

CN Amyloid (Bal. Acc.)

Preclinical screen

T1 + DTI + Clin
0.688
T1 + Clin
0.642
DTI + Clin
0.623
Clin only
0.557
T1 + DTI
0.549
T1 only
0.494
DTI only
0.552

DTI+Clin wins three of the four tasks (DX Binary, overall Amyloid, tied on DX 3-class). For preclinical CN amyloid though, you need all three: T1 + DTI + Clinical reaches 0.688 Bal. Acc. vs. Clinical-only at 0.557, a 13.1 point gap on the hardest task. DTI alone is the weakest combination overall, but it still carries real signal for severe dementia cases.

07

Retrieval-Augmented VQA Extension

The VLM gives you a prediction and a confidence score. What it doesn't give you is an explanation, or any way to ask follow-up questions in natural language. That's what the VQA extension adds. The frozen VLM encodes a query scan into a 512-d embedding. FAISS finds the 50 most similar training subjects by inner product. A cross-encoder reranks those 50 down to the top 5 most relevant matches. Those 5 captions become the context fed to a language model, which answers clinical questions about the scan.

FrozenT1 Enc3D ResNet-18DTI Enc3D ResNet-18Clin EncMLPAttentionFusion8-head, pool-qz_fℝ⁵¹²FAISS IndexIndexFlatIP1,889 vectorsCross-EncoderReranktop-50 → top-5LLMMistral 7BGemma 4 26BMedGemma 4BVQA AnswerCaptionSimilar Casestop-50top-5 context

Frozen multi-modal encoders produce the fused embedding z_f ∈ ℝ⁵¹². FAISS retrieves top-50 similar training subjects; a cross-encoder reranks down to top-5. The LLM generates answers from the retrieved context. Three LLM backbones are compared: Mistral 7B, Gemma 4 26B MoE, and MedGemma 1.5 4B.

Text encoder

all-MiniLM-L6-v2, MLM-pretrained on 26,889 clinical sentences (25K synthetic + 1,889 real captions), then contrastively aligned to the imaging embedding space. 384-d output projected up to 512-d.

Retrieval + rerank

FAISS IndexFlatIP over 1,889 ℓ2-normalized training vectors, exact inner-product search. Cross-encoder: ms-marco-MiniLM-L-6-v2, reranking top-50 to top-5.

08

LLM Backbone Comparison

Three models were given the same retrieved context: Mistral 7B Instruct v0.3 (general-purpose, dense), Gemma 4 26B MoE (larger, mixture-of-experts), and MedGemma 1.5 4B IT (smaller, fine-tuned on medical data). All quantized to 4-bit NF4. The question was simple: does a medical fine-tune beat a bigger general model on this task?

VQA Diagnosis

Standard · full modality

Mistral 7B
0.947
Gemma 4 26B
0.927
MedGemma 4B
0.507

VQA Diagnosis

DTI-only query

Mistral 7B
0.753
Gemma 4 26B
0.533
MedGemma 4B
0.627

BERTScore

contextual similarity

Mistral 7B
0.894
Gemma 4 26B
0.845
MedGemma 4B
0.823

SBERT CosSim

sentence-level

Mistral 7B
0.811
Gemma 4 26B
0.810
MedGemma 4B
0.428

Mistral 7B

7B dense

Gemma 4 26B

MoE

MedGemma 4B

medical FT

VQA Diagnosis

Standard · full modality

0.947
0.927
0.507

VQA Diagnosis

DTI-only query

0.753
0.533
0.627

BERTScore

contextual similarity

0.894
0.845
0.823

SBERT CosSim

sentence-level

0.811
0.810
0.428

Same retrieved context across all three models; only the generation model changes. Mistral 7B wins every metric: diagnosis VQA accuracy and text quality (BERTScore, SBERT). MedGemma's medical fine-tune loses to a general-purpose 7B model. At this size, instruction-following matters more than domain knowledge.

Headline finding

Scale beats domain. Mistral 7B, a general-purpose dense model, outperforms both a 26B MoE and a medically fine-tuned 4B on every metric. The retrieved context already supplies the medical knowledge. What matters is whether the model can follow instructions and format its output correctly.

09

Key Findings

Imaging adds real signal to preclinical AD screening

CN amyloid Bal. Acc. goes from 0.557 (Clinical-only) to 0.688 (T1 + DTI + Clinical), a 13.1 point gain on the hardest and most clinically useful task. Specificity sits at 0.846, which is what you want for a triage tool.

The model works even without DTI

Modality dropout during training (T1 10%, DTI 30%, Clinical 5%) means the model handles any combination at inference. In practice it works for the 60% of subjects who only have T1 + Clinical, not just the 39.4% with full DTI coverage.

DTI + Clinical is strong, except for CN amyloid

DTI + Clinical wins DX Binary (0.938) and overall Amyloid (0.746). For CN amyloid specifically though, T1 matters: DTI + Clinical drops to 0.623 while the full stack reaches 0.688.

Mistral 7B is the best VQA model here

A general-purpose 7B dense model beats a 26B MoE and a medically fine-tuned 4B on both diagnosis accuracy (94.7% vs 92.7% vs 50.7%) and text quality. The retrieved context does the medical heavy lifting. The model just needs to read it and respond clearly.

VQA holds up even when retrieval is poor

Under a DTI-only query, FAISS retrieval at @5 drops to 40.9%. Mistral 7B still reaches 75.3% VQA accuracy on those same queries. The LLM can extract useful signal even from partially mismatched context.

Status

Paper in preparation · code to be released on publication

This work was completed at the Keck School of Medicine of USC. Once the paper is submitted I'll link the manuscript and GitHub repo here.

Back to portfolioGitHub profile