Files
Klima-KI/app/climatenet.py
2026-04-13 11:32:51 +02:00

772 lines
30 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Urban Climate U-Net
===================
Predicts nocturnal T_air (2 m AGL) and daytime PET at 5 m resolution
from morphological raster inputs, with aleatoric uncertainty estimation.
Architecture overview
---------------------
Shared U-Net encoderdecoder (~58 M parameters)
Multi-scale input:
fine — 256×256 patch at 5 m (6 channels)
coarse — 64×64 patch at 25 m (6 channels, broader context)
Encoder: 4 ConvBlock levels, Dropout2d on levels 24
Bottleneck: fuses fine enc-4 + coarse context, Dropout2d 0.2
Decoder: 4 UpBlock levels with skip connections, Dropout2d on level 3
Heads:
t_air_head Conv 32→1 predicted mean T_air
pet_head Conv 32→1 predicted mean PET
log_sigma_head Conv 32→2 log-σ for T_air and PET (aleatoric)
Uncertainty
-----------
Aleatoric (irreducible noise):
The model outputs σ per pixel alongside the mean prediction.
No label is provided — the Gaussian NLL loss teaches the model
to self-calibrate σ:
Loss = (y μ)² / (2σ²) + ½ log(σ²)
Overclaiming certainty on a wrong prediction is penalised by the
first term; blanket hedging is penalised by the second. σ converges
to reflect genuine local predictability from morphology.
Epistemic (model ignorance, added in v1.1):
MC Dropout: keep Dropout2d active at inference, run N stochastic
forward passes, take std(predictions) as epistemic uncertainty.
σ_total = √(σ_aleatoric² + σ_epistemic²)
Input channels (fine and coarse, identical schema)
---------------------------------------------------
0 Building height m
1 Impervious fraction 01
2 Vegetation cover 01
3 Albedo 01
4 DEM elevation m (z-scored per city)
5 Slope degrees (z-scored)
6 Distance to water/large green m (z-scored)
Labels (from FITNAH-3D or equivalent hi-fi simulation)
-------------------------------------------------------
0 T_air nocturnal air temperature at 2 m [°C delta from city mean]
1 PET daytime PET at 1.1 m [°C delta from city mean]
Normalise labels city-by-city (subtract city mean, divide by city std)
before training to remove macro-climate offsets. Denormalise for output.
Dataset layout
--------------
root/
city_berlin/
fine/ 0001.npy … (6, 256, 256) float32
coarse/ 0001.npy … (6, 64, 64) float32
labels/ 0001.npy … (2, 256, 256) float32
masks/ 0001.npy … (1, 256, 256) bool [optional]
city_munich/
...
Requirements
------------
torch>=2.0 numpy (training only)
onnxruntime (QGIS inference, no torch needed)
"""
from __future__ import annotations
from pathlib import Path
from typing import NamedTuple
import numpy as np
import torch # type: ignore
import torch.nn as nn # type: ignore
import torch.nn.functional as F # type: ignore
from torch.utils.data import DataLoader, Dataset # type: ignore
# ---------------------------------------------------------------------------
# Output container
# ---------------------------------------------------------------------------
class ModelOutput(NamedTuple):
"""Typed output tuple — also valid as a plain tuple for ONNX export."""
t_air_mean: torch.Tensor # (B, 1, H, W)
t_air_sigma: torch.Tensor # (B, 1, H, W) aleatoric σ in °C
pet_mean: torch.Tensor # (B, 1, H, W)
pet_sigma: torch.Tensor # (B, 1, H, W)
# ---------------------------------------------------------------------------
# Building blocks
# ---------------------------------------------------------------------------
class ConvBlock(nn.Module):
"""
Two Conv(3×3)BNReLU layers with optional spatial dropout.
Dropout2d zeros entire feature maps rather than individual units.
This gives better spatial calibration and is standard for dense
prediction tasks.
"""
def __init__(self, in_ch: int, out_ch: int, dropout_p: float = 0.0):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
# Dropout2d after activation: drop whole feature maps stochastically.
# During MC Dropout inference we re-enable these with enable_dropout().
self.drop = nn.Dropout2d(p=dropout_p) if dropout_p > 0.0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.drop(self.conv(x))
class UpBlock(nn.Module):
"""
Bilinear upsample (×2) followed by ConvBlock with skip concatenation.
Bilinear + 1×1 conv is preferred over ConvTranspose2d because it
avoids checkerboard artefacts common in urban grid-like rasters.
"""
def __init__(self, in_ch: int, skip_ch: int, out_ch: int, dropout_p: float = 0.0):
super().__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False),
)
self.conv = ConvBlock(out_ch + skip_ch, out_ch, dropout_p=dropout_p)
def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
x = self.up(x)
# Guard against ±1 pixel size mismatch from odd spatial dimensions.
if x.shape[-2:] != skip.shape[-2:]:
x = F.interpolate(
x, size=skip.shape[-2:], mode="bilinear", align_corners=False
)
return self.conv(torch.cat([x, skip], dim=1))
class CoarseEncoder(nn.Module):
"""
Lightweight encoder for the 25 m context patch (64×64 input).
Produces a 16×16 feature map that matches the bottleneck spatial
size so it can be directly concatenated before the bottleneck ConvBlock.
This implements the KLIMASCANNER principle of "decreasing information
density with distance": the coarse branch captures the macro-climate
context (e.g. a large park, urban fringe effect) that is outside the
fine patch's 1.28 km window but still influences local microclimate.
"""
def __init__(self, in_ch: int, out_ch: int = 128):
super().__init__()
self.net = nn.Sequential(
ConvBlock(in_ch, 32), # 64×64
nn.MaxPool2d(2), # 32×32
ConvBlock(32, 64), # 32×32
nn.MaxPool2d(2), # 16×16
ConvBlock(64, out_ch), # 16×16 ← same spatial as bottleneck
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
# ---------------------------------------------------------------------------
# Main model
# ---------------------------------------------------------------------------
class UrbanClimateUNet(nn.Module):
"""
U-Net for urban climate prediction with aleatoric uncertainty.
Parameters
----------
in_ch : Input channels (same for fine and coarse patches).
base_ch : Feature channels at Enc-1; doubles each encoder level.
coarse_ch : Output channels of CoarseEncoder; fused at bottleneck.
dropout_enc : Dropout2d probability for encoder levels 2 and 3.
dropout_bot : Dropout2d probability for encoder level 4 and bottleneck.
dropout_dec : Dropout2d probability for decoder level 3 (Dec-3).
"""
def __init__(
self,
in_ch: int = 6,
base_ch: int = 64,
coarse_ch: int = 128,
dropout_enc: float = 0.1,
dropout_bot: float = 0.2,
dropout_dec: float = 0.1,
):
super().__init__()
c = base_ch # shorthand: 64
# ── Encoder ──────────────────────────────────────────────────────────
# Spatial dropout is deliberately absent at Enc-1: the first layer
# needs all raw input channels to avoid discarding morphological signal
# before any abstractions are built.
self.enc1 = ConvBlock(in_ch, c, dropout_p=0.0) # 256²
self.enc2 = ConvBlock(c, c * 2, dropout_p=dropout_enc) # 128²
self.enc3 = ConvBlock(c * 2, c * 4, dropout_p=dropout_enc) # 64²
self.enc4 = ConvBlock(c * 4, c * 8, dropout_p=dropout_bot) # 32²
self.pool = nn.MaxPool2d(2)
# ── Coarse context ───────────────────────────────────────────────────
if coarse_ch > 0: self.coarse_enc = CoarseEncoder(in_ch, coarse_ch)
# ── Bottleneck ───────────────────────────────────────────────────────
# Receives: enc4-pooled (c*8 = 512 ch) + coarse context (128 ch)
# → 640 ch input, 512 ch output
self.bottleneck = ConvBlock(c * 8 + coarse_ch, c * 8, dropout_p=dropout_bot)
# ── Decoder ──────────────────────────────────────────────────────────
# UpBlock(in_ch, skip_ch, out_ch):
# in_ch = channels coming up from the previous decoder level
# skip_ch = channels of the matching encoder skip connection
# out_ch = channels output by this decoder level
self.dec4 = UpBlock(c * 8, c * 8, c * 4) # 32²
self.dec3 = UpBlock(c * 4, c * 4, c * 2, dropout_p=dropout_dec) # 64²
self.dec2 = UpBlock(c * 2, c * 2, c) # 128²
self.dec1 = UpBlock(c, c, c // 2) # 256²
# ── Output heads (1×1 convolutions) ─────────────────────────────────
self.t_air_head = nn.Conv2d(c // 2, 1, kernel_size=1)
self.pet_head = nn.Conv2d(c // 2, 1, kernel_size=1)
# Single head outputs 2 channels: [log_σ_t_air, log_σ_pet]
# Using log-σ for numerical stability; softplus converts to σ > 0.
self.log_sigma_head = nn.Conv2d(c // 2, 2, kernel_size=1)
self._init_weights()
def _init_weights(self) -> None:
"""He initialisation for Conv layers; zero-init the σ head bias."""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
# Initialise σ head to predict moderate uncertainty (~1 °C) at start.
# Without this, early training can be unstable if σ starts near 0.
nn.init.zeros_(self.log_sigma_head.bias)
def forward(self, fine: torch.Tensor, coarse: torch.Tensor | None = None) -> ModelOutput:
"""
Parameters
----------
fine : (B, in_ch, 256, 256) 5 m resolution morphology patch
coarse : (B, in_ch, 64, 64) 25 m resolution context patch
Returns
-------
ModelOutput with fields t_air_mean, t_air_sigma, pet_mean, pet_sigma,
each shaped (B, 1, 256, 256).
"""
# ── Encoder ──────────────────────────────────────────────────────────
e1 = self.enc1(fine) # (B, 64, 256, 256)
e2 = self.enc2(self.pool(e1)) # (B, 128, 128, 128)
e3 = self.enc3(self.pool(e2)) # (B, 256, 64, 64)
e4 = self.enc4(self.pool(e3)) # (B, 512, 32, 32)
# ── Coarse context ───────────────────────────────────────────────────
if self.coarse_enc: ctx = self.coarse_enc(coarse) # (B, 128, 16, 16)
# ── Bottleneck ───────────────────────────────────────────────────────
x = self.pool(e4) # (B, 512, 16, 16)
if self.coarse_enc: x = torch.cat([x, ctx], dim=1) # (B, 640, 16, 16)
x = self.bottleneck(x) # (B, 512, 16, 16)
# ── Decoder ──────────────────────────────────────────────────────────
x = self.dec4(x, e4) # (B, 256, 32, 32)
x = self.dec3(x, e3) # (B, 128, 64, 64)
x = self.dec2(x, e2) # (B, 64, 128, 128)
x = self.dec1(x, e1) # (B, 32, 256, 256)
# ── Output heads ─────────────────────────────────────────────────────
t_air_mean = self.t_air_head(x) # (B, 1, 256, 256)
pet_mean = self.pet_head(x) # (B, 1, 256, 256)
log_sigma = self.log_sigma_head(x) # (B, 2, 256, 256)
# softplus is smooth and always positive; 1e-4 floor prevents σ → 0.
sigma = F.softplus(log_sigma) + 1e-4
t_air_sigma = sigma[:, 0:1] # (B, 1, 256, 256)
pet_sigma = sigma[:, 1:2] # (B, 1, 256, 256)
return ModelOutput(t_air_mean, t_air_sigma, pet_mean, pet_sigma)
# ---------------------------------------------------------------------------
# Loss function — Gaussian NLL
# ---------------------------------------------------------------------------
def gaussian_nll_loss(
mean: torch.Tensor,
sigma: torch.Tensor,
target: torch.Tensor,
mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Gaussian negative log-likelihood.
L = (y μ)² / (2σ²) + ½ log(σ²)
This is the mechanism by which the model learns σ without any label:
• First term: penalises wrong predictions; large σ softens the penalty.
• Second term: penalises large σ; prevents the model from always hedging.
The model converges to the σ that minimises the sum for each pixel type.
Parameters
----------
mean : Predicted mean (B, 1, H, W)
sigma : Predicted sigma (B, 1, H, W), must be > 0
target : FITNAH label (B, 1, H, W)
mask : Optional valid-pixel mask (B, 1, H, W) bool.
Pixels outside the simulation extent should be masked out.
"""
nll = (target - mean) ** 2 / (2.0 * sigma ** 2) + 0.5 * torch.log(sigma ** 2)
if mask is not None:
nll = nll[mask]
return nll.mean()
def combined_loss(
preds: ModelOutput,
t_air_target: torch.Tensor,
pet_target: torch.Tensor,
mask: torch.Tensor | None = None,
t_air_weight: float = 1.0,
pet_weight: float = 1.0,
) -> torch.Tensor:
"""Weighted sum of NLL losses for T_air and PET."""
l_t = gaussian_nll_loss(preds.t_air_mean, preds.t_air_sigma, t_air_target, mask)
l_p = gaussian_nll_loss(preds.pet_mean, preds.pet_sigma, pet_target, mask)
return t_air_weight * l_t + pet_weight * l_p
# ---------------------------------------------------------------------------
# MC Dropout inference
# ---------------------------------------------------------------------------
def enable_dropout(model: nn.Module) -> None:
"""
Enable only Dropout2d layers for MC inference.
Keeps BatchNorm in eval mode (uses running statistics, not batch stats).
This is the correct approach: BN.train() during inference introduces
noise from the current batch's statistics, which would contaminate the
epistemic uncertainty estimate.
"""
model.eval()
for m in model.modules():
if isinstance(m, (nn.Dropout, nn.Dropout2d)):
m.train()
@torch.no_grad()
def predict_with_uncertainty(
model: UrbanClimateUNet,
fine: torch.Tensor,
coarse: torch.Tensor,
n_passes: int = 20,
device: torch.device = torch.device("cpu"),
) -> dict[str, object]:
"""
Single-sample inference with full uncertainty decomposition.
Aleatoric σ = mean of per-pass σ predictions
(model's self-reported irreducible noise)
Epistemic σ = std of per-pass mean predictions
(model's disagreement = out-of-distribution signal)
Total σ = √(σ_al² + σ_ep²)
The scene-level trust score mirrors the KLIMASCANNER Vertrauenswert:
gut ≥ 80 % of pixels with total σ < 0.5 °C
befriedigend ≥ 50 %
ausreichend < 50 %
For the MVP (aleatoric only), set n_passes=1 and ignore the
epistemic fields — they will be zero.
Parameters
----------
model : Trained UrbanClimateUNet.
fine : (1, in_ch, 256, 256)
coarse : (1, in_ch, 64, 64)
n_passes : Stochastic forward passes. 1 = aleatoric only (fast).
20 = full decomposition (recommended for production).
"""
enable_dropout(model)
fine = fine.to(device)
coarse = coarse.to(device)
t_means, t_sigmas, p_means, p_sigmas = [], [], [], []
for _ in range(n_passes):
out = model(fine, coarse)
t_means.append(out.t_air_mean)
t_sigmas.append(out.t_air_sigma)
p_means.append(out.pet_mean)
p_sigmas.append(out.pet_sigma)
def decompose(
means: list[torch.Tensor], sigmas: list[torch.Tensor]
) -> tuple[torch.Tensor, ...]:
m = torch.stack(means) # (n_passes, B, 1, H, W)
s = torch.stack(sigmas)
mu = m.mean(0) # mean prediction
sigma_al = s.mean(0) # aleatoric: mean of σ
sigma_ep = m.std(0) if n_passes > 1 else torch.zeros_like(mu)
sigma_tot = (sigma_al ** 2 + sigma_ep ** 2).sqrt()
return mu, sigma_al, sigma_ep, sigma_tot
t_mu, t_al, t_ep, t_tot = decompose(t_means, t_sigmas)
p_mu, p_al, p_ep, p_tot = decompose(p_means, p_sigmas)
pct_good = (t_tot < 0.5).float().mean().item()
trust = (
"gut" if pct_good >= 0.8 else
"befriedigend" if pct_good >= 0.5 else
"ausreichend"
)
model.eval()
return {
"t_air_mean": t_mu,
"t_air_sigma_aleatoric": t_al,
"t_air_sigma_epistemic": t_ep,
"t_air_sigma_total": t_tot,
"pet_mean": p_mu,
"pet_sigma_aleatoric": p_al,
"pet_sigma_epistemic": p_ep,
"pet_sigma_total": p_tot,
"trust_score": trust, # "gut" / "befriedigend" / "ausreichend"
"pct_pixels_good": pct_good, # fraction of pixels with σ < 0.5 °C
}
# ---------------------------------------------------------------------------
# Dataset
# ---------------------------------------------------------------------------
class UrbanClimateDataset(Dataset):
"""
Loads pre-rasterised patch triplets (fine, coarse, labels) from disk.
All inputs must be normalised before saving (z-score per channel per city).
Labels must be city-normalised (subtract city mean T/PET, divide by std)
so the model learns local spatial patterns, not absolute temperatures.
See dataset layout in module docstring.
"""
def __init__(self, root: str | Path, augment: bool = True):
self.augment = augment
self.samples: list[tuple[Path, Path, Path, Path | None]] = []
for city_dir in sorted(Path(root).iterdir()):
if not city_dir.is_dir():
continue
fine_dir = city_dir / "fine"
coarse_dir = city_dir / "coarse"
label_dir = city_dir / "labels"
mask_dir = city_dir / "masks"
for fine_p in sorted(fine_dir.glob("*.npy")):
stem = fine_p.stem
coarse_p = coarse_dir / f"{stem}.npy"
label_p = label_dir / f"{stem}.npy"
mask_p = (mask_dir / f"{stem}.npy") if mask_dir.exists() else None
if coarse_p.exists() and label_p.exists():
self.samples.append((fine_p, coarse_p, label_p, mask_p))
if not self.samples:
raise FileNotFoundError(f"No samples found under {root}")
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
fine_p, coarse_p, label_p, mask_p = self.samples[idx]
fine = torch.from_numpy(np.load(fine_p)).float() # (6, 256, 256)
coarse = torch.from_numpy(np.load(coarse_p)).float() # (6, 64, 64)
labels = torch.from_numpy(np.load(label_p)).float() # (2, 256, 256)
mask = (
torch.from_numpy(np.load(mask_p)).bool()
if mask_p and mask_p.exists()
else torch.ones(1, fine.shape[-2], fine.shape[-1], dtype=torch.bool)
)
if self.augment:
fine, coarse, labels, mask = self._random_rotate_flip(
fine, coarse, labels, mask
)
return {
"fine": fine,
"coarse": coarse,
"t_air": labels[0:1], # (1, 256, 256)
"pet": labels[1:2], # (1, 256, 256)
"mask": mask, # (1, 256, 256) bool
}
@staticmethod
def _random_rotate_flip(
fine: torch.Tensor,
coarse: torch.Tensor,
labels: torch.Tensor,
mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Random 90° rotation + horizontal flip.
Note: wind direction changes with rotation. This is intentional —
the model should learn morphological patterns, not wind-relative
ones (wind direction is a runtime input, not a training feature).
"""
k = int(torch.randint(4, ()).item())
dims = (-2, -1)
fine, coarse, labels, mask = (
torch.rot90(t, k, dims) for t in (fine, coarse, labels, mask)
)
if torch.rand(()).item() > 0.5:
fine, coarse, labels, mask = (
t.flip(-1) for t in (fine, coarse, labels, mask)
)
return fine, coarse, labels, mask
# ---------------------------------------------------------------------------
# Training
# ---------------------------------------------------------------------------
def train_one_epoch(
model: UrbanClimateUNet,
loader: DataLoader,
optimizer: torch.optim.Optimizer,
device: torch.device,
scaler: torch.cuda.amp.GradScaler | None = None,
) -> float:
model.train()
total = 0.0
for batch in loader:
fine = batch["fine"].to(device)
coarse = batch["coarse"].to(device)
t_air = batch["t_air"].to(device)
pet = batch["pet"].to(device)
mask = batch["mask"].to(device)
optimizer.zero_grad()
if scaler is not None:
with torch.autocast("cuda"):
preds = model(fine, coarse)
loss = combined_loss(preds, t_air, pet, mask)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
else:
preds = model(fine, coarse)
loss = combined_loss(preds, t_air, pet, mask)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total += loss.item()
return total / len(loader)
def train(
data_root: str,
epochs: int = 100,
batch_size: int = 8,
lr: float = 1e-4,
weight_decay: float = 1e-4,
checkpoint_dir: str = "checkpoints",
device_str: str = "cuda" if torch.cuda.is_available() else "cpu",
) -> UrbanClimateUNet:
"""
Full training loop with cosine LR schedule and mixed-precision.
Recommended hyperparameters for MVP (1015 training cities):
epochs=150, batch_size=8, lr=1e-4
For fine-tuning on a single new city:
epochs=50, batch_size=4, lr=3e-5 (load pretrained weights first)
"""
device = torch.device(device_str)
Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
dataset = UrbanClimateDataset(data_root, augment=True)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=min(4, batch_size),
pin_memory=(device.type == "cuda"),
persistent_workers=True,
)
model = UrbanClimateUNet().to(device)
optimizer = torch.optim.AdamW(
model.parameters(), lr=lr, weight_decay=weight_decay
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=epochs, eta_min=lr * 0.01
)
scaler = torch.cuda.amp.GradScaler() if device.type == "cuda" else None
n_params = sum(p.numel() for p in model.parameters())
print(f"Model : {n_params:,} parameters")
print(f"Dataset: {len(dataset):,} patches")
print(f"Device : {device}")
for epoch in range(1, epochs + 1):
loss = train_one_epoch(model, loader, optimizer, device, scaler)
scheduler.step()
if epoch % 10 == 0 or epoch == epochs:
ckpt = Path(checkpoint_dir) / f"epoch_{epoch:04d}_loss{loss:.4f}.pt"
torch.save(
{
"epoch": epoch,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
"loss": loss,
"hparams": {
"epochs": epochs,
"batch_size": batch_size,
"lr": lr,
},
},
ckpt,
)
print(f"Epoch {epoch:4d}/{epochs} loss {loss:.4f}{ckpt.name}")
return model
# ---------------------------------------------------------------------------
# ONNX export for QGIS plugin deployment
# ---------------------------------------------------------------------------
class _OnnxWrapper(nn.Module):
"""Thin wrapper so torch.onnx.export receives a plain tuple output."""
def __init__(self, model: UrbanClimateUNet):
super().__init__()
self.model = model
def forward(
self, fine: torch.Tensor, coarse: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
out = self.model(fine, coarse)
return out.t_air_mean, out.t_air_sigma, out.pet_mean, out.pet_sigma
def export_onnx(
model: UrbanClimateUNet,
output_path: str = "urban_climate_unet.onnx",
in_ch: int = 6,
opset: int = 16,
) -> None:
"""
Export trained model to ONNX for deployment in the QGIS plugin.
The plugin loads the .onnx file with onnxruntime (pure Python wheel,
~15 MB, no PyTorch dependency at inference time).
After export, verify with:
import onnxruntime as ort
sess = ort.InferenceSession("urban_climate_unet.onnx")
print([o.name for o in sess.get_outputs()])
"""
model.eval()
wrapper = _OnnxWrapper(model)
dummy_fine = torch.zeros(1, in_ch, 256, 256)
dummy_coarse = torch.zeros(1, in_ch, 64, 64)
torch.onnx.export(
wrapper,
(dummy_fine, dummy_coarse),
output_path,
input_names = ["fine_patch", "coarse_patch"],
output_names = ["t_air_mean", "t_air_sigma", "pet_mean", "pet_sigma"],
dynamic_axes = {
"fine_patch": {0: "batch"},
"coarse_patch": {0: "batch"},
},
opset_version = opset,
do_constant_folding = True,
)
print(f"ONNX model exported → {output_path}")
print("Load in QGIS plugin with: ort.InferenceSession('urban_climate_unet.onnx')")
# ---------------------------------------------------------------------------
# Quick sanity check
# ---------------------------------------------------------------------------
if __name__ == "__main__":
# torch.manual_seed(46)
# model = UrbanClimateUNet()
# model.eval()
# B = 2
# fine = torch.randn(B, 6, 256, 256)
# coarse = torch.randn(B, 6, 64, 64)
# # ── Forward pass ──────────────────────────────────────────────────────
# out = model(fine, coarse)
# print("=== Forward pass (batch=2) ===")
# for name, tensor in zip(out._fields, out):
# print(f" {name:<18} {tuple(tensor.shape)}")
# # ── Loss ──────────────────────────────────────────────────────────────
# t_tgt = torch.randn(B, 1, 256, 256)
# p_tgt = torch.randn(B, 1, 256, 256)
# loss = combined_loss(out, t_tgt, p_tgt)
# print(f"\nGaussian NLL loss : {loss.item():.4f}")
# # ── Parameter count ───────────────────────────────────────────────────
# n = sum(p.numel() for p in model.parameters())
# print(f"Parameters : {n:,}")
# # ── MC Dropout uncertainty (fast, 5 passes) ───────────────────────────
# result = predict_with_uncertainty(
# model, fine[0:1], coarse[0:1], n_passes=5
# )
# print(f"\n=== MC Dropout (5 passes) ===")
# print(f" Trust score : {result['trust_score']}")
# print(f" Pct pixels good : {result['pct_pixels_good']:.1%}")
# print(f" T_air σ aleat. : {result['t_air_sigma_aleatoric'].mean():.3f} °C")
# print(f" T_air σ epist. : {result['t_air_sigma_epistemic'].mean():.3f} °C")
# print(f" T_air σ total : {result['t_air_sigma_total'].mean():.3f} °C")
# # ── ONNX export (requires onnx package) ───────────────────────────────
# try:
# export_onnx(model, "urban_climate_unet.onnx")
# except Exception as e:temp_coarse
# print(f"\nONNX export skipped ({e})")
print(f"climatenet class loaded.")