Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions fastgen/configs/experiments/WanT2V/config_anyflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Reference AnyFlow experiment config on Wan-1.3B T2V.

Mirrors the AnyFlow paper's pretrain configuration: 1.3B student initialised
from a Wan2.1-T2V checkpoint, flow-matching shift=5, beta08 loss weighting,
6k iterations with batch_size_global=32 and lr=5e-5.

Switching to the on-policy stage:

config.model.loss_config.training_stage = "onpolicy"
config.model.pretrained_student_net_path = "<path-to-pretrain-ckpt>"

and adjust ``student_update_freq`` / ``gan_loss_weight_gen`` to taste.
"""

import fastgen.configs.methods.config_anyflow as config_anyflow_default
from fastgen.configs.data import VideoLoaderConfig
from fastgen.configs.discriminator import Discriminator_Wan_1_3B_Config
from fastgen.configs.net import Wan_1_3B_Config


def create_config():
config = config_anyflow_default.create_config()

# Default to the pretrain stage; flip the switch to "onpolicy" once the
# flow-map pretrain checkpoint is available.
config.model.loss_config.training_stage = "pretrain"
config.model.loss_config.jvp_finite_diff_eps = 5e-3
config.model.loss_config.diffusion_ratio = 0.5
config.model.loss_config.consistency_ratio = 0.25
config.model.loss_config.weight_type = "beta08"
config.model.loss_config.shift = 5.0

config.model.net = Wan_1_3B_Config
config.model.net.r_timestep = True

# The on-policy stage uses these too, but they are harmless in pretrain.
config.model.discriminator = Discriminator_Wan_1_3B_Config
config.model.discriminator.disc_type = "multiscale_down_mlp_large"
config.model.discriminator.feature_indices = [15, 22, 29]
config.model.gan_loss_weight_gen = 0.0 # disabled by default in pretrain
config.model.guidance_scale = 5.0

config.model.precision = "bfloat16"
# VAE compress ratio: (1 + T/4) * H/8 * W/8. 81-frame, 480p clips.
config.model.input_shape = [16, 21, 60, 104]

config.model.net_optimizer.lr = 5e-5
config.model.fake_score_optimizer.lr = 5e-5
config.model.discriminator_optimizer.lr = 5e-5

config.model.sample_t_cfg.time_dist_type = "shifted"
config.model.sample_t_cfg.min_t = 0.001
config.model.sample_t_cfg.max_t = 0.999

config.model.student_sample_type = "ode"
# Any-step model — multiple NFEs validated at inference time.
config.model.student_sample_steps = 4
config.model.sample_t_cfg.t_list = [0.999, 0.937, 0.833, 0.624, 0.0]

config.dataloader_train = VideoLoaderConfig
config.dataloader_train.img_size = (config.model.input_shape[-1] * 8, config.model.input_shape[-2] * 8)
config.dataloader_train.sequence_length = (config.model.input_shape[1] - 1) * 4 + 1
config.dataloader_train.batch_size = 1

config.trainer.max_iter = 6000
config.trainer.logging_iter = 100
config.trainer.save_ckpt_iter = 500
config.trainer.batch_size_global = 32

config.log_config.group = "wan_anyflow"
return config
40 changes: 40 additions & 0 deletions fastgen/configs/experiments/WanT2V/config_anyflow_onpolicy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Reference AnyFlow on-policy distillation config on Wan-1.3B T2V (Stage 3).

Inherits the pretrain config and flips the loss into the on-policy stage,
turning on DMD2's alternating fake_score / discriminator updates with the
``r=0`` dual-timestep conditioning that AnyFlow keeps from MeanFlow. Mirrors
the paper's Stage 3 hyperparameters: 1.2k iterations at lr=2e-6 on top of a
Stage 2 flow-map pretrain checkpoint.

Note: the AnyFlow paper trains this stage with a rank-256 LoRA adapter, but
FastGen does not ship a PEFT/LoRA training path today, so this config does
full-rank fine-tuning. Set ``config.model.pretrained_student_net_path`` to a
checkpoint produced by :mod:`config_anyflow` before launching.
"""

