Supports eagle3 training for Gemma3 27B and Gemma4 26B.#553
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for Gemma 3 and Gemma 4 models within the Eagle3 framework, including new configurations, training scripts, and a dedicated gemma-4 chat template. Key architectural improvements include a fast path for models where draft and target vocab sizes match, the ability to reuse and freeze the target model's LM head, and an improved weight initialization strategy for stable training. The training script now supports multiple data paths and directory resolution. Feedback focuses on preventing race conditions in distributed output directory creation, improving error handling for mismatched tool lists, and adhering to PEP-8 import standards.
| run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | ||
| args.output_dir = os.path.join(args.output_dir, run_timestamp) |
There was a problem hiding this comment.
Generating the run_timestamp independently on each rank can lead to different output directories across processes if they cross a second boundary during initialization. This will break distributed training and checkpoint saving. The timestamp should be generated on rank 0 and broadcasted to all other ranks.
| run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| args.output_dir = os.path.join(args.output_dir, run_timestamp) | |
| run_timestamp = [datetime.now().strftime("%Y%m%d_%H%M%S") if dist.get_rank() == 0 else None] | |
| dist.broadcast_object_list(run_timestamp, src=0) | |
| args.output_dir = os.path.join(args.output_dir, run_timestamp[0]) |
| if tools is None or len(tools) != len(conversations): | ||
| tools = [[] for _ in range(len(conversations))] |
There was a problem hiding this comment.
Silently replacing the tools list with empty lists when the length doesn't match conversations can hide data preparation bugs. It is safer to raise a ValueError if an explicitly provided tools list has an incorrect length.
| if tools is None or len(tools) != len(conversations): | |
| tools = [[] for _ in range(len(conversations))] | |
| if tools is None: | |
| tools = [[] for _ in range(len(conversations))] | |
| elif len(tools) != len(conversations): | |
| raise ValueError(f"Length of tools ({len(tools)}) does not match length of conversations ({len(conversations)})") |
| # transformers v5 mutating rope_scaling/rope_parameters and other | ||
| # fields in model.config during save_pretrained. | ||
| if getattr(args, "draft_model_config", None): | ||
| import json |
There was a problem hiding this comment.
Per PEP-8, imports should be placed at the top of the file. Moving import json to the module level improves readability and follows standard Python practices.
References
- Imports should be at the top of the file, after any module comments and docstrings, and before module globals and constants. (link)
0074bd7 to
c9910b2
Compare
Gemma3 27B and Gemma4 26B have a vocabulary size of 262144, which makes triton.next_power_of_2 round up to 262144 (==2^18). The previous limit of 131072 caused _calculate_settings() to raise RuntimeError before the log-softmax loss kernel could launch, preventing Eagle3 training on these targets.
Bump MAX_FUSED_SIZE to 262208 to fit Gemma3/4 vocab
Motivation
This PR supports eagle3 training for Gemma3 27B and Gemma4 26B. Other Gemma3/4 models should be supported as well but didn't verify.
Modifications
Besides the new models, it also supports the following features
For Gemma4, it requires transformers v5+.
Related Issues
Accuracy Test
Benchmark & Profiling
Checklist