optionalised coarse feedthrough
This commit is contained in:
@@ -221,7 +221,7 @@ class UrbanClimateUNet(nn.Module):
|
|||||||
self.pool = nn.MaxPool2d(2)
|
self.pool = nn.MaxPool2d(2)
|
||||||
|
|
||||||
# ── Coarse context ───────────────────────────────────────────────────
|
# ── Coarse context ───────────────────────────────────────────────────
|
||||||
self.coarse_enc = CoarseEncoder(in_ch, coarse_ch)
|
if coarse_ch > 0: self.coarse_enc = CoarseEncoder(in_ch, coarse_ch)
|
||||||
|
|
||||||
# ── Bottleneck ───────────────────────────────────────────────────────
|
# ── Bottleneck ───────────────────────────────────────────────────────
|
||||||
# Receives: enc4-pooled (c*8 = 512 ch) + coarse context (128 ch)
|
# Receives: enc4-pooled (c*8 = 512 ch) + coarse context (128 ch)
|
||||||
@@ -261,7 +261,7 @@ class UrbanClimateUNet(nn.Module):
|
|||||||
# Without this, early training can be unstable if σ starts near 0.
|
# Without this, early training can be unstable if σ starts near 0.
|
||||||
nn.init.zeros_(self.log_sigma_head.bias)
|
nn.init.zeros_(self.log_sigma_head.bias)
|
||||||
|
|
||||||
def forward(self, fine: torch.Tensor, coarse: torch.Tensor) -> ModelOutput:
|
def forward(self, fine: torch.Tensor, coarse: torch.Tensor | None = None) -> ModelOutput:
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@@ -280,11 +280,11 @@ class UrbanClimateUNet(nn.Module):
|
|||||||
e4 = self.enc4(self.pool(e3)) # (B, 512, 32, 32)
|
e4 = self.enc4(self.pool(e3)) # (B, 512, 32, 32)
|
||||||
|
|
||||||
# ── Coarse context ───────────────────────────────────────────────────
|
# ── Coarse context ───────────────────────────────────────────────────
|
||||||
ctx = self.coarse_enc(coarse) # (B, 128, 16, 16)
|
if self.coarse_enc: ctx = self.coarse_enc(coarse) # (B, 128, 16, 16)
|
||||||
|
|
||||||
# ── Bottleneck ───────────────────────────────────────────────────────
|
# ── Bottleneck ───────────────────────────────────────────────────────
|
||||||
x = self.pool(e4) # (B, 512, 16, 16)
|
x = self.pool(e4) # (B, 512, 16, 16)
|
||||||
x = torch.cat([x, ctx], dim=1) # (B, 640, 16, 16)
|
if self.coarse_enc: x = torch.cat([x, ctx], dim=1) # (B, 640, 16, 16)
|
||||||
x = self.bottleneck(x) # (B, 512, 16, 16)
|
x = self.bottleneck(x) # (B, 512, 16, 16)
|
||||||
|
|
||||||
# ── Decoder ──────────────────────────────────────────────────────────
|
# ── Decoder ──────────────────────────────────────────────────────────
|
||||||
@@ -539,6 +539,7 @@ class UrbanClimateDataset(Dataset):
|
|||||||
fine, coarse, labels, mask = (
|
fine, coarse, labels, mask = (
|
||||||
torch.rot90(t, k, dims) for t in (fine, coarse, labels, mask)
|
torch.rot90(t, k, dims) for t in (fine, coarse, labels, mask)
|
||||||
)
|
)
|
||||||
|
|
||||||
if torch.rand(()).item() > 0.5:
|
if torch.rand(()).item() > 0.5:
|
||||||
fine, coarse, labels, mask = (
|
fine, coarse, labels, mask = (
|
||||||
t.flip(-1) for t in (fine, coarse, labels, mask)
|
t.flip(-1) for t in (fine, coarse, labels, mask)
|
||||||
@@ -723,7 +724,7 @@ def export_onnx(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
torch.manual_seed(46)
|
# torch.manual_seed(46)
|
||||||
|
|
||||||
# model = UrbanClimateUNet()
|
# model = UrbanClimateUNet()
|
||||||
# model.eval()
|
# model.eval()
|
||||||
@@ -762,7 +763,7 @@ if __name__ == "__main__":
|
|||||||
# # ── ONNX export (requires onnx package) ───────────────────────────────
|
# # ── ONNX export (requires onnx package) ───────────────────────────────
|
||||||
# try:
|
# try:
|
||||||
# export_onnx(model, "urban_climate_unet.onnx")
|
# export_onnx(model, "urban_climate_unet.onnx")
|
||||||
# except Exception as e:
|
# except Exception as e:temp_coarse
|
||||||
# print(f"\nONNX export skipped ({e})")
|
# print(f"\nONNX export skipped ({e})")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -33,6 +33,9 @@ def save_raster_like(tensor: torch.Tensor, ref_path: str, file_path: str):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
torch.manual_seed(46)
|
||||||
|
|
||||||
|
|
||||||
building_height = load_raster("data/INPUT/building_height.tif")
|
building_height = load_raster("data/INPUT/building_height.tif")
|
||||||
tree_height = load_raster("data/INPUT/tree_height.tif")
|
tree_height = load_raster("data/INPUT/tree_height.tif")
|
||||||
@@ -91,5 +94,5 @@ output_pet_pred = out.pet_mean.squeeze(0)
|
|||||||
print(f"PET pred shape: {output_pet_pred.shape}");
|
print(f"PET pred shape: {output_pet_pred.shape}");
|
||||||
print(f"PET pred shape: {output_temp_pred.shape}");
|
print(f"PET pred shape: {output_temp_pred.shape}");
|
||||||
|
|
||||||
save_raster_like(output_temp_pred, "data/OUTPUT/ta_av_h001_1.0m_14h.tif", "data/PRED/temp.tif")
|
save_raster_like(output_temp_pred, "data/OUTPUT/ta_av_h001_1.0m_14h.tif", "data/PRED/temp_coarse.tif")
|
||||||
save_raster_like(output_pet_pred, "data/OUTPUT/bio_pet_xy_av_14h.tif", "data/PRED/pet.tif")
|
save_raster_like(output_pet_pred, "data/OUTPUT/bio_pet_xy_av_14h.tif", "data/PRED/pet.tif")
|
||||||
Reference in New Issue
Block a user