diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index bbc3a0a91..5b4fc2250 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -135,6 +135,9 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]: "auto_map", "torch_dtype", "use_cache", + # Architecture-family marker some transformers v4 configs carry (e.g. LlamaConfig); dropped + # in v5, not consumed by Fast-LLM, and absent from a bare ``PretrainedConfig``. + "is_llama_config", # Token ids — generation/inference, not architecture (a bare v5 config omits these). "bos_token_id", "decoder_start_token_id", diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 8c6365a5f..a23801d5c 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -10,6 +10,7 @@ from fast_llm.core.distributed import broadcast, broadcast_object, safe_barrier from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat +from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.inference.config import _TRANSFORMERS_V4, HuggingfaceModelConfig from fast_llm.engine.inference.runner import InferenceRunner @@ -113,7 +114,21 @@ def from_pretrained( stage_filter=stage_filter, ) - return cls(fast_llm_model, **kwargs) + model = cls(fast_llm_model, **kwargs) + model._apply_generation_token_ids(pretrained_model_name_or_path) + return model + + def _apply_generation_token_ids(self, pretrained: CheckpointLoadConfig) -> None: + # Honor the source HF config's generation token ids: Fast-LLM's import drops them (they are + # generation metadata, not architecture), so `generate` would otherwise never stop at EOS. + # Only external (HF) checkpoints carry them; native Fast-LLM checkpoints leave the defaults. + handler_class = pretrained.format.get_handler_class() + if not issubclass(handler_class, HuggingfaceStateDictCheckpointHandler): + return + hf_config = handler_class._load_config(pretrained.path) + for key in ("bos_token_id", "eos_token_id", "pad_token_id"): + if (token_id := hf_config.get(key)) is not None: + setattr(self.generation_config, key, token_id) def _init_weights(self, module) -> None: raise NotImplementedError(module) @@ -249,3 +264,9 @@ def stop_workers(self): def inner_forward(self, *args, **kwargs) -> tuple | transformers.utils.generic.ModelOutput: # Meant to be overridden in derived classes raise NotImplementedError() + + @classmethod + def can_generate(cls) -> bool: + # `PreTrainedModel.can_generate` walks `__bases__` by name and stops at any base containing + # "PreTrainedModel"; this intermediate base hides the `GenerationMixin` inheritance from that check. + return True diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 7d1383b00..3d6a85c4b 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -47,6 +47,10 @@ def inner_forward( output_hidden_states: bool | None = None, return_dict: bool | None = None, return_all_prediction_heads: bool = False, + # `generate` passes version-dependent plumbing kwargs (`cache_position`, `logits_to_keep`, ...). + # They don't apply to the `use_cache=False` path: positions are reconstructed from `attention_mask`, + # and the full logits are computed and the last position selected downstream. + **kwargs, ) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast: return self._inner_forward( self._get_batch(input_ids, attention_mask), @@ -129,16 +133,13 @@ def _inner_forward( for name, (meta, tensor) in model_input.hidden_states.items() } - logits = hidden_states.pop(f"{self.fast_llm_base_model.head.module_name}.logits") - if return_all_prediction_heads: - logits = torch.stack( - [logits] - + [ - hidden_states.pop(f"{head.module_name}.logits") - for head in self.fast_llm_base_model.multi_token_prediction.heads - ], - dim=-2, - ) + # Every head emits its logits into the hidden-states namespace; pop them all so the prediction + # heads' logits don't leak into the returned hidden states. + head_logits = [ + hidden_states.pop(f"{head.module_name}.logits") + for head in (self.fast_llm_base_model.head, *self.fast_llm_base_model.multi_token_prediction.heads) + ] + logits = torch.stack(head_logits, dim=-2) if return_all_prediction_heads else head_logits[0] output = transformers.modeling_outputs.CausalLMOutputWithPast( logits=logits, diff --git a/tests/models/test_generate.py b/tests/models/test_generate.py index c595b5148..6eabe8683 100644 --- a/tests/models/test_generate.py +++ b/tests/models/test_generate.py @@ -10,7 +10,9 @@ from fast_llm.models.gpt.config import PretrainedGPTModelConfig from fast_llm.models.gpt.conversion.config import LlamaCheckpointFormat from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM +from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.model_configs import ModelTestingGroup +from tests.utils.utils import requires_cuda def _prepare_data(tokenizer, use_batch_size2: bool): @@ -108,7 +110,9 @@ def _get_fast_llm_model_from_model( multi_stage.load_checkpoint(config.pretrained) - return HuggingfaceGPTModelForCausalLM(multi_stage, runner=runner) + model = HuggingfaceGPTModelForCausalLM(multi_stage, runner=runner) + model._apply_generation_token_ids(config.pretrained) + return model def _trim_output(output, inputs): @@ -151,7 +155,9 @@ def _test_for_batches( if tokenizer is not None: inputs = _prepare_data(tokenizer, use_batch_size2=False) else: - inputs = _prepare_rand_data(fast_llm_model.config.fast_llm_config.base_model.vocab_size, use_batch_size2=False) + inputs = _prepare_rand_data( + fast_llm_model.config.fast_llm_config.base_model.embeddings.vocab_size, use_batch_size2=False + ) outputs = _generate( inputs, hf_model, @@ -163,7 +169,9 @@ def _test_for_batches( if tokenizer is not None: inputs = _prepare_data(tokenizer, use_batch_size2=True) else: - inputs = _prepare_rand_data(fast_llm_model.config.fast_llm_config.base_model.vocab_size, use_batch_size2=True) + inputs = _prepare_rand_data( + fast_llm_model.config.fast_llm_config.base_model.embeddings.vocab_size, use_batch_size2=True + ) outputs = _generate( inputs, hf_model, @@ -204,6 +212,7 @@ def _test_generate( ) +@requires_cuda @pytest.mark.extra_slow @pytest.mark.parametrize( "use_flash_attention, use_bf16, max_new_tokens, min_matching_tokens_batch_size_1, min_matching_tokens_batch_size_2", @@ -242,14 +251,18 @@ def test_export_for_generate(run_test_script_for_all_models, model_testing_confi if model_testing_config.checkpoint_format is None: pytest.skip(f"Conversion not supported for {model_testing_config.name}") run_test_script_for_all_models( - [ - "training.train_iters=1", - f"training.export.format={model_testing_config.checkpoint_format.name}", - "training.export.interval=1", - ], + distributed_testing_config=DistributedTestingConfig( + name="test_export_for_generate", + config_args=[ + "training.train_iters=1", + f"training.export.format={model_testing_config.checkpoint_format.name}", + "training.export.interval=1", + ], + ) ) +@requires_cuda @pytest.mark.slow @pytest.mark.depends_on(on=["test_export_for_generate[{model_testing_config}]"]) @pytest.mark.parametrize( @@ -303,6 +316,7 @@ def _test_generate_from_model(model_path, tokenizer, fast_llm_checkpoint_format) ) +@requires_cuda @pytest.mark.extra_slow def test_generate_from_model( model_path, @@ -310,6 +324,7 @@ def test_generate_from_model( _test_generate_from_model(model_path, AutoTokenizer.from_pretrained(model_path), LlamaCheckpointFormat) +@requires_cuda @pytest.mark.slow @pytest.mark.depends_on(on=["test_export_for_generate[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.generate) @@ -334,7 +349,7 @@ def _test_forward_return_hidden_states( inputs_ids = torch.randint( 1, - fast_llm_model.config.fast_llm_config.base_model.vocab_size if vocab_size is None else vocab_size, + fast_llm_model.config.fast_llm_config.base_model.embeddings.vocab_size if vocab_size is None else vocab_size, [1, 10], dtype=torch.int64, generator=torch.Generator().manual_seed(42), @@ -345,10 +360,13 @@ def _test_forward_return_hidden_states( input_ids=inputs_ids, output_hidden_states=True, return_dict=True, use_cache=False ) - # hidden_states include embeddings layer - assert len(res_fast_llm.hidden_states) - 1 == len(fast_llm_model.config.fast_llm_config.base_model.decoder) + # Embeddings + one state per decoder block + one final-norm state per prediction head + # (the last block's output is carried by the heads' final norms). + base_model = fast_llm_model.config.fast_llm_config.base_model + assert len(res_fast_llm.hidden_states) == base_model.decoder.num_blocks + base_model.head.prediction_heads +@requires_cuda @pytest.mark.extra_slow def test_forward_return_hidden_states(model_path): _test_forward_return_hidden_states( @@ -356,6 +374,7 @@ def test_forward_return_hidden_states(model_path): ) +@requires_cuda @pytest.mark.slow @pytest.mark.model_testing_group(ModelTestingGroup.generate) @pytest.mark.depends_on(on=["test_export_for_generate[{model_testing_config}]"]) diff --git a/tests/models/test_lm_eval.py b/tests/models/test_lm_eval.py index 7ae26c2d6..75a1b3e98 100644 --- a/tests/models/test_lm_eval.py +++ b/tests/models/test_lm_eval.py @@ -54,7 +54,7 @@ def do_get_lm_eval_config(base_path): # "gsm8k,xnli_en,wikitext" -@pytest.mark.model_testing_group(ModelTestingGroup.generate) +@pytest.mark.model_testing_group(ModelTestingGroup.lm_eval) def test_lm_eval_in_training(run_test_script_for_all_models, run_test_script_base_path, get_lm_eval_config): run_test_script_for_all_models( distributed_testing_config=DistributedTestingConfig( @@ -75,7 +75,7 @@ def do_copy_training_output(distributed_testing_config: DistributedTestingConfig @pytest.mark.depends_on(on=["test_lm_eval_in_training[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.generate) +@pytest.mark.model_testing_group(ModelTestingGroup.lm_eval) def test_lm_eval_evaluation_last_checkpoint( run_test_script_for_all_models, run_test_script_base_path, get_lm_eval_config, copy_training_output ): @@ -89,7 +89,7 @@ def test_lm_eval_evaluation_last_checkpoint( @pytest.mark.depends_on(on=["test_lm_eval_in_training[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.generate) +@pytest.mark.model_testing_group(ModelTestingGroup.lm_eval) def test_lm_eval_evaluation_from_pretrained( run_test_script_for_all_models, run_test_script_base_path, get_lm_eval_config ): @@ -108,7 +108,7 @@ def test_lm_eval_evaluation_from_pretrained( # TODO: rewrite for a new distributed test function # @pytest.mark.depends_on(on=["test_lm_eval_in_training[{model_testing_config}]"]) -# @pytest.mark.model_testing_group(ModelTestingGroup.generate, ModelTestingGroup.distributed) +# @pytest.mark.model_testing_group(ModelTestingGroup.lm_eval, ModelTestingGroup.distributed) # def test_lm_eval_in_training_dp2(run_test_script_for_all_models, run_test_script_base_path, get_lm_eval_config): # run_test_script_for_all_models( # distributed_testing_config=DistributedTestingConfig( diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 3a54be088..4f2da8b54 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -48,6 +48,7 @@ class ModelTestingGroup(enum.StrEnum): checkpoint = "checkpoint" convert = "convert" generate = "generate" + lm_eval = "lm_eval" megatron = "megatron" distributed = "distributed" streaming = "streaming" @@ -358,12 +359,13 @@ def update_and_add_testing_config( "--no-position-embedding", ], checkpoint_format=None, - # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.normal, - ModelTestingGroup.generate: ModelTestingGroupAction.broken, + # No HF checkpoint format: the native conversion round-trip is redundant with other models, + # and the export-based generate tests can't run. + ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.unimportant, ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, }, @@ -393,12 +395,11 @@ def update_and_add_testing_config( "--untie-embeddings-and-output-weights", ], checkpoint_format=LlamaCheckpointFormat, - # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.main, ModelTestingGroup.checkpoint: ModelTestingGroupAction.main, ModelTestingGroup.convert: ModelTestingGroupAction.main, - ModelTestingGroup.generate: ModelTestingGroupAction.broken, + ModelTestingGroup.generate: ModelTestingGroupAction.normal, ModelTestingGroup.megatron: ModelTestingGroupAction.normal, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, ModelTestingGroup.streaming: ModelTestingGroupAction.normal, @@ -486,12 +487,11 @@ def update_and_add_testing_config( # Megatron doesn't support multi-token prediction. megatron_args=None, checkpoint_format=MTPLlamaCheckpointFormat, - # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, ModelTestingGroup.convert: ModelTestingGroupAction.normal, - ModelTestingGroup.generate: ModelTestingGroupAction.broken, + ModelTestingGroup.generate: ModelTestingGroupAction.normal, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, }, @@ -514,7 +514,9 @@ def update_and_add_testing_config( # Megatron doesn't support per sub layer biases. megatron_args=None, checkpoint_format=Qwen2CheckpointFormat, - # TODO: Add back generate as `normal` when stable. + # `generate` matches HF in fp32 but diverges in bf16/flash: a near-tie argmax flips on numerical + # noise within the compared horizon. Stays `broken` pending a curated case free of near-tie + # (low-margin) argmax positions. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, @@ -560,12 +562,11 @@ def update_and_add_testing_config( # Megatron doesn't support sliding windows. megatron_args=None, checkpoint_format=MistralCheckpointFormat, - # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, ModelTestingGroup.convert: ModelTestingGroupAction.normal, - ModelTestingGroup.generate: ModelTestingGroupAction.broken, + ModelTestingGroup.generate: ModelTestingGroupAction.normal, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, }, @@ -653,7 +654,7 @@ def update_and_add_testing_config( ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, ModelTestingGroup.convert: ModelTestingGroupAction.normal, - ModelTestingGroup.generate: ModelTestingGroupAction.broken, + ModelTestingGroup.generate: ModelTestingGroupAction.normal, ModelTestingGroup.megatron: ModelTestingGroupAction.normal, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, },