Skip to content

[rank5]: TypeError: SortOp.forward() takes from 2 to 3 positional arguments but 5 were given When running moe script #164

@rtmadduri

Description

@rtmadduri
[rank5]: Traceback (most recent call last):
[rank5]:   File "/root/Stanford-Megatron-LM/pretrain_gpt.py", line 154, in <module>
[rank5]:     pretrain(train_valid_test_datasets_provider, model_provider,
[rank5]:   File "/root/Stanford-Megatron-LM/megatron/training.py", line 147, in pretrain
[rank5]:     iteration = train(forward_step_func,
[rank5]:                 ^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/root/Stanford-Megatron-LM/megatron/training.py", line 712, in train
[rank5]:     train_step(forward_step_func,
[rank5]:   File "/root/Stanford-Megatron-LM/megatron/training.py", line 421, in train_step
[rank5]:     losses_reduced = forward_backward_func(
[rank5]:                      ^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/root/Stanford-Megatron-LM/megatron/schedules.py", line 263, in forward_backward_no_pipelining
[rank5]:     output_tensor = forward_step(forward_step_func, data_iterator,
[rank5]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/root/Stanford-Megatron-LM/megatron/schedules.py", line 133, in forward_step
[rank5]:     output_tensor, loss_func = forward_step_func(data_iterator, model)
[rank5]:                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/root/Stanford-Megatron-LM/pretrain_gpt.py", line 124, in forward_step
[rank5]:     output_tensor = model(tokens, position_ids, attention_mask,
[rank5]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank5]:     return forward_call(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/root/Stanford-Megatron-LM/megatron/model/distributed.py", line 59, in forward
[rank5]:     return self.module(*inputs, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank5]:     return forward_call(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/root/Stanford-Megatron-LM/megatron/model/module.py", line 184, in forward
[rank5]:     outputs = self.module(*inputs, **kwargs)
[rank5]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank5]:     return forward_call(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/root/Stanford-Megatron-LM/megatron/model/gpt_model.py", line 80, in forward
[rank5]:     lm_output = self.language_model(
[rank5]:                 ^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank5]:     return forward_call(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/root/Stanford-Megatron-LM/megatron/model/language_model.py", line 432, in forward
[rank5]:     encoder_output = self.encoder(
[rank5]:                      ^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank5]:     return forward_call(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/root/Stanford-Megatron-LM/megatron/model/transformer.py", line 1227, in forward
[rank5]:     hidden_states = layer(
[rank5]:                     ^^^^^^
[rank5]:   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank5]:     return forward_call(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/root/Stanford-Megatron-LM/megatron/model/transformer.py", line 800, in forward
[rank5]:     mlp_output, mlp_bias = self.mlp(layernorm_output)
[rank5]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank5]:     return forward_call(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/root/Stanford-Megatron-LM/megatron/model/transformer.py", line 202, in forward
[rank5]:     return self.moe.forward(x)
[rank5]:            ^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/megablocks/layers/moe.py", line 468, in forward
[rank5]:     out = self.experts(x, scores, expert_weights, top_experts)
[rank5]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank5]:     return forward_call(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/megablocks/layers/moe.py", line 429, in forward
[rank5]:     x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
[rank5]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/megablocks/layers/moe.py", line 262, in parallel_forward_once
[rank5]:     indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
[rank5]:                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/megablocks/layers/moe.py", line 161, in indices_and_bins
[rank5]:     output = ops.sort(top_expert, self.sort_end_bit)
[rank5]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/autograd/function.py", line 574, in apply
[rank5]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/megablocks/ops/sort.py", line 34, in forward
[rank5]:     ops.sort(x, end_bit, x_out, iota_out)
[rank5]:   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/autograd/function.py", line 574, in apply
[rank5]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]: TypeError: SortOp.forward() takes from 2 to 3 positional arguments but 5 were given

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions