diff --git a/csrc/models/mistral/mistral_for_causal_lm.cpp b/csrc/models/mistral/mistral_for_causal_lm.cpp index a862add7f..181e39ceb 100644 --- a/csrc/models/mistral/mistral_for_causal_lm.cpp +++ b/csrc/models/mistral/mistral_for_causal_lm.cpp @@ -22,6 +22,10 @@ std::shared_ptr create_mistral_model_config(std:: config_json["attention_bias"] = false; } + if (!config_json.contains("torch_dtype")) { + config_json["torch_dtype"] = "bfloat16"; + } + return model_config; } diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index 10cf58be2..31c0ecf1b 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -98,9 +98,11 @@ def __init__( @property def dtype(self): - torch_dtype = self.hf_config.get("torch_dtype") - if torch_dtype is None: - torch_dtype = self.hf_config.get("dtype") + torch_dtype = ( + self.hf_config.get("torch_dtype") or + self.hf_config.get("dtype") or + "bfloat16" + ) return parse_dtype(torch_dtype) @property diff --git a/python/infinilm/modeling_utils.py b/python/infinilm/modeling_utils.py index 94f4016a9..e39d9a938 100644 --- a/python/infinilm/modeling_utils.py +++ b/python/infinilm/modeling_utils.py @@ -183,7 +183,21 @@ def load_model_state_dict_by_file( already_loaded_keys = [] embed_tokens_torch_unscaled = None - file_list = glob.glob(os.path.join(model_path, "*.safetensors")) + index_file_path = os.path.join(model_path, "model.safetensors.index.json") + if os.path.exists(index_file_path): + # Priority 1: If the index file exists, strictly load exactly what it maps to. + # This handles all standard sharded models perfectly, regardless of their actual prefix. + print(f"Found index file: {index_file_path}. Loading shards by index.") + with open(index_file_path, "r") as f: + index_data = json.load(f) + weight_map = index_data.get("weight_map", {}) + unique_filenames = set(weight_map.values()) + file_list = [os.path.join(model_path, fname) for fname in unique_filenames] + else: + # Priority 2: If no index file, scan all safetensors files. + print("No index file found. Scanning all safetensors files...") + file_list = glob.glob(os.path.join(model_path, "*.safetensors")) + if len(file_list) > 0: for file_path in tqdm(file_list, desc="Processing files"): tqdm.write(f"Processing: {os.path.basename(file_path)}")