Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions fast_llm/engine/checkpoint/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]:
"auto_map",
"torch_dtype",
"use_cache",
# Architecture-family marker some transformers v4 configs carry (e.g. LlamaConfig); dropped
# in v5, not consumed by Fast-LLM, and absent from a bare ``PretrainedConfig``.
"is_llama_config",
# Token ids — generation/inference, not architecture (a bare v5 config omits these).
"bos_token_id",
"decoder_start_token_id",
Expand Down
23 changes: 22 additions & 1 deletion fast_llm/engine/inference/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from fast_llm.core.distributed import broadcast, broadcast_object, safe_barrier
from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat
from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.engine.inference.config import _TRANSFORMERS_V4, HuggingfaceModelConfig
from fast_llm.engine.inference.runner import InferenceRunner
Expand Down Expand Up @@ -113,7 +114,21 @@ def from_pretrained(
stage_filter=stage_filter,
)

return cls(fast_llm_model, **kwargs)
model = cls(fast_llm_model, **kwargs)
model._apply_generation_token_ids(pretrained_model_name_or_path)
return model

def _apply_generation_token_ids(self, pretrained: CheckpointLoadConfig) -> None:
# Honor the source HF config's generation token ids: Fast-LLM's import drops them (they are
# generation metadata, not architecture), so `generate` would otherwise never stop at EOS.
# Only external (HF) checkpoints carry them; native Fast-LLM checkpoints leave the defaults.
handler_class = pretrained.format.get_handler_class()
if not issubclass(handler_class, HuggingfaceStateDictCheckpointHandler):
return
hf_config = handler_class._load_config(pretrained.path)
for key in ("bos_token_id", "eos_token_id", "pad_token_id"):
if (token_id := hf_config.get(key)) is not None:
setattr(self.generation_config, key, token_id)

def _init_weights(self, module) -> None:
raise NotImplementedError(module)
Expand Down Expand Up @@ -249,3 +264,9 @@ def stop_workers(self):
def inner_forward(self, *args, **kwargs) -> tuple | transformers.utils.generic.ModelOutput:
# Meant to be overridden in derived classes
raise NotImplementedError()

