optionalised coarse feedthrough

This commit is contained in:
luptmoor
2026-04-13 11:32:51 +02:00
parent c926d72827
commit 2a2d4ca3c2
2 changed files with 52 additions and 48 deletions

View File

@@ -221,7 +221,7 @@ class UrbanClimateUNet(nn.Module):
self.pool = nn.MaxPool2d(2) self.pool = nn.MaxPool2d(2)
# ── Coarse context ─────────────────────────────────────────────────── # ── Coarse context ───────────────────────────────────────────────────
self.coarse_enc = CoarseEncoder(in_ch, coarse_ch) if coarse_ch > 0: self.coarse_enc = CoarseEncoder(in_ch, coarse_ch)
# ── Bottleneck ─────────────────────────────────────────────────────── # ── Bottleneck ───────────────────────────────────────────────────────
# Receives: enc4-pooled (c*8 = 512 ch) + coarse context (128 ch) # Receives: enc4-pooled (c*8 = 512 ch) + coarse context (128 ch)
@@ -261,7 +261,7 @@ class UrbanClimateUNet(nn.Module):
# Without this, early training can be unstable if σ starts near 0. # Without this, early training can be unstable if σ starts near 0.
nn.init.zeros_(self.log_sigma_head.bias) nn.init.zeros_(self.log_sigma_head.bias)
def forward(self, fine: torch.Tensor, coarse: torch.Tensor) -> ModelOutput: def forward(self, fine: torch.Tensor, coarse: torch.Tensor | None = None) -> ModelOutput:
""" """
Parameters Parameters
---------- ----------
@@ -280,11 +280,11 @@ class UrbanClimateUNet(nn.Module):
e4 = self.enc4(self.pool(e3)) # (B, 512, 32, 32) e4 = self.enc4(self.pool(e3)) # (B, 512, 32, 32)
# ── Coarse context ─────────────────────────────────────────────────── # ── Coarse context ───────────────────────────────────────────────────
ctx = self.coarse_enc(coarse) # (B, 128, 16, 16) if self.coarse_enc: ctx = self.coarse_enc(coarse) # (B, 128, 16, 16)
# ── Bottleneck ─────────────────────────────────────────────────────── # ── Bottleneck ───────────────────────────────────────────────────────
x = self.pool(e4) # (B, 512, 16, 16) x = self.pool(e4) # (B, 512, 16, 16)
x = torch.cat([x, ctx], dim=1) # (B, 640, 16, 16) if self.coarse_enc: x = torch.cat([x, ctx], dim=1) # (B, 640, 16, 16)
x = self.bottleneck(x) # (B, 512, 16, 16) x = self.bottleneck(x) # (B, 512, 16, 16)
# ── Decoder ────────────────────────────────────────────────────────── # ── Decoder ──────────────────────────────────────────────────────────
@@ -539,6 +539,7 @@ class UrbanClimateDataset(Dataset):
fine, coarse, labels, mask = ( fine, coarse, labels, mask = (
torch.rot90(t, k, dims) for t in (fine, coarse, labels, mask) torch.rot90(t, k, dims) for t in (fine, coarse, labels, mask)
) )
if torch.rand(()).item() > 0.5: if torch.rand(()).item() > 0.5:
fine, coarse, labels, mask = ( fine, coarse, labels, mask = (
t.flip(-1) for t in (fine, coarse, labels, mask) t.flip(-1) for t in (fine, coarse, labels, mask)
@@ -723,7 +724,7 @@ def export_onnx(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
if __name__ == "__main__": if __name__ == "__main__":
torch.manual_seed(46) # torch.manual_seed(46)
# model = UrbanClimateUNet() # model = UrbanClimateUNet()
# model.eval() # model.eval()
@@ -762,9 +763,9 @@ if __name__ == "__main__":
# # ── ONNX export (requires onnx package) ─────────────────────────────── # # ── ONNX export (requires onnx package) ───────────────────────────────
# try: # try:
# export_onnx(model, "urban_climate_unet.onnx") # export_onnx(model, "urban_climate_unet.onnx")
# except Exception as e: # except Exception as e:temp_coarse
# print(f"\nONNX export skipped ({e})") # print(f"\nONNX export skipped ({e})")
print(f"climatenet class loaded.") print(f"climatenet class loaded.")

View File

@@ -33,63 +33,66 @@ def save_raster_like(tensor: torch.Tensor, ref_path: str, file_path: str):
if __name__ == "__main__":
torch.manual_seed(46)
building_height = load_raster("data/INPUT/building_height.tif")
tree_height = load_raster("data/INPUT/tree_height.tif")
surface_height = load_raster("data/INPUT/zt.tif")
pavement_type = load_raster("data/INPUT/pavement_type.tif")
vegetation_type = load_raster("data/INPUT/vegetation_type.tif")
water_type = load_raster("data/INPUT/water_type.tif")
output_pet = load_raster("data/OUTPUT/bio_pet_xy_av_14h.tif") building_height = load_raster("data/INPUT/building_height.tif")
output_temp = load_raster("data/OUTPUT/ta_av_h001_1.0m_14h.tif") tree_height = load_raster("data/INPUT/tree_height.tif")
surface_height = load_raster("data/INPUT/zt.tif")
pavement_type = load_raster("data/INPUT/pavement_type.tif")
vegetation_type = load_raster("data/INPUT/vegetation_type.tif")
water_type = load_raster("data/INPUT/water_type.tif")
output_pet = load_raster("data/OUTPUT/bio_pet_xy_av_14h.tif")
output_temp = load_raster("data/OUTPUT/ta_av_h001_1.0m_14h.tif")
building_height = F.pad(building_height, (3, 3, 3, 3), mode='constant', value=0).unsqueeze(0) building_height = F.pad(building_height, (3, 3, 3, 3), mode='constant', value=0).unsqueeze(0)
tree_height = F.pad(tree_height, (3, 3, 3, 3), mode='constant', value=0).unsqueeze(0) tree_height = F.pad(tree_height, (3, 3, 3, 3), mode='constant', value=0).unsqueeze(0)
surface_height = F.pad(surface_height, (3, 3, 3, 3), mode='constant', value=0).unsqueeze(0) surface_height = F.pad(surface_height, (3, 3, 3, 3), mode='constant', value=0).unsqueeze(0)
pavement_type = F.pad(pavement_type, (3, 3, 3, 3), mode='constant', value=0).unsqueeze(0) pavement_type = F.pad(pavement_type, (3, 3, 3, 3), mode='constant', value=0).unsqueeze(0)
vegetation_type = F.pad(vegetation_type, (3, 3, 3, 3), mode='constant', value=0).unsqueeze(0) vegetation_type = F.pad(vegetation_type, (3, 3, 3, 3), mode='constant', value=0).unsqueeze(0)
water_type = F.pad(water_type, (3, 3, 3, 3), mode='constant', value=0).unsqueeze(0) water_type = F.pad(water_type, (3, 3, 3, 3), mode='constant', value=0).unsqueeze(0)
building_height_coarse = F.interpolate(building_height, size=(64, 64), mode='bilinear', align_corners=False) building_height_coarse = F.interpolate(building_height, size=(64, 64), mode='bilinear', align_corners=False)
tree_height_coarse = F.interpolate(tree_height, size=(64, 64), mode='bilinear', align_corners=False) tree_height_coarse = F.interpolate(tree_height, size=(64, 64), mode='bilinear', align_corners=False)
surface_height_coarse = F.interpolate(surface_height, size=(64, 64), mode='bilinear', align_corners=False) surface_height_coarse = F.interpolate(surface_height, size=(64, 64), mode='bilinear', align_corners=False)
pavement_type_coarse = F.interpolate(pavement_type, size=(64, 64), mode='nearest') pavement_type_coarse = F.interpolate(pavement_type, size=(64, 64), mode='nearest')
vegetation_type_coarse = F.interpolate(vegetation_type, size=(64, 64), mode='nearest') vegetation_type_coarse = F.interpolate(vegetation_type, size=(64, 64), mode='nearest')
water_type_coarse = F.interpolate(water_type, size=(64, 64), mode='nearest') water_type_coarse = F.interpolate(water_type, size=(64, 64), mode='nearest')
print(building_height.shape) print(building_height.shape)
print(tree_height.shape) print(tree_height.shape)
print(surface_height.shape) print(surface_height.shape)
print(pavement_type.shape) print(pavement_type.shape)
print(vegetation_type.shape) print(vegetation_type.shape)
print(water_type.shape) print(water_type.shape)
print(output_pet.shape) print(output_pet.shape)
print(output_temp.shape) print(output_temp.shape)
input_data_fine = torch.cat([building_height, tree_height, surface_height, pavement_type, vegetation_type, water_type], dim=1) input_data_fine = torch.cat([building_height, tree_height, surface_height, pavement_type, vegetation_type, water_type], dim=1)
print(f"total dim fine {input_data_fine.shape}") print(f"total dim fine {input_data_fine.shape}")
input_data_coarse = torch.cat([building_height_coarse, tree_height_coarse, surface_height_coarse, pavement_type_coarse, vegetation_type_coarse, water_type_coarse], dim=1) input_data_coarse = torch.cat([building_height_coarse, tree_height_coarse, surface_height_coarse, pavement_type_coarse, vegetation_type_coarse, water_type_coarse], dim=1)
print(f"total dim coarse {input_data_coarse.shape}") print(f"total dim coarse {input_data_coarse.shape}")
model = UrbanClimateUNet(); model = UrbanClimateUNet();
model.eval() model.eval()
out = model.forward(input_data_fine, input_data_coarse) out = model.forward(input_data_fine, input_data_coarse)
output_temp_pred = out.t_air_mean.squeeze(0) output_temp_pred = out.t_air_mean.squeeze(0)
output_pet_pred = out.pet_mean.squeeze(0) output_pet_pred = out.pet_mean.squeeze(0)
print(f"PET pred shape: {output_pet_pred.shape}"); print(f"PET pred shape: {output_pet_pred.shape}");
print(f"PET pred shape: {output_temp_pred.shape}"); print(f"PET pred shape: {output_temp_pred.shape}");
save_raster_like(output_temp_pred, "data/OUTPUT/ta_av_h001_1.0m_14h.tif", "data/PRED/temp.tif") save_raster_like(output_temp_pred, "data/OUTPUT/ta_av_h001_1.0m_14h.tif", "data/PRED/temp_coarse.tif")
save_raster_like(output_pet_pred, "data/OUTPUT/bio_pet_xy_av_14h.tif", "data/PRED/pet.tif") save_raster_like(output_pet_pred, "data/OUTPUT/bio_pet_xy_av_14h.tif", "data/PRED/pet.tif")