optionalised coarse feedthrough

This commit is contained in:
luptmoor
2026-04-13 11:32:51 +02:00
parent c926d72827
commit 2a2d4ca3c2
2 changed files with 52 additions and 48 deletions

View File

@@ -221,7 +221,7 @@ class UrbanClimateUNet(nn.Module):
self.pool = nn.MaxPool2d(2)
# ── Coarse context ───────────────────────────────────────────────────
self.coarse_enc = CoarseEncoder(in_ch, coarse_ch)
if coarse_ch > 0: self.coarse_enc = CoarseEncoder(in_ch, coarse_ch)
# ── Bottleneck ───────────────────────────────────────────────────────
# Receives: enc4-pooled (c*8 = 512 ch) + coarse context (128 ch)
@@ -261,7 +261,7 @@ class UrbanClimateUNet(nn.Module):
# Without this, early training can be unstable if σ starts near 0.
nn.init.zeros_(self.log_sigma_head.bias)
def forward(self, fine: torch.Tensor, coarse: torch.Tensor) -> ModelOutput:
def forward(self, fine: torch.Tensor, coarse: torch.Tensor | None = None) -> ModelOutput:
"""
Parameters
----------
@@ -280,11 +280,11 @@ class UrbanClimateUNet(nn.Module):
e4 = self.enc4(self.pool(e3)) # (B, 512, 32, 32)
# ── Coarse context ───────────────────────────────────────────────────
ctx = self.coarse_enc(coarse) # (B, 128, 16, 16)
if self.coarse_enc: ctx = self.coarse_enc(coarse) # (B, 128, 16, 16)
# ── Bottleneck ───────────────────────────────────────────────────────
x = self.pool(e4) # (B, 512, 16, 16)
x = torch.cat([x, ctx], dim=1) # (B, 640, 16, 16)
if self.coarse_enc: x = torch.cat([x, ctx], dim=1) # (B, 640, 16, 16)
x = self.bottleneck(x) # (B, 512, 16, 16)
# ── Decoder ──────────────────────────────────────────────────────────
@@ -539,6 +539,7 @@ class UrbanClimateDataset(Dataset):
fine, coarse, labels, mask = (
torch.rot90(t, k, dims) for t in (fine, coarse, labels, mask)
)
if torch.rand(()).item() > 0.5:
fine, coarse, labels, mask = (
t.flip(-1) for t in (fine, coarse, labels, mask)
@@ -723,7 +724,7 @@ def export_onnx(
# ---------------------------------------------------------------------------
if __name__ == "__main__":
torch.manual_seed(46)
# torch.manual_seed(46)
# model = UrbanClimateUNet()
# model.eval()
@@ -762,9 +763,9 @@ if __name__ == "__main__":
# # ── ONNX export (requires onnx package) ───────────────────────────────
# try:
# export_onnx(model, "urban_climate_unet.onnx")
# except Exception as e:
# except Exception as e:temp_coarse
# print(f"\nONNX export skipped ({e})")
print(f"climatenet class loaded.")
print(f"climatenet class loaded.")