@classmethod
def can_generate(cls) -> bool:
# `PreTrainedModel.can_generate` walks `__bases__` by name and stops at any base containing
# "PreTrainedModel"; this intermediate base hides the `GenerationMixin` inheritance from that check.
return True
21 changes: 11 additions & 10 deletions fast_llm/models/gpt/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def inner_forward(
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
return_all_prediction_heads: bool = False,
# `generate` passes version-dependent plumbing kwargs (`cache_position`, `logits_to_keep`, ...).
# They don't apply to the `use_cache=False` path: positions are reconstructed from `attention_mask`,
# and the full logits are computed and the last position selected downstream.
**kwargs,
) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast:
return self._inner_forward(
self._get_batch(input_ids, attention_mask),
Expand Down Expand Up @@ -129,16 +133,13 @@ def _inner_forward(
for name, (meta, tensor) in model_input.hidden_states.items()
}

logits = hidden_states.pop(f"{self.fast_llm_base_model.head.module_name}.logits")
if return_all_prediction_heads:
logits = torch.stack(
[logits]
+ [
hidden_states.pop(f"{head.module_name}.logits")
for head in self.fast_llm_base_model.multi_token_prediction.heads
],
dim=-2,
)
# Every head emits its logits into the hidden-states namespace; pop them all so the prediction
# heads' logits don't leak into the returned hidden states.
head_logits = [
hidden_states.pop(f"{head.module_name}.logits")
for head in (self.fast_llm_base_model.head, *self.fast_llm_base_model.multi_token_prediction.heads)
]
logits = torch.stack(head_logits, dim=-2) if return_all_prediction_heads else head_logits[0]

output = transformers.modeling_outputs.CausalLMOutputWithPast(
logits=logits,
Expand Down
41 changes: 30 additions & 11 deletions tests/models/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from fast_llm.models.gpt.config import PretrainedGPTModelConfig
from fast_llm.models.gpt.conversion.config import LlamaCheckpointFormat
from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM
from tests.utils.distributed_configs import DistributedTestingConfig
from tests.utils.model_configs import ModelTestingGroup
from tests.utils.utils import requires_cuda


def _prepare_data(tokenizer, use_batch_size2: bool):
Expand Down Expand Up @@ -108,7 +110,9 @@ def _get_fast_llm_model_from_model(

multi_stage.load_checkpoint(config.pretrained)

return HuggingfaceGPTModelForCausalLM(multi_stage, runner=runner)
model = HuggingfaceGPTModelForCausalLM(multi_stage, runner=runner)
model._apply_generation_token_ids(config.pretrained)
return model


def _trim_output(output, inputs):
Expand Down Expand Up @@ -151,7 +155,9 @@ def _test_for_batches(
if tokenizer is not None:
inputs = _prepare_data(tokenizer, use_batch_size2=False)
else:
inputs = _prepare_rand_data(fast_llm_model.config.fast_llm_config.base_model.vocab_size, use_batch_size2=False)
inputs = _prepare_rand_data(
fast_llm_model.config.fast_llm_config.base_model.embeddings.vocab_size, use_batch_size2=False
)
outputs = _generate(
inputs,
hf_model,
Expand All @@ -163,7 +169,9 @@ def _test_for_batches(
if tokenizer is not None:
inputs = _prepare_data(tokenizer, use_batch_size2=True)
else:
inputs = _prepare_rand_data(fast_llm_model.config.fast_llm_config.base_model.vocab_size, use_batch_size2=True)
inputs = _prepare_rand_data(
fast_llm_model.config.fast_llm_config.base_model.embeddings.vocab_size, use_batch_size2=True
)
outputs = _generate(
inputs,
hf_model,
Expand Down Expand Up @@ -204,6 +212,7 @@ def _test_generate(
)


@requires_cuda
@pytest.mark.extra_slow
@pytest.mark.parametrize(
"use_flash_attention, use_bf16, max_new_tokens, min_matching_tokens_batch_size_1, min_matching_tokens_batch_size_2",
Expand Down Expand Up @@ -242,14 +251,18 @@ def test_export_for_generate(run_test_script_for_all_models, model_testing_confi
if model_testing_config.checkpoint_format is None:
pytest.skip(f"Conversion not supported for {model_testing_config.name}")
run_test_script_for_all_models(
[
"training.train_iters=1",
f"training.export.format={model_testing_config.checkpoint_format.name}",
"training.export.interval=1",
],
distributed_testing_config=DistributedTestingConfig(
name="test_export_for_generate",
config_args=[
"training.train_iters=1",
f"training.export.format={model_testing_config.checkpoint_format.name}",
"training.export.interval=1",
],
)
)


@requires_cuda
@pytest.mark.slow
@pytest.mark.depends_on(on=["test_export_for_generate[{model_testing_config}]"])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -303,13 +316,15 @@ def _test_generate_from_model(model_path, tokenizer, fast_llm_checkpoint_format)
)


@requires_cuda
@pytest.mark.extra_slow
def test_generate_from_model(
model_path,
):
_test_generate_from_model(model_path, AutoTokenizer.from_pretrained(model_path), LlamaCheckpointFormat)


@requires_cuda
@pytest.mark.slow
@pytest.mark.depends_on(on=["test_export_for_generate[{model_testing_config}]"])
@pytest.mark.model_testing_group(ModelTestingGroup.generate)
Expand All @@ -334,7 +349,7 @@ def _test_forward_return_hidden_states(

inputs_ids = torch.randint(
1,
fast_llm_model.config.fast_llm_config.base_model.vocab_size if vocab_size is None else vocab_size,
fast_llm_model.config.fast_llm_config.base_model.embeddings.vocab_size if vocab_size is None else vocab_size,
[1, 10],
dtype=torch.int64,
generator=torch.Generator().manual_seed(42),
Expand All @@ -345,17 +360,21 @@ def _test_forward_return_hidden_states(
input_ids=inputs_ids, output_hidden_states=True, return_dict=True, use_cache=False
)

# hidden_states include embeddings layer
assert len(res_fast_llm.hidden_states) - 1 == len(fast_llm_model.config.fast_llm_config.base_model.decoder)
# Embeddings + one state per decoder block + one final-norm state per prediction head
# (the last block's output is carried by the heads' final norms).
base_model = fast_llm_model.config.fast_llm_config.base_model
assert len(res_fast_llm.hidden_states) == base_model.decoder.num_blocks + base_model.head.prediction_heads


@requires_cuda
@pytest.mark.extra_slow
def test_forward_return_hidden_states(model_path):
_test_forward_return_hidden_states(
model_path, LlamaCheckpointFormat, AutoTokenizer.from_pretrained(model_path).vocab_size
)


@requires_cuda
@pytest.mark.slow
@pytest.mark.model_testing_group(ModelTestingGroup.generate)
@pytest.mark.depends_on(on=["test_export_for_generate[{model_testing_config}]"])
Expand Down
8 changes: 4 additions & 4 deletions tests/models/test_lm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def do_get_lm_eval_config(base_path):
# "gsm8k,xnli_en,wikitext"


@pytest.mark.model_testing_group(ModelTestingGroup.generate)
@pytest.mark.model_testing_group(ModelTestingGroup.lm_eval)
def test_lm_eval_in_training(run_test_script_for_all_models, run_test_script_base_path, get_lm_eval_config):
run_test_script_for_all_models(
distributed_testing_config=DistributedTestingConfig(
Expand All @@ -75,7 +75,7 @@ def do_copy_training_output(distributed_testing_config: DistributedTestingConfig


@pytest.mark.depends_on(on=["test_lm_eval_in_training[{model_testing_config}]"])
@pytest.mark.model_testing_group(ModelTestingGroup.generate)
@pytest.mark.model_testing_group(ModelTestingGroup.lm_eval)
def test_lm_eval_evaluation_last_checkpoint(
run_test_script_for_all_models, run_test_script_base_path, get_lm_eval_config, copy_training_output
):
Expand All @@ -89,7 +89,7 @@ def test_lm_eval_evaluation_last_checkpoint(


@pytest.mark.depends_on(on=["test_lm_eval_in_training[{model_testing_config}]"])
@pytest.mark.model_testing_group(ModelTestingGroup.generate)
@pytest.mark.model_testing_group(ModelTestingGroup.lm_eval)
def test_lm_eval_evaluation_from_pretrained(
run_test_script_for_all_models, run_test_script_base_path, get_lm_eval_config
):
Expand All @@ -108,7 +108,7 @@ def test_lm_eval_evaluation_from_pretrained(

# TODO: rewrite for a new distributed test function
# @pytest.mark.depends_on(on=["test_lm_eval_in_training[{model_testing_config}]"])
# @pytest.mark.model_testing_group(ModelTestingGroup.generate, ModelTestingGroup.distributed)
# @pytest.mark.model_testing_group(ModelTestingGroup.lm_eval, ModelTestingGroup.distributed)
# def test_lm_eval_in_training_dp2(run_test_script_for_all_models, run_test_script_base_path, get_lm_eval_config):
# run_test_script_for_all_models(
# distributed_testing_config=DistributedTestingConfig(
Expand Down
23 changes: 12 additions & 11 deletions tests/utils/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class ModelTestingGroup(enum.StrEnum):
checkpoint = "checkpoint"
convert = "convert"
generate = "generate"
lm_eval = "lm_eval"
megatron = "megatron"
distributed = "distributed"
streaming = "streaming"
Expand Down Expand Up @@ -358,12 +359,13 @@ def update_and_add_testing_config(
"--no-position-embedding",
],
checkpoint_format=None,
# TODO: Add back generate as `normal` when stable.
groups={
ModelTestingGroup.basic: ModelTestingGroupAction.normal,
ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal,
ModelTestingGroup.convert: ModelTestingGroupAction.normal,
ModelTestingGroup.generate: ModelTestingGroupAction.broken,
# No HF checkpoint format: the native conversion round-trip is redundant with other models,
# and the export-based generate tests can't run.
ModelTestingGroup.convert: ModelTestingGroupAction.unimportant,
ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented,
ModelTestingGroup.megatron: ModelTestingGroupAction.unimportant,
ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant,
},
Expand Down Expand Up @@ -393,12 +395,11 @@ def update_and_add_testing_config(
"--untie-embeddings-and-output-weights",
],
checkpoint_format=LlamaCheckpointFormat,
# TODO: Add back generate as `normal` when stable.
groups={
ModelTestingGroup.basic: ModelTestingGroupAction.main,
ModelTestingGroup.checkpoint: ModelTestingGroupAction.main,
ModelTestingGroup.convert: ModelTestingGroupAction.main,
ModelTestingGroup.generate: ModelTestingGroupAction.broken,
ModelTestingGroup.generate: ModelTestingGroupAction.normal,
ModelTestingGroup.megatron: ModelTestingGroupAction.normal,
ModelTestingGroup.distributed: ModelTestingGroupAction.normal,
ModelTestingGroup.streaming: ModelTestingGroupAction.normal,
Expand Down Expand Up @@ -486,12 +487,11 @@ def update_and_add_testing_config(
# Megatron doesn't support multi-token prediction.
megatron_args=None,
checkpoint_format=MTPLlamaCheckpointFormat,
# TODO: Add back generate as `normal` when stable.
groups={
ModelTestingGroup.basic: ModelTestingGroupAction.normal,
ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal,
ModelTestingGroup.convert: ModelTestingGroupAction.normal,
ModelTestingGroup.generate: ModelTestingGroupAction.broken,
ModelTestingGroup.generate: ModelTestingGroupAction.normal,
ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented,
ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant,
},
Expand All @@ -514,7 +514,9 @@ def update_and_add_testing_config(
# Megatron doesn't support per sub layer biases.
megatron_args=None,
checkpoint_format=Qwen2CheckpointFormat,
# TODO: Add back generate as `normal` when stable.
# `generate` matches HF in fp32 but diverges in bf16/flash: a near-tie argmax flips on numerical
# noise within the compared horizon. Stays `broken` pending a curated case free of near-tie
# (low-margin) argmax positions.
groups={
ModelTestingGroup.basic: ModelTestingGroupAction.normal,
ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal,
Expand Down Expand Up @@ -560,12 +562,11 @@ def update_and_add_testing_config(
# Megatron doesn't support sliding windows.
megatron_args=None,
checkpoint_format=MistralCheckpointFormat,
# TODO: Add back generate as `normal` when stable.
groups={
ModelTestingGroup.basic: ModelTestingGroupAction.normal,
ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal,
ModelTestingGroup.convert: ModelTestingGroupAction.normal,
ModelTestingGroup.generate: ModelTestingGroupAction.broken,
ModelTestingGroup.generate: ModelTestingGroupAction.normal,
ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented,
ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant,
},
Expand Down Expand Up @@ -653,7 +654,7 @@ def update_and_add_testing_config(
ModelTestingGroup.basic: ModelTestingGroupAction.normal,
ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal,
ModelTestingGroup.convert: ModelTestingGroupAction.normal,
ModelTestingGroup.generate: ModelTestingGroupAction.broken,
ModelTestingGroup.generate: ModelTestingGroupAction.normal,
ModelTestingGroup.megatron: ModelTestingGroupAction.normal,
ModelTestingGroup.distributed: ModelTestingGroupAction.normal,
},
Expand Down
Loading