0

I want to split a torch array by a list of indices.

For example say my input array is torch.arange(20)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])

and my list of indices is splits = [1,2,5,10]

Then my result would be:

(tensor([0]),
 tensor([1, 2]),
 tensor([3, 4, 5, 6, 7]),
 tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17]))

assume my input array is always long enough to bigger than the sum of my list of indices.

4 Answers 4

3

You could use tensor_split on the cumulated sum of the splits (e.g. with np.cumsum), excluding the last chunk:

import torch
import numpy as np

t = torch.arange(20)
splits = [1,2,5,10]

t.tensor_split(np.cumsum(splits).tolist())[:-1]

Output:

(tensor([0]),
 tensor([1, 2]),
 tensor([3, 4, 5, 6, 7]),
 tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17]),
)
Sign up to request clarification or add additional context in comments.

Comments

2

Another possible option would be to slice the tensor first, then split it :

import torch

t = torch.arange(20)

splits = [1, 2, 5, 10]

out = torch.split(t[: sum(splits)], splits)

Output :

(tensor([0]),
 tensor([1, 2]),
 tensor([3, 4, 5, 6, 7]),
 tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17]))

Comments

0

I used numpy for my example but it might work with tensors as well.

import numpy as np

array = np.array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])

def split(array, split_list: list[int]):
    if sum(split_list) > len(array):
        return None
    tup = ()
    for i in split_list:
        tup = tup + (array[:i].tolist(),)
        array = array[i:]
    return tup

print(split(array, [1, 2, 5]))

And i got this as output:

([0], [1, 2], [3, 4, 5, 6, 7])

I hope this is what you meant?

Comments

0

You can achieve this using PyTorch's torch.split function along with list comprehension to split the array according to the provided indices.

import torch

input_array = torch.arange(20)
splits = [1, 2, 5, 10]

result = [input_array.split(split) for split in splits]
print(tuple(result))

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.