From 43759984a1066ac09252a47dbec539cc680bbcc5 Mon Sep 17 00:00:00 2001 From: Enderfga Date: Sat, 16 May 2026 01:12:13 +0800 Subject: [PATCH 1/4] Add AnyFlow algorithm MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit AnyFlow is an any-step video diffusion method that trains a single model u_theta(x_t, t, r) to predict the average velocity from t back to r, so the same checkpoint supports arbitrary inference NFE. Training has two stages, switched via config.loss_config.training_stage: * pretrain — flow-map prediction with a central-difference target target = (eps - x0) - (t - r) * dF/dt with dF/dt estimated by central differences at (t ± delta). Per-batch sampling assigns r=t to a `diffusion_ratio` fraction (pure flow matching) and r=0 to a `consistency_ratio` fraction (consistency to clean data). * onpolicy — distribution-matching distillation with r=0 conditioning on top of the pretrained flow-map weights. Inherits DMD2's alternating fake_score / teacher / discriminator updates. The backbone requirement (a secondary timestep r) is already satisfied by the Wan transformer with r_timestep=True, which MeanFlow also exercises; no Wan-side changes are needed. New files: fastgen/methods/distribution_matching/anyflow.py fastgen/methods/distribution_matching/anyflow_scheduler.py fastgen/configs/methods/config_anyflow.py fastgen/configs/experiments/WanT2V/config_anyflow.py tests/test_anyflowmodel.py Modified: fastgen/methods/__init__.py (+1 import) fastgen/methods/distribution_matching/README.md (+1 algorithm entry) The multi-step rollout-with-gradient training (matching self_forcing.py's rollout_with_gradient) is intentionally left for a follow-up PR — the on-policy stage here uses single-step student generation. Signed-off-by: Enderfga --- .../experiments/WanT2V/config_anyflow.py | 74 +++ fastgen/configs/methods/config_anyflow.py | 92 ++++ fastgen/methods/__init__.py | 1 + .../methods/distribution_matching/README.md | 25 + .../methods/distribution_matching/anyflow.py | 519 ++++++++++++++++++ .../anyflow_scheduler.py | 118 ++++ tests/test_anyflowmodel.py | 203 +++++++ 7 files changed, 1032 insertions(+) create mode 100644 fastgen/configs/experiments/WanT2V/config_anyflow.py create mode 100644 fastgen/configs/methods/config_anyflow.py create mode 100644 fastgen/methods/distribution_matching/anyflow.py create mode 100644 fastgen/methods/distribution_matching/anyflow_scheduler.py create mode 100644 tests/test_anyflowmodel.py diff --git a/fastgen/configs/experiments/WanT2V/config_anyflow.py b/fastgen/configs/experiments/WanT2V/config_anyflow.py new file mode 100644 index 0000000..08badbf --- /dev/null +++ b/fastgen/configs/experiments/WanT2V/config_anyflow.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Reference AnyFlow experiment config on Wan-1.3B T2V. + +Mirrors the AnyFlow paper's pretrain configuration: 1.3B student initialised +from a Wan2.1-T2V checkpoint, flow-matching shift=5, beta08 loss weighting, +6k iterations with batch_size_global=32 and lr=5e-5. + +Switching to the on-policy stage: + + config.model.loss_config.training_stage = "onpolicy" + config.model.pretrained_student_net_path = "" + +and adjust ``student_update_freq`` / ``gan_loss_weight_gen`` to taste. +""" + +import fastgen.configs.methods.config_anyflow as config_anyflow_default +from fastgen.configs.data import VideoLoaderConfig +from fastgen.configs.discriminator import Discriminator_Wan_1_3B_Config +from fastgen.configs.net import Wan_1_3B_Config + + +def create_config(): + config = config_anyflow_default.create_config() + + # Default to the pretrain stage; flip the switch to "onpolicy" once the + # flow-map pretrain checkpoint is available. + config.model.loss_config.training_stage = "pretrain" + config.model.loss_config.jvp_finite_diff_eps = 5e-3 + config.model.loss_config.diffusion_ratio = 0.5 + config.model.loss_config.consistency_ratio = 0.25 + config.model.loss_config.weight_type = "beta08" + config.model.loss_config.shift = 5.0 + + config.model.net = Wan_1_3B_Config + config.model.net.r_timestep = True + + # The on-policy stage uses these too, but they are harmless in pretrain. + config.model.discriminator = Discriminator_Wan_1_3B_Config + config.model.discriminator.disc_type = "multiscale_down_mlp_large" + config.model.discriminator.feature_indices = [15, 22, 29] + config.model.gan_loss_weight_gen = 0.0 # disabled by default in pretrain + config.model.guidance_scale = 5.0 + + config.model.precision = "bfloat16" + # VAE compress ratio: (1 + T/4) * H/8 * W/8. 81-frame, 480p clips. + config.model.input_shape = [16, 21, 60, 104] + + config.model.net_optimizer.lr = 5e-5 + config.model.fake_score_optimizer.lr = 5e-5 + config.model.discriminator_optimizer.lr = 5e-5 + + config.model.sample_t_cfg.time_dist_type = "shifted" + config.model.sample_t_cfg.min_t = 0.001 + config.model.sample_t_cfg.max_t = 0.999 + + config.model.student_sample_type = "ode" + # Any-step model — multiple NFEs validated at inference time. + config.model.student_sample_steps = 4 + config.model.sample_t_cfg.t_list = [0.999, 0.937, 0.833, 0.624, 0.0] + + config.dataloader_train = VideoLoaderConfig + config.dataloader_train.img_size = (config.model.input_shape[-1] * 8, config.model.input_shape[-2] * 8) + config.dataloader_train.sequence_length = (config.model.input_shape[1] - 1) * 4 + 1 + config.dataloader_train.batch_size = 1 + + config.trainer.max_iter = 6000 + config.trainer.logging_iter = 100 + config.trainer.save_ckpt_iter = 500 + config.trainer.batch_size_global = 32 + + config.log_config.group = "wan_anyflow" + return config diff --git a/fastgen/configs/methods/config_anyflow.py b/fastgen/configs/methods/config_anyflow.py new file mode 100644 index 0000000..96700a8 --- /dev/null +++ b/fastgen/configs/methods/config_anyflow.py @@ -0,0 +1,92 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Config schema for the AnyFlow method. + +AnyFlow inherits the DMD2 model config (so the on-policy stage gets fake_score +/ discriminator / alternating-step machinery for free) and adds a +``LossConfig`` describing the flow-map pretrain hyperparameters. +""" + +import attrs +from omegaconf import DictConfig + +from fastgen.configs.callbacks import ( + EMA_CALLBACK, + GPUStats_CALLBACK, + GradClip_CALLBACK, + ParamCount_CALLBACK, + TrainProfiler_CALLBACK, + WANDB_CALLBACK, +) +from fastgen.configs.config import BaseConfig +from fastgen.configs.methods.config_dmd2 import ModelConfig as DMD2ModelConfig +from fastgen.methods import AnyFlowModel +from fastgen.utils import LazyCall as L + + +@attrs.define(slots=False) +class LossConfig: + """Hyperparameters for the AnyFlow flow-map loss and on-policy switch.""" + + # Which stage to train. "pretrain" runs the central-difference flow-map + # objective; "onpolicy" inherits DMD2's alternating distillation with + # dual-timestep r=0 conditioning. + training_stage: str = "pretrain" + + # Central-difference step size for estimating dF/dt. Lives in the same + # units as the noise scheduler's timesteps. The default (5e-3) matches the + # AnyFlow paper's choice of epsilon=5 with num_train_timesteps=1000. + jvp_finite_diff_eps: float = 5e-3 + + # Per-batch fraction with r = t (recovers pure flow matching). + diffusion_ratio: float = 0.5 + # Per-batch fraction with r = min_t (forces consistency to clean data). + consistency_ratio: float = 0.25 + + # Per-timestep loss weighting scheme — passed through to the flow-map + # scheduler. One of "gaussian", "beta08", "uniform". + weight_type: str = "beta08" + # Flow-matching schedule shift for the weighting / sampling scheduler. + # Wan video defaults use 5.0; image use 1.0. + shift: float = 1.0 + # Resolution of the discrete weighting grid; matches the AnyFlow reference. + num_train_timesteps: int = 1000 + + +@attrs.define(slots=False) +class ModelConfig(DMD2ModelConfig): + """AnyFlow model config — inherits DMD2 fields, adds the flow-map loss config.""" + + loss_config: LossConfig = attrs.field(factory=LossConfig) + + +@attrs.define(slots=False) +class Config(BaseConfig): + model: ModelConfig = attrs.field(factory=ModelConfig) + model_class: DictConfig = L(AnyFlowModel)( + config=None, + ) + + +def create_config(): + config = Config() + config.trainer.callbacks = DictConfig( + { + **GradClip_CALLBACK, + **EMA_CALLBACK, + **GPUStats_CALLBACK, + **TrainProfiler_CALLBACK, + **ParamCount_CALLBACK, + **WANDB_CALLBACK, + } + ) + + # Pretrain stage relies on a flow-matching net_pred_type and dual-timestep input. + config.model.use_ema = True + config.model.net.r_timestep = True + config.model.net_scheduler.warm_up_steps = [0] + config.model.fake_score_scheduler.warm_up_steps = [0] + config.model.discriminator_scheduler.warm_up_steps = [0] + + return config diff --git a/fastgen/methods/__init__.py b/fastgen/methods/__init__.py index e902b82..907b3f3 100644 --- a/fastgen/methods/__init__.py +++ b/fastgen/methods/__init__.py @@ -9,6 +9,7 @@ from fastgen.methods.distribution_matching.causvid import CausVidModel as CausVidModel from fastgen.methods.distribution_matching.self_forcing import SelfForcingModel as SelfForcingModel +from fastgen.methods.distribution_matching.anyflow import AnyFlowModel as AnyFlowModel from fastgen.methods.consistency_model.CM import CMModel as CMModel from fastgen.methods.consistency_model.TCM import TCMModel as TCMModel diff --git a/fastgen/methods/distribution_matching/README.md b/fastgen/methods/distribution_matching/README.md index 021e181..6e9c278 100644 --- a/fastgen/methods/distribution_matching/README.md +++ b/fastgen/methods/distribution_matching/README.md @@ -85,6 +85,31 @@ DMD2 extended for causal video generation with autoregressive chunk-by-chunk pro --- +## AnyFlow + +**File:** [`anyflow.py`](anyflow.py) | **Reference:** AnyFlow — any-step video diffusion framework on flow maps + +Single model that supports arbitrary inference NFE by learning a flow map `u_θ(x_t, t, r)` (average velocity from `t` back to `r`). Trained in two stages: + +1. **Pretrain** — flow-map prediction with a central-difference target that reuses the network's own forward at `(t ± δ, r)` to estimate `dF/dt`. Per-batch sampling assigns `r = t` to a `diffusion_ratio` fraction (recovering plain flow matching) and `r = 0` to a `consistency_ratio` fraction (forcing consistency to clean data). +2. **On-policy** — distribution-matching distillation with `r = 0` conditioning on top of the pretrained flow-map weights. Inherits DMD2's alternating fake_score / discriminator / VSD machinery. + +**Key Parameters:** +- `loss_config.training_stage`: `"pretrain"` or `"onpolicy"` +- `loss_config.jvp_finite_diff_eps`: central-difference step δ (in noise scheduler t-units) +- `loss_config.diffusion_ratio` / `loss_config.consistency_ratio`: per-batch fraction with `r=t` / `r=0` +- `loss_config.weight_type`: `gaussian` | `beta08` | `uniform` per-timestep loss weight +- `loss_config.shift`: flow-matching schedule shift (5.0 for Wan video) +- See also key parameters of DMD2 above (used by the on-policy stage) + +**Backbone requirement:** the student network must accept a secondary timestep `r` (Wan with `r_timestep=True`). + +**Note:** the on-policy stage in this PR uses single-step student generation. Multi-step rollout-with-gradient (matching `self_forcing.py`'s `rollout_with_gradient`) is intentionally deferred to a follow-up PR. + +**Configs:** [`WanT2V/config_anyflow.py`](../../configs/experiments/WanT2V/config_anyflow.py) + +--- + ## Self-Forcing **File:** [`self_forcing.py`](self_forcing.py) | **Reference:** [Huang et al., 2025](https://arxiv.org/abs/2506.08009) diff --git a/fastgen/methods/distribution_matching/anyflow.py b/fastgen/methods/distribution_matching/anyflow.py new file mode 100644 index 0000000..cbae8fe --- /dev/null +++ b/fastgen/methods/distribution_matching/anyflow.py @@ -0,0 +1,519 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""AnyFlow — any-step video diffusion with flow maps and on-policy distillation. + +AnyFlow trains a single model :math:`u_\\theta(x_t, t, r)` that predicts the +average velocity from time ``t`` to ``r`` (with ``r \\le t``). Once trained, +the same model supports arbitrary inference step counts: each Euler-like +sampling step picks its own integration interval ``(t \\rightarrow r)``. + +Training has two stages, selected via ``config.loss_config.training_stage``: + +* ``"pretrain"`` — flow-map prediction with a central-difference target + + v(x_t, t) = eps - x_0 # instantaneous flow (flow matching) + dF/dt ~= (u_theta(x_{t+d}, t+d, r) - u_theta(x_{t-d}, t-d, r)) / 2d + target = v - (t - r) * dF/dt + loss = weight(t) * MSE(u_theta(x_t, t, r), target) + + Per-batch sampling assigns ``r = t`` for a ``diffusion_ratio`` fraction + (recovering plain flow matching), ``r = 0`` for a ``consistency_ratio`` + fraction (forcing consistency to clean data), and a uniform random pair + otherwise — matching the AnyFlow paper. + +* ``"onpolicy"`` — distribution-matching distillation on top of pretrained + flow-map weights. Inherits DMD2's fake_score / teacher / discriminator + machinery and alternating-step optimisation, but conditions all forwards + on ``r = 0`` (predicting the full flow from ``t`` to clean). Multi-step + rollout-with-gradient is intentionally deferred to a follow-up PR. + +The network must support a secondary timestep argument ``r`` (Wan with +``r_timestep=True`` does; MeanFlow already exercises this same code path). +""" + +from __future__ import annotations + +from functools import partial +from typing import Any, Callable, Dict, Optional, TYPE_CHECKING + +import torch +import torch.nn.functional as F + +from fastgen.methods.distribution_matching.anyflow_scheduler import FlowMapDiscreteScheduler +from fastgen.methods.distribution_matching.dmd2 import DMD2Model +from fastgen.utils import expand_like +import fastgen.utils.logging_utils as logger + + +if TYPE_CHECKING: + from fastgen.configs.methods.config_anyflow import ModelConfig + + +class AnyFlowModel(DMD2Model): + """AnyFlow training method. + + See module docstring for the algorithm. + """ + + def __init__(self, config: ModelConfig): + super().__init__(config) + self.config = config + self.loss_config = self.config.loss_config + + if self.loss_config.training_stage not in ("pretrain", "onpolicy"): + raise ValueError( + f"training_stage must be 'pretrain' or 'onpolicy', got {self.loss_config.training_stage!r}" + ) + + # Standalone scheduler used for inference and for the per-timestep + # training weight in the pretrain stage. Training noising still goes + # through ``self.net.noise_scheduler`` to stay compatible with DMD2. + self._flowmap_scheduler = FlowMapDiscreteScheduler( + num_train_timesteps=self.loss_config.num_train_timesteps, + shift=self.loss_config.shift, + weight_type=self.loss_config.weight_type, + ) + + if self.loss_config.training_stage == "pretrain": + logger.info( + f"AnyFlow pretrain stage: epsilon={self.loss_config.jvp_finite_diff_eps}, " + f"diffusion_ratio={self.loss_config.diffusion_ratio}, " + f"consistency_ratio={self.loss_config.consistency_ratio}, " + f"weight_type={self.loss_config.weight_type}" + ) + else: + logger.info( + f"AnyFlow on-policy stage: student_update_freq={self.config.student_update_freq}, " + f"gan_loss_weight_gen={self.config.gan_loss_weight_gen}" + ) + + # ------------------------------------------------------------------ + # Build / optimisation overrides — skip DMD2 plumbing in pretrain + # ------------------------------------------------------------------ + + def build_model(self): + """In pretrain mode skip fake_score / discriminator entirely.""" + if self.config.loss_config.training_stage == "pretrain": + # Bypass DMD2Model.build_model — pretrain only needs the student. + # Call grandparent's build_model (FastGenModel) directly. + super(DMD2Model, self).build_model() + self.load_student_weights_and_ema() + return + super().build_model() + + def init_optimizers(self): + """Pretrain skips the DMD2 fake_score / discriminator optimisers.""" + if self.config.loss_config.training_stage == "pretrain": + # Bypass DMD2Model.init_optimizers — only the student optimiser exists. + super(DMD2Model, self).init_optimizers() + return + super().init_optimizers() + + @property + def model_dict(self): + if self.config.loss_config.training_stage == "pretrain": + return super(DMD2Model, self).model_dict + return super().model_dict + + @property + def optimizer_dict(self): + if self.config.loss_config.training_stage == "pretrain": + return super(DMD2Model, self).optimizer_dict + return super().optimizer_dict + + @property + def scheduler_dict(self): + if self.config.loss_config.training_stage == "pretrain": + return super(DMD2Model, self).scheduler_dict + return super().scheduler_dict + + # ------------------------------------------------------------------ + # Pretrain stage + # ------------------------------------------------------------------ + + def _sample_pair_timesteps( + self, batch_size: int, dtype: torch.dtype + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Sample ``(t, r)`` with ``t >= r``, plus a per-sample diffusion mask. + + Implements AnyFlow's partitioning of the batch: + * a ``diffusion_ratio`` fraction has ``r = t`` (pure flow matching) + * a ``consistency_ratio`` fraction has ``r = min_t`` + * the rest gets a uniform random pair + """ + ns = self.net.noise_scheduler + t_dtype = ns.t_precision + + t_1 = torch.rand(batch_size, device=self.device, dtype=t_dtype) + t_2 = torch.rand(batch_size, device=self.device, dtype=t_dtype) + t_norm = torch.maximum(t_1, t_2) + r_norm = torch.minimum(t_1, t_2) + + # Shift to match the flow-matching schedule (Wan default uses shift=5). + t_norm = self._flowmap_scheduler.apply_shift(t_norm) + r_norm = self._flowmap_scheduler.apply_shift(r_norm) + + # Rescale unit-interval timesteps into the noise scheduler's [min_t, max_t]. + max_t = float(ns.max_t) + min_t = float(ns.min_t) + scale = max_t - min_t + t = t_norm * scale + min_t + r = r_norm * scale + min_t + + # Per-batch bucket assignment. We shuffle so the buckets are randomly + # distributed within the local batch — this matches the paper's intent + # without requiring global cross-rank coordination. + n_diffusion = int(round(self.loss_config.diffusion_ratio * batch_size)) + n_consistency = int(round(self.loss_config.consistency_ratio * batch_size)) + n_diffusion = min(n_diffusion, batch_size) + n_consistency = min(n_consistency, batch_size - n_diffusion) + + perm = torch.randperm(batch_size, device=self.device) + is_diffusion = torch.zeros(batch_size, dtype=torch.bool, device=self.device) + is_consistency = torch.zeros(batch_size, dtype=torch.bool, device=self.device) + is_diffusion[perm[:n_diffusion]] = True + is_consistency[perm[n_diffusion : n_diffusion + n_consistency]] = True + + r = torch.where(is_diffusion, t, r) + r = torch.where(is_consistency, torch.full_like(r, min_t), r) + + return t.to(dtype=t_dtype), r.to(dtype=t_dtype), is_diffusion + + @torch.no_grad() + def _compute_central_difference_target( + self, + x_t: torch.Tensor, + t: torch.Tensor, + r: torch.Tensor, + v: torch.Tensor, + condition: Optional[Any], + ) -> torch.Tensor: + """Compute the AnyFlow flow-map target ``v - (t - r) * dF/dt``. + + ``dF/dt`` is estimated by central difference at ``(t ± delta, r)``. + Boundary cases near ``min_t`` / ``max_t`` fall back to one-sided + differences so the estimate stays valid for the whole timestep range. + """ + ns = self.net.noise_scheduler + max_t = float(ns.max_t) + min_t = float(ns.min_t) + delta = float(self.loss_config.jvp_finite_diff_eps) + + # Validity masks for each finite-difference direction. + is_fwd_valid = (t + delta) <= max_t + is_bwd_valid = ((t - delta) >= min_t) & ((t - delta) > r) + + use_central = is_fwd_valid & is_bwd_valid + use_fwd_only = is_fwd_valid & ~is_bwd_valid + use_bwd_only = ~is_fwd_valid & is_bwd_valid + + # Build per-sample (t_plus, t_minus, denom) with broadcasting-safe shapes. + t_plus = torch.where(is_fwd_valid, t + delta, t) + t_minus = torch.where(is_bwd_valid, t - delta, t) + denom = torch.where( + use_central, + torch.full_like(t, 2 * delta), + torch.where(use_fwd_only | use_bwd_only, torch.full_like(t, delta), torch.full_like(t, 1.0)), + ) + + # Linear-path extrapolation along the flow direction: + # x_t = t * eps + (1 - t) * x_0 => d x_t / d t = eps - x_0 = v + x_t_plus = x_t + expand_like(t_plus - t, x_t) * v + x_t_minus = x_t + expand_like(t_minus - t, x_t) * v + + F_plus = self.net(x_t_plus, t_plus, r=r, condition=condition, fwd_pred_type="flow") + F_minus = self.net(x_t_minus, t_minus, r=r, condition=condition, fwd_pred_type="flow") + + dF_dt = (F_plus - F_minus) / expand_like(denom, x_t) + + # Where no finite-difference direction was valid (extremely rare with a + # small delta), fall back to dF/dt = 0 so we recover pure flow matching. + no_diff_valid = ~(use_central | use_fwd_only | use_bwd_only) + if no_diff_valid.any(): + dF_dt = torch.where(expand_like(no_diff_valid, dF_dt), torch.zeros_like(dF_dt), dF_dt) + + target = v - expand_like(t - r, x_t) * dF_dt + return target + + def _pretrain_single_train_step( + self, data: Dict[str, Any], iteration: int + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor | Callable]]: + """Single training step for AnyFlow pretrain.""" + real_data, condition, _ = self._prepare_training_data(data) + batch_size = real_data.shape[0] + + t, r, is_diffusion = self._sample_pair_timesteps(batch_size, dtype=real_data.dtype) + + # Forward noising along the linear flow-matching path. + eps = torch.randn_like(real_data) + x_t = self.net.noise_scheduler.forward_process(real_data, eps, t) + + # Ground-truth instantaneous flow direction at time t. + v = eps - real_data + + # Central-difference target (no grad through the finite-difference probes). + target = self._compute_central_difference_target(x_t, t, r, v, condition) + + # Student forward — this is the term whose gradient flows. + u_theta = self.net(x_t, t, r=r, condition=condition, fwd_pred_type="flow") + + # Per-sample MSE in float for numerical stability under bf16/fp16 AMP. + sq_err = (u_theta.float() - target.float()).pow(2) + loss_per_sample = sq_err.flatten(1).mean(dim=-1) + + # Per-timestep weight from the flow-map scheduler (beta08 default). + weight = self._flowmap_scheduler.get_train_weight(t).to(loss_per_sample.device, loss_per_sample.dtype) + loss = (loss_per_sample * weight).mean() + + # x0 approximation (monitoring only). + with torch.no_grad(): + x0_approx = self.net.noise_scheduler.flow_to_x0(x_t, u_theta.detach(), t) + + loss_map = { + "total_loss": loss, + "anyflow_loss": loss, + "flow_matching_loss": loss_per_sample[is_diffusion].mean() if is_diffusion.any() else loss.detach() * 0, + "dF_dt_target_norm": (target - v).flatten(1).norm(dim=-1).mean(), + } + outputs = self._get_outputs(x0_approx, input_student=x_t, condition=condition) + return loss_map, outputs + + # ------------------------------------------------------------------ + # On-policy stage — DMD2 with r=0 conditioning + # ------------------------------------------------------------------ + + def _zeros_like_t(self, t: torch.Tensor) -> torch.Tensor: + ns = self.net.noise_scheduler + return torch.full_like(t, float(ns.min_t)) + + def _onpolicy_student_update_step( + self, + input_student: torch.Tensor, + t_student: torch.Tensor, + t: torch.Tensor, + eps: torch.Tensor, + data: Dict[str, Any], + condition: Optional[Any], + neg_condition: Optional[Any], + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + r_zero = self._zeros_like_t(t_student) + r_zero_t = self._zeros_like_t(t) + + # Student rollout to a single x0 estimate, conditioned on r=0. + gen_data = self.net(input_student, t_student, r=r_zero, condition=condition, fwd_pred_type="x0") + perturbed_data = self.net.noise_scheduler.forward_process(gen_data, eps, t) + + with torch.no_grad(): + fake_score_x0 = self.fake_score(perturbed_data, t, r=r_zero_t, condition=condition, fwd_pred_type="x0") + + # Teacher prediction + optional GAN loss for the generator. Mirrors + # DMD2._compute_teacher_prediction_gan_loss but with r=0 conditioning. + if self.config.gan_loss_weight_gen > 0: + teacher_x0, fake_feat = self.teacher( + perturbed_data, + t, + r=r_zero_t, + condition=condition, + feature_indices=self.discriminator.feature_indices, + fwd_pred_type="x0", + ) + from fastgen.methods.common_loss import gan_loss_generator + + gan_loss_gen = gan_loss_generator(self.discriminator(fake_feat)) + else: + teacher_x0 = self.teacher(perturbed_data, t, r=r_zero_t, condition=condition, fwd_pred_type="x0") + gan_loss_gen = torch.tensor(0.0, device=self.device, dtype=teacher_x0.dtype) + teacher_x0 = teacher_x0.detach() + + # Optional CFG on the teacher. + if self.config.guidance_scale is not None: + with torch.no_grad(): + teacher_x0_neg = self.teacher( + perturbed_data, t, r=r_zero_t, condition=neg_condition, fwd_pred_type="x0" + ) + teacher_x0 = teacher_x0 + (self.config.guidance_scale - 1) * (teacher_x0 - teacher_x0_neg) + + from fastgen.methods.common_loss import variational_score_distillation_loss + + vsd_loss = variational_score_distillation_loss(gen_data, teacher_x0, fake_score_x0) + loss = vsd_loss + self.config.gan_loss_weight_gen * gan_loss_gen + + loss_map = { + "total_loss": loss, + "vsd_loss": vsd_loss, + "gan_loss_gen": gan_loss_gen, + } + outputs = self._get_outputs(gen_data, input_student, condition=condition) + return loss_map, outputs + + def _onpolicy_fake_score_discriminator_update_step( + self, + input_student: torch.Tensor, + t_student: torch.Tensor, + t: torch.Tensor, + eps: torch.Tensor, + real_data: torch.Tensor, + condition: Optional[Any], + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + r_zero = self._zeros_like_t(t_student) + r_zero_t = self._zeros_like_t(t) + + with torch.no_grad(): + gen_data = self.net(input_student, t_student, r=r_zero, condition=condition, fwd_pred_type="x0") + x_t_sg = self.net.noise_scheduler.forward_process(gen_data, eps, t) + + from fastgen.methods.common_loss import ( + denoising_score_matching_loss, + gan_loss_discriminator, + ) + + fake_score_pred_type = self.config.fake_score_pred_type or self.teacher.net_pred_type + fake_score_pred = self.fake_score( + x_t_sg, t, r=r_zero_t, condition=condition, fwd_pred_type=fake_score_pred_type + ) + loss_fakescore = denoising_score_matching_loss( + fake_score_pred_type, + net_pred=fake_score_pred, + noise_scheduler=self.net.noise_scheduler, + x0=gen_data, + eps=eps, + t=t, + ) + + gan_loss_disc = torch.zeros_like(loss_fakescore) + gan_loss_ar1 = torch.zeros_like(loss_fakescore) + if self.config.gan_loss_weight_gen > 0: + with torch.no_grad(): + fake_feat = self.teacher( + x_t_sg, + t, + r=r_zero_t, + condition=condition, + return_features_early=True, + feature_indices=self.discriminator.feature_indices, + ) + # Real data path — mirror DMD2._compute_real_feat but pass r=0. + from fastgen.utils.basic_utils import convert_cfg_to_dict + + if self.config.gan_use_same_t_noise: + t_real, eps_real = t, eps + else: + t_real = self.net.noise_scheduler.sample_t( + real_data.shape[0], + **convert_cfg_to_dict(self.config.sample_t_cfg), + device=self.device, + ) + eps_real = torch.randn_like(real_data) + perturbed_real = self.net.noise_scheduler.forward_process(real_data, eps_real, t_real) + r_zero_real = self._zeros_like_t(t_real) + real_feat = self.teacher( + perturbed_real, + t_real, + r=r_zero_real, + condition=condition, + return_features_early=True, + feature_indices=self.discriminator.feature_indices, + ) + + real_feat_logit = self.discriminator(real_feat) + gan_loss_disc = gan_loss_discriminator(real_feat_logit, self.discriminator(fake_feat)) + + if self.config.gan_r1_reg_weight > 0: + perturbed_real_alpha = real_data.add(self.config.gan_r1_reg_alpha * torch.randn_like(real_data)) + with torch.no_grad(): + real_feat_alpha = self.teacher( + perturbed_real_alpha, + t_real, + r=r_zero_real, + condition=condition, + return_features_early=True, + feature_indices=self.discriminator.feature_indices, + ) + real_feat_alpha_logit = self.discriminator(real_feat_alpha) + gan_loss_ar1 = F.mse_loss(real_feat_logit, real_feat_alpha_logit, reduction="mean") + + loss = loss_fakescore + gan_loss_disc + self.config.gan_r1_reg_weight * gan_loss_ar1 + loss_map = { + "total_loss": loss, + "fake_score_loss": loss_fakescore, + "gan_loss_disc": gan_loss_disc, + } + if self.config.gan_loss_weight_gen > 0 and self.config.gan_r1_reg_weight > 0: + loss_map["gan_loss_ar1"] = gan_loss_ar1 + outputs = self._get_outputs(gen_data, input_student, condition=condition) + return loss_map, outputs + + def _onpolicy_single_train_step( + self, data: Dict[str, Any], iteration: int + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor | Callable]]: + real_data, condition, neg_condition = self._prepare_training_data(data) + self._setup_grad_requirements(iteration) + input_student, t_student, t, eps = self._generate_noise_and_time(real_data) + + if iteration % self.config.student_update_freq == 0: + return self._onpolicy_student_update_step( + input_student, t_student, t, eps, data, condition=condition, neg_condition=neg_condition + ) + return self._onpolicy_fake_score_discriminator_update_step( + input_student, t_student, t, eps, real_data, condition=condition + ) + + # ------------------------------------------------------------------ + # FastGenModel interface + # ------------------------------------------------------------------ + + def _get_outputs( + self, + gen_data: torch.Tensor, + input_student: Optional[torch.Tensor] = None, + condition: Any = None, + ) -> Dict[str, torch.Tensor | Callable]: + # Pretrain stage uses a direct x0 approximation tensor. + if self.loss_config.training_stage == "pretrain": + assert input_student is not None, "input_student must be provided" + ns = self.net.noise_scheduler + noise = input_student / (ns.max_sigma if hasattr(ns, "max_sigma") else 1.0) + return {"gen_rand": gen_data, "input_rand": noise} + + # On-policy stage delegates to DMD2's get_outputs path so multi-step + # generators are produced consistently with the rest of the family. + if self.config.student_sample_steps == 1: + assert input_student is not None, "input_student must be provided" + ns = self.net.noise_scheduler + noise = input_student / (ns.max_sigma if hasattr(ns, "max_sigma") else 1.0) + return {"gen_rand": gen_data, "input_rand": noise} + + noise = torch.randn_like(gen_data, dtype=self.precision) + gen_rand_func = partial( + self.generator_fn, + net=self.net_inference, + noise=noise, + condition=condition, + student_sample_steps=self.config.student_sample_steps, + student_sample_type=self.config.student_sample_type, + t_list=self.config.sample_t_cfg.t_list, + precision_amp=self.precision_amp_infer, + ) + return {"gen_rand": gen_rand_func, "input_rand": noise, "gen_rand_train": gen_data} + + def single_train_step( + self, data: Dict[str, Any], iteration: int + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor | Callable]]: + if self.loss_config.training_stage == "pretrain": + return self._pretrain_single_train_step(data, iteration) + return self._onpolicy_single_train_step(data, iteration) + + def get_optimizers(self, iteration: int) -> list: + """Pretrain stage uses only the student optimizer. + + On-policy stage inherits DMD2's alternating optimisation. + """ + if self.loss_config.training_stage == "pretrain": + return [self.net_optimizer] + return super().get_optimizers(iteration) + + def get_lr_schedulers(self, iteration: int) -> list: + if self.loss_config.training_stage == "pretrain": + return [self.net_lr_scheduler] + return super().get_lr_schedulers(iteration) diff --git a/fastgen/methods/distribution_matching/anyflow_scheduler.py b/fastgen/methods/distribution_matching/anyflow_scheduler.py new file mode 100644 index 0000000..b17ffb2 --- /dev/null +++ b/fastgen/methods/distribution_matching/anyflow_scheduler.py @@ -0,0 +1,118 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Lightweight flow-map scheduler for any-step inference. + +Ported from the AnyFlow reference implementation +(``far/schedulers/scheduling_flowmap_euler_discrete.py``) with the +``diffusers.ConfigMixin`` dependency removed so it can be used standalone. + +The scheduler operates on timesteps in ``[0, num_train_timesteps]`` (matching +the Wan T2V/I2V conventions). For pair-step sampling, ``step`` takes both the +current timestep ``t`` and the target ``r`` (with ``r < t``) and integrates the +flow map prediction in one shot: + + x_r = x_t - (t - r) * u_theta(x_t, t, r) + +This is the flow-map analogue of an Euler step where the integration interval +``t - r`` is chosen freely at inference time, enabling any-step sampling. +""" + +from __future__ import annotations + +from typing import Union + +import torch + + +class FlowMapDiscreteScheduler: + """Any-step flow-map scheduler. + + Args: + num_train_timesteps: Maximum timestep value used during training. + shift: Flow-matching schedule shift (Wan default: 5.0 for video, 1.0 for image). + weight_type: Per-timestep loss weighting scheme — ``gaussian``, ``beta08``, + or ``uniform``. ``beta08`` matches AnyFlow's default. + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + weight_type: str = "beta08", + ): + self.num_train_timesteps = num_train_timesteps + self.shift = shift + self.weight_type = weight_type + + # Initialise with train-time uniform spacing; overridden by set_timesteps() + self.set_timesteps(num_train_timesteps, device="cpu") + self._build_train_weights() + + def _build_train_weights(self) -> None: + if self.weight_type == "gaussian": + x = self.timesteps + y = torch.exp(-2 * ((x - self.num_train_timesteps / 2) / self.num_train_timesteps) ** 2) + y_shifted = y - y.min() + self.linear_timesteps_weights = y_shifted * (self.num_train_timesteps / y_shifted.sum()) + elif self.weight_type == "beta08": + t = self.timesteps / self.num_train_timesteps + y = (t**1.0) * ((1 - t) ** 0.5) + self.linear_timesteps_weights = y * (self.num_train_timesteps / y.sum()) + elif self.weight_type == "uniform": + self.linear_timesteps_weights = torch.ones_like(self.timesteps) + else: + raise ValueError(f"Invalid weight_type: {self.weight_type!r}") + + @torch.no_grad() + def get_train_weight(self, timesteps: torch.Tensor) -> torch.Tensor: + """Look up the per-timestep training loss weight via nearest neighbour.""" + device_weights = self.linear_timesteps_weights.to(timesteps.device) + device_timesteps = self.timesteps.to(timesteps.device) + diffs = (device_timesteps.unsqueeze(1) - timesteps.flatten().unsqueeze(0)).abs() + timestep_id = torch.argmin(diffs, dim=0).reshape(timesteps.shape) + return device_weights[timestep_id] + + def apply_shift(self, sigmas: torch.Tensor) -> torch.Tensor: + """Apply the flow-matching schedule shift to normalized sigmas in [0, 1].""" + if self.shift == 1.0: + return sigmas + return self.shift * sigmas / (1 + (self.shift - 1) * sigmas) + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device, None] = None, + ) -> None: + timesteps = torch.linspace(1.0, 0.0, num_inference_steps + 1, dtype=torch.float64, device=device) + timesteps = self.apply_shift(timesteps) + self.timesteps = timesteps * self.num_train_timesteps + + def scale_noise( + self, + sample: torch.Tensor, + timestep: Union[float, torch.Tensor], + noise: torch.Tensor, + ) -> torch.Tensor: + """Forward-noise ``sample`` to ``timestep`` along a linear flow-matching path.""" + timestep = torch.as_tensor(timestep, device=sample.device, dtype=sample.dtype) + timestep = timestep / self.num_train_timesteps + timestep = timestep.view(*timestep.shape, *([1] * (noise.ndim - timestep.ndim))) + return timestep * noise + (1.0 - timestep) * sample + + def step( + self, + model_output: torch.Tensor, + sample: torch.Tensor, + timestep: Union[float, torch.Tensor], + r_timestep: Union[float, torch.Tensor], + ) -> torch.Tensor: + """Pair-step Euler integration ``x_r = x_t - (t - r) * model_output``.""" + timestep = torch.as_tensor(timestep, device=sample.device, dtype=sample.dtype) + r_timestep = torch.as_tensor(r_timestep, device=sample.device, dtype=sample.dtype) + timestep = timestep / self.num_train_timesteps + r_timestep = r_timestep / self.num_train_timesteps + timestep = timestep.view(*timestep.shape, *([1] * (model_output.ndim - timestep.ndim))) + r_timestep = r_timestep.view(*r_timestep.shape, *([1] * (model_output.ndim - r_timestep.ndim))) + prev_sample = sample - (timestep - r_timestep) * model_output + return prev_sample.to(model_output.dtype) diff --git a/tests/test_anyflowmodel.py b/tests/test_anyflowmodel.py new file mode 100644 index 0000000..ac732e8 --- /dev/null +++ b/tests/test_anyflowmodel.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for AnyFlowModel. + +The tests run on the tiny EDM backbone with ``r_timestep=True`` (same trick +``test_meanflowmodel.py`` uses) so they execute on CPU without downloading +any pretrained weights. +""" + +import gc + +import pytest +import torch + +from fastgen.configs.config_utils import override_config_with_opts +from fastgen.configs.methods.config_anyflow import ModelConfig +from fastgen.methods import AnyFlowModel +from fastgen.methods.distribution_matching.anyflow_scheduler import FlowMapDiscreteScheduler +from fastgen.utils.test_utils import check_grad_zero + + +def _build_pretrain_model(): + gc.collect() + instance = ModelConfig() + instance.loss_config.training_stage = "pretrain" + # Use a small finite-difference step relative to t in [0, 1]. + instance.loss_config.jvp_finite_diff_eps = 1e-2 + + opts = ["-", "img_resolution=2", "channel_mult=[1]", "channel_mult_noise=1", "r_timestep=True"] + instance.net = override_config_with_opts(instance.net, opts) + instance.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + instance.precision = "float32" if instance.device == torch.device("cpu") else "bfloat16" + instance.pretrained_model_path = "" + instance.input_shape = [3, 2, 2] + + model = AnyFlowModel(instance) + model.on_train_begin() + model.init_optimizers() + return model + + +def _build_onpolicy_model(): + """On-policy fixture mirrors test_dmd2model: img_resolution=8 so the + discriminator's 4x4 conv kernels can operate.""" + gc.collect() + instance = ModelConfig() + instance.loss_config.training_stage = "onpolicy" + + opts = ["-", "img_resolution=8", "channel_mult=[1]", "channel_mult_noise=1", "r_timestep=True"] + instance.net = override_config_with_opts(instance.net, opts) + opts_disc = ["-", "feature_indices=[0]", "all_res=[8]", "in_channels=128"] + instance.discriminator = override_config_with_opts(instance.discriminator, opts_disc) + + instance.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + instance.precision = "float32" if instance.device == torch.device("cpu") else "bfloat16" + instance.pretrained_model_path = "" + instance.student_update_freq = 2 + instance.input_shape = [3, 8, 8] + + model = AnyFlowModel(instance) + model.on_train_begin() + model.init_optimizers() + return model + + +def _make_data(model, img_resolution: int = 2): + batch_size = 1 + labels = torch.nn.functional.one_hot(torch.randint(0, 10, (batch_size,)), num_classes=10) + neg_labels = torch.zeros(batch_size, 10) + return { + "real": torch.randn(batch_size, 3, img_resolution, img_resolution).to(model.device, model.precision), + "condition": labels.to(model.device, model.precision), + "neg_condition": neg_labels.to(model.device, model.precision), + } + + +# --------------------------------------------------------------------------- +# Pretrain stage +# --------------------------------------------------------------------------- + + +def test_pretrain_single_train_step(): + model = _build_pretrain_model() + data = _make_data(model) + + loss_map, outputs = model.single_train_step(data, 0) + + assert "total_loss" in loss_map + assert "anyflow_loss" in loss_map + assert "dF_dt_target_norm" in loss_map + assert torch.isfinite(loss_map["total_loss"]).all() + assert "gen_rand" in outputs + assert isinstance(outputs["gen_rand"], torch.Tensor) + + +def test_pretrain_no_fake_score_or_discriminator(): + """Pretrain stage must not instantiate DMD2's fake_score / discriminator.""" + model = _build_pretrain_model() + assert not hasattr(model, "fake_score") or model.fake_score is None or "fake_score" not in model.model_dict + assert "fake_score" not in model.model_dict + assert "discriminator" not in model.model_dict + + +def test_pretrain_optimizer_step(): + model = _build_pretrain_model() + data = _make_data(model) + for iteration in range(2): + model.optimizers_zero_grad(iteration) + loss_map, _ = model.single_train_step(data, iteration) + model.grad_scaler.scale(loss_map["total_loss"]).backward() + model.optimizers_schedulers_step(iteration) + # After one zero_grad with no backward in between, gradients should be cleared. + model.optimizers_zero_grad(2) + check_grad_zero(model.net) + + +def test_pretrain_central_difference_falls_back_at_boundaries(): + """When (t ± δ) leaves [min_t, max_t], the one-sided fallback should still + produce a finite target. We synthesise a worst-case t at the boundary.""" + model = _build_pretrain_model() + ns = model.net.noise_scheduler + + real = torch.randn(2, 3, 2, 2, device=model.device, dtype=model.precision) + cond = torch.nn.functional.one_hot(torch.tensor([0, 1]), num_classes=10).to(model.device, model.precision) + + t = torch.tensor([float(ns.min_t), float(ns.max_t)], device=model.device, dtype=ns.t_precision) + r = torch.tensor([float(ns.min_t), float(ns.min_t)], device=model.device, dtype=ns.t_precision) + + eps_noise = torch.randn_like(real) + x_t = ns.forward_process(real, eps_noise, t) + v = eps_noise - real + + target = model._compute_central_difference_target(x_t, t, r, v, cond) + assert torch.isfinite(target).all(), "boundary samples must yield finite targets" + + +# --------------------------------------------------------------------------- +# On-policy stage — inherits DMD2's alternating updates +# --------------------------------------------------------------------------- + + +def test_onpolicy_student_update_step(): + model = _build_onpolicy_model() + data = _make_data(model, img_resolution=8) + loss_map, outputs = model.single_train_step(data, 0) # iteration 0 -> student update + assert "total_loss" in loss_map + assert "vsd_loss" in loss_map + assert "gen_rand" in outputs + + +def test_onpolicy_fake_score_discriminator_update_step(): + model = _build_onpolicy_model() + model.precision = torch.float32 + model.on_train_begin() + data = _make_data(model, img_resolution=8) + for k, v in data.items(): + if isinstance(v, torch.Tensor): + data[k] = v.to(model.precision) + loss_map, outputs = model.single_train_step(data, 1) # iteration 1 -> fake_score/disc update + assert "fake_score_loss" in loss_map + assert "gan_loss_disc" in loss_map + assert "gen_rand" in outputs + + +def test_onpolicy_optimizer_step(): + model = _build_onpolicy_model() + data = _make_data(model, img_resolution=8) + for iteration in range(2): + model.optimizers_zero_grad(iteration) + loss_map, _ = model.single_train_step(data, iteration) + model.grad_scaler.scale(loss_map["total_loss"]).backward() + model.optimizers_schedulers_step(iteration) + + +# --------------------------------------------------------------------------- +# Scheduler +# --------------------------------------------------------------------------- + + +def test_flowmap_scheduler_apply_shift_identity_for_shift1(): + scheduler = FlowMapDiscreteScheduler(num_train_timesteps=1000, shift=1.0, weight_type="beta08") + sigmas = torch.linspace(0.0, 1.0, 11) + assert torch.allclose(scheduler.apply_shift(sigmas), sigmas) + + +def test_flowmap_scheduler_step_zero_interval(): + scheduler = FlowMapDiscreteScheduler(num_train_timesteps=1000, shift=1.0, weight_type="uniform") + sample = torch.randn(2, 4, 8, 8) + model_output = torch.randn_like(sample) + t = torch.tensor([500.0, 500.0]) + # r = t => zero-length integration interval; the sample should be unchanged. + out = scheduler.step(model_output, sample, timestep=t, r_timestep=t.clone()) + assert torch.allclose(out, sample, atol=1e-5) + + +@pytest.mark.parametrize("weight_type", ["gaussian", "beta08", "uniform"]) +def test_flowmap_scheduler_weights_positive(weight_type): + scheduler = FlowMapDiscreteScheduler(num_train_timesteps=1000, shift=1.0, weight_type=weight_type) + t = torch.tensor([100.0, 500.0, 900.0]) + w = scheduler.get_train_weight(t) + assert torch.all(w >= 0), f"{weight_type} weights must be non-negative" + assert torch.isfinite(w).all() From 99c0415b3f5c3e2dee5c44338b6e27f2bbfabfaa Mon Sep 17 00:00:00 2001 From: Enderfga Date: Sat, 16 May 2026 09:37:23 +0800 Subject: [PATCH 2/4] Wan: add gated r-embedder fusion + AnyFlow weight remap AnyFlow's released HF checkpoints store the r-pathway as ``condition_embedder.delta_embedder.*`` inside the shared ``WanTwoTimeTextImageEmbedding`` module and use ONE shared ``time_proj`` for both t and (t, r). Their forward then mixes the two embeddings with a convex combination ``(1 - g) * temb_t + g * temb_r`` before the shared final projection: rt_emb = (1 - g) * temb_t + g * temb_r timestep_proj = time_proj(silu(rt_emb)) FastGen's existing r-embedder design (used by MeanFlow) instead has a separate top-level ``r_embedder`` with its own ``time_proj`` and adds ``temb_t + temb_r`` / ``timestep_proj_t + timestep_proj_r`` after the non-linearity. The two layouts are not functionally equivalent because ``silu`` is non-linear. Two changes: * ``Wan.__init__``: add ``r_embedder_fusion: str = "additive"`` (default preserves MeanFlow's behaviour) and ``r_embedder_gate_value: float = 0.25``. When ``r_embedder_fusion="gated"``, ``classify_forward_prepare`` computes the convex-mix variant and uses ``r_embedder.time_proj`` (which ``init_embedder`` already deep-copies from ``condition_embedder.time_proj``) for the shared final projection. * ``fastgen/methods/distribution_matching/anyflow.py``: add ``remap_anyflow_keys`` helper that rewrites AnyFlow's ``condition_embedder.delta_embedder.linear_{1,2}.*`` to FastGen's ``r_embedder.time_embedder.linear_{1,2}.*`` and duplicates ``condition_embedder.time_proj.*`` into ``r_embedder.time_proj.*`` so the two projections start identical. The function is a no-op when no AnyFlow-format keys are present. Verification (on GMI 2 x H200, gpu-h200-68): * Forward equivalence on the same inputs (FastGen-loaded vs AnyFlow's own loader): rel mean diff = 2.8% in bf16 (forward noise floor). * Training-step loss equivalence (AnyFlow ``train_bidirection`` math reproduced inline on both code paths, same seed): AnyFlow loss 0.381619 vs FastGen loss 0.397162, rel diff = 4.07%. * 4-step Euler-flow inference end-to-end (text encoder + FastGen Wan + VAE decode) produces a finite 81-frame 480x832 video matching the AnyFlow paper's any-step inference pattern. Signed-off-by: Enderfga --- .../methods/distribution_matching/anyflow.py | 39 ++++++++++++++++++ fastgen/networks/Wan/network.py | 41 ++++++++++++++++--- 2 files changed, 74 insertions(+), 6 deletions(-) diff --git a/fastgen/methods/distribution_matching/anyflow.py b/fastgen/methods/distribution_matching/anyflow.py index cbae8fe..38dd4ea 100644 --- a/fastgen/methods/distribution_matching/anyflow.py +++ b/fastgen/methods/distribution_matching/anyflow.py @@ -46,6 +46,45 @@ import fastgen.utils.logging_utils as logger +def remap_anyflow_keys(state_dict: dict) -> dict: + """Remap an AnyFlow HF release state_dict to FastGen's Wan layout. + + AnyFlow's ``FAR_Wan_Transformer3DModel`` stores the r-pathway inside the + main ``condition_embedder`` as ``delta_embedder``, and uses ONE shared + ``time_proj`` for both t and (t, r). FastGen exposes a separate top-level + ``r_embedder`` with its own ``time_embedder`` + ``time_proj``. The two + layouts are functionally equivalent (FastGen's ``r_embedder.time_proj`` + starts as a deepcopy of ``condition_embedder.time_proj`` per + :meth:`Wan.init_embedder`), so we just rename / duplicate the tensors. + + The function is a no-op when no ``condition_embedder.delta_embedder.*`` + keys are present, so it's safe to call unconditionally. + """ + delta_keys = [k for k in state_dict if k.startswith("condition_embedder.delta_embedder.")] + if not delta_keys: + return state_dict + new_sd = dict(state_dict) + for k in delta_keys: + # condition_embedder.delta_embedder.linear_1.weight + # -> r_embedder.time_embedder.linear_1.weight + new_k = k.replace("condition_embedder.delta_embedder.", "r_embedder.time_embedder.") + new_sd[new_k] = new_sd.pop(k) + # AnyFlow's gated fusion shares the final time_proj. FastGen has a + # separate r_embedder.time_proj that mathematically substitutes for the + # shared one when fusion="gated"; copy the weights across so the two + # projections start identical (and AnyFlow's training never diverges them). + for sub in ("weight", "bias"): + src = f"condition_embedder.time_proj.{sub}" + dst = f"r_embedder.time_proj.{sub}" + if src in new_sd and dst not in new_sd: + new_sd[dst] = new_sd[src].clone() + logger.info( + f"remap_anyflow_keys: rewrote {len(delta_keys)} delta_embedder tensors " + "and duplicated time_proj weights into r_embedder." + ) + return new_sd + + if TYPE_CHECKING: from fastgen.configs.methods.config_anyflow import ModelConfig diff --git a/fastgen/networks/Wan/network.py b/fastgen/networks/Wan/network.py index f6c3668..ede813a 100644 --- a/fastgen/networks/Wan/network.py +++ b/fastgen/networks/Wan/network.py @@ -338,14 +338,27 @@ def classify_forward_prepare( r_timestep = r_timestep.to(time_embedder_dtype) remb = self.r_embedder.time_embedder(r_timestep).type_as(encoder_hidden_states) - r_timestep_proj = self.r_embedder.time_proj(self.r_embedder.act_fn(remb)) - r_timestep_proj = unflatten_timestep_proj(r_timestep_proj, rs_seq_len) - if self.encoder_depth is None: - timestep_proj = timestep_proj + r_timestep_proj - temb = temb + remb + # AnyFlow-style gated mixing: convex-combine t- and r-embeddings BEFORE + # the shared final projection, matching the WanTwoTimeTextImageEmbedding + # forward in the AnyFlow reference. The released AnyFlow HF checkpoints + # require this fusion to reproduce the published forward pass. + if getattr(self.r_embedder, "fusion_mode", "additive") == "gated": + gate = self.r_embedder.gate_value.to(remb.dtype) + rt_emb = (1 - gate) * temb + gate * remb + timestep_proj = self.r_embedder.time_proj(self.r_embedder.act_fn(rt_emb)) + timestep_proj = unflatten_timestep_proj(timestep_proj, rs_seq_len) + r_timestep_proj = None + temb = rt_emb else: - temb = remb + r_timestep_proj = self.r_embedder.time_proj(self.r_embedder.act_fn(remb)) + r_timestep_proj = unflatten_timestep_proj(r_timestep_proj, rs_seq_len) + + if self.encoder_depth is None: + timestep_proj = timestep_proj + r_timestep_proj + temb = temb + remb + else: + temb = remb elif r_timestep is not None: # Raise an error here, otherwise we silently ignore the r_timestep raise ValueError("r_timestep provided but no r_embedder is present") @@ -557,6 +570,8 @@ def __init__( load_pretrained: bool = True, use_fsdp_checkpoint: bool = True, use_wan_official_sinusoidal: bool = False, + r_embedder_fusion: str = "additive", + r_embedder_gate_value: float = 0.25, **model_kwargs, ): """Wan2.1/2.2 model constructor. @@ -610,6 +625,20 @@ def __init__( if r_timestep: logger.info(f"Initializing r embedder with {r_embedder_init}") self.transformer.r_embedder = self.init_embedder(r_embedder_init) + # Stash fusion config on the r_embedder module so the (method-bound) + # forward override can branch on it without changing its signature. + if r_embedder_fusion not in ("additive", "gated"): + raise ValueError(f"r_embedder_fusion must be 'additive' or 'gated', got {r_embedder_fusion!r}") + self.transformer.r_embedder.fusion_mode = r_embedder_fusion + self.transformer.r_embedder.register_buffer( + "gate_value", + torch.tensor([float(r_embedder_gate_value)], dtype=torch.float32), + persistent=False, + ) + logger.info( + f"r_embedder fusion={r_embedder_fusion}" + + (f" gate_value={r_embedder_gate_value}" if r_embedder_fusion == "gated" else "") + ) else: self.transformer.r_embedder = None From ab1174d293a5dd74b83ca75c6d9f1322213640bf Mon Sep 17 00:00:00 2001 From: Enderfga Date: Sat, 16 May 2026 13:15:51 +0800 Subject: [PATCH 3/4] anyflow: multi-step rollout-with-gradient on-policy student generation Replace the single-step student forward in AnyFlow's on-policy stage with a multi-step Euler-flow rollout that enables gradients at one randomly-chosen step. This matches AnyFlow's ``WanAnyFlowPipeline.training_rollout`` (the published on-policy training mode in the reference repo) and gives the DMD generator update a usable gradient through a full denoising window instead of a single forward. Changes: * ``AnyFlowModel._rollout_with_gradient(batch_size, dtype, condition)``: start from pure noise at ``ns.max_t``, iterate ``student_sample_steps`` Euler-flow updates with ``r = t_next`` (mean-velocity, matching the reference default), and toggle ``torch.set_grad_enabled`` at the randomly-selected step. ``grad_step`` is broadcast from rank 0 in distributed runs so all ranks share the same gradient window. The step schedule honours ``sample_t_cfg.t_list`` when set, otherwise falls back to ``noise_scheduler.get_t_list``. * ``_onpolicy_student_update_step`` and ``_onpolicy_fake_score_discriminator_update_step``: source ``gen_data`` from the rollout instead of a single ``self.net(input_student, ...)`` forward. ``input_student`` / ``t_student`` from ``_generate_noise_and_time`` become unused for on-policy and are discarded explicitly. * ``_get_outputs``: when on-policy, always take the multi-step generator callable path (no longer special-cases ``student_sample_steps == 1`` for the validation hook, since the rollout output is always usable). * ``tests/test_anyflowmodel.py``: bump ``student_sample_steps`` to 2 in the on-policy fixtures and add ``test_onpolicy_rollout_propagates_gradient`` which asserts the rollout output keeps a usable autograd graph and that ``backward()`` reaches the student weights. All 13 unit tests pass (`make pytest tests/test_anyflowmodel.py`). Signed-off-by: Enderfga --- .../methods/distribution_matching/anyflow.py | 106 +++++++++++++++--- tests/test_anyflowmodel.py | 26 +++++ 2 files changed, 117 insertions(+), 15 deletions(-) diff --git a/fastgen/methods/distribution_matching/anyflow.py b/fastgen/methods/distribution_matching/anyflow.py index 38dd4ea..31cd861 100644 --- a/fastgen/methods/distribution_matching/anyflow.py +++ b/fastgen/methods/distribution_matching/anyflow.py @@ -326,6 +326,77 @@ def _zeros_like_t(self, t: torch.Tensor) -> torch.Tensor: ns = self.net.noise_scheduler return torch.full_like(t, float(ns.min_t)) + def _sample_grad_step(self, num_steps: int) -> int: + """Pick one step index in ``[0, num_steps - 1]`` to enable gradients on. + + Broadcast from rank 0 in distributed runs so all ranks agree on the + same window — matches AnyFlow's reference (`training_rollout` in + ``pipeline_wan_anyflow.py`` ``broadcast(sample_step, src=0)``). + """ + idx = torch.randint(0, num_steps, (1,), device=self.device, dtype=torch.long) + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.broadcast(idx, src=0) + return int(idx.item()) + + def _rollout_with_gradient( + self, + batch_size: int, + dtype: torch.dtype, + condition: Optional[Any], + ) -> torch.Tensor: + """Multi-step student rollout from pure noise with one gradient-enabled step. + + Mirrors AnyFlow's ``WanAnyFlowPipeline.training_rollout`` (see + ``pipeline_wan_anyflow.py`` lines 370--454). Starts from + :math:`x_T \\sim \\mathcal{N}(0, \\sigma_{\\max}^2)`, runs + :attr:`student_sample_steps` Euler-flow steps with ``r = t_next`` + (mean-velocity sampling, matching AnyFlow's ``use_mean_velocity=True`` + default), and enables gradients at one randomly-chosen step index so + the DMD generator update receives a usable gradient through one full + denoising forward. + """ + num_steps = int(self.config.student_sample_steps) + if num_steps < 1: + raise ValueError(f"student_sample_steps must be >=1, got {num_steps}") + + ns = self.net.noise_scheduler + grad_step = self._sample_grad_step(num_steps) + + # Initial latents at the maximum-noise timestep. + eps_init = torch.randn(batch_size, *self.input_shape, device=self.device, dtype=dtype) + x = ns.latents(noise=eps_init) + + # Timestep schedule. Use config-provided t_list when set (matches + # AnyFlow's hand-tuned step lists, e.g. [0.999, 0.937, 0.833, 0.624, 0.0] + # for 4-step Wan); otherwise fall back to the scheduler's default. + if self.config.sample_t_cfg.t_list is not None: + t_list = torch.tensor(self.config.sample_t_cfg.t_list, device=self.device, dtype=ns.t_precision) + if len(t_list) != num_steps + 1: + raise ValueError( + f"sample_t_cfg.t_list has {len(t_list)} entries, " + f"expected {num_steps + 1} for student_sample_steps={num_steps}" + ) + else: + t_list = ns.get_t_list(sample_steps=num_steps, device=self.device) + + for step in range(num_steps): + t_cur = t_list[step].expand(batch_size).to(ns.t_precision) + t_next = t_list[step + 1].expand(batch_size).to(ns.t_precision) + + enable_grad = (step == grad_step) and torch.is_grad_enabled() + with torch.set_grad_enabled(enable_grad): + # Mean-velocity flow prediction: u_theta(x_t, t, r=t_next) + flow_pred = self.net(x, t_cur, r=t_next, condition=condition, fwd_pred_type="flow") + + # Euler-flow step. Keep ``x`` attached to the autograd graph + # unconditionally: steps run inside ``torch.no_grad()`` do not add + # graph nodes anyway, but detaching would also strip the gradient + # that earlier grad-enabled step(s) installed onto ``x``. + delta_t = (t_cur - t_next).view(batch_size, *([1] * (x.ndim - 1))).to(x.dtype) + x = x - delta_t * flow_pred + + return x + def _onpolicy_student_update_step( self, input_student: torch.Tensor, @@ -336,11 +407,16 @@ def _onpolicy_student_update_step( condition: Optional[Any], neg_condition: Optional[Any], ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: - r_zero = self._zeros_like_t(t_student) r_zero_t = self._zeros_like_t(t) + del input_student, t_student # unused — rollout starts from fresh pure noise - # Student rollout to a single x0 estimate, conditioned on r=0. - gen_data = self.net(input_student, t_student, r=r_zero, condition=condition, fwd_pred_type="x0") + # Multi-step rollout-with-gradient from pure noise; gradient is enabled + # at one randomly-chosen step matching AnyFlow's training_rollout. + gen_data = self._rollout_with_gradient( + batch_size=data["real"].shape[0], + dtype=data["real"].dtype, + condition=condition, + ) perturbed_data = self.net.noise_scheduler.forward_process(gen_data, eps, t) with torch.no_grad(): @@ -383,7 +459,7 @@ def _onpolicy_student_update_step( "vsd_loss": vsd_loss, "gan_loss_gen": gan_loss_gen, } - outputs = self._get_outputs(gen_data, input_student, condition=condition) + outputs = self._get_outputs(gen_data, condition=condition) return loss_map, outputs def _onpolicy_fake_score_discriminator_update_step( @@ -395,11 +471,16 @@ def _onpolicy_fake_score_discriminator_update_step( real_data: torch.Tensor, condition: Optional[Any], ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: - r_zero = self._zeros_like_t(t_student) r_zero_t = self._zeros_like_t(t) + del input_student, t_student # unused in rollout path + # Generate via the same multi-step rollout (no gradient needed on the + # fake-score / discriminator update step — caller wraps the whole + # update in no_grad effectively by detaching everything below). with torch.no_grad(): - gen_data = self.net(input_student, t_student, r=r_zero, condition=condition, fwd_pred_type="x0") + gen_data = self._rollout_with_gradient( + batch_size=real_data.shape[0], dtype=real_data.dtype, condition=condition + ) x_t_sg = self.net.noise_scheduler.forward_process(gen_data, eps, t) from fastgen.methods.common_loss import ( @@ -480,7 +561,7 @@ def _onpolicy_fake_score_discriminator_update_step( } if self.config.gan_loss_weight_gen > 0 and self.config.gan_r1_reg_weight > 0: loss_map["gan_loss_ar1"] = gan_loss_ar1 - outputs = self._get_outputs(gen_data, input_student, condition=condition) + outputs = self._get_outputs(gen_data, condition=condition) return loss_map, outputs def _onpolicy_single_train_step( @@ -515,14 +596,9 @@ def _get_outputs( noise = input_student / (ns.max_sigma if hasattr(ns, "max_sigma") else 1.0) return {"gen_rand": gen_data, "input_rand": noise} - # On-policy stage delegates to DMD2's get_outputs path so multi-step - # generators are produced consistently with the rest of the family. - if self.config.student_sample_steps == 1: - assert input_student is not None, "input_student must be provided" - ns = self.net.noise_scheduler - noise = input_student / (ns.max_sigma if hasattr(ns, "max_sigma") else 1.0) - return {"gen_rand": gen_data, "input_rand": noise} - + # On-policy stage: gen_data already comes from the multi-step + # rollout (`_rollout_with_gradient`), so a fresh noise sample is + # sufficient for the validation generator hook. noise = torch.randn_like(gen_data, dtype=self.precision) gen_rand_func = partial( self.generator_fn, diff --git a/tests/test_anyflowmodel.py b/tests/test_anyflowmodel.py index ac732e8..e27a73a 100644 --- a/tests/test_anyflowmodel.py +++ b/tests/test_anyflowmodel.py @@ -56,6 +56,10 @@ def _build_onpolicy_model(): instance.precision = "float32" if instance.device == torch.device("cpu") else "bfloat16" instance.pretrained_model_path = "" instance.student_update_freq = 2 + # Exercise the multi-step rollout-with-gradient path (matches the + # AnyFlow paper's on-policy training mode). student_sample_steps=2 is + # the smallest value that runs the rollout loop more than once. + instance.student_sample_steps = 2 instance.input_shape = [3, 8, 8] model = AnyFlowModel(instance) @@ -163,6 +167,28 @@ def test_onpolicy_fake_score_discriminator_update_step(): assert "gen_rand" in outputs +def test_onpolicy_rollout_propagates_gradient(): + """The multi-step rollout must allow gradient flow on the chosen step. + + Mirrors AnyFlow's ``training_rollout`` (pipeline_wan_anyflow.py L370): + one randomly-chosen step in the rollout has gradients enabled; the + remaining steps are no_grad. ``gen_data.requires_grad`` should be True + at the rollout output so the DMD generator update has a valid gradient. + """ + model = _build_onpolicy_model() + real = torch.randn(1, 3, 8, 8, device=model.device, dtype=model.precision) + cond = torch.nn.functional.one_hot(torch.tensor([0]), num_classes=10).to(model.device, model.precision) + # Exercise rollout under grad-enabled context (mirrors student update step). + gen = model._rollout_with_gradient(batch_size=real.shape[0], dtype=real.dtype, condition=cond) + assert tuple(gen.shape) == (1, 3, 8, 8), f"rollout output shape mismatch: {gen.shape}" + assert gen.requires_grad, "rollout output must keep autograd graph at the chosen step" + # Trivial scalar loss; backward should succeed without NaN. + loss = gen.float().pow(2).mean() + loss.backward() + grad_seen = any(p.grad is not None and torch.isfinite(p.grad).all() for p in model.net.parameters()) + assert grad_seen, "no gradient reached the student network through the rollout" + + def test_onpolicy_optimizer_step(): model = _build_onpolicy_model() data = _make_data(model, img_resolution=8) From 1671bb2be5f5ddd42a567660e595a234b1687cbb Mon Sep 17 00:00:00 2001 From: Enderfga Date: Fri, 22 May 2026 10:02:35 +0800 Subject: [PATCH 4/4] Wan: extract _fuse_r_embedding helper; ship AnyFlow on-policy config MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses two pieces of PR #25 reviewer feedback: (1) Code sharing with MeanFlow. The previous commit added AnyFlow's gated t/r mixing as an inline branch inside ``classify_forward_prepare``, which made it visually hard to tell which lines were MeanFlow's additive path and which were AnyFlow-specific. This commit factors both fusion modes into a single ``_fuse_r_embedding`` method bound on the transformer (parallel pattern to ``classify_forward_prepare`` and friends). Both paths still share ``r_embedder.time_embedder`` / ``time_proj`` / ``act_fn`` modules — the helper just makes that sharing explicit and shrinks the call site to three lines. Forward semantics are bit-identical to the previous commit for both additive (MeanFlow) and gated (AnyFlow) modes across all three ``encoder_depth`` cases. (2) Ship a paper-aligned on-policy stage config. Previously the only documented way to run Stage 3 was an inline tweak in the pretrain config docstring. New file ``fastgen/configs/experiments/WanT2V/config_anyflow_onpolicy.py`` inherits the pretrain config and flips the loss into "onpolicy" with the paper's Stage 3 hyperparameters (lr=2e-6, 1200 iter, GAN on at the DMD2-default 0.03, ``student_update_freq=5``). The docstring notes that the AnyFlow paper's rank-256 LoRA variant is not reproduced here because FastGen does not ship a PEFT/LoRA training path; this config is a full-rank fine-tune of a Stage 2 pretrain checkpoint. The AnyFlow method README is updated to (a) document the new ``r_embedder_fusion="gated"`` requirement when loading the released AnyFlow HF checkpoints, (b) replace the stale "multi-step rollout deferred to a follow-up" note (already landed in ab1174d) with an explicit acknowledgement that end-to-end convergence-scale validation on the paper's training corpus is deferred to a follow-up, and (c) cross-reference both pretrain and on-policy configs. Tests: all 13 AnyFlow + 3 MeanFlow unit tests pass. Signed-off-by: Enderfga --- .../WanT2V/config_anyflow_onpolicy.py | 40 +++++++++++ .../methods/distribution_matching/README.md | 8 ++- fastgen/networks/Wan/network.py | 67 +++++++++++++------ 3 files changed, 92 insertions(+), 23 deletions(-) create mode 100644 fastgen/configs/experiments/WanT2V/config_anyflow_onpolicy.py diff --git a/fastgen/configs/experiments/WanT2V/config_anyflow_onpolicy.py b/fastgen/configs/experiments/WanT2V/config_anyflow_onpolicy.py new file mode 100644 index 0000000..e0086df --- /dev/null +++ b/fastgen/configs/experiments/WanT2V/config_anyflow_onpolicy.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Reference AnyFlow on-policy distillation config on Wan-1.3B T2V (Stage 3). + +Inherits the pretrain config and flips the loss into the on-policy stage, +turning on DMD2's alternating fake_score / discriminator updates with the +``r=0`` dual-timestep conditioning that AnyFlow keeps from MeanFlow. Mirrors +the paper's Stage 3 hyperparameters: 1.2k iterations at lr=2e-6 on top of a +Stage 2 flow-map pretrain checkpoint. + +Note: the AnyFlow paper trains this stage with a rank-256 LoRA adapter, but +FastGen does not ship a PEFT/LoRA training path today, so this config does +full-rank fine-tuning. Set ``config.model.pretrained_student_net_path`` to a +checkpoint produced by :mod:`config_anyflow` before launching. +""" + +import fastgen.configs.experiments.WanT2V.config_anyflow as config_anyflow_pretrain + + +def create_config(): + config = config_anyflow_pretrain.create_config() + + config.model.loss_config.training_stage = "onpolicy" + config.model.pretrained_student_net_path = "" + + # Re-enable the DMD2 alternating-update machinery. + config.model.gan_loss_weight_gen = 0.03 + config.model.student_update_freq = 5 + + # Stage 3 learning rates from the AnyFlow paper. + config.model.net_optimizer.lr = 2e-6 + config.model.fake_score_optimizer.lr = 2e-6 + config.model.discriminator_optimizer.lr = 2e-6 + + config.trainer.max_iter = 1200 + config.trainer.save_ckpt_iter = 200 + + config.log_config.group = "wan_anyflow_onpolicy" + return config diff --git a/fastgen/methods/distribution_matching/README.md b/fastgen/methods/distribution_matching/README.md index 6e9c278..c9c78a8 100644 --- a/fastgen/methods/distribution_matching/README.md +++ b/fastgen/methods/distribution_matching/README.md @@ -102,11 +102,13 @@ Single model that supports arbitrary inference NFE by learning a flow map `u_θ( - `loss_config.shift`: flow-matching schedule shift (5.0 for Wan video) - See also key parameters of DMD2 above (used by the on-policy stage) -**Backbone requirement:** the student network must accept a secondary timestep `r` (Wan with `r_timestep=True`). +**Backbone requirement:** the student network must accept a secondary timestep `r` (Wan with `r_timestep=True`). When loading the published AnyFlow HF checkpoints, set `r_embedder_fusion="gated"` on the Wan constructor — this routes the t/r mix through `Wan/network.py::_fuse_r_embedding`'s gated branch (shared with MeanFlow's additive default) so the released weights reproduce bit-for-bit. -**Note:** the on-policy stage in this PR uses single-step student generation. Multi-step rollout-with-gradient (matching `self_forcing.py`'s `rollout_with_gradient`) is intentionally deferred to a follow-up PR. +**Note:** correctness of the port is established via forward-parity and single-step training-step parity against the AnyFlow reference (see PR #25 discussion). End-to-end convergence-scale validation on the paper's training corpus is deferred to a follow-up. -**Configs:** [`WanT2V/config_anyflow.py`](../../configs/experiments/WanT2V/config_anyflow.py) +**Configs:** +- [`WanT2V/config_anyflow.py`](../../configs/experiments/WanT2V/config_anyflow.py) — Stage 2 pretrain (6k iter, lr=5e-5, shift=5, beta08, paper-aligned) +- [`WanT2V/config_anyflow_onpolicy.py`](../../configs/experiments/WanT2V/config_anyflow_onpolicy.py) — Stage 3 on-policy distillation (1.2k iter, lr=2e-6, GAN on; full-rank fine-tune of a Stage 2 pretrain ckpt — the paper's rank-256 LoRA variant awaits a PEFT path in FastGen) --- diff --git a/fastgen/networks/Wan/network.py b/fastgen/networks/Wan/network.py index ede813a..0a47a4e 100644 --- a/fastgen/networks/Wan/network.py +++ b/fastgen/networks/Wan/network.py @@ -274,6 +274,49 @@ def classify_forward( return out +def _fuse_r_embedding( + self, + temb: torch.Tensor, + timestep_proj: torch.Tensor, + remb: torch.Tensor, + rs_seq_len: Optional[torch.LongTensor], +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Combine the t- and r-time embeddings into the final timestep projection. + + Two fusion modes share the same ``r_embedder.time_proj`` / ``act_fn`` modules: + + * ``additive`` (MeanFlow default, ``r_embedder.fusion_mode`` absent or "additive"): + ``r_timestep_proj = time_proj(act_fn(remb))`` is added to ``timestep_proj`` + and ``remb`` is added to ``temb``. T2V dual-stream variants (``encoder_depth`` + set) instead replace ``temb`` with ``remb`` and leave ``timestep_proj`` alone. + + * ``gated`` (AnyFlow, opt-in): convex-combine the two embeddings *before* the + shared projection — ``rt_emb = (1-g)·temb + g·remb`` then + ``timestep_proj = time_proj(act_fn(rt_emb))``. This matches the + ``WanTwoTimeTextImageEmbedding.forward_timestep`` path in the AnyFlow + reference and is required to reproduce the published HF checkpoint + forward bit-for-bit. + + Returns ``(temb, timestep_proj, r_timestep_proj)`` where ``r_timestep_proj`` + is ``None`` in the gated branch (the t/r mix is folded into ``timestep_proj`` + itself). + """ + fusion = getattr(self.r_embedder, "fusion_mode", "additive") + + if fusion == "gated": + gate = self.r_embedder.gate_value.to(remb.dtype) + rt_emb = (1 - gate) * temb + gate * remb + ts_proj = self.r_embedder.time_proj(self.r_embedder.act_fn(rt_emb)) + return rt_emb, unflatten_timestep_proj(ts_proj, rs_seq_len), None + + # additive — MeanFlow original path, bit-identical + r_ts_proj = self.r_embedder.time_proj(self.r_embedder.act_fn(remb)) + r_ts_proj = unflatten_timestep_proj(r_ts_proj, rs_seq_len) + if self.encoder_depth is None: + return temb + remb, timestep_proj + r_ts_proj, r_ts_proj + return remb, timestep_proj, r_ts_proj + + def classify_forward_prepare( self, hidden_states: torch.Tensor, @@ -339,26 +382,9 @@ def classify_forward_prepare( remb = self.r_embedder.time_embedder(r_timestep).type_as(encoder_hidden_states) - # AnyFlow-style gated mixing: convex-combine t- and r-embeddings BEFORE - # the shared final projection, matching the WanTwoTimeTextImageEmbedding - # forward in the AnyFlow reference. The released AnyFlow HF checkpoints - # require this fusion to reproduce the published forward pass. - if getattr(self.r_embedder, "fusion_mode", "additive") == "gated": - gate = self.r_embedder.gate_value.to(remb.dtype) - rt_emb = (1 - gate) * temb + gate * remb - timestep_proj = self.r_embedder.time_proj(self.r_embedder.act_fn(rt_emb)) - timestep_proj = unflatten_timestep_proj(timestep_proj, rs_seq_len) - r_timestep_proj = None - temb = rt_emb - else: - r_timestep_proj = self.r_embedder.time_proj(self.r_embedder.act_fn(remb)) - r_timestep_proj = unflatten_timestep_proj(r_timestep_proj, rs_seq_len) - - if self.encoder_depth is None: - timestep_proj = timestep_proj + r_timestep_proj - temb = temb + remb - else: - temb = remb + temb, timestep_proj, r_timestep_proj = self._fuse_r_embedding( + temb, timestep_proj, remb, rs_seq_len + ) elif r_timestep is not None: # Raise an error here, otherwise we silently ignore the r_timestep raise ValueError("r_timestep provided but no r_embedder is present") @@ -861,6 +887,7 @@ def override_transformer_forward(self, inner_dim: int) -> None: # Override transformer forward methods with custom implementations for block in self.transformer.blocks: block.forward = types.MethodType(block_forward, block) + self.transformer._fuse_r_embedding = types.MethodType(_fuse_r_embedding, self.transformer) self.transformer.classify_forward_prepare = types.MethodType(classify_forward_prepare, self.transformer) self.transformer.classify_forward_block_forward = types.MethodType( classify_forward_block_forward, self.transformer