|
|
import os |
|
|
import json |
|
|
import random |
|
|
from typing import Dict, List, Tuple, Optional, Any |
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from tqdm import tqdm |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop |
|
|
from torchvision.utils import save_image |
|
|
import lpips |
|
|
|
|
|
from diffusers import ( |
|
|
AutoencoderKL, |
|
|
AutoencoderKLWan, |
|
|
AutoencoderKLLTXVideo, |
|
|
AutoencoderKLQwenImage |
|
|
) |
|
|
|
|
|
from scipy.stats import skew, kurtosis |
|
|
|
|
|
|
|
|
|
|
|
DEVICE = "cuda" |
|
|
DTYPE = torch.float16 |
|
|
IMAGE_FOLDER = "/home/recoilme/dataset/alchemist" |
|
|
MIN_SIZE = 1280 |
|
|
CROP_SIZE = 512 |
|
|
BATCH_SIZE = 5 |
|
|
MAX_IMAGES = 0 |
|
|
NUM_WORKERS = 4 |
|
|
SAMPLES_DIR = "test" |
|
|
|
|
|
VAE_LIST = [ |
|
|
("SD15 VAE", AutoencoderKL, "stable-diffusion-v1-5/stable-diffusion-v1-5", "vae"), |
|
|
("SDXL VAE fp16 fix", AutoencoderKL, "madebyollin/sdxl-vae-fp16-fix", None), |
|
|
("AiArtLab/sdxl_vae", AutoencoderKL, "AiArtLab/sdxl_vae", "vae"), |
|
|
("LTX-Video VAE", AutoencoderKLLTXVideo, "Lightricks/LTX-Video", "vae"), |
|
|
("Wan2.2-TI2V-5B", AutoencoderKLWan, "Wan-AI/Wan2.2-TI2V-5B-Diffusers", "vae"), |
|
|
("AiArtLab/wan16x_vae", AutoencoderKLWan, "AiArtLab/wan16x_vae", "vae"), |
|
|
("Wan2.2-T2V-A14B", AutoencoderKLWan, "Wan-AI/Wan2.2-T2V-A14B-Diffusers", "vae"), |
|
|
("QwenImage", AutoencoderKLQwenImage, "Qwen/Qwen-Image", "vae"), |
|
|
("AuraDiffusion/16ch-vae", AutoencoderKL, "AuraDiffusion/16ch-vae", None), |
|
|
("FLUX.1-schnell VAE", AutoencoderKL, "black-forest-labs/FLUX.1-schnell", "vae"), |
|
|
("AiArtLab/simplevae", AutoencoderKL, "AiArtLab/simplevae", "vae"), |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
def to_neg1_1(x: torch.Tensor) -> torch.Tensor: |
|
|
return x * 2 - 1 |
|
|
|
|
|
|
|
|
def to_0_1(x: torch.Tensor) -> torch.Tensor: |
|
|
return (x + 1) * 0.5 |
|
|
|
|
|
|
|
|
def safe_psnr(mse: float) -> float: |
|
|
if mse <= 1e-12: |
|
|
return float("inf") |
|
|
return 10.0 * float(np.log10(1.0 / mse)) |
|
|
|
|
|
|
|
|
def is_video_like_vae(vae) -> bool: |
|
|
|
|
|
return isinstance(vae, (AutoencoderKLWan, AutoencoderKLLTXVideo,AutoencoderKLQwenImage)) |
|
|
|
|
|
|
|
|
def add_time_dim_if_needed(x: torch.Tensor, vae) -> torch.Tensor: |
|
|
if is_video_like_vae(vae) and x.ndim == 4: |
|
|
return x.unsqueeze(2) |
|
|
return x |
|
|
|
|
|
|
|
|
def strip_time_dim_if_possible(x: torch.Tensor, vae) -> torch.Tensor: |
|
|
if is_video_like_vae(vae) and x.ndim == 5 and x.shape[2] == 1: |
|
|
return x.squeeze(2) |
|
|
return x |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def sobel_edge_l1(real_0_1: torch.Tensor, fake_0_1: torch.Tensor) -> float: |
|
|
real = to_neg1_1(real_0_1) |
|
|
fake = to_neg1_1(fake_0_1) |
|
|
kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device=real.device).view(1, 1, 3, 3) |
|
|
ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32, device=real.device).view(1, 1, 3, 3) |
|
|
C = real.shape[1] |
|
|
kx = kx.to(real.dtype).repeat(C, 1, 1, 1) |
|
|
ky = ky.to(real.dtype).repeat(C, 1, 1, 1) |
|
|
|
|
|
def grad_mag(x): |
|
|
gx = F.conv2d(x, kx, padding=1, groups=C) |
|
|
gy = F.conv2d(x, ky, padding=1, groups=C) |
|
|
return torch.sqrt(gx * gx + gy * gy + 1e-12) |
|
|
|
|
|
return F.l1_loss(grad_mag(fake), grad_mag(real)).item() |
|
|
|
|
|
|
|
|
def flatten_channels(x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
if x.ndim == 4: |
|
|
return x.permute(1, 0, 2, 3).reshape(x.shape[1], -1) |
|
|
elif x.ndim == 5: |
|
|
return x.permute(1, 0, 2, 3, 4).reshape(x.shape[1], -1) |
|
|
else: |
|
|
raise ValueError(f"Unexpected tensor ndim={x.ndim}") |
|
|
|
|
|
|
|
|
def _to_numpy_1d(x: Any) -> Optional[np.ndarray]: |
|
|
if x is None: |
|
|
return None |
|
|
if isinstance(x, (int, float)): |
|
|
return None |
|
|
if isinstance(x, torch.Tensor): |
|
|
x = x.detach().cpu().float().numpy() |
|
|
elif isinstance(x, (list, tuple)): |
|
|
x = np.array(x, dtype=np.float32) |
|
|
elif isinstance(x, np.ndarray): |
|
|
x = x.astype(np.float32, copy=False) |
|
|
else: |
|
|
return None |
|
|
x = x.reshape(-1) |
|
|
return x |
|
|
|
|
|
|
|
|
def _to_float(x: Any) -> Optional[float]: |
|
|
if x is None: |
|
|
return None |
|
|
if isinstance(x, (int, float)): |
|
|
return float(x) |
|
|
if isinstance(x, np.ndarray) and x.size == 1: |
|
|
return float(x.item()) |
|
|
if isinstance(x, torch.Tensor) and x.numel() == 1: |
|
|
return float(x.item()) |
|
|
return None |
|
|
|
|
|
|
|
|
def get_norm_tensors_and_summary(vae, latent_like: torch.Tensor): |
|
|
""" |
|
|
Нормализация латентов: глобальная и поканальная. |
|
|
Применение: сначала глобальная (scalar), затем поканальная (vector). |
|
|
Если в конфиге есть несколько ключей — аккумулируем. |
|
|
""" |
|
|
cfg = getattr(vae, "config", vae) |
|
|
|
|
|
scale_keys = [ |
|
|
"latents_std" |
|
|
] |
|
|
shift_keys = [ |
|
|
"latents_mean" |
|
|
] |
|
|
|
|
|
C = latent_like.shape[1] |
|
|
nd = latent_like.ndim |
|
|
dev = latent_like.device |
|
|
dt = latent_like.dtype |
|
|
|
|
|
scale_global = getattr(vae.config, "scaling_factor", 1.0) |
|
|
shift_global = getattr(vae.config, "shift_factor", 0.0) |
|
|
if scale_global is None: |
|
|
scale_global = 1.0 |
|
|
if shift_global is None: |
|
|
shift_global = 0.0 |
|
|
|
|
|
scale_channel = np.ones(C, dtype=np.float32) |
|
|
shift_channel = np.zeros(C, dtype=np.float32) |
|
|
|
|
|
for k in scale_keys: |
|
|
v = getattr(cfg, k, None) |
|
|
if v is None: |
|
|
continue |
|
|
vec = _to_numpy_1d(v) |
|
|
if vec is not None and vec.size == C: |
|
|
scale_channel *= vec |
|
|
else: |
|
|
s = _to_float(v) |
|
|
if s is not None: |
|
|
scale_global *= s |
|
|
|
|
|
for k in shift_keys: |
|
|
v = getattr(cfg, k, None) |
|
|
if v is None: |
|
|
continue |
|
|
vec = _to_numpy_1d(v) |
|
|
if vec is not None and vec.size == C: |
|
|
shift_channel += vec |
|
|
else: |
|
|
s = _to_float(v) |
|
|
if s is not None: |
|
|
shift_global += s |
|
|
|
|
|
g_shape = [1] * nd |
|
|
c_shape = [1] * nd |
|
|
c_shape[1] = C |
|
|
|
|
|
t_scale_g = torch.tensor(scale_global, dtype=dt, device=dev).view(*g_shape) |
|
|
t_shift_g = torch.tensor(shift_global, dtype=dt, device=dev).view(*g_shape) |
|
|
t_scale_c = torch.from_numpy(scale_channel).to(device=dev, dtype=dt).view(*c_shape) |
|
|
t_shift_c = torch.from_numpy(shift_channel).to(device=dev, dtype=dt).view(*c_shape) |
|
|
|
|
|
summary = { |
|
|
"scale_global": float(scale_global), |
|
|
"shift_global": float(shift_global), |
|
|
"scale_channel_min": float(scale_channel.min()), |
|
|
"scale_channel_mean": float(scale_channel.mean()), |
|
|
"scale_channel_max": float(scale_channel.max()), |
|
|
"shift_channel_min": float(shift_channel.min()), |
|
|
"shift_channel_mean": float(shift_channel.mean()), |
|
|
"shift_channel_max": float(shift_channel.max()), |
|
|
} |
|
|
return t_shift_g, t_scale_g, t_shift_c, t_scale_c, summary |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def kl_divergence_per_image(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: |
|
|
kl_map = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) |
|
|
return kl_map.float().view(kl_map.shape[0], -1).mean(dim=1) |
|
|
|
|
|
|
|
|
def sanitize_filename(name: str) -> str: |
|
|
name = name.replace("/", "_").replace("\\", "_").replace(" ", "_") |
|
|
return "".join(ch if (ch.isalnum() or ch in "._-") else "_" for ch in name) |
|
|
|
|
|
|
|
|
|
|
|
class ImageFolderDataset(Dataset): |
|
|
def __init__(self, root_dir: str, extensions=(".png", ".jpg", ".jpeg", ".webp"), min_size=1024, crop_size=512, limit=None): |
|
|
paths = [] |
|
|
for root, _, files in os.walk(root_dir): |
|
|
for fname in files: |
|
|
if fname.lower().endswith(extensions): |
|
|
paths.append(os.path.join(root, fname)) |
|
|
if limit: |
|
|
paths = paths[:limit] |
|
|
|
|
|
valid = [] |
|
|
for p in tqdm(paths, desc="Проверяем файлы"): |
|
|
try: |
|
|
with Image.open(p) as im: |
|
|
im.verify() |
|
|
valid.append(p) |
|
|
except Exception: |
|
|
pass |
|
|
if not valid: |
|
|
raise RuntimeError(f"Нет валидных изображений в {root_dir}") |
|
|
random.shuffle(valid) |
|
|
self.paths = valid |
|
|
print(f"Найдено {len(self.paths)} изображений") |
|
|
|
|
|
self.transform = Compose([ |
|
|
Resize(min_size), |
|
|
CenterCrop(crop_size), |
|
|
ToTensor(), |
|
|
]) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.paths) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
with Image.open(self.paths[idx]) as img: |
|
|
img = img.convert("RGB") |
|
|
return self.transform(img) |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
torch.set_grad_enabled(False) |
|
|
os.makedirs(SAMPLES_DIR, exist_ok=True) |
|
|
|
|
|
dataset = ImageFolderDataset(IMAGE_FOLDER, min_size=MIN_SIZE, crop_size=CROP_SIZE, limit=MAX_IMAGES) |
|
|
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True) |
|
|
|
|
|
lpips_net = lpips.LPIPS(net="vgg").to(DEVICE).eval() |
|
|
|
|
|
|
|
|
vaes: List[Tuple[str, object]] = [] |
|
|
print("\nЗагрузка VAE...") |
|
|
for human_name, vae_class, model_path, subfolder in VAE_LIST: |
|
|
try: |
|
|
vae = vae_class.from_pretrained(model_path, subfolder=subfolder, torch_dtype=DTYPE) |
|
|
vae = vae.to(DEVICE).eval() |
|
|
vaes.append((human_name, vae)) |
|
|
print(f" ✅ {human_name}") |
|
|
except Exception as e: |
|
|
print(f" ❌ {human_name}: {e}") |
|
|
|
|
|
if not vaes: |
|
|
print("Нет успешно загруженных VAE. Выходим.") |
|
|
return |
|
|
|
|
|
|
|
|
per_model_metrics: Dict[str, Dict[str, float]] = { |
|
|
name: {"mse": 0.0, "psnr": 0.0, "lpips": 0.0, "edge": 0.0, "kl": 0.0, "count": 0.0} |
|
|
for name, _ in vaes |
|
|
} |
|
|
|
|
|
buffers_zmodel: Dict[str, List[torch.Tensor]] = {name: [] for name, _ in vaes} |
|
|
norm_summaries: Dict[str, Dict[str, float]] = {} |
|
|
|
|
|
|
|
|
saved_first_for: Dict[str, bool] = {name: False for name, _ in vaes} |
|
|
|
|
|
for batch_0_1 in tqdm(loader, desc="Батчи"): |
|
|
batch_0_1 = batch_0_1.to(DEVICE, torch.float32) |
|
|
batch_neg1_1 = to_neg1_1(batch_0_1).to(DTYPE) |
|
|
|
|
|
for model_name, vae in vaes: |
|
|
x_in = add_time_dim_if_needed(batch_neg1_1, vae) |
|
|
|
|
|
posterior = vae.encode(x_in).latent_dist |
|
|
mu, logvar = posterior.mean, posterior.logvar |
|
|
|
|
|
|
|
|
z_raw_mode = posterior.mode() |
|
|
x_dec = vae.decode(z_raw_mode).sample |
|
|
x_dec = strip_time_dim_if_possible(x_dec, vae) |
|
|
x_rec_0_1 = to_0_1(x_dec.float()).clamp(0, 1) |
|
|
|
|
|
|
|
|
z_raw_sample = posterior.sample() |
|
|
t_shift_g, t_scale_g, t_shift_c, t_scale_c, summary = get_norm_tensors_and_summary(vae, z_raw_sample) |
|
|
|
|
|
if model_name not in norm_summaries: |
|
|
norm_summaries[model_name] = summary |
|
|
|
|
|
z_tmp = (z_raw_sample - t_shift_g) * t_scale_g |
|
|
z_model = (z_tmp - t_shift_c) * t_scale_c |
|
|
z_model = strip_time_dim_if_possible(z_model, vae) |
|
|
|
|
|
buffers_zmodel[model_name].append(z_model.detach().to("cpu", torch.float32)) |
|
|
|
|
|
|
|
|
if not saved_first_for[model_name]: |
|
|
safe = sanitize_filename(model_name) |
|
|
orig_path = os.path.join(SAMPLES_DIR, f"{safe}_original.png") |
|
|
dec_path = os.path.join(SAMPLES_DIR, f"{safe}_decoded.png") |
|
|
save_image(batch_0_1[0:1].cpu(), orig_path) |
|
|
save_image(x_rec_0_1[0:1].cpu(), dec_path) |
|
|
saved_first_for[model_name] = True |
|
|
|
|
|
|
|
|
B = batch_0_1.shape[0] |
|
|
for i in range(B): |
|
|
gt = batch_0_1[i:i+1] |
|
|
rec = x_rec_0_1[i:i+1] |
|
|
|
|
|
mse = F.mse_loss(gt, rec).item() |
|
|
psnr = safe_psnr(mse) |
|
|
lp = float(lpips_net(gt, rec, normalize=True).mean().item()) |
|
|
edge = sobel_edge_l1(gt, rec) |
|
|
|
|
|
per_model_metrics[model_name]["mse"] += mse |
|
|
per_model_metrics[model_name]["psnr"] += psnr |
|
|
per_model_metrics[model_name]["lpips"] += lp |
|
|
per_model_metrics[model_name]["edge"] += edge |
|
|
|
|
|
|
|
|
kl_pi = kl_divergence_per_image(mu, logvar) |
|
|
per_model_metrics[model_name]["kl"] += float(kl_pi.sum().item()) |
|
|
per_model_metrics[model_name]["count"] += B |
|
|
|
|
|
|
|
|
for name in per_model_metrics: |
|
|
c = max(1.0, per_model_metrics[name]["count"]) |
|
|
for k in ["mse", "psnr", "lpips", "edge", "kl"]: |
|
|
per_model_metrics[name][k] /= c |
|
|
|
|
|
|
|
|
per_model_latent_stats = {} |
|
|
for name, _ in vaes: |
|
|
if not buffers_zmodel[name]: |
|
|
continue |
|
|
Z = torch.cat(buffers_zmodel[name], dim=0) |
|
|
|
|
|
|
|
|
z_min = float(Z.min().item()) |
|
|
z_mean = float(Z.mean().item()) |
|
|
z_max = float(Z.max().item()) |
|
|
z_std = float(Z.std(unbiased=True).item()) |
|
|
|
|
|
|
|
|
Z_ch = flatten_channels(Z).numpy() |
|
|
C = Z_ch.shape[0] |
|
|
sk = np.zeros(C, dtype=np.float64) |
|
|
ku = np.zeros(C, dtype=np.float64) |
|
|
for c in range(C): |
|
|
v = Z_ch[c] |
|
|
sk[c] = float(skew(v, bias=False)) |
|
|
ku[c] = float(kurtosis(v, fisher=True, bias=False)) |
|
|
|
|
|
skew_min, skew_mean, skew_max = float(sk.min()), float(sk.mean()), float(sk.max()) |
|
|
kurt_min, kurt_mean, kurt_max = float(ku.min()), float(ku.mean()), float(ku.max()) |
|
|
mean_abs_skew = float(np.mean(np.abs(sk))) |
|
|
mean_abs_kurt = float(np.mean(np.abs(ku))) |
|
|
|
|
|
per_model_latent_stats[name] = { |
|
|
"Z_min": z_min, "Z_mean": z_mean, "Z_max": z_max, "Z_std": z_std, |
|
|
"skew_min": skew_min, "skew_mean": skew_mean, "skew_max": skew_max, |
|
|
"kurt_min": kurt_min, "kurt_mean": kurt_mean, "kurt_max": kurt_max, |
|
|
"mean_abs_skew": mean_abs_skew, "mean_abs_kurt": mean_abs_kurt, |
|
|
} |
|
|
|
|
|
|
|
|
print("\n=== Параметры нормализации латентов (как применялись) ===") |
|
|
for name, _ in vaes: |
|
|
if name not in norm_summaries: |
|
|
continue |
|
|
s = norm_summaries[name] |
|
|
print( |
|
|
f"{name:26s} | " |
|
|
f"shift_g={s['shift_global']:.6g} scale_g={s['scale_global']:.6g} | " |
|
|
f"shift_c[min/mean/max]=[{s['shift_channel_min']:.6g}, {s['shift_channel_mean']:.6g}, {s['shift_channel_max']:.6g}] | " |
|
|
f"scale_c[min/mean/max]=[{s['scale_channel_min']:.6g}, {s['scale_channel_mean']:.6g}, {s['scale_channel_max']:.6g}]" |
|
|
) |
|
|
|
|
|
|
|
|
print("\n=== Абсолютные метрики реконструкции и латентов ===") |
|
|
for name, _ in vaes: |
|
|
if name not in per_model_latent_stats: |
|
|
continue |
|
|
m = per_model_metrics[name] |
|
|
s = per_model_latent_stats[name] |
|
|
print( |
|
|
f"{name:26s} | " |
|
|
f"MSE={m['mse']:.3e} PSNR={m['psnr']:.2f} LPIPS={m['lpips']:.3f} Edge={m['edge']:.3f} KL={m['kl']:.3f} | " |
|
|
f"Z[min/mean/max/std]=[{s['Z_min']:.3f}, {s['Z_mean']:.3f}, {s['Z_max']:.3f}, {s['Z_std']:.3f}] | " |
|
|
f"Skew[min/mean/max]=[{s['skew_min']:.3f}, {s['skew_mean']:.3f}, {s['skew_max']:.3f}] | " |
|
|
f"Kurt[min/mean/max]=[{s['kurt_min']:.3f}, {s['kurt_mean']:.3f}, {s['kurt_max']:.3f}]" |
|
|
) |
|
|
|
|
|
|
|
|
baseline = vaes[0][0] |
|
|
print("\n=== Сравнение с первой моделью (проценты) ===") |
|
|
print(f"| {'Модель':26s} | {'MSE':>9s} | {'PSNR':>9s} | {'LPIPS':>9s} | {'Edge':>9s} | {'Skew|0':>9s} | {'Kurt|0':>9s} |") |
|
|
print(f"|{'-'*28}|{'-'*11}|{'-'*11}|{'-'*11}|{'-'*11}|{'-'*11}|{'-'*11}|") |
|
|
|
|
|
b_m = per_model_metrics[baseline] |
|
|
b_s = per_model_latent_stats[baseline] |
|
|
|
|
|
for name, _ in vaes: |
|
|
m = per_model_metrics[name] |
|
|
s = per_model_latent_stats[name] |
|
|
|
|
|
mse_pct = (b_m["mse"] / max(1e-12, m["mse"])) * 100.0 |
|
|
psnr_pct = (m["psnr"] / max(1e-12, b_m["psnr"])) * 100.0 |
|
|
lpips_pct= (b_m["lpips"] / max(1e-12, m["lpips"])) * 100.0 |
|
|
edge_pct = (b_m["edge"] / max(1e-12, m["edge"])) * 100.0 |
|
|
|
|
|
skew0_pct = (b_s["mean_abs_skew"] / max(1e-12, s["mean_abs_skew"])) * 100.0 |
|
|
kurt0_pct = (b_s["mean_abs_kurt"] / max(1e-12, s["mean_abs_kurt"])) * 100.0 |
|
|
|
|
|
if name == baseline: |
|
|
print(f"| {name:26s} | {'100%':>9s} | {'100%':>9s} | {'100%':>9s} | {'100%':>9s} | {'100%':>9s} | {'100%':>9s} |") |
|
|
else: |
|
|
print(f"| {name:26s} | {mse_pct:8.1f}% | {psnr_pct:8.1f}% | {lpips_pct:8.1f}% | {edge_pct:8.1f}% | {skew0_pct:8.1f}% | {kurt0_pct:8.1f}% |") |
|
|
|
|
|
|
|
|
last_name = vaes[-1][0] |
|
|
if buffers_zmodel[last_name]: |
|
|
Z = torch.cat(buffers_zmodel[last_name], dim=0) |
|
|
|
|
|
|
|
|
z_mean = float(Z.mean().item()) |
|
|
z_std = float(Z.std(unbiased=True).item()) |
|
|
correction_global = { |
|
|
"shift": -z_mean, |
|
|
"scale": (1.0 / z_std) if z_std > 1e-12 else 1.0 |
|
|
} |
|
|
|
|
|
|
|
|
Z_ch = flatten_channels(Z) |
|
|
ch_means_t = Z_ch.mean(dim=1) |
|
|
ch_stds_t = Z_ch.std(dim=1, unbiased=True) + 1e-12 |
|
|
ch_means = [float(x) for x in ch_means_t.tolist()] |
|
|
ch_stds = [float(x) for x in ch_stds_t.tolist()] |
|
|
|
|
|
correction_per_channel = [ |
|
|
{"shift": float(-m), "scale": float(1.0 / s)} |
|
|
for m, s in zip(ch_means, ch_stds) |
|
|
] |
|
|
|
|
|
print(f"\n=== Доп. коррекция для {last_name} (поверх VAE-нормализации) ===") |
|
|
print(f"global_correction = {correction_global}") |
|
|
print(f"channelwise_means = {ch_means}") |
|
|
print(f"channelwise_stds = {ch_stds}") |
|
|
print(f"channelwise_correction = {correction_per_channel}") |
|
|
|
|
|
|
|
|
json_path = os.path.join(SAMPLES_DIR, f"{sanitize_filename(last_name)}_correction.json") |
|
|
to_save = { |
|
|
"model_name": last_name, |
|
|
"vae_normalization_summary": norm_summaries.get(last_name, {}), |
|
|
"global_correction": correction_global, |
|
|
"per_channel_means": ch_means, |
|
|
"per_channel_stds": ch_stds, |
|
|
"per_channel_correction": correction_per_channel, |
|
|
"apply_order": { |
|
|
"forward": "z_model -> (z - global_shift)*global_scale -> (per-channel: (z - mean_c)/std_c)", |
|
|
"inverse": "z_corr -> (per-channel: z*std_c + mean_c) -> (z/global_scale + global_shift)" |
|
|
}, |
|
|
"note": "Эти коэффициенты рассчитаны по z_model (после встроенных VAE shift/scale), чтобы привести распределение к N(0,1)." |
|
|
} |
|
|
with open(json_path, "w", encoding="utf-8") as f: |
|
|
json.dump(to_save, f, ensure_ascii=False, indent=2) |
|
|
print("Corrections JSON saved to:", os.path.abspath(json_path)) |
|
|
|
|
|
print("\n✅ Готово. Сэмплы сохранены в:", os.path.abspath(SAMPLES_DIR)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|