diff --git a/csrc/engine/infer_engine.cpp b/csrc/engine/infer_engine.cpp index b48a1090..ae1eec3b 100644 --- a/csrc/engine/infer_engine.cpp +++ b/csrc/engine/infer_engine.cpp @@ -173,6 +173,7 @@ InferEngine::Input::to_model_input(infinicore::Device device) const { to_device_vec(image_bound), to_device_vec(tgt_sizes), visual_token_ranges, + to_device(target_hidden_states), }; infinilm::global_state::get_forward_context().attn_metadata = { diff --git a/csrc/engine/infer_engine.hpp b/csrc/engine/infer_engine.hpp index 480f58ec..4c0c0345 100644 --- a/csrc/engine/infer_engine.hpp +++ b/csrc/engine/infer_engine.hpp @@ -48,7 +48,7 @@ class InferEngine { std::vector state_dict_keys(); - // Run a single forward pass on all workers and return the outputs from all ranks + // Run a single forward pass on all workers and return sampled token IDs. Output forward(const Input &input); void compile(); diff --git a/csrc/engine/rank_worker.cpp b/csrc/engine/rank_worker.cpp index 57b5b23b..975db561 100644 --- a/csrc/engine/rank_worker.cpp +++ b/csrc/engine/rank_worker.cpp @@ -407,8 +407,10 @@ void RankWorker::thread_loop() { std::lock_guard lk(mutex_); infinicore::Tensor logits; - // Try to get compiled graph - if (compiler_ != nullptr) { + infinicore::Tensor hidden_states; + // All-position speculative/MTP runs need eager mode because + // hidden states are not part of compiled graph outputs. + if (!local_args.sample_all_positions && compiler_ != nullptr) { auto [graph, output] = compiler_->get_compiled(local_args.to_model_input(infinicore::Device::cpu())); if (graph != nullptr && output != nullptr) { graph->run(); @@ -418,7 +420,9 @@ void RankWorker::thread_loop() { // Fall back to eager mode if (!logits) { auto model_args = local_args.to_model_input(rank_info_.device); - logits = model_->forward(model_args).logits; + auto model_output = model_->forward(model_args); + logits = model_output.logits; + hidden_states = model_output.hidden_states; } // Random sampling (rank 0 only) @@ -435,10 +439,16 @@ void RankWorker::thread_loop() { auto n_req = local_args.input_offsets.value()->size(0) - 1; int32_t *input_offsets = (int32_t *)local_args.input_offsets.value()->data(); - auto output_ids{infinicore::Tensor::empty({n_req}, infinicore::DataType::I64, rank_info_.device)}; + const bool sample_all_positions = local_args.sample_all_positions; + const size_t n_out = sample_all_positions ? static_cast(input_offsets[n_req]) : n_req; + auto output_ids{infinicore::Tensor::empty({n_out}, infinicore::DataType::I64, rank_info_.device)}; - for (auto i{decltype(n_req)(0)}; i < n_req; ++i) { - auto score{logits->view({batch_size * total_len, vocab_size})->narrow({{0, size_t(input_offsets[i + 1] - 1), 1}})->view({vocab_size})}; + for (size_t i{0}; i < n_out; ++i) { + size_t score_idx = i; + if (!sample_all_positions) { + score_idx = static_cast(input_offsets[i + 1] - 1); + } + auto score{logits->view({batch_size * total_len, vocab_size})->narrow({{0, score_idx, 1}})->view({vocab_size})}; auto out{output_ids->narrow({{0, i, 1}})->view({})}; float random_val = std::uniform_real_distribution(0, 1)(rng_); infinicore::op::random_sample_( @@ -449,7 +459,7 @@ void RankWorker::thread_loop() { infinicore::context::syncStream(); - auto out{Output{output_ids}}; + auto out{Output{output_ids, logits, hidden_states}}; output_ = std::move(out); } diff --git a/csrc/engine/rank_worker.hpp b/csrc/engine/rank_worker.hpp index 69516a5b..5bbd355a 100644 --- a/csrc/engine/rank_worker.hpp +++ b/csrc/engine/rank_worker.hpp @@ -62,6 +62,10 @@ class RankWorker { std::optional> image_req_ids; /// Flattened [start, end) visual token ranges in the packed language sequence. std::optional> visual_token_ranges; + /// Target model hidden states for draft/MTP models. + std::optional target_hidden_states; + /// Sample logits at every packed input position instead of one token per request. + bool sample_all_positions{false}; float temperature{1}; @@ -74,6 +78,8 @@ class RankWorker { struct Output { infinicore::Tensor output_ids; + infinicore::Tensor logits; + infinicore::Tensor hidden_states; }; RankWorker(std::shared_ptr infinilm_config, @@ -96,7 +102,7 @@ class RankWorker { std::vector state_dict_keys(); - // Submit a run (forward) job. + // Submit a run (forward + sampling) job. void run(const Input &args); // Reset the internal cache with a new configuration diff --git a/csrc/models/infinilm_model.hpp b/csrc/models/infinilm_model.hpp index fdebc49d..a8ba29c9 100644 --- a/csrc/models/infinilm_model.hpp +++ b/csrc/models/infinilm_model.hpp @@ -45,11 +45,15 @@ class InfinilmModel : public infinicore::nn::Module { std::optional> tgt_sizes; /// Flattened [start, end) visual token ranges in the packed language sequence. std::optional> visual_token_ranges; + /// Target model hidden states consumed by draft/MTP models. + std::optional target_hidden_states; }; struct Output { /// Logits. infinicore::Tensor logits; + /// Optional final hidden states, used by MTP/Eagle draft models. + infinicore::Tensor hidden_states; }; virtual ~InfinilmModel() = default; diff --git a/csrc/models/minicpm4/minicpm4_for_causal_lm.cpp b/csrc/models/minicpm4/minicpm4_for_causal_lm.cpp new file mode 100644 index 00000000..d4da29f4 --- /dev/null +++ b/csrc/models/minicpm4/minicpm4_for_causal_lm.cpp @@ -0,0 +1,173 @@ +#include "minicpm4_for_causal_lm.hpp" +#include "../../global_state/global_state.hpp" +#include "../models_registry.hpp" + +#include +#include +#include + +namespace infinilm::models::minicpm4 { + +namespace { +float residual_scale(const std::shared_ptr &model_config) { + const float scale_depth = model_config->get_or("scale_depth", 1.0f); + if (model_config->get_or("model_type", "") == "minicpm_eagle") { + const float mup_denominator = model_config->get_or("mup_denominator", 1.0f); + return scale_depth / std::sqrt(mup_denominator); + } + const float num_hidden_layers = static_cast(model_config->get("num_hidden_layers")); + return scale_depth / std::sqrt(num_hidden_layers); +} +} // namespace + +MiniCPM4Attention::MiniCPM4Attention(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device) + : Attention(model_config, layer_idx, device) { + o_proj_->set_alpha(residual_scale(model_config)); +} + +MiniCPM4MLP::MiniCPM4MLP(std::shared_ptr model_config, + const infinicore::Device &device) + : MLP(model_config, device) { + down_proj_->set_alpha(residual_scale(model_config)); +} + +MiniCPM4DecoderLayer::MiniCPM4DecoderLayer(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device) { + const auto &dtype = model_config->get_dtype(); + const size_t hidden_size = model_config->get("hidden_size"); + const double rms_norm_eps = model_config->get("rms_norm_eps"); + + INFINICORE_NN_MODULE_INIT(input_layernorm, hidden_size, rms_norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(self_attn, model_config, layer_idx, device); + INFINICORE_NN_MODULE_INIT(post_attention_layernorm, hidden_size, rms_norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(mlp, model_config, device); +} + +std::tuple MiniCPM4DecoderLayer::forward(const infinicore::Tensor &positions, + infinicore::Tensor &hidden_states, + infinicore::Tensor &residual) { + input_layernorm_->forward_inplace(hidden_states, residual); + hidden_states = self_attn_->forward(positions, hidden_states); + post_attention_layernorm_->forward_inplace(hidden_states, residual); + hidden_states = mlp_->forward(hidden_states); + return std::make_tuple(hidden_states, residual); +} + +infinicore::Tensor MiniCPM4DecoderLayer::forward(const infinicore::Tensor &positions, + infinicore::Tensor &hidden_states) { + auto residual = hidden_states; + hidden_states = input_layernorm_->forward(hidden_states); + hidden_states = self_attn_->forward(positions, hidden_states); + hidden_states = infinicore::op::add(residual, hidden_states); + + residual = hidden_states; + hidden_states = post_attention_layernorm_->forward(hidden_states); + hidden_states = mlp_->forward(hidden_states); + hidden_states = infinicore::op::add(residual, hidden_states); + return hidden_states; +} + +MiniCPM4Model::MiniCPM4Model(std::shared_ptr model_config, + const infinicore::Device &device) { + const auto &dtype = model_config->get_dtype(); + const size_t vocab_size = model_config->get("vocab_size"); + const size_t hidden_size = model_config->get("hidden_size"); + const size_t num_hidden_layers = model_config->get("num_hidden_layers"); + const double rms_norm_eps = model_config->get("rms_norm_eps"); + + INFINICORE_NN_MODULE_INIT(embed_tokens, vocab_size, hidden_size, std::nullopt, dtype, device); + layers_.reserve(num_hidden_layers); + for (size_t i = 0; i < num_hidden_layers; ++i) { + layers_.push_back(this->register_module("layers." + std::to_string(i), model_config, i, device)); + } + INFINICORE_NN_MODULE_INIT(norm, hidden_size, rms_norm_eps, dtype, device); +} + +infinicore::Tensor MiniCPM4Model::forward(const infinilm::InfinilmModel::Input &input) const { + auto input_ids = input.input_ids.value(); + auto positions = input.position_ids.value(); + auto hidden_states = embed_tokens_->forward(input_ids); + + infinicore::Tensor residual; + for (const auto &layer : layers_) { + layer->forward(positions, hidden_states, residual); + } + + norm_->forward_inplace(hidden_states, residual); + return hidden_states; +} + +infinicore::Tensor MiniCPM4Model::embed_tokens(const infinicore::Tensor &input_ids) const { + return embed_tokens_->forward(input_ids); +} + +MiniCPM4ForCausalLM::MiniCPM4ForCausalLM(std::shared_ptr model_config, + const infinicore::Device &device) { + model_config_ = model_config; + const auto &dtype = model_config->get_dtype(); + const size_t hidden_size = model_config->get("hidden_size"); + const size_t vocab_size = model_config->get("vocab_size"); + + INFINICORE_NN_MODULE_INIT(model, model_config, device); + INFINICORE_NN_MODULE_INIT(lm_head, hidden_size, vocab_size, false, dtype, device); + + if (model_config->get_config_json().contains("dim_model_base")) { + const float dim_model_base = model_config->get("dim_model_base"); + lm_head_->set_alpha(dim_model_base / static_cast(hidden_size)); + } +} + +infinilm::InfinilmModel::Output MiniCPM4ForCausalLM::forward(const infinilm::InfinilmModel::Input &input) const { + auto hidden_states = forward_hidden(input); + auto logits = lm_head_->forward(hidden_states); + return {logits, hidden_states}; +} + +infinicore::Tensor MiniCPM4ForCausalLM::forward_hidden(const Input &input) const { + return model_->forward(input); +} + +infinicore::Tensor MiniCPM4ForCausalLM::logits_from_hidden(const infinicore::Tensor &hidden_states) const { + return lm_head_->forward(const_cast(hidden_states)); +} + +std::shared_ptr create_minicpm4_model_config(std::shared_ptr model_config) { + const std::string &model_type = model_config->get("model_type"); + if ("minicpm4" != model_type && "minicpm" != model_type) { + throw std::runtime_error("infinilm::models::minicpm4::create_minicpm4_model_config: model_type is not minicpm4"); + } + + auto &json = model_config->get_config_json(); + if (!json.contains("head_dim")) { + json["head_dim"] = model_config->get("hidden_size") / model_config->get("num_attention_heads"); + } + if (!json.contains("rope_theta")) { + json["rope_theta"] = 10000.0; + } + if (json.contains("bias")) { + json["attention_bias"] = json["bias"]; + json["mlp_bias"] = json["bias"]; + } + if (!json.contains("attention_bias")) { + json["attention_bias"] = false; + } + if (!json.contains("mlp_bias")) { + json["mlp_bias"] = false; + } + if (!json.contains("attention_output_bias")) { + json["attention_output_bias"] = false; + } + return model_config; +} + +} // namespace infinilm::models::minicpm4 + +namespace { +INFINILM_REGISTER_CAUSAL_LM_MODEL( + minicpm4, + infinilm::models::minicpm4::MiniCPM4ForCausalLM, + infinilm::models::minicpm4::create_minicpm4_model_config); +} // namespace diff --git a/csrc/models/minicpm4/minicpm4_for_causal_lm.hpp b/csrc/models/minicpm4/minicpm4_for_causal_lm.hpp new file mode 100644 index 00000000..084a868d --- /dev/null +++ b/csrc/models/minicpm4/minicpm4_for_causal_lm.hpp @@ -0,0 +1,89 @@ +#pragma once + +#include "../../layers/common_modules.hpp" +#include "../../models/infinilm_model.hpp" +#include "infinicore/nn/embedding.hpp" +#include "infinicore/nn/rmsnorm.hpp" + +#include +#include + +namespace infinilm::models::minicpm4 { + +class MiniCPM4Attention : public infinilm::layers::attention::Attention { +public: + MiniCPM4Attention(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device); +}; + +class MiniCPM4MLP : public infinilm::layers::mlp::MLP { +public: + MiniCPM4MLP(std::shared_ptr model_config, + const infinicore::Device &device); +}; + +class MiniCPM4DecoderLayer : public infinicore::nn::Module { +public: + MiniCPM4DecoderLayer(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device); + + std::tuple forward(const infinicore::Tensor &positions, + infinicore::Tensor &hidden_states, + infinicore::Tensor &residual); + + infinicore::Tensor forward(const infinicore::Tensor &positions, + infinicore::Tensor &hidden_states); + + void process_weights_after_loading() override { + self_attn_->process_weights_after_loading(); + mlp_->process_weights_after_loading(); + } + + void reset_runtime_state() const override { + self_attn_->reset_runtime_state(); + mlp_->reset_runtime_state(); + } + +protected: + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, input_layernorm); + INFINICORE_NN_MODULE(MiniCPM4Attention, self_attn); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, post_attention_layernorm); + INFINICORE_NN_MODULE(MiniCPM4MLP, mlp); +}; + +class MiniCPM4Model : public infinicore::nn::Module { +public: + MiniCPM4Model(std::shared_ptr model_config, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinilm::InfinilmModel::Input &input) const; + + infinicore::Tensor embed_tokens(const infinicore::Tensor &input_ids) const; + +protected: + INFINICORE_NN_MODULE(infinicore::nn::Embedding, embed_tokens); + INFINICORE_NN_MODULE_VEC(MiniCPM4DecoderLayer, layers); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, norm); +}; + +class MiniCPM4ForCausalLM : public InfinilmModel { +public: + MiniCPM4ForCausalLM(std::shared_ptr model_config, + const infinicore::Device &device); + + Output forward(const Input &input) const override; + + infinicore::Tensor forward_hidden(const Input &input) const; + + infinicore::Tensor logits_from_hidden(const infinicore::Tensor &hidden_states) const; + +protected: + INFINICORE_NN_MODULE(MiniCPM4Model, model); + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, lm_head); +}; + +std::shared_ptr create_minicpm4_model_config(std::shared_ptr model_config); + +} // namespace infinilm::models::minicpm4 diff --git a/csrc/models/minicpm_eagle/minicpm_eagle_for_causal_lm.cpp b/csrc/models/minicpm_eagle/minicpm_eagle_for_causal_lm.cpp new file mode 100644 index 00000000..ee3df529 --- /dev/null +++ b/csrc/models/minicpm_eagle/minicpm_eagle_for_causal_lm.cpp @@ -0,0 +1,122 @@ +#include "minicpm_eagle_for_causal_lm.hpp" +#include "../../utils.hpp" +#include "../models_registry.hpp" + +#include "infinicore/ops.hpp" + +#include +#include + +namespace infinilm::models::minicpm_eagle { + +MiniCPMEagleModel::MiniCPMEagleModel(std::shared_ptr model_config, + const infinicore::Device &device) + : dtype_(model_config->get_dtype()), + device_(device), + hidden_size_(model_config->get("hidden_size")) { + const size_t vocab_size = model_config->get("vocab_size"); + const size_t num_hidden_layers = model_config->get("num_hidden_layers"); + const double rms_norm_eps = model_config->get("rms_norm_eps"); + + INFINICORE_NN_MODULE_INIT(embed_tokens, vocab_size, hidden_size_, std::nullopt, dtype_, device_); + // Eagle/MTP-specific input projection weights. The draft model fuses the + // previous draft token embedding with target-model hidden states through + // input_norm1, input_norm2, and fc before running eagle_layers. + INFINICORE_NN_MODULE_INIT(input_norm1, hidden_size_, rms_norm_eps, dtype_, device_); + INFINICORE_NN_MODULE_INIT(input_norm2, hidden_size_, rms_norm_eps, dtype_, device_); + INFINICORE_NN_MODULE_INIT(fc, hidden_size_ * 2, hidden_size_, false, dtype_, device_); + + eagle_layers_.reserve(num_hidden_layers); + for (size_t i = 0; i < num_hidden_layers; ++i) { + eagle_layers_.push_back(this->register_module("eagle_layers." + std::to_string(i), model_config, i, device_)); + } + + INFINICORE_NN_MODULE_INIT(norm, hidden_size_, rms_norm_eps, dtype_, device_); +} + +infinicore::Tensor MiniCPMEagleModel::embed_input_ids(const infinicore::Tensor &input_ids) const { + return embed_tokens_->forward(input_ids); +} + +infinicore::Tensor MiniCPMEagleModel::forward_with_hidden(const infinicore::Tensor &input_ids, + const infinicore::Tensor &position_ids, + const infinicore::Tensor &target_hidden_states) const { + auto input_embeds = input_norm1_->forward(embed_input_ids(input_ids)); + auto target_hidden = input_norm2_->forward(target_hidden_states); + auto fused_shape = input_embeds->shape(); + fused_shape.back() = hidden_size_ * 2; + auto fused_input = infinicore::Tensor::empty(fused_shape, input_embeds->dtype(), input_embeds->device()); + fused_input->narrow({{fused_shape.size() - 1, 0, hidden_size_}})->copy_from(input_embeds); + fused_input->narrow({{fused_shape.size() - 1, hidden_size_, hidden_size_}})->copy_from(target_hidden); + auto hidden_states = fc_->forward(fused_input); + + for (const auto &layer : eagle_layers_) { + hidden_states = layer->forward(position_ids, hidden_states); + } + + return hidden_states; +} + +infinicore::Tensor MiniCPMEagleModel::forward(const infinilm::InfinilmModel::Input &input) const { + auto input_ids = input.input_ids.value(); + auto positions = input.position_ids.value(); + auto zero_hidden_states = infinicore::Tensor::zeros({input_ids->shape()[0], input_ids->shape()[1], hidden_size_}, dtype_, device_); + return forward_with_hidden(input_ids, positions, zero_hidden_states); +} + +MiniCPMEagleForCausalLM::MiniCPMEagleForCausalLM(std::shared_ptr model_config, + const infinicore::Device &device) { + model_config_ = model_config; + const auto &dtype = model_config->get_dtype(); + const size_t hidden_size = model_config->get("hidden_size"); + const size_t vocab_size = model_config->get("vocab_size"); + + INFINICORE_NN_MODULE_INIT(model, model_config, device); + INFINICORE_NN_MODULE_INIT(lm_head, hidden_size, vocab_size, false, dtype, device); + + if (model_config->get_config_json().contains("dim_model_base")) { + const float dim_model_base = model_config->get("dim_model_base"); + lm_head_->set_alpha(dim_model_base / static_cast(hidden_size)); + } +} + +infinilm::InfinilmModel::Output MiniCPMEagleForCausalLM::forward(const infinilm::InfinilmModel::Input &input) const { + infinicore::Tensor hidden_states; + if (input.target_hidden_states.has_value()) { + hidden_states = model_->forward_with_hidden(input.input_ids.value(), input.position_ids.value(), input.target_hidden_states.value()); + } else { + hidden_states = model_->forward(input); + } + auto logits = lm_head_->forward(hidden_states); + return {logits, hidden_states}; +} + +std::shared_ptr create_minicpm_eagle_model_config(std::shared_ptr model_config) { + const std::string &model_type = model_config->get("model_type"); + if ("minicpm_eagle" != model_type && "minicpm" != model_type) { + throw std::runtime_error("infinilm::models::minicpm_eagle::create_minicpm_eagle_model_config: model_type is not minicpm_eagle"); + } + + auto &json = model_config->get_config_json(); + if (!json.contains("rope_theta")) { + json["rope_theta"] = 10000.0; + } + if (json.contains("bias")) { + json["attention_bias"] = json["bias"]; + json["mlp_bias"] = json["bias"]; + } + if (!json.contains("attention_output_bias")) { + json["attention_output_bias"] = false; + } + + return model_config; +} + +} // namespace infinilm::models::minicpm_eagle + +namespace { +INFINILM_REGISTER_CAUSAL_LM_MODEL( + minicpm_eagle, + infinilm::models::minicpm_eagle::MiniCPMEagleForCausalLM, + infinilm::models::minicpm_eagle::create_minicpm_eagle_model_config); +} // namespace diff --git a/csrc/models/minicpm_eagle/minicpm_eagle_for_causal_lm.hpp b/csrc/models/minicpm_eagle/minicpm_eagle_for_causal_lm.hpp new file mode 100644 index 00000000..af870280 --- /dev/null +++ b/csrc/models/minicpm_eagle/minicpm_eagle_for_causal_lm.hpp @@ -0,0 +1,54 @@ +#pragma once + +#include "../../layers/common_modules.hpp" +#include "../../models/infinilm_model.hpp" +#include "../minicpm4/minicpm4_for_causal_lm.hpp" +#include "infinicore/nn/embedding.hpp" +#include "infinicore/nn/rmsnorm.hpp" + +#include +#include + +namespace infinilm::models::minicpm_eagle { + +class MiniCPMEagleModel : public infinicore::nn::Module { +public: + MiniCPMEagleModel(std::shared_ptr model_config, + const infinicore::Device &device); + + infinicore::Tensor embed_input_ids(const infinicore::Tensor &input_ids) const; + + infinicore::Tensor forward_with_hidden(const infinicore::Tensor &input_ids, + const infinicore::Tensor &position_ids, + const infinicore::Tensor &target_hidden_states) const; + + infinicore::Tensor forward(const infinilm::InfinilmModel::Input &input) const; + +protected: + INFINICORE_NN_MODULE(infinicore::nn::Embedding, embed_tokens); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, input_norm1); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, input_norm2); + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, fc); + INFINICORE_NN_MODULE_VEC(infinilm::models::minicpm4::MiniCPM4DecoderLayer, eagle_layers); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, norm); + + infinicore::DataType dtype_; + infinicore::Device device_; + size_t hidden_size_; +}; + +class MiniCPMEagleForCausalLM : public InfinilmModel { +public: + MiniCPMEagleForCausalLM(std::shared_ptr model_config, + const infinicore::Device &device); + + Output forward(const Input &input) const override; + +protected: + INFINICORE_NN_MODULE(MiniCPMEagleModel, model); + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, lm_head); +}; + +std::shared_ptr create_minicpm_eagle_model_config(std::shared_ptr model_config); + +} // namespace infinilm::models::minicpm_eagle diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index 96bae0cb..c0178d90 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -145,6 +145,8 @@ inline void bind_infer_engine(py::module &m) { std::optional> tgt_sizes, std::optional> image_req_ids, std::optional> visual_token_ranges, + std::optional target_hidden_states, + bool sample_all_positions, py::kwargs kwargs) { InferEngine::Input input{ std::move(input_ids), @@ -160,6 +162,8 @@ inline void bind_infer_engine(py::module &m) { std::move(tgt_sizes), std::move(image_req_ids), std::move(visual_token_ranges), + std::move(target_hidden_states), + sample_all_positions, }; // Explicit defaults @@ -205,7 +209,9 @@ inline void bind_infer_engine(py::module &m) { py::arg("image_bound") = std::nullopt, py::arg("tgt_sizes") = std::nullopt, py::arg("image_req_ids") = std::nullopt, - py::arg("visual_token_ranges") = std::nullopt) + py::arg("visual_token_ranges") = std::nullopt, + py::arg("target_hidden_states") = std::nullopt, + py::arg("sample_all_positions") = false) .def_readwrite("input_ids", &InferEngine::Input::input_ids) .def_readwrite("position_ids", &InferEngine::Input::position_ids) .def_readwrite("past_sequence_lengths", &InferEngine::Input::past_sequence_lengths) @@ -219,12 +225,16 @@ inline void bind_infer_engine(py::module &m) { .def_readwrite("tgt_sizes", &InferEngine::Input::tgt_sizes) .def_readwrite("image_req_ids", &InferEngine::Input::image_req_ids) .def_readwrite("visual_token_ranges", &InferEngine::Input::visual_token_ranges) + .def_readwrite("target_hidden_states", &InferEngine::Input::target_hidden_states) + .def_readwrite("sample_all_positions", &InferEngine::Input::sample_all_positions) .def_readwrite("temperature", &InferEngine::Input::temperature) .def_readwrite("top_k", &InferEngine::Input::top_k) .def_readwrite("top_p", &InferEngine::Input::top_p); py::class_(infer_engine, "Output") - .def_readwrite("output_ids", &InferEngine::Output::output_ids, "Output tensor"); + .def_readwrite("output_ids", &InferEngine::Output::output_ids, "Sampled token IDs") + .def_readwrite("logits", &InferEngine::Output::logits, "Raw logits tensor") + .def_readwrite("hidden_states", &InferEngine::Output::hidden_states, "Raw hidden states tensor"); } } // namespace infinilm::engine diff --git a/examples/bench.py b/examples/bench.py index 392a0a2e..1c8127e1 100644 --- a/examples/bench.py +++ b/examples/bench.py @@ -4,6 +4,23 @@ import time from collections import OrderedDict +import infinicore +import numpy as np +from infinilm.base_config import BaseConfig +from infinilm.cache import PagedKVCacheConfig, StaticKVCacheConfig +from infinilm.distributed import DistConfig +from infinilm.infer_engine import GenerationConfig, InferEngine +from infinilm.llm.llm import LLM +from infinilm.llm.sampling_params import SamplingParams +from infinilm.modeling_utils import load_model_state_dict_by_file +from infinilm.moe_config import configure_moe_ep_backend +from infinilm.processors import AutoInfinilmProcessor +from tqdm import tqdm + +import sys +import time +from collections import OrderedDict + import infinicore import numpy as np from infinilm.base_config import BaseConfig @@ -180,6 +197,8 @@ class TestModel: def __init__( self, model_path, + draft_model_path=None, + num_draft_tokens=4, infini_device=infinicore.device("cpu", 0), tp=1, skip_load=False, @@ -192,6 +211,30 @@ def __init__( moe_ep_size=1, ) -> None: model_path = os.path.expanduser(model_path) + self.draft_model_path = draft_model_path + self.num_draft_tokens = num_draft_tokens + self.model_path = model_path + self.device_str = infini_device.type + self.tp = tp + self.cache_config = cache_config + self.enable_graph = enable_graph + self.attn_backend = attn_backend + self.use_mla = use_mla + self.weight_load_mode = weight_load_mode + self.skip_load = skip_load + + if draft_model_path is not None: + self.processor = AutoInfinilmProcessor.from_pretrained(model_path) + self.tokenizer = self.processor.get_tokenizer() + input_content = self.processor.apply_chat_template( + conversation=[{"role": "user", "content": prompt}], + add_generation_prompt=True, + tokenize=False, + ) + self.input_ids_list = [self.tokenizer.encode(input_content)] + self.model = None + return + # ---------------------------------------------------------------------------- # # 创建模型, # ---------------------------------------------------------------------------- # @@ -240,6 +283,16 @@ def __init__( self.model = model self.input_ids_list = input_ids_list + self.draft_model_path = draft_model_path + self.model_path = model_path + self.device_str = infini_device.type + self.tp = tp + self.cache_config = cache_config + self.enable_graph = enable_graph + self.attn_backend = attn_backend + self.use_mla = use_mla + self.weight_load_mode = weight_load_mode + self.skip_load = skip_load def run( self, @@ -256,6 +309,45 @@ def run( # ---------------------------------------------------------------------------- # # 自回归生成 # ---------------------------------------------------------------------------- # + if self.draft_model_path is not None: + prompt_text = self.tokenizer.decode(input_ids, skip_special_tokens=False) + llm = LLM( + model_path=self.model_path, + draft_model_path=self.draft_model_path, + num_draft_tokens=self.num_draft_tokens, + device=self.device_str, + tensor_parallel_size=self.tp, + cache_type="paged" if self.cache_config is not None else "static", + max_batch_size=batch_size, + max_tokens=output_len, + temperature=temperature, + top_p=top_p, + top_k=top_k, + enable_graph=self.enable_graph, + attn_backend=self.attn_backend, + use_mla=self.use_mla, + weight_load_mode=self.weight_load_mode, + skip_load=self.skip_load, + ) + t1 = time.time() + print("=================== start generate ====================") + outputs = llm.generate( + prompts=[prompt_text] * batch_size, + sampling_params=SamplingParams(max_tokens=output_len, ignore_eos=True), + use_tqdm=False, + ) + t2 = time.time() + if cfg.verbose and not skip_load: + if output_len <= 256: + for output in outputs: + print(output.outputs[0].text) + else: + print( + f"[bench] output text omitted because output_len={output_len} > 256." + ) + print(f"total_time: {round((t2 - t1) * 1000, 2)} ms") + return + input_ids_infini = infinicore.from_list(input_ids_list, dtype=infinicore.int64) t1 = time.time() @@ -350,6 +442,8 @@ def run( test = TestModel( model_path, + draft_model_path=cfg.draft_model, + num_draft_tokens=cfg.num_draft_tokens, infini_device=infini_device, tp=tp, skip_load=skip_load, diff --git a/examples/test_infer.py b/examples/test_infer.py index 6da1498c..131f7497 100644 --- a/examples/test_infer.py +++ b/examples/test_infer.py @@ -10,6 +10,8 @@ def test( prompts: list[str], model_path, + draft_model_path=None, + num_draft_tokens=4, max_new_tokens=100, device="cpu", tp=1, @@ -38,6 +40,8 @@ def test( model = LLM( model_path=model_path, + draft_model_path=draft_model_path, + num_draft_tokens=num_draft_tokens, device=device, tensor_parallel_size=tp, moe_ep_backend=moe_ep_backend, @@ -120,7 +124,9 @@ def test( test( prompts, model_path, - max_new_tokens, + draft_model_path=cfg.draft_model, + num_draft_tokens=cfg.num_draft_tokens, + max_new_tokens=max_new_tokens, device=device_str, tp=tp, moe_ep_backend=moe_ep_backend, diff --git a/python/infinilm/base_config.py b/python/infinilm/base_config.py index 4023e603..e0b4cde7 100644 --- a/python/infinilm/base_config.py +++ b/python/infinilm/base_config.py @@ -59,6 +59,8 @@ def __init__(self): ) self.model = self.args.model + self.draft_model = self.args.draft_model + self.num_draft_tokens = self.args.num_draft_tokens self.device = self.args.device self.tp = self.args.tp self.dp = self.args.dp @@ -174,6 +176,18 @@ def _force_sync_for_metax(self): def _add_common_args(self): # --- base configuration --- self.parser.add_argument("--model", type=str, required=True) + self.parser.add_argument( + "--draft-model", + type=str, + default=None, + help="optional Eagle/MTP draft model directory", + ) + self.parser.add_argument( + "--num-draft-tokens", + type=int, + default=4, + help="number of Eagle draft tokens to verify per target step", + ) self.parser.add_argument( "--device", type=str, diff --git a/python/infinilm/config/engine_config.py b/python/infinilm/config/engine_config.py index 9eeb97cc..1bb1f733 100644 --- a/python/infinilm/config/engine_config.py +++ b/python/infinilm/config/engine_config.py @@ -10,6 +10,8 @@ class EngineConfig: Attributes: model_path: Path to the model directory. + draft_model_path: Optional Eagle/MTP draft model directory. + num_draft_tokens: Number of Eagle draft tokens to verify per step. device: Device type string ('cpu', 'cuda', 'mlu', etc.). dtype: Data type string ('float16', 'bfloat16', 'float32'). tensor_parallel_size: Number of devices for tensor parallelism. @@ -33,6 +35,8 @@ class EngineConfig: """ model_path: str + draft_model_path: Optional[str] = None + num_draft_tokens: int = 4 device: str = "cuda" dtype: str = "float16" tensor_parallel_size: int = 1 @@ -56,6 +60,9 @@ class EngineConfig: kv_transfer_config: Optional[KVTransferConfig] = None def __post_init__(self) -> None: + if self.num_draft_tokens < 1: + raise ValueError("num_draft_tokens must be >= 1") + if self.weight_load_mode not in {"async", "sync"}: raise ValueError("weight_load_mode must be either 'async' or 'sync'") diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index 3625cca3..3f2ed42f 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -60,12 +60,27 @@ def read_hf_config(model_path): f"`model_type` is not specified in the config file `{config_path}`." ) + if config_dict.get("model_type") == "minicpm": + model_dir_name = os.path.basename(os.path.normpath(model_path)).lower() + if "eagle" in model_dir_name: + config_dict["model_type"] = "minicpm_eagle" + else: + config_dict["model_type"] = "minicpm4" + config_dict.setdefault("rope_theta", 10000.0) + if "bias" in config_dict: + config_dict["attention_bias"] = config_dict["bias"] + config_dict["mlp_bias"] = config_dict["bias"] + config_dict.setdefault("attention_bias", False) + config_dict.setdefault("mlp_bias", False) + config_dict.setdefault("attention_output_bias", False) + config_dict = _apply_torch_dtype_defaults(config_dict) config_dict = _normalize_videonsa_config(config_dict) return config_dict + # config.json (required) defines model architecture, while generation_config.json # (optional) defines generation behavior. They are kept as separate readers # because: 1) config.json must exist and requires model_type validation, @@ -184,7 +199,7 @@ def eos_token_id(self): def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) - def forward( + def _build_input( self, input_ids, *, @@ -200,54 +215,99 @@ def forward( tgt_sizes=None, image_req_ids=None, visual_token_ranges=None, + target_hidden_states=None, + sample_all_positions=False, temperature=None, top_k=None, top_p=None, ): - try: - # TODO: Remove `_underlying` and simplify the corresponding code. - input_ids = input_ids._underlying if input_ids is not None else None - position_ids = ( - position_ids._underlying if position_ids is not None else None - ) - past_kv_lengths = ( - past_kv_lengths._underlying if past_kv_lengths is not None else None - ) - total_kv_lengths = ( - total_kv_lengths._underlying if total_kv_lengths is not None else None - ) - input_offsets = ( - input_offsets._underlying if input_offsets is not None else None - ) - block_tables = ( - block_tables._underlying if block_tables is not None else None - ) - cu_seqlens = cu_seqlens._underlying if cu_seqlens is not None else None - slot_mapping = ( - slot_mapping._underlying if slot_mapping is not None else None - ) - - def convert_tensor_list(tensor_list_): - if tensor_list_ is None: - return None - if not isinstance(tensor_list_, list): - tensor_list_ = [tensor_list_] - if len(tensor_list_) == 0: - return None - return [tensor._underlying for tensor in tensor_list_] + # TODO: Remove `_underlying` and simplify the corresponding code. + input_ids = input_ids._underlying if input_ids is not None else None + position_ids = position_ids._underlying if position_ids is not None else None + past_kv_lengths = ( + past_kv_lengths._underlying if past_kv_lengths is not None else None + ) + total_kv_lengths = ( + total_kv_lengths._underlying if total_kv_lengths is not None else None + ) + input_offsets = input_offsets._underlying if input_offsets is not None else None + block_tables = block_tables._underlying if block_tables is not None else None + cu_seqlens = cu_seqlens._underlying if cu_seqlens is not None else None + slot_mapping = slot_mapping._underlying if slot_mapping is not None else None + target_hidden_states = ( + target_hidden_states._underlying + if target_hidden_states is not None + else None + ) - pixel_values = convert_tensor_list(pixel_values) - image_bound = convert_tensor_list(image_bound) - tgt_sizes = convert_tensor_list(tgt_sizes) + def convert_tensor_list(tensor_list_): + if tensor_list_ is None: + return None + if not isinstance(tensor_list_, list): + tensor_list_ = [tensor_list_] + if len(tensor_list_) == 0: + return None + return [tensor._underlying for tensor in tensor_list_] + + pixel_values = convert_tensor_list(pixel_values) + image_bound = convert_tensor_list(image_bound) + tgt_sizes = convert_tensor_list(tgt_sizes) + + temperature = 1.0 if temperature is None else temperature + top_k = 1 if top_k is None else top_k + top_p = 1.0 if top_p is None else top_p + + return super().Input( + input_ids, + position_ids=position_ids, + past_sequence_lengths=past_kv_lengths, + total_sequence_lengths=total_kv_lengths, + input_offsets=input_offsets, + cu_seqlens=cu_seqlens, + block_tables=block_tables, + slot_mapping=slot_mapping, + pixel_values=pixel_values, + image_bound=image_bound, + tgt_sizes=tgt_sizes, + image_req_ids=image_req_ids, + visual_token_ranges=visual_token_ranges, + target_hidden_states=target_hidden_states, + sample_all_positions=sample_all_positions, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + def forward( + self, + input_ids, + *, + position_ids=None, + past_kv_lengths=None, + total_kv_lengths=None, + input_offsets=None, + cu_seqlens=None, + block_tables=None, + slot_mapping=None, + pixel_values=None, + image_bound=None, + tgt_sizes=None, + image_req_ids=None, + visual_token_ranges=None, + target_hidden_states=None, + temperature=None, + top_k=None, + top_p=None, + ): + try: return infinicore.Tensor( super() .forward( - super().Input( + self._build_input( input_ids, position_ids=position_ids, - past_sequence_lengths=past_kv_lengths, - total_sequence_lengths=total_kv_lengths, + past_kv_lengths=past_kv_lengths, + total_kv_lengths=total_kv_lengths, input_offsets=input_offsets, cu_seqlens=cu_seqlens, block_tables=block_tables, @@ -257,6 +317,7 @@ def convert_tensor_list(tensor_list_): tgt_sizes=tgt_sizes, image_req_ids=image_req_ids, visual_token_ranges=visual_token_ranges, + target_hidden_states=target_hidden_states, temperature=temperature, top_k=top_k, top_p=top_p, @@ -268,6 +329,60 @@ def convert_tensor_list(tensor_list_): handle_oom_and_exit(e) raise + def forward_raw( + self, + input_ids, + *, + position_ids=None, + past_kv_lengths=None, + total_kv_lengths=None, + input_offsets=None, + cu_seqlens=None, + block_tables=None, + slot_mapping=None, + pixel_values=None, + image_bound=None, + tgt_sizes=None, + image_req_ids=None, + visual_token_ranges=None, + target_hidden_states=None, + sample_all_positions=True, + temperature=None, + top_k=None, + top_p=None, + ): + try: + output = super().forward( + self._build_input( + input_ids, + position_ids=position_ids, + past_kv_lengths=past_kv_lengths, + total_kv_lengths=total_kv_lengths, + input_offsets=input_offsets, + cu_seqlens=cu_seqlens, + block_tables=block_tables, + slot_mapping=slot_mapping, + pixel_values=pixel_values, + image_bound=image_bound, + tgt_sizes=tgt_sizes, + image_req_ids=image_req_ids, + visual_token_ranges=visual_token_ranges, + target_hidden_states=target_hidden_states, + sample_all_positions=sample_all_positions, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + ) + return { + "output_ids": infinicore.Tensor(output.output_ids), + "logits": infinicore.Tensor(output.logits), + "hidden_states": infinicore.Tensor(output.hidden_states), + } + except BaseException as e: + handle_oom_and_exit(e) + raise + def generate( self, input_ids, diff --git a/python/infinilm/llm/cache_manager.py b/python/infinilm/llm/cache_manager.py index 8e8007cd..239fb974 100644 --- a/python/infinilm/llm/cache_manager.py +++ b/python/infinilm/llm/cache_manager.py @@ -62,9 +62,9 @@ def compute_hash( return h.intdigest() def __init__(self, num_blocks: int, block_size: int): - assert num_blocks > 0 and block_size > 0, ( - "num_blocks and block_size must be positive" - ) + assert ( + num_blocks > 0 and block_size > 0 + ), "num_blocks and block_size must be positive" self.num_blocks = num_blocks self.block_size = block_size @@ -105,9 +105,9 @@ def _allocate_full_block(self) -> Block: def _deallocate_block(self, block_id: int): """Deallocate a block and return it to free list.""" block = self.blocks[block_id] - assert block.ref_count == 0, ( - f"Block {block_id} ref_count not zero, cannot deallocate" - ) + assert ( + block.ref_count == 0 + ), f"Block {block_id} ref_count not zero, cannot deallocate" if block.hash != -1 and self.hash_to_block_id.get(block.hash) == block_id: del self.hash_to_block_id[block.hash] @@ -172,7 +172,9 @@ def get_computed_blocks( cache_miss = False mm_start_counter = 0 mm_caching_queue = deque() - blocks_blueprint = [] # [{"prefix_hash": int or -1 if not a full block, "block_id": int or -1 if not cached}, ...] + blocks_blueprint = ( + [] + ) # [{"prefix_hash": int or -1 if not a full block, "block_id": int or -1 if not cached}, ...] max_blocks_to_reuse = num_full_blocks for block_idx in range(num_blocks): @@ -337,8 +339,87 @@ def allocate_slots( return block_table, slot_mapping + def append_slots( + self, + block_table: List[int], + start_num_tokens: int, + num_slots: int, + total_token_ids: List[int] = None, + update_hash: bool = True, + ) -> tuple[List[int], List[int]]: + """Append multiple decode slots for speculative target verification.""" + slots = [] + for offset in range(num_slots): + block_table, slot = self.append_slot( + block_table, + start_num_tokens + offset, + total_token_ids, + update_hash=update_hash, + ) + slots.append(slot) + return block_table, slots + + def truncate_blocks( + self, block_table: List[int], keep_num_tokens: int + ) -> List[int]: + """Trim block_table to the logical token length after speculative verify. + + KV tensors are not physically cleared; future attention only sees slots reachable + from the returned block table and sequence lengths. Any newly allocated blocks + past keep_num_tokens are dereferenced, and hash metadata for the partial tail is + invalidated so rejected draft tokens are never reused as a prefix hit. + """ + assert keep_num_tokens > 0, "keep_num_tokens must be greater than 0" + keep_blocks = (keep_num_tokens + self.block_size - 1) // self.block_size + keep_blocks = min(keep_blocks, len(block_table)) + + removed = block_table[keep_blocks:] + for block_id in removed: + block = self.blocks[block_id] + block.ref_count = 0 + self._deallocate_block(block_id) + + truncated = block_table[:keep_blocks] + if truncated and keep_num_tokens % self.block_size != 0: + tail_id = truncated[-1] + tail = self.blocks[tail_id] + if tail.hash != -1 and self.hash_to_block_id.get(tail.hash) == tail_id: + del self.hash_to_block_id[tail.hash] + tail.hash = -1 + tail.token_ids = [] + + return truncated + + def commit_blocks_hash( + self, block_table: List[int], token_ids: List[int], num_tokens: int + ) -> None: + """Register hashes for full blocks whose tokens are finalized.""" + assert num_tokens <= len(token_ids), "num_tokens exceeds token_ids length" + num_full_blocks = min(num_tokens // self.block_size, len(block_table)) + prefix_hash = -1 + for block_idx in range(num_full_blocks): + block_id = block_table[block_idx] + block = self.blocks[block_id] + block_start = block_idx * self.block_size + block_end = block_start + self.block_size + block_tokens = token_ids[block_start:block_end] + current_hash = self.compute_hash(block_tokens, prefix_hash) + + if block.hash != -1 and block.hash != current_hash: + if self.hash_to_block_id.get(block.hash) == block_id: + del self.hash_to_block_id[block.hash] + + if block.hash != current_hash or block.token_ids != block_tokens: + block.update(current_hash, block_tokens) + self.hash_to_block_id[current_hash] = block_id + prefix_hash = current_hash + def append_slot( - self, block_table: List[int], num_tokens: int, total_token_ids: List[int] = None + self, + block_table: List[int], + num_tokens: int, + total_token_ids: List[int] = None, + update_hash: bool = True, ) -> tuple[List[int], int]: """Append slot for decode phase (generate one new token). @@ -354,26 +435,27 @@ def append_slot( assert num_tokens > 0, "num_tokens must be greater than 0" if num_tokens % self.block_size == 1: - # Previous block is full, update its hash for future prefix caching - last_block_id = block_table[-1] - last_block = self.blocks[last_block_id] - - # Only update if block's token_ids is empty (avoid duplicate updates) - if len(last_block.token_ids) == 0: - block_start_idx = num_tokens - self.block_size - 1 - block_end_idx = num_tokens - 1 - block_tokens = total_token_ids[block_start_idx:block_end_idx] - - # Compute prefix_hash using previous block's hash if available - if len(block_table) > 1: - prev_block = self.blocks[block_table[-2]] - prefix_hash = prev_block.hash - else: - prefix_hash = -1 - - current_hash = self.compute_hash(block_tokens, prefix_hash) - last_block.update(current_hash, block_tokens) - self.hash_to_block_id[current_hash] = last_block_id + if update_hash: + # Previous block is full, update its hash for future prefix caching. + last_block_id = block_table[-1] + last_block = self.blocks[last_block_id] + + # Only update if block's token_ids is empty (avoid duplicate updates) + if len(last_block.token_ids) == 0: + block_start_idx = num_tokens - self.block_size - 1 + block_end_idx = num_tokens - 1 + block_tokens = total_token_ids[block_start_idx:block_end_idx] + + # Compute prefix_hash using previous block's hash if available + if len(block_table) > 1: + prev_block = self.blocks[block_table[-2]] + prefix_hash = prev_block.hash + else: + prefix_hash = -1 + + current_hash = self.compute_hash(block_tokens, prefix_hash) + last_block.update(current_hash, block_tokens) + self.hash_to_block_id[current_hash] = last_block_id # Need new block if not self.free_block_ids: @@ -426,9 +508,9 @@ def update_blocks_hash(self, block_table: List[int], num_local_cached_tokens: in num_local_cached_tokens: Number of locally cached tokens (must be a multiple of block_size). """ - assert num_local_cached_tokens % self.block_size == 0, ( - "num_local_cached_tokens must be multiple of block_size" - ) + assert ( + num_local_cached_tokens % self.block_size == 0 + ), "num_local_cached_tokens must be multiple of block_size" for idx in range(num_local_cached_tokens // self.block_size, len(block_table)): block_id = block_table[idx] block = self.blocks[block_id] diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py index a07536fd..b75e7544 100644 --- a/python/infinilm/llm/llm.py +++ b/python/infinilm/llm/llm.py @@ -159,7 +159,7 @@ def _update_requests( case _: raise ValueError(f"Unsupported cache_type: {self.cache_type}") pending = [] - for req, token_id in zip(requests, sampled_tokens): + for req, token_ids in zip(requests, sampled_tokens): if req.is_aborted(): logger.info( f"Request {req.request_id} aborted by client, skipping update" @@ -170,57 +170,63 @@ def _update_requests( req.mark_canceled() continue - req.generated_token_ids.append(token_id) - pending_tokens = req.generated_token_ids[req._token_decode_offset :] - delta = self.tokenizer.decode(pending_tokens) - holds_back = bool(delta) and delta.endswith("\ufffd") + if not isinstance(token_ids, list): + token_ids = [token_ids] - last_committed_text = req.generated_text + for token_id in token_ids: + if req.is_finished(): + break + req.generated_token_ids.append(token_id) + pending_tokens = req.generated_token_ids[req._token_decode_offset :] + delta = self.tokenizer.decode(pending_tokens) + holds_back = bool(delta) and delta.endswith("\ufffd") - if not holds_back: - req.generated_text = last_committed_text + delta - req._token_decode_offset = len(req.generated_token_ids) + last_committed_text = req.generated_text - is_finished = self._check_request_finished(req, token_id) + if not holds_back: + req.generated_text = last_committed_text + delta + req._token_decode_offset = len(req.generated_token_ids) - # vLLM-style replacement character handling is primarily relevant for streaming. - # For offline generation (no output queue), keep the fast incremental path. - if req._output_queue is None: - if is_finished: - req.mark_finished(req.finish_reason) - - else: - if holds_back and not is_finished: - token_text = "" - else: - if is_finished and req.finish_reason in ( - FinishReason.EOS_TOKEN, - FinishReason.LENGTH, - FinishReason.STOP_STRING, - ): - token_text = "" - else: - token_text = req.generated_text[req._text_output_offset :] - if token_text: - req._text_output_offset = len(req.generated_text) + is_finished = self._check_request_finished(req, token_id) + # vLLM-style replacement character handling is primarily relevant for streaming. + # For offline generation (no output queue), keep the fast incremental path. + if req._output_queue is None: if is_finished: req.mark_finished(req.finish_reason) - output = TokenOutput( - request_id=req.request_id, - token_id=token_id, - token_text=token_text, - finished=is_finished, - finish_reason=req.finish_reason if is_finished else None, - generated_text=req.generated_text, - ) - if req.is_aborted(): - logger.info( - f"Request {req.request_id} aborted before putting token" + else: + if holds_back and not is_finished: + token_text = "" + else: + if is_finished and req.finish_reason in ( + FinishReason.EOS_TOKEN, + FinishReason.LENGTH, + FinishReason.STOP_STRING, + ): + token_text = "" + else: + token_text = req.generated_text[req._text_output_offset :] + if token_text: + req._text_output_offset = len(req.generated_text) + + if is_finished: + req.mark_finished(req.finish_reason) + + output = TokenOutput( + request_id=req.request_id, + token_id=token_id, + token_text=token_text, + finished=is_finished, + finish_reason=req.finish_reason if is_finished else None, + generated_text=req.generated_text, ) - continue - pending.append((req.output_queue.async_q, output)) + if req.is_aborted(): + logger.info( + f"Request {req.request_id} aborted before putting token" + ) + continue + pending.append((req.output_queue.async_q, output)) self.scheduler.complete_requests(requests) return pending @@ -287,6 +293,8 @@ class LLM: def __init__( self, model_path: str, + draft_model_path: Optional[str] = None, + num_draft_tokens: int = 4, device: str = "cuda", dtype: str = "float16", tensor_parallel_size: int = 1, @@ -331,6 +339,8 @@ def __init__( """ config = EngineConfig( model_path=model_path, + draft_model_path=draft_model_path, + num_draft_tokens=num_draft_tokens, device=device, dtype=dtype, tensor_parallel_size=tensor_parallel_size, @@ -490,6 +500,8 @@ class AsyncLLMEngine: def __init__( self, model_path: str, + draft_model_path: Optional[str] = None, + num_draft_tokens: int = 4, device: str = "cuda", dtype: str = "float16", tensor_parallel_size: int = 1, @@ -537,6 +549,8 @@ def __init__( """ config = EngineConfig( model_path=model_path, + draft_model_path=draft_model_path, + num_draft_tokens=num_draft_tokens, device=device, dtype=dtype, tensor_parallel_size=tensor_parallel_size, diff --git a/python/infinilm/llm/model_runner/model_runner.py b/python/infinilm/llm/model_runner/model_runner.py index e7b9ac58..b72ef54b 100644 --- a/python/infinilm/llm/model_runner/model_runner.py +++ b/python/infinilm/llm/model_runner/model_runner.py @@ -14,6 +14,7 @@ ) from infinilm.modeling_utils import load_model_state_dict_by_file from infinilm.processors import AutoInfinilmProcessor +from infinilm.llm.model_runner.speculative_runner import SpeculativeRunner logger = logging.getLogger(__name__) @@ -36,7 +37,7 @@ class KVConnectorOutput: class ModelRunnerOutput: # [num_reqs] req_ids: list[str] = field(default_factory=list) - sampled_token_ids: list[int] = field(default_factory=list) + sampled_token_ids: list[int | list[int]] = field(default_factory=list) kv_connector_output: KVConnectorOutput | None = None @@ -81,12 +82,25 @@ def __init__(self, config: EngineConfig): skip_legacy_moe=config.skip_legacy_moe, ) + if self.model_engine.model_type == "minicpm_eagle": + raise RuntimeError( + "MiniCPM4 Eagle-vLLM is a speculative draft head, not a standalone " + "causal LM. Use the MiniCPM4-8B base model as --model and pass " + "this checkpoint through --draft-model for Eagle speculative decoding." + ) + # Load model weights if not self.config.skip_load: load_model_state_dict_by_file( self.model_engine, config.model_path, dtype=self.model_engine.dtype ) + self.speculative_runner = None + if config.draft_model_path is not None: + self.speculative_runner = SpeculativeRunner( + config, self.model_engine, self.device + ) + # Initialize processor self.processor = AutoInfinilmProcessor.from_pretrained(config.model_path) @@ -180,12 +194,18 @@ def _model_forward(self, scheduler_output): self.config.top_k, ) + if self.speculative_runner is not None: + return self._model_forward_with_speculative(scheduler_output, model_input) + # Run inference sampled_tokens = self.model_engine.forward(**model_input) sampled_tokens_list = sampled_tokens.to_numpy().tolist() return sampled_tokens_list + def _model_forward_with_speculative(self, scheduler_output, model_input): + return self.speculative_runner.forward(scheduler_output, model_input) + @contextmanager def maybe_get_kv_connector_output( self, scheduler_output: Any diff --git a/python/infinilm/llm/model_runner/speculative_runner.py b/python/infinilm/llm/model_runner/speculative_runner.py new file mode 100644 index 00000000..9f842f5a --- /dev/null +++ b/python/infinilm/llm/model_runner/speculative_runner.py @@ -0,0 +1,317 @@ +import infinicore + +from infinilm.cache.cache import StaticKVCacheConfig +from infinilm.distributed import DistConfig +from infinilm.infer_engine import InferEngine +from infinilm.modeling_utils import load_model_state_dict_by_file + + +class SpeculativeRunner: + def __init__(self, config, target_model_engine, device): + self.config = config + self.target_model_engine = target_model_engine + self.num_draft_tokens = config.num_draft_tokens + self.draft_max_batch_size = config.max_batch_size + self.eagle_accept_count = 0 + self.eagle_total_count = 0 + + draft_cache_config = StaticKVCacheConfig( + max_batch_size=config.max_batch_size, max_cache_len=config.max_cache_len + ) + self.draft_model_engine = InferEngine( + model_path=config.draft_model_path, + device=device, + distributed_config=DistConfig(config.tensor_parallel_size), + cache_config=draft_cache_config, + enable_graph_compiling=config.enable_graph, + attention_backend="default", + use_mla=False, + weight_load_mode=config.weight_load_mode, + ) + if self.draft_model_engine.model_type != "minicpm_eagle": + raise RuntimeError( + f"draft_model_path must point to a MiniCPM Eagle draft model, " + f"got model_type={self.draft_model_engine.model_type}" + ) + if not config.skip_load: + load_model_state_dict_by_file( + self.draft_model_engine, + config.draft_model_path, + dtype=self.draft_model_engine.dtype, + ) + + def forward(self, scheduler_output, model_input): + cache_ops = getattr(scheduler_output, "speculative_cache_ops", None) + if cache_ops is None: + sampled_tokens = self.target_model_engine.forward(**model_input) + return sampled_tokens.to_numpy().tolist() + + # Keep non-greedy sampling on the established target path. Correct stochastic + # speculative sampling needs distribution-level acceptance, while current MTP + # verification is exact for greedy decoding. + if self.config.top_k != 1 or self.config.temperature != 1.0: + sampled_tokens = self.target_model_engine.forward(**model_input) + return sampled_tokens.to_numpy().tolist() + + requests = scheduler_output.scheduled_requests + if not requests: + return [] + + target_output = self.target_model_engine.forward_raw(**model_input) + target_token_ids = target_output["output_ids"].to_numpy().tolist() + if not target_token_ids: + return target_token_ids + + input_offsets = model_input["input_offsets"].to_numpy().tolist() + hidden_states = target_output["hidden_states"] + output_tokens_by_req: list[list[int]] = [[] for _ in requests] + draft_jobs = [] + + for req_idx, req in enumerate(requests): + last_input_idx = int(input_offsets[req_idx + 1]) - 1 + target_token = int(target_token_ids[last_input_idx]) + max_tokens = req.sampling_params.max_tokens + remaining = ( + None + if max_tokens is None + else max_tokens - req.get_num_generated_tokens() + ) + if remaining is not None and remaining <= 1: + output_tokens_by_req[req_idx] = [target_token] + continue + + draft_budget = self.num_draft_tokens + if remaining is not None: + draft_budget = min(draft_budget, max(1, remaining - 1)) + if draft_budget <= 0: + output_tokens_by_req[req_idx] = [target_token] + continue + + source_token, source_position = self._get_last_input_token_and_position( + req, scheduler_output.is_prefill + ) + draft_jobs.append( + { + "req_idx": req_idx, + "req": req, + "target_token": target_token, + "remaining": remaining, + "source_token": source_token, + "source_position": source_position, + "target_hidden": hidden_states.narrow(1, last_input_idx, 1), + "num_tokens": draft_budget, + } + ) + + draft_results = self._draft_eagle_tokens_batch(draft_jobs) + verify_candidates = [] + for job, draft_tokens in zip(draft_jobs, draft_results): + req_idx = job["req_idx"] + req = job["req"] + target_token = job["target_token"] + if not draft_tokens: + output_tokens_by_req[req_idx] = [target_token] + continue + + self.eagle_total_count += len(draft_tokens) + if draft_tokens[0] != target_token: + output_tokens_by_req[req_idx] = [target_token] + continue + + base_len = req.get_total_length() + total_token_ids = req.get_all_token_ids() + draft_tokens + verify_block_table, verify_slots = cache_ops.append_verify_slots( + list(req.block_table), + base_len + 1, + len(draft_tokens), + total_token_ids, + ) + req.block_table = verify_block_table + req.num_blocks = len(req.block_table) + verify_candidates.append( + { + "req_idx": req_idx, + "req": req, + "base_len": base_len, + "remaining": job["remaining"], + "draft_tokens": draft_tokens, + "slot_mapping": verify_slots, + } + ) + + if verify_candidates: + verify_output = self.target_model_engine.forward_raw( + **self._build_paged_verify_batch_input(verify_candidates) + ) + verify_token_ids = verify_output["output_ids"].to_numpy().tolist() + verify_offsets = [0] + for candidate in verify_candidates: + verify_offsets.append( + verify_offsets[-1] + len(candidate["draft_tokens"]) + ) + + for idx, candidate in enumerate(verify_candidates): + req = candidate["req"] + req_idx = candidate["req_idx"] + draft_tokens = candidate["draft_tokens"] + segment = verify_token_ids[ + verify_offsets[idx] : verify_offsets[idx + 1] + ] + accepted = 1 + correction = None + for draft_idx in range(1, len(draft_tokens)): + expected = int(segment[draft_idx - 1]) + if draft_tokens[draft_idx] != expected: + correction = expected + break + accepted += 1 + + if correction is None: + correction = int(segment[len(draft_tokens) - 1]) + + self.eagle_accept_count += accepted + keep_tokens = candidate["base_len"] + accepted + req.block_table = cache_ops.rollback_to_length( + req.block_table, keep_tokens + ) + req.num_blocks = len(req.block_table) + req.slot_mapping = [] + accepted_token_ids = req.get_all_token_ids() + draft_tokens[:accepted] + cache_ops.commit_accepted_tokens( + req.block_table, accepted_token_ids, keep_tokens + ) + + output_tokens = draft_tokens[:accepted] + [correction] + remaining = candidate["remaining"] + if remaining is not None: + output_tokens = output_tokens[:remaining] + output_tokens_by_req[req_idx] = output_tokens + + return output_tokens_by_req + + def _get_last_input_token_and_position(self, req, is_prefill): + if is_prefill: + return req.prompt_token_ids[-1], req.prompt_length - 1 + token = ( + req.generated_token_ids[-1] + if req.generated_token_ids + else req.prompt_token_ids[-1] + ) + return token, req.get_total_length() - 1 + + def _draft_eagle_tokens_batch(self, jobs: list[dict]) -> list[list[int]]: + if not jobs: + return [] + + draft_tokens_by_job: list[list[int]] = [[] for _ in jobs] + current_tokens = [int(job["source_token"]) for job in jobs] + current_hiddens = [job["target_hidden"] for job in jobs] + max_steps = max(int(job["num_tokens"]) for job in jobs) + if max_steps <= 0: + return draft_tokens_by_job + + real_batch = len(jobs) + draft_batch = max(self.draft_max_batch_size, real_batch) + if real_batch > self.draft_max_batch_size: + raise RuntimeError( + f"Eagle draft batch {real_batch} exceeds configured max_batch_size " + f"{self.draft_max_batch_size}. Increase max_batch_size when creating LLM." + ) + dummy_token = current_tokens[0] + dummy_hidden = current_hiddens[0] + + for step in range(max_steps): + input_tokens = [ + current_tokens[idx] if idx < real_batch else dummy_token + for idx in range(draft_batch) + ] + positions = [ + int(jobs[idx]["source_position"]) + step if idx < real_batch else 0 + for idx in range(draft_batch) + ] + hidden_inputs = [ + current_hiddens[idx] if idx < real_batch else dummy_hidden + for idx in range(draft_batch) + ] + target_hidden = infinicore.cat(hidden_inputs, dim=0) + seq_len = step + 1 + + draft_output = self.draft_model_engine.forward_raw( + input_ids=infinicore.from_list( + [[token] for token in input_tokens], dtype=infinicore.int64 + ), + position_ids=infinicore.from_list( + [[pos] for pos in positions], dtype=infinicore.int64 + ), + past_kv_lengths=infinicore.from_list( + [step] * draft_batch, dtype=infinicore.int32 + ), + total_kv_lengths=infinicore.from_list( + [seq_len] * draft_batch, dtype=infinicore.int32 + ), + input_offsets=infinicore.from_list( + list(range(draft_batch + 1)), dtype=infinicore.int32 + ), + cu_seqlens=infinicore.from_list( + [i * seq_len for i in range(draft_batch + 1)], + dtype=infinicore.int32, + ), + target_hidden_states=target_hidden, + temperature=1.0, + top_k=1, + top_p=1.0, + ) + token_ids = draft_output["output_ids"].to_numpy().tolist() + draft_hidden = draft_output["hidden_states"] + for job_idx, job in enumerate(jobs): + token = int(token_ids[job_idx]) + if step < int(job["num_tokens"]): + draft_tokens_by_job[job_idx].append(token) + current_tokens[job_idx] = token + current_hiddens[job_idx] = draft_hidden.narrow(0, job_idx, 1) + + return draft_tokens_by_job + + def _build_paged_verify_batch_input(self, candidates: list[dict]) -> dict: + tokens = [] + position_ids = [] + past_lens = [] + seq_lens = [] + input_offsets = [0] + cu_seqlens = [0] + slot_mapping = [] + block_tables = [] + max_block_table_len = max( + len(candidate["req"].block_table) for candidate in candidates + ) + + for candidate in candidates: + req = candidate["req"] + base_len = candidate["base_len"] + draft_tokens = candidate["draft_tokens"] + tokens.extend(draft_tokens) + position_ids.extend(range(base_len, base_len + len(draft_tokens))) + past_lens.append(base_len) + seq_lens.append(base_len + len(draft_tokens)) + input_offsets.append(input_offsets[-1] + len(draft_tokens)) + cu_seqlens.append(cu_seqlens[-1] + base_len + len(draft_tokens)) + slot_mapping.extend(candidate["slot_mapping"]) + block_tables.append( + req.block_table + [-1] * (max_block_table_len - len(req.block_table)) + ) + + return { + "input_ids": infinicore.from_list([tokens], dtype=infinicore.int64), + "position_ids": infinicore.from_list(position_ids, dtype=infinicore.int64), + "past_kv_lengths": infinicore.from_list(past_lens, dtype=infinicore.int32), + "total_kv_lengths": infinicore.from_list(seq_lens, dtype=infinicore.int32), + "input_offsets": infinicore.from_list( + input_offsets, dtype=infinicore.int32 + ), + "cu_seqlens": infinicore.from_list(cu_seqlens, dtype=infinicore.int32), + "block_tables": infinicore.from_list(block_tables, dtype=infinicore.int32), + "slot_mapping": infinicore.from_list(slot_mapping, dtype=infinicore.int64), + "temperature": 1.0, + "top_k": 1, + "top_p": 1.0, + } diff --git a/python/infinilm/llm/scheduler.py b/python/infinilm/llm/scheduler.py index c99c54f5..087a87a8 100644 --- a/python/infinilm/llm/scheduler.py +++ b/python/infinilm/llm/scheduler.py @@ -12,6 +12,36 @@ logger = logging.getLogger(__name__) +class SpeculativeCacheOps: + """Limited cache operations needed by speculative verification.""" + + def __init__(self, cache_manager: BlockManager): + self._cache_manager = cache_manager + + def append_verify_slots( + self, + block_table: List[int], + start_length: int, + num_slots: int, + token_ids: List[int], + ): + return self._cache_manager.append_slots( + block_table, + start_length, + num_slots, + token_ids, + update_hash=False, + ) + + def rollback_to_length(self, block_table: List[int], keep_tokens: int): + return self._cache_manager.truncate_blocks(block_table, keep_tokens) + + def commit_accepted_tokens( + self, block_table: List[int], token_ids: List[int], num_tokens: int + ) -> None: + self._cache_manager.commit_blocks_hash(block_table, token_ids, num_tokens) + + class SchedulerOutput: """Scheduler output containing scheduled requests and execution phase info.""" @@ -19,10 +49,12 @@ def __init__( self, scheduled_requests: List[InferenceRequest], is_prefill: bool = False, + speculative_cache_ops: Optional[SpeculativeCacheOps] = None, ): self.scheduled_requests = scheduled_requests self.num_requests = len(scheduled_requests) self.is_prefill = is_prefill + self.speculative_cache_ops = speculative_cache_ops self.kv_connector_metadata = None @@ -54,6 +86,7 @@ def __init__( self.remote_kv_requests: dict[str, InferenceRequest] = {} self.cache_manager = BlockManager(num_blocks=num_blocks, block_size=block_size) + self.speculative_cache_ops = SpeculativeCacheOps(self.cache_manager) self.block_size = block_size self.max_num_batched_tokens = max_num_batched_tokens self.connector = connector @@ -236,6 +269,7 @@ def schedule(self) -> Optional[SchedulerOutput]: scheduler_output = SchedulerOutput( scheduled_requests=scheduled_requests, is_prefill=is_prefill, + speculative_cache_ops=self.speculative_cache_ops, ) if self.connector is not None: meta = self.connector.build_connector_meta() @@ -298,6 +332,7 @@ def schedule(self) -> Optional[SchedulerOutput]: scheduler_output = SchedulerOutput( scheduled_requests=scheduled_requests, is_prefill=is_prefill, + speculative_cache_ops=self.speculative_cache_ops, ) if self.connector is not None: @@ -306,7 +341,10 @@ def schedule(self) -> Optional[SchedulerOutput]: return scheduler_output if self.connector is not None: - scheduler_output = SchedulerOutput(scheduled_requests=[]) + scheduler_output = SchedulerOutput( + scheduled_requests=[], + speculative_cache_ops=self.speculative_cache_ops, + ) meta = self.connector.build_connector_meta() scheduler_output.kv_connector_metadata = meta return scheduler_output diff --git a/python/infinilm/modeling_utils.py b/python/infinilm/modeling_utils.py index f9216e8e..db18e1be 100644 --- a/python/infinilm/modeling_utils.py +++ b/python/infinilm/modeling_utils.py @@ -17,7 +17,7 @@ def _get_scale_emb(model_path: str) -> float: raise FileNotFoundError(f"config.json not found at {config_path}") with open(config_path, "r") as f: config = json.load(f) - if config.get("model_type") != "fm9g": + if config.get("model_type") not in ("fm9g", "minicpm"): return 1.0 return config.get("scale_emb", 1.0) @@ -196,6 +196,7 @@ def load_model_state_dict_by_file( already_loaded_keys = [] embed_tokens_torch_unscaled = None + weights_processed = False remapper = _WEIGHT_REMAPPER.get(model_type) @@ -261,6 +262,7 @@ def load_model_state_dict_by_file( gc.collect() model.process_weights_after_loading() + weights_processed = True elif os.path.exists(os.path.join(model_path, "pytorch_model.bin")): file_path = os.path.join(model_path, "pytorch_model.bin") @@ -309,6 +311,9 @@ def load_model_state_dict_by_file( check_parameters(model_keys, already_loaded_keys) + if not weights_processed: + model.process_weights_after_loading() + t2 = time.time() print(f" load weights over! {(t2 - t1) * 1000} ms \n") diff --git a/python/infinilm/processors/sentencepiece_processor.py b/python/infinilm/processors/sentencepiece_processor.py index 19eefbb3..715bb21e 100644 --- a/python/infinilm/processors/sentencepiece_processor.py +++ b/python/infinilm/processors/sentencepiece_processor.py @@ -75,6 +75,12 @@ class MistralProcessor(SentencePieceProcessor): pass +@register_processor("minicpm") +@register_processor("minicpm4") +class MiniCPMProcessor(SentencePieceProcessor): + pass + + @register_processor("deepseek_v2") class DeepSeekV2Processor(BasicLLMProcessor): pass