Skip to content

Efficiency of treetensor.torch.as_tensor #19

Description

@LamannaLeonardo

I am experiencing an efficiency issue apparently due to treetensor.torch.as_tensor applied on a dictionary. I provide a minimal code for reproducibility (I am using Python 3.10, DI-treetensor==0.5.0 and torch==2.6.0):

import treetensor.torch as ttorch
import torch

if __name__ == '__main__':
    obs = torch.randn(1, 28, 28)
    for i in range(5000):
        # fast = ttorch.as_tensor(obs)
        slow = ttorch.as_tensor({'key': obs})

ttorch.as_tensor(obs) takes about 1 second, while ttorch.as_tensor({'key': obs}) requires about 16 seconds. Profiling the above code I see the __repr__ method of every tensor obs is being called, taking 15 out of 16 seconds.

Image

In my specific case, I faced such efficiency issue when collecting rollouts in DI-engine, where transition dictionaries are converted to tensors by means of ttorch.as_tensor at this line.

Can this slowdown be mitigated? I see converting a dictionary to a tree tensor requires additional computation, but I am wondering if some optimisation is possible. Moreover, adding torch.set_printoptions(precision=3, threshold=10) significantly reduces the computation time (from 16 to 2.5 seconds), but I am not sure of possible side effects and how the print options affect converting a dictionary to a tree tensor.

Thank you

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions