basic forward pass test
This commit is contained in:
861
app/main.py
861
app/main.py
@@ -1,766 +1,95 @@
|
||||
"""
|
||||
Urban Climate U-Net
|
||||
===================
|
||||
Predicts nocturnal T_air (2 m AGL) and daytime PET at 5 m resolution
|
||||
from morphological raster inputs, with aleatoric uncertainty estimation.
|
||||
|
||||
Architecture overview
|
||||
---------------------
|
||||
Shared U-Net encoder–decoder (~5–8 M parameters)
|
||||
Multi-scale input:
|
||||
fine — 256×256 patch at 5 m (7 channels)
|
||||
coarse — 64×64 patch at 25 m (7 channels, broader context)
|
||||
Encoder: 4 ConvBlock levels, Dropout2d on levels 2–4
|
||||
Bottleneck: fuses fine enc-4 + coarse context, Dropout2d 0.2
|
||||
Decoder: 4 UpBlock levels with skip connections, Dropout2d on level 3
|
||||
Heads:
|
||||
t_air_head Conv 32→1 predicted mean T_air
|
||||
pet_head Conv 32→1 predicted mean PET
|
||||
log_sigma_head Conv 32→2 log-σ for T_air and PET (aleatoric)
|
||||
|
||||
Uncertainty
|
||||
-----------
|
||||
Aleatoric (irreducible noise):
|
||||
The model outputs σ per pixel alongside the mean prediction.
|
||||
No label is provided — the Gaussian NLL loss teaches the model
|
||||
to self-calibrate σ:
|
||||
|
||||
Loss = (y − μ)² / (2σ²) + ½ log(σ²)
|
||||
|
||||
Overclaiming certainty on a wrong prediction is penalised by the
|
||||
first term; blanket hedging is penalised by the second. σ converges
|
||||
to reflect genuine local predictability from morphology.
|
||||
|
||||
Epistemic (model ignorance, added in v1.1):
|
||||
MC Dropout: keep Dropout2d active at inference, run N stochastic
|
||||
forward passes, take std(predictions) as epistemic uncertainty.
|
||||
σ_total = √(σ_aleatoric² + σ_epistemic²)
|
||||
|
||||
Input channels (fine and coarse, identical schema)
|
||||
---------------------------------------------------
|
||||
0 Building height m
|
||||
1 Impervious fraction 0–1
|
||||
2 Vegetation cover 0–1
|
||||
3 Albedo 0–1
|
||||
4 DEM elevation m (z-scored per city)
|
||||
5 Slope degrees (z-scored)
|
||||
6 Distance to water/large green m (z-scored)
|
||||
|
||||
Labels (from FITNAH-3D or equivalent hi-fi simulation)
|
||||
-------------------------------------------------------
|
||||
0 T_air nocturnal air temperature at 2 m [°C delta from city mean]
|
||||
1 PET daytime PET at 1.1 m [°C delta from city mean]
|
||||
|
||||
Normalise labels city-by-city (subtract city mean, divide by city std)
|
||||
before training to remove macro-climate offsets. Denormalise for output.
|
||||
|
||||
Dataset layout
|
||||
--------------
|
||||
root/
|
||||
city_berlin/
|
||||
fine/ 0001.npy … (7, 256, 256) float32
|
||||
coarse/ 0001.npy … (7, 64, 64) float32
|
||||
labels/ 0001.npy … (2, 256, 256) float32
|
||||
masks/ 0001.npy … (1, 256, 256) bool [optional]
|
||||
city_munich/
|
||||
...
|
||||
|
||||
Requirements
|
||||
------------
|
||||
torch>=2.0 numpy (training only)
|
||||
onnxruntime (QGIS inference, no torch needed)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple
|
||||
|
||||
import numpy as np
|
||||
import torch # type: ignore
|
||||
import torch.nn as nn # type: ignore
|
||||
import torch.nn.functional as F # type: ignore
|
||||
from torch.utils.data import DataLoader, Dataset # type: ignore
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Output container
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ModelOutput(NamedTuple):
|
||||
"""Typed output tuple — also valid as a plain tuple for ONNX export."""
|
||||
t_air_mean: torch.Tensor # (B, 1, H, W)
|
||||
t_air_sigma: torch.Tensor # (B, 1, H, W) aleatoric σ in °C
|
||||
pet_mean: torch.Tensor # (B, 1, H, W)
|
||||
pet_sigma: torch.Tensor # (B, 1, H, W)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Building blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
"""
|
||||
Two Conv(3×3)–BN–ReLU layers with optional spatial dropout.
|
||||
|
||||
Dropout2d zeros entire feature maps rather than individual units.
|
||||
This gives better spatial calibration and is standard for dense
|
||||
prediction tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, in_ch: int, out_ch: int, dropout_p: float = 0.0):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_ch),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_ch),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
# Dropout2d after activation: drop whole feature maps stochastically.
|
||||
# During MC Dropout inference we re-enable these with enable_dropout().
|
||||
self.drop = nn.Dropout2d(p=dropout_p) if dropout_p > 0.0 else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.drop(self.conv(x))
|
||||
|
||||
|
||||
class UpBlock(nn.Module):
|
||||
"""
|
||||
Bilinear upsample (×2) followed by ConvBlock with skip concatenation.
|
||||
|
||||
Bilinear + 1×1 conv is preferred over ConvTranspose2d because it
|
||||
avoids checkerboard artefacts common in urban grid-like rasters.
|
||||
"""
|
||||
|
||||
def __init__(self, in_ch: int, skip_ch: int, out_ch: int, dropout_p: float = 0.0):
|
||||
super().__init__()
|
||||
|
||||
self.up = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
|
||||
nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False),
|
||||
)
|
||||
self.conv = ConvBlock(out_ch + skip_ch, out_ch, dropout_p=dropout_p)
|
||||
|
||||
def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
|
||||
x = self.up(x)
|
||||
# Guard against ±1 pixel size mismatch from odd spatial dimensions.
|
||||
if x.shape[-2:] != skip.shape[-2:]:
|
||||
x = F.interpolate(
|
||||
x, size=skip.shape[-2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
return self.conv(torch.cat([x, skip], dim=1))
|
||||
|
||||
|
||||
class CoarseEncoder(nn.Module):
|
||||
"""
|
||||
Lightweight encoder for the 25 m context patch (64×64 input).
|
||||
|
||||
Produces a 16×16 feature map that matches the bottleneck spatial
|
||||
size so it can be directly concatenated before the bottleneck ConvBlock.
|
||||
|
||||
This implements the KLIMASCANNER principle of "decreasing information
|
||||
density with distance": the coarse branch captures the macro-climate
|
||||
context (e.g. a large park, urban fringe effect) that is outside the
|
||||
fine patch's 1.28 km window but still influences local microclimate.
|
||||
"""
|
||||
|
||||
def __init__(self, in_ch: int, out_ch: int = 128):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
ConvBlock(in_ch, 32), # 64×64
|
||||
nn.MaxPool2d(2), # 32×32
|
||||
ConvBlock(32, 64), # 32×32
|
||||
nn.MaxPool2d(2), # 16×16
|
||||
ConvBlock(64, out_ch), # 16×16 ← same spatial as bottleneck
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class UrbanClimateUNet(nn.Module):
|
||||
"""
|
||||
U-Net for urban climate prediction with aleatoric uncertainty.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_ch : Input channels (same for fine and coarse patches).
|
||||
base_ch : Feature channels at Enc-1; doubles each encoder level.
|
||||
coarse_ch : Output channels of CoarseEncoder; fused at bottleneck.
|
||||
dropout_enc : Dropout2d probability for encoder levels 2 and 3.
|
||||
dropout_bot : Dropout2d probability for encoder level 4 and bottleneck.
|
||||
dropout_dec : Dropout2d probability for decoder level 3 (Dec-3).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_ch: int = 7,
|
||||
base_ch: int = 64,
|
||||
coarse_ch: int = 128,
|
||||
dropout_enc: float = 0.1,
|
||||
dropout_bot: float = 0.2,
|
||||
dropout_dec: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
c = base_ch # shorthand: 64
|
||||
|
||||
# ── Encoder ──────────────────────────────────────────────────────────
|
||||
# Spatial dropout is deliberately absent at Enc-1: the first layer
|
||||
# needs all raw input channels to avoid discarding morphological signal
|
||||
# before any abstractions are built.
|
||||
self.enc1 = ConvBlock(in_ch, c, dropout_p=0.0) # 256²
|
||||
self.enc2 = ConvBlock(c, c * 2, dropout_p=dropout_enc) # 128²
|
||||
self.enc3 = ConvBlock(c * 2, c * 4, dropout_p=dropout_enc) # 64²
|
||||
self.enc4 = ConvBlock(c * 4, c * 8, dropout_p=dropout_bot) # 32²
|
||||
self.pool = nn.MaxPool2d(2)
|
||||
|
||||
# ── Coarse context ───────────────────────────────────────────────────
|
||||
self.coarse_enc = CoarseEncoder(in_ch, coarse_ch)
|
||||
|
||||
# ── Bottleneck ───────────────────────────────────────────────────────
|
||||
# Receives: enc4-pooled (c*8 = 512 ch) + coarse context (128 ch)
|
||||
# → 640 ch input, 512 ch output
|
||||
self.bottleneck = ConvBlock(c * 8 + coarse_ch, c * 8, dropout_p=dropout_bot)
|
||||
|
||||
# ── Decoder ──────────────────────────────────────────────────────────
|
||||
# UpBlock(in_ch, skip_ch, out_ch):
|
||||
# in_ch = channels coming up from the previous decoder level
|
||||
# skip_ch = channels of the matching encoder skip connection
|
||||
# out_ch = channels output by this decoder level
|
||||
self.dec4 = UpBlock(c * 8, c * 8, c * 4) # 32²
|
||||
self.dec3 = UpBlock(c * 4, c * 4, c * 2, dropout_p=dropout_dec) # 64²
|
||||
self.dec2 = UpBlock(c * 2, c * 2, c) # 128²
|
||||
self.dec1 = UpBlock(c, c, c // 2) # 256²
|
||||
|
||||
# ── Output heads (1×1 convolutions) ─────────────────────────────────
|
||||
self.t_air_head = nn.Conv2d(c // 2, 1, kernel_size=1)
|
||||
self.pet_head = nn.Conv2d(c // 2, 1, kernel_size=1)
|
||||
# Single head outputs 2 channels: [log_σ_t_air, log_σ_pet]
|
||||
# Using log-σ for numerical stability; softplus converts to σ > 0.
|
||||
self.log_sigma_head = nn.Conv2d(c // 2, 2, kernel_size=1)
|
||||
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self) -> None:
|
||||
"""He initialisation for Conv layers; zero-init the σ head bias."""
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
# Initialise σ head to predict moderate uncertainty (~1 °C) at start.
|
||||
# Without this, early training can be unstable if σ starts near 0.
|
||||
nn.init.zeros_(self.log_sigma_head.bias)
|
||||
|
||||
def forward(self, fine: torch.Tensor, coarse: torch.Tensor) -> ModelOutput:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
fine : (B, in_ch, 256, 256) 5 m resolution morphology patch
|
||||
coarse : (B, in_ch, 64, 64) 25 m resolution context patch
|
||||
|
||||
Returns
|
||||
-------
|
||||
ModelOutput with fields t_air_mean, t_air_sigma, pet_mean, pet_sigma,
|
||||
each shaped (B, 1, 256, 256).
|
||||
"""
|
||||
# ── Encoder ──────────────────────────────────────────────────────────
|
||||
e1 = self.enc1(fine) # (B, 64, 256, 256)
|
||||
e2 = self.enc2(self.pool(e1)) # (B, 128, 128, 128)
|
||||
e3 = self.enc3(self.pool(e2)) # (B, 256, 64, 64)
|
||||
e4 = self.enc4(self.pool(e3)) # (B, 512, 32, 32)
|
||||
|
||||
# ── Coarse context ───────────────────────────────────────────────────
|
||||
ctx = self.coarse_enc(coarse) # (B, 128, 16, 16)
|
||||
|
||||
# ── Bottleneck ───────────────────────────────────────────────────────
|
||||
x = self.pool(e4) # (B, 512, 16, 16)
|
||||
x = torch.cat([x, ctx], dim=1) # (B, 640, 16, 16)
|
||||
x = self.bottleneck(x) # (B, 512, 16, 16)
|
||||
|
||||
# ── Decoder ──────────────────────────────────────────────────────────
|
||||
x = self.dec4(x, e4) # (B, 256, 32, 32)
|
||||
x = self.dec3(x, e3) # (B, 128, 64, 64)
|
||||
x = self.dec2(x, e2) # (B, 64, 128, 128)
|
||||
x = self.dec1(x, e1) # (B, 32, 256, 256)
|
||||
|
||||
# ── Output heads ─────────────────────────────────────────────────────
|
||||
t_air_mean = self.t_air_head(x) # (B, 1, 256, 256)
|
||||
pet_mean = self.pet_head(x) # (B, 1, 256, 256)
|
||||
|
||||
log_sigma = self.log_sigma_head(x) # (B, 2, 256, 256)
|
||||
# softplus is smooth and always positive; 1e-4 floor prevents σ → 0.
|
||||
sigma = F.softplus(log_sigma) + 1e-4
|
||||
t_air_sigma = sigma[:, 0:1] # (B, 1, 256, 256)
|
||||
pet_sigma = sigma[:, 1:2] # (B, 1, 256, 256)
|
||||
|
||||
return ModelOutput(t_air_mean, t_air_sigma, pet_mean, pet_sigma)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Loss function — Gaussian NLL
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def gaussian_nll_loss(
|
||||
mean: torch.Tensor,
|
||||
sigma: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Gaussian negative log-likelihood.
|
||||
|
||||
L = (y − μ)² / (2σ²) + ½ log(σ²)
|
||||
|
||||
This is the mechanism by which the model learns σ without any label:
|
||||
• First term: penalises wrong predictions; large σ softens the penalty.
|
||||
• Second term: penalises large σ; prevents the model from always hedging.
|
||||
The model converges to the σ that minimises the sum for each pixel type.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : Predicted mean (B, 1, H, W)
|
||||
sigma : Predicted sigma (B, 1, H, W), must be > 0
|
||||
target : FITNAH label (B, 1, H, W)
|
||||
mask : Optional valid-pixel mask (B, 1, H, W) bool.
|
||||
Pixels outside the simulation extent should be masked out.
|
||||
"""
|
||||
nll = (target - mean) ** 2 / (2.0 * sigma ** 2) + 0.5 * torch.log(sigma ** 2)
|
||||
if mask is not None:
|
||||
nll = nll[mask]
|
||||
return nll.mean()
|
||||
|
||||
|
||||
def combined_loss(
|
||||
preds: ModelOutput,
|
||||
t_air_target: torch.Tensor,
|
||||
pet_target: torch.Tensor,
|
||||
mask: torch.Tensor | None = None,
|
||||
t_air_weight: float = 1.0,
|
||||
pet_weight: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""Weighted sum of NLL losses for T_air and PET."""
|
||||
l_t = gaussian_nll_loss(preds.t_air_mean, preds.t_air_sigma, t_air_target, mask)
|
||||
l_p = gaussian_nll_loss(preds.pet_mean, preds.pet_sigma, pet_target, mask)
|
||||
return t_air_weight * l_t + pet_weight * l_p
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MC Dropout inference
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def enable_dropout(model: nn.Module) -> None:
|
||||
"""
|
||||
Enable only Dropout2d layers for MC inference.
|
||||
|
||||
Keeps BatchNorm in eval mode (uses running statistics, not batch stats).
|
||||
This is the correct approach: BN.train() during inference introduces
|
||||
noise from the current batch's statistics, which would contaminate the
|
||||
epistemic uncertainty estimate.
|
||||
"""
|
||||
model.eval()
|
||||
for m in model.modules():
|
||||
if isinstance(m, (nn.Dropout, nn.Dropout2d)):
|
||||
m.train()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_with_uncertainty(
|
||||
model: UrbanClimateUNet,
|
||||
fine: torch.Tensor,
|
||||
coarse: torch.Tensor,
|
||||
n_passes: int = 20,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> dict[str, object]:
|
||||
"""
|
||||
Single-sample inference with full uncertainty decomposition.
|
||||
|
||||
Aleatoric σ = mean of per-pass σ predictions
|
||||
(model's self-reported irreducible noise)
|
||||
Epistemic σ = std of per-pass mean predictions
|
||||
(model's disagreement = out-of-distribution signal)
|
||||
Total σ = √(σ_al² + σ_ep²)
|
||||
|
||||
The scene-level trust score mirrors the KLIMASCANNER Vertrauenswert:
|
||||
gut ≥ 80 % of pixels with total σ < 0.5 °C
|
||||
befriedigend ≥ 50 %
|
||||
ausreichend < 50 %
|
||||
|
||||
For the MVP (aleatoric only), set n_passes=1 and ignore the
|
||||
epistemic fields — they will be zero.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : Trained UrbanClimateUNet.
|
||||
fine : (1, in_ch, 256, 256)
|
||||
coarse : (1, in_ch, 64, 64)
|
||||
n_passes : Stochastic forward passes. 1 = aleatoric only (fast).
|
||||
20 = full decomposition (recommended for production).
|
||||
"""
|
||||
enable_dropout(model)
|
||||
fine = fine.to(device)
|
||||
coarse = coarse.to(device)
|
||||
|
||||
t_means, t_sigmas, p_means, p_sigmas = [], [], [], []
|
||||
for _ in range(n_passes):
|
||||
out = model(fine, coarse)
|
||||
t_means.append(out.t_air_mean)
|
||||
t_sigmas.append(out.t_air_sigma)
|
||||
p_means.append(out.pet_mean)
|
||||
p_sigmas.append(out.pet_sigma)
|
||||
|
||||
def decompose(
|
||||
means: list[torch.Tensor], sigmas: list[torch.Tensor]
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
m = torch.stack(means) # (n_passes, B, 1, H, W)
|
||||
s = torch.stack(sigmas)
|
||||
mu = m.mean(0) # mean prediction
|
||||
sigma_al = s.mean(0) # aleatoric: mean of σ
|
||||
sigma_ep = m.std(0) if n_passes > 1 else torch.zeros_like(mu)
|
||||
sigma_tot = (sigma_al ** 2 + sigma_ep ** 2).sqrt()
|
||||
return mu, sigma_al, sigma_ep, sigma_tot
|
||||
|
||||
t_mu, t_al, t_ep, t_tot = decompose(t_means, t_sigmas)
|
||||
p_mu, p_al, p_ep, p_tot = decompose(p_means, p_sigmas)
|
||||
|
||||
pct_good = (t_tot < 0.5).float().mean().item()
|
||||
trust = (
|
||||
"gut" if pct_good >= 0.8 else
|
||||
"befriedigend" if pct_good >= 0.5 else
|
||||
"ausreichend"
|
||||
)
|
||||
|
||||
model.eval()
|
||||
return {
|
||||
"t_air_mean": t_mu,
|
||||
"t_air_sigma_aleatoric": t_al,
|
||||
"t_air_sigma_epistemic": t_ep,
|
||||
"t_air_sigma_total": t_tot,
|
||||
"pet_mean": p_mu,
|
||||
"pet_sigma_aleatoric": p_al,
|
||||
"pet_sigma_epistemic": p_ep,
|
||||
"pet_sigma_total": p_tot,
|
||||
"trust_score": trust, # "gut" / "befriedigend" / "ausreichend"
|
||||
"pct_pixels_good": pct_good, # fraction of pixels with σ < 0.5 °C
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dataset
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class UrbanClimateDataset(Dataset):
|
||||
"""
|
||||
Loads pre-rasterised patch triplets (fine, coarse, labels) from disk.
|
||||
|
||||
All inputs must be normalised before saving (z-score per channel per city).
|
||||
Labels must be city-normalised (subtract city mean T/PET, divide by std)
|
||||
so the model learns local spatial patterns, not absolute temperatures.
|
||||
|
||||
See dataset layout in module docstring.
|
||||
"""
|
||||
|
||||
def __init__(self, root: str | Path, augment: bool = True):
|
||||
self.augment = augment
|
||||
self.samples: list[tuple[Path, Path, Path, Path | None]] = []
|
||||
|
||||
for city_dir in sorted(Path(root).iterdir()):
|
||||
if not city_dir.is_dir():
|
||||
continue
|
||||
fine_dir = city_dir / "fine"
|
||||
coarse_dir = city_dir / "coarse"
|
||||
label_dir = city_dir / "labels"
|
||||
mask_dir = city_dir / "masks"
|
||||
|
||||
for fine_p in sorted(fine_dir.glob("*.npy")):
|
||||
stem = fine_p.stem
|
||||
coarse_p = coarse_dir / f"{stem}.npy"
|
||||
label_p = label_dir / f"{stem}.npy"
|
||||
mask_p = (mask_dir / f"{stem}.npy") if mask_dir.exists() else None
|
||||
if coarse_p.exists() and label_p.exists():
|
||||
self.samples.append((fine_p, coarse_p, label_p, mask_p))
|
||||
|
||||
if not self.samples:
|
||||
raise FileNotFoundError(f"No samples found under {root}")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
fine_p, coarse_p, label_p, mask_p = self.samples[idx]
|
||||
|
||||
fine = torch.from_numpy(np.load(fine_p)).float() # (7, 256, 256)
|
||||
coarse = torch.from_numpy(np.load(coarse_p)).float() # (7, 64, 64)
|
||||
labels = torch.from_numpy(np.load(label_p)).float() # (2, 256, 256)
|
||||
mask = (
|
||||
torch.from_numpy(np.load(mask_p)).bool()
|
||||
if mask_p and mask_p.exists()
|
||||
else torch.ones(1, fine.shape[-2], fine.shape[-1], dtype=torch.bool)
|
||||
)
|
||||
|
||||
if self.augment:
|
||||
fine, coarse, labels, mask = self._random_rotate_flip(
|
||||
fine, coarse, labels, mask
|
||||
)
|
||||
|
||||
return {
|
||||
"fine": fine,
|
||||
"coarse": coarse,
|
||||
"t_air": labels[0:1], # (1, 256, 256)
|
||||
"pet": labels[1:2], # (1, 256, 256)
|
||||
"mask": mask, # (1, 256, 256) bool
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _random_rotate_flip(
|
||||
fine: torch.Tensor,
|
||||
coarse: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Random 90° rotation + horizontal flip.
|
||||
|
||||
Note: wind direction changes with rotation. This is intentional —
|
||||
the model should learn morphological patterns, not wind-relative
|
||||
ones (wind direction is a runtime input, not a training feature).
|
||||
"""
|
||||
k = int(torch.randint(4, ()).item())
|
||||
dims = (-2, -1)
|
||||
fine, coarse, labels, mask = (
|
||||
torch.rot90(t, k, dims) for t in (fine, coarse, labels, mask)
|
||||
)
|
||||
if torch.rand(()).item() > 0.5:
|
||||
fine, coarse, labels, mask = (
|
||||
t.flip(-1) for t in (fine, coarse, labels, mask)
|
||||
)
|
||||
return fine, coarse, labels, mask
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Training
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def train_one_epoch(
|
||||
model: UrbanClimateUNet,
|
||||
loader: DataLoader,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
device: torch.device,
|
||||
scaler: torch.cuda.amp.GradScaler | None = None,
|
||||
) -> float:
|
||||
model.train()
|
||||
total = 0.0
|
||||
|
||||
for batch in loader:
|
||||
fine = batch["fine"].to(device)
|
||||
coarse = batch["coarse"].to(device)
|
||||
t_air = batch["t_air"].to(device)
|
||||
pet = batch["pet"].to(device)
|
||||
mask = batch["mask"].to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if scaler is not None:
|
||||
with torch.autocast("cuda"):
|
||||
preds = model(fine, coarse)
|
||||
loss = combined_loss(preds, t_air, pet, mask)
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
preds = model(fine, coarse)
|
||||
loss = combined_loss(preds, t_air, pet, mask)
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
optimizer.step()
|
||||
|
||||
total += loss.item()
|
||||
|
||||
return total / len(loader)
|
||||
|
||||
|
||||
def train(
|
||||
data_root: str,
|
||||
epochs: int = 100,
|
||||
batch_size: int = 8,
|
||||
lr: float = 1e-4,
|
||||
weight_decay: float = 1e-4,
|
||||
checkpoint_dir: str = "checkpoints",
|
||||
device_str: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||
) -> UrbanClimateUNet:
|
||||
"""
|
||||
Full training loop with cosine LR schedule and mixed-precision.
|
||||
|
||||
Recommended hyperparameters for MVP (10–15 training cities):
|
||||
epochs=150, batch_size=8, lr=1e-4
|
||||
|
||||
For fine-tuning on a single new city:
|
||||
epochs=50, batch_size=4, lr=3e-5 (load pretrained weights first)
|
||||
"""
|
||||
device = torch.device(device_str)
|
||||
Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
dataset = UrbanClimateDataset(data_root, augment=True)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=min(4, batch_size),
|
||||
pin_memory=(device.type == "cuda"),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
model = UrbanClimateUNet().to(device)
|
||||
optimizer = torch.optim.AdamW(
|
||||
model.parameters(), lr=lr, weight_decay=weight_decay
|
||||
)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer, T_max=epochs, eta_min=lr * 0.01
|
||||
)
|
||||
scaler = torch.cuda.amp.GradScaler() if device.type == "cuda" else None
|
||||
|
||||
n_params = sum(p.numel() for p in model.parameters())
|
||||
print(f"Model : {n_params:,} parameters")
|
||||
print(f"Dataset: {len(dataset):,} patches")
|
||||
print(f"Device : {device}")
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
loss = train_one_epoch(model, loader, optimizer, device, scaler)
|
||||
scheduler.step()
|
||||
|
||||
if epoch % 10 == 0 or epoch == epochs:
|
||||
ckpt = Path(checkpoint_dir) / f"epoch_{epoch:04d}_loss{loss:.4f}.pt"
|
||||
torch.save(
|
||||
{
|
||||
"epoch": epoch,
|
||||
"model_state": model.state_dict(),
|
||||
"optimizer_state": optimizer.state_dict(),
|
||||
"loss": loss,
|
||||
"hparams": {
|
||||
"epochs": epochs,
|
||||
"batch_size": batch_size,
|
||||
"lr": lr,
|
||||
},
|
||||
},
|
||||
ckpt,
|
||||
)
|
||||
print(f"Epoch {epoch:4d}/{epochs} loss {loss:.4f} → {ckpt.name}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ONNX export for QGIS plugin deployment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _OnnxWrapper(nn.Module):
|
||||
"""Thin wrapper so torch.onnx.export receives a plain tuple output."""
|
||||
|
||||
def __init__(self, model: UrbanClimateUNet):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self, fine: torch.Tensor, coarse: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
out = self.model(fine, coarse)
|
||||
return out.t_air_mean, out.t_air_sigma, out.pet_mean, out.pet_sigma
|
||||
|
||||
|
||||
def export_onnx(
|
||||
model: UrbanClimateUNet,
|
||||
output_path: str = "urban_climate_unet.onnx",
|
||||
in_ch: int = 7,
|
||||
opset: int = 17,
|
||||
) -> None:
|
||||
"""
|
||||
Export trained model to ONNX for deployment in the QGIS plugin.
|
||||
|
||||
The plugin loads the .onnx file with onnxruntime (pure Python wheel,
|
||||
~15 MB, no PyTorch dependency at inference time).
|
||||
|
||||
After export, verify with:
|
||||
import onnxruntime as ort
|
||||
sess = ort.InferenceSession("urban_climate_unet.onnx")
|
||||
print([o.name for o in sess.get_outputs()])
|
||||
"""
|
||||
model.eval()
|
||||
wrapper = _OnnxWrapper(model)
|
||||
|
||||
dummy_fine = torch.zeros(1, in_ch, 256, 256)
|
||||
dummy_coarse = torch.zeros(1, in_ch, 64, 64)
|
||||
|
||||
torch.onnx.export(
|
||||
wrapper,
|
||||
(dummy_fine, dummy_coarse),
|
||||
output_path,
|
||||
input_names = ["fine_patch", "coarse_patch"],
|
||||
output_names = ["t_air_mean", "t_air_sigma", "pet_mean", "pet_sigma"],
|
||||
dynamic_axes = {
|
||||
"fine_patch": {0: "batch"},
|
||||
"coarse_patch": {0: "batch"},
|
||||
},
|
||||
opset_version = opset,
|
||||
do_constant_folding = True,
|
||||
)
|
||||
print(f"ONNX model exported → {output_path}")
|
||||
print("Load in QGIS plugin with: ort.InferenceSession('urban_climate_unet.onnx')")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quick sanity check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(42)
|
||||
|
||||
model = UrbanClimateUNet()
|
||||
model.eval()
|
||||
|
||||
B = 2
|
||||
fine = torch.randn(B, 7, 256, 256)
|
||||
coarse = torch.randn(B, 7, 64, 64)
|
||||
|
||||
# ── Forward pass ──────────────────────────────────────────────────────
|
||||
out = model(fine, coarse)
|
||||
print("=== Forward pass (batch=2) ===")
|
||||
for name, tensor in zip(out._fields, out):
|
||||
print(f" {name:<18} {tuple(tensor.shape)}")
|
||||
|
||||
# ── Loss ──────────────────────────────────────────────────────────────
|
||||
t_tgt = torch.randn(B, 1, 256, 256)
|
||||
p_tgt = torch.randn(B, 1, 256, 256)
|
||||
loss = combined_loss(out, t_tgt, p_tgt)
|
||||
print(f"\nGaussian NLL loss : {loss.item():.4f}")
|
||||
|
||||
# ── Parameter count ───────────────────────────────────────────────────
|
||||
n = sum(p.numel() for p in model.parameters())
|
||||
print(f"Parameters : {n:,}")
|
||||
|
||||
# ── MC Dropout uncertainty (fast, 5 passes) ───────────────────────────
|
||||
result = predict_with_uncertainty(
|
||||
model, fine[0:1], coarse[0:1], n_passes=5
|
||||
)
|
||||
print(f"\n=== MC Dropout (5 passes) ===")
|
||||
print(f" Trust score : {result['trust_score']}")
|
||||
print(f" Pct pixels good : {result['pct_pixels_good']:.1%}")
|
||||
print(f" T_air σ aleat. : {result['t_air_sigma_aleatoric'].mean():.3f} °C")
|
||||
print(f" T_air σ epist. : {result['t_air_sigma_epistemic'].mean():.3f} °C")
|
||||
print(f" T_air σ total : {result['t_air_sigma_total'].mean():.3f} °C")
|
||||
|
||||
# ── ONNX export (requires onnx package) ───────────────────────────────
|
||||
try:
|
||||
export_onnx(model, "urban_climate_unet.onnx")
|
||||
except Exception as e:
|
||||
print(f"\nONNX export skipped ({e})")
|
||||
from climatenet import *
|
||||
import rasterio
|
||||
|
||||
|
||||
|
||||
print("main.py started")
|
||||
|
||||
|
||||
def load_raster(file_path: str):
|
||||
with rasterio.open(file_path) as src:
|
||||
data = src.read()
|
||||
|
||||
tensor = torch.from_numpy(data.astype(np.float32))
|
||||
return tensor
|
||||
|
||||
|
||||
def save_raster_like(tensor: torch.Tensor, ref_path: str, file_path: str):
|
||||
data = tensor.detach().cpu().numpy()
|
||||
|
||||
with rasterio.open(ref_path) as example:
|
||||
meta = example.meta.copy()
|
||||
|
||||
meta.update({
|
||||
"driver": "GTiff",
|
||||
"height": data.shape[1],
|
||||
"width": data.shape[2],
|
||||
"count": data.shape[0],
|
||||
"dtype": data.dtype
|
||||
})
|
||||
|
||||
with rasterio.open(file_path, "w", **meta) as dst:
|
||||
dst.write(data)
|
||||
|
||||
|
||||
|
||||
|
||||
building_height = load_raster("data/INPUT/building_height.tif")
|
||||
tree_height = load_raster("data/INPUT/tree_height.tif")
|
||||
surface_height = load_raster("data/INPUT/zt.tif")
|
||||
pavement_type = load_raster("data/INPUT/pavement_type.tif")
|
||||
vegetation_type = load_raster("data/INPUT/vegetation_type.tif")
|
||||
water_type = load_raster("data/INPUT/water_type.tif")
|
||||
|
||||
output_pet = load_raster("data/OUTPUT/bio_pet_xy_av_14h.tif")
|
||||
output_temp = load_raster("data/OUTPUT/ta_av_h001_1.0m_14h.tif")
|
||||
|
||||
|
||||
|
||||
building_height = F.pad(building_height, (3, 3, 3, 3), mode='constant', value=0).unsqueeze(0)
|
||||
tree_height = F.pad(tree_height, (3, 3, 3, 3), mode='constant', value=0).unsqueeze(0)
|
||||
surface_height = F.pad(surface_height, (3, 3, 3, 3), mode='constant', value=0).unsqueeze(0)
|
||||
pavement_type = F.pad(pavement_type, (3, 3, 3, 3), mode='constant', value=0).unsqueeze(0)
|
||||
vegetation_type = F.pad(vegetation_type, (3, 3, 3, 3), mode='constant', value=0).unsqueeze(0)
|
||||
water_type = F.pad(water_type, (3, 3, 3, 3), mode='constant', value=0).unsqueeze(0)
|
||||
|
||||
building_height_coarse = F.interpolate(building_height, size=(64, 64), mode='bilinear', align_corners=False)
|
||||
tree_height_coarse = F.interpolate(tree_height, size=(64, 64), mode='bilinear', align_corners=False)
|
||||
surface_height_coarse = F.interpolate(surface_height, size=(64, 64), mode='bilinear', align_corners=False)
|
||||
pavement_type_coarse = F.interpolate(pavement_type, size=(64, 64), mode='nearest')
|
||||
vegetation_type_coarse = F.interpolate(vegetation_type, size=(64, 64), mode='nearest')
|
||||
water_type_coarse = F.interpolate(water_type, size=(64, 64), mode='nearest')
|
||||
|
||||
|
||||
print(building_height.shape)
|
||||
print(tree_height.shape)
|
||||
print(surface_height.shape)
|
||||
print(pavement_type.shape)
|
||||
print(vegetation_type.shape)
|
||||
print(water_type.shape)
|
||||
print(output_pet.shape)
|
||||
print(output_temp.shape)
|
||||
|
||||
|
||||
input_data_fine = torch.cat([building_height, tree_height, surface_height, pavement_type, vegetation_type, water_type], dim=1)
|
||||
print(f"total dim fine {input_data_fine.shape}")
|
||||
|
||||
|
||||
input_data_coarse = torch.cat([building_height_coarse, tree_height_coarse, surface_height_coarse, pavement_type_coarse, vegetation_type_coarse, water_type_coarse], dim=1)
|
||||
print(f"total dim coarse {input_data_coarse.shape}")
|
||||
|
||||
|
||||
|
||||
model = UrbanClimateUNet();
|
||||
model.eval()
|
||||
|
||||
out = model.forward(input_data_fine, input_data_coarse)
|
||||
|
||||
output_temp_pred = out.t_air_mean.squeeze(0)
|
||||
output_pet_pred = out.pet_mean.squeeze(0)
|
||||
|
||||
print(f"PET pred shape: {output_pet_pred.shape}");
|
||||
print(f"PET pred shape: {output_temp_pred.shape}");
|
||||
|
||||
save_raster_like(output_temp_pred, "data/OUTPUT/ta_av_h001_1.0m_14h.tif", "data/PRED/temp.tif")
|
||||
save_raster_like(output_pet_pred, "data/OUTPUT/bio_pet_xy_av_14h.tif", "data/PRED/pet.tif")
|
||||
Reference in New Issue
Block a user