import fastgen.configs.experiments.WanT2V.config_anyflow as config_anyflow_pretrain


def create_config():
config = config_anyflow_pretrain.create_config()

config.model.loss_config.training_stage = "onpolicy"
config.model.pretrained_student_net_path = "<path-to-stage2-pretrain-ckpt>"

# Re-enable the DMD2 alternating-update machinery.
config.model.gan_loss_weight_gen = 0.03
config.model.student_update_freq = 5

# Stage 3 learning rates from the AnyFlow paper.
config.model.net_optimizer.lr = 2e-6
config.model.fake_score_optimizer.lr = 2e-6
config.model.discriminator_optimizer.lr = 2e-6

config.trainer.max_iter = 1200
config.trainer.save_ckpt_iter = 200

config.log_config.group = "wan_anyflow_onpolicy"
return config
92 changes: 92 additions & 0 deletions fastgen/configs/methods/config_anyflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Config schema for the AnyFlow method.

AnyFlow inherits the DMD2 model config (so the on-policy stage gets fake_score
/ discriminator / alternating-step machinery for free) and adds a
``LossConfig`` describing the flow-map pretrain hyperparameters.
"""

import attrs
from omegaconf import DictConfig

from fastgen.configs.callbacks import (
EMA_CALLBACK,
GPUStats_CALLBACK,
GradClip_CALLBACK,
ParamCount_CALLBACK,
TrainProfiler_CALLBACK,
WANDB_CALLBACK,
)
from fastgen.configs.config import BaseConfig
from fastgen.configs.methods.config_dmd2 import ModelConfig as DMD2ModelConfig
from fastgen.methods import AnyFlowModel
from fastgen.utils import LazyCall as L


@attrs.define(slots=False)
class LossConfig:
"""Hyperparameters for the AnyFlow flow-map loss and on-policy switch."""

# Which stage to train. "pretrain" runs the central-difference flow-map
# objective; "onpolicy" inherits DMD2's alternating distillation with
# dual-timestep r=0 conditioning.
training_stage: str = "pretrain"

# Central-difference step size for estimating dF/dt. Lives in the same
# units as the noise scheduler's timesteps. The default (5e-3) matches the
# AnyFlow paper's choice of epsilon=5 with num_train_timesteps=1000.
jvp_finite_diff_eps: float = 5e-3

# Per-batch fraction with r = t (recovers pure flow matching).
diffusion_ratio: float = 0.5
# Per-batch fraction with r = min_t (forces consistency to clean data).
consistency_ratio: float = 0.25

# Per-timestep loss weighting scheme — passed through to the flow-map
# scheduler. One of "gaussian", "beta08", "uniform".
weight_type: str = "beta08"
# Flow-matching schedule shift for the weighting / sampling scheduler.
# Wan video defaults use 5.0; image use 1.0.
shift: float = 1.0
# Resolution of the discrete weighting grid; matches the AnyFlow reference.
num_train_timesteps: int = 1000


@attrs.define(slots=False)
class ModelConfig(DMD2ModelConfig):
"""AnyFlow model config — inherits DMD2 fields, adds the flow-map loss config."""

loss_config: LossConfig = attrs.field(factory=LossConfig)


@attrs.define(slots=False)
class Config(BaseConfig):
model: ModelConfig = attrs.field(factory=ModelConfig)
model_class: DictConfig = L(AnyFlowModel)(
config=None,
)


def create_config():
config = Config()
config.trainer.callbacks = DictConfig(
{
**GradClip_CALLBACK,
**EMA_CALLBACK,
**GPUStats_CALLBACK,
**TrainProfiler_CALLBACK,
**ParamCount_CALLBACK,
**WANDB_CALLBACK,
}
)

# Pretrain stage relies on a flow-matching net_pred_type and dual-timestep input.
config.model.use_ema = True
config.model.net.r_timestep = True
config.model.net_scheduler.warm_up_steps = [0]
config.model.fake_score_scheduler.warm_up_steps = [0]
config.model.discriminator_scheduler.warm_up_steps = [0]

return config
1 change: 1 addition & 0 deletions fastgen/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from fastgen.methods.distribution_matching.causvid import CausVidModel as CausVidModel
from fastgen.methods.distribution_matching.self_forcing import SelfForcingModel as SelfForcingModel
from fastgen.methods.distribution_matching.anyflow import AnyFlowModel as AnyFlowModel

from fastgen.methods.consistency_model.CM import CMModel as CMModel
from fastgen.methods.consistency_model.TCM import TCMModel as TCMModel
Expand Down
27 changes: 27 additions & 0 deletions fastgen/methods/distribution_matching/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,33 @@ DMD2 extended for causal video generation with autoregressive chunk-by-chunk pro

---

## AnyFlow

**File:** [`anyflow.py`](anyflow.py) | **Reference:** AnyFlow — any-step video diffusion framework on flow maps

Single model that supports arbitrary inference NFE by learning a flow map `u_θ(x_t, t, r)` (average velocity from `t` back to `r`). Trained in two stages:

1. **Pretrain** — flow-map prediction with a central-difference target that reuses the network's own forward at `(t ± δ, r)` to estimate `dF/dt`. Per-batch sampling assigns `r = t` to a `diffusion_ratio` fraction (recovering plain flow matching) and `r = 0` to a `consistency_ratio` fraction (forcing consistency to clean data).
2. **On-policy** — distribution-matching distillation with `r = 0` conditioning on top of the pretrained flow-map weights. Inherits DMD2's alternating fake_score / discriminator / VSD machinery.

**Key Parameters:**
- `loss_config.training_stage`: `"pretrain"` or `"onpolicy"`
- `loss_config.jvp_finite_diff_eps`: central-difference step δ (in noise scheduler t-units)
- `loss_config.diffusion_ratio` / `loss_config.consistency_ratio`: per-batch fraction with `r=t` / `r=0`
- `loss_config.weight_type`: `gaussian` | `beta08` | `uniform` per-timestep loss weight
- `loss_config.shift`: flow-matching schedule shift (5.0 for Wan video)
- See also key parameters of DMD2 above (used by the on-policy stage)

**Backbone requirement:** the student network must accept a secondary timestep `r` (Wan with `r_timestep=True`). When loading the published AnyFlow HF checkpoints, set `r_embedder_fusion="gated"` on the Wan constructor — this routes the t/r mix through `Wan/network.py::_fuse_r_embedding`'s gated branch (shared with MeanFlow's additive default) so the released weights reproduce bit-for-bit.

**Note:** correctness of the port is established via forward-parity and single-step training-step parity against the AnyFlow reference (see PR #25 discussion). End-to-end convergence-scale validation on the paper's training corpus is deferred to a follow-up.

**Configs:**
- [`WanT2V/config_anyflow.py`](../../configs/experiments/WanT2V/config_anyflow.py) — Stage 2 pretrain (6k iter, lr=5e-5, shift=5, beta08, paper-aligned)
- [`WanT2V/config_anyflow_onpolicy.py`](../../configs/experiments/WanT2V/config_anyflow_onpolicy.py) — Stage 3 on-policy distillation (1.2k iter, lr=2e-6, GAN on; full-rank fine-tune of a Stage 2 pretrain ckpt — the paper's rank-256 LoRA variant awaits a PEFT path in FastGen)

---

## Self-Forcing

**File:** [`self_forcing.py`](self_forcing.py) | **Reference:** [Huang et al., 2025](https://arxiv.org/abs/2506.08009)
Expand Down
Loading