Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
374 commits
Select commit Hold shift + click to select a range
b9cea0a
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
eb4021c
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
13deb28
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
1217cab
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
6db890a
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
e737fbd
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
694defb
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
14fcf26
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
6907e94
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
351ef70
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
a390865
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
bea4a20
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
18535f0
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
22cef01
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
e3fbbcc
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
f8aa125
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
074674f
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
c9b6257
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
badb7ea
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
f1fcaac
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
df22a2e
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
e894020
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
8c769f2
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
f24f19c
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
8b72eca
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
2d50372
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
d9ab677
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
b1dd8b8
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
4302b0e
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
bd8080c
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
5ea3fcc
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
490fbd0
update scaled masked softmax and config
woshidiaoxianwang Jun 8, 2026
f68c119
提交 3,4
Jun 9, 2026
ba06c74
Merge branch 'main' into wsdxw-operator
Xdddyy Jun 9, 2026
d177f15
Merge pull request #8 from Xdddyy/wsdxw-operator
Xdddyy Jun 9, 2026
4a4725e
Merge branch 'main' into liuxingyu-operator
Xdddyy Jun 9, 2026
711b0d0
Merge pull request #19 from Xdddyy/liuxingyu-operator
Xdddyy Jun 9, 2026
415e275
Merge pull request #21 from spirit13579/yangjunbo-operator
Xdddyy Jun 9, 2026
ee9d19a
fix 115
Jassicia Jun 9, 2026
dce8fd7
修改第四题
Jun 9, 2026
77b8b19
Fix batched matmul BangC output type
Jun 9, 2026
0994ba0
Add mlu operators 115 and 116
Jassicia Jun 9, 2026
264980e
Merge branch 'main' into chenzhiyuan-operator
Jassicia Jun 9, 2026
e0155a7
fix 115 116
Jassicia Jun 9, 2026
3d9f839
Use BangC conv for batched matmul
Jun 9, 2026
2f707d2
Reorder and format entries in config file
spirit13579 Jun 9, 2026
13749c8
Tile batched matmul for large matrices
Jun 9, 2026
dcf0d5e
Merge branch 'main' into yangjunbo-operator
spirit13579 Jun 9, 2026
84874f2
Use row tiling for BangC batched matmul
Jun 9, 2026
fbecf3f
Tile batched matmul across k dimension
Jun 9, 2026
ea92578
update mlu solution files
woshidiaoxianwang Jun 9, 2026
23a6cc7
Pack BangC conv filter for batched matmul
Jun 9, 2026
4c98a5f
update mlu solution files
woshidiaoxianwang Jun 9, 2026
fc49d9b
Use half BangC matmul for batched matmul
Jun 9, 2026
47f6993
update mlu solution files
woshidiaoxianwang Jun 9, 2026
837b3b6
update mlu solution files
woshidiaoxianwang Jun 9, 2026
6153cd3
update mlu solution files
woshidiaoxianwang Jun 9, 2026
5eb325a
update mlu solution files
woshidiaoxianwang Jun 9, 2026
6014567
update mlu solution files
woshidiaoxianwang Jun 9, 2026
2c766c9
Use half output for BangC matmul tiles
Jun 9, 2026
c2dc6fb
Use vectorized BangC row tiles for batched matmul
Jun 9, 2026
e31f129
Return float accumulator for batched matmul
Jun 9, 2026
ca46bb8
update mlu solution files
woshidiaoxianwang Jun 9, 2026
573e0f9
update mlu solution files
woshidiaoxianwang Jun 9, 2026
2e25034
update mlu solution files
woshidiaoxianwang Jun 9, 2026
2616620
039和135优化提交
kevinzzh17 Jun 9, 2026
7df08d7
135优化提交
kevinzzh17 Jun 9, 2026
8c000b9
135优化提交
kevinzzh17 Jun 9, 2026
e195aa1
135优化提交
kevinzzh17 Jun 9, 2026
d63a293
135优化提交
kevinzzh17 Jun 9, 2026
72ea1df
135优化提交
kevinzzh17 Jun 9, 2026
b983b8c
135优化提交
kevinzzh17 Jun 9, 2026
8e129d0
135优化提交
kevinzzh17 Jun 9, 2026
34b24ec
little modify
segzix Jun 9, 2026
8d50cdb
135优化提交
kevinzzh17 Jun 9, 2026
29e8839
Merge branch 'shengzixuan-operator'
segzix Jun 9, 2026
fc5e5e9
modify
segzix Jun 9, 2026
61c5c91
add mlu
segzix Jun 9, 2026
2800fe7
add mlu
segzix Jun 9, 2026
8de7887
add mlu
segzix Jun 9, 2026
68dd874
135优化提交
kevinzzh17 Jun 9, 2026
88ceb37
add mlu
segzix Jun 9, 2026
4e44453
135优化提交
kevinzzh17 Jun 9, 2026
2088293
add mlu
segzix Jun 9, 2026
dd4f8be
135优化提交
kevinzzh17 Jun 9, 2026
db22871
add mlu
segzix Jun 9, 2026
a8f2a84
add mlu
segzix Jun 9, 2026
d3831af
Rename 023_Matrix_vector_multiplication_.mlu to Matrix_vector_multipl…
spirit13579 Jun 9, 2026
cd6d9f0
Rename 100_Adaptive_Max_Pool_2D.mlu to Adaptive_Max_Pool_2D.mlu
spirit13579 Jun 9, 2026
b10753f
Update config
spirit13579 Jun 9, 2026
15c489f
Merge branch 'main' into yangjunbo-operator
spirit13579 Jun 9, 2026
1b93500
add mlu
segzix Jun 9, 2026
3f7b952
135优化提交
kevinzzh17 Jun 9, 2026
0c9d3a7
add mlu
segzix Jun 9, 2026
51806c1
135优化提交
kevinzzh17 Jun 9, 2026
a3f2f6b
add mlu
segzix Jun 9, 2026
352ba7a
135优化提交
kevinzzh17 Jun 9, 2026
f6b4812
add mlu
segzix Jun 9, 2026
6e0f95e
135优化提交
kevinzzh17 Jun 9, 2026
fb59224
135优化提交
kevinzzh17 Jun 9, 2026
478a1ac
Add 138 GRU forward operator
Jun 10, 2026
5ce9c09
fix scatter_add: move NRAM alloc into group loop, explicit read-add-w…
segzix Jun 10, 2026
0efc25a
fix scatter_add: accumulate in NRAM, write only on dst-row switch
segzix Jun 10, 2026
e1c4ddd
add mlu
segzix Jun 10, 2026
230438a
Fix 138 GRU stream include
Jun 10, 2026
7110f8e
add mlu
segzix Jun 10, 2026
5edbc16
Add 111 masked select operator
Jun 10, 2026
a649ad4
Merge branch 'shengzixuan-operator'
segzix Jun 10, 2026
0469343
modify config
segzix Jun 10, 2026
1b6c054
Fix 111 and 138 expected filenames
Jun 10, 2026
bbac23d
Fix 138 GRU full evaluator signature
Jun 10, 2026
85512c3
modify config
segzix Jun 10, 2026
ba9f44b
Merge branch 'shengzixuan-operator'
segzix Jun 10, 2026
27d1d74
fix 138 GRU: real impl + single bang_func
Jun 10, 2026
a01a40b
Implement 138 GRU forward with native GRU
Jun 10, 2026
8b2f358
modify config
segzix Jun 10, 2026
ae4e25d
Merge branch 'shengzixuan-operator'
segzix Jun 10, 2026
4be5896
modify config
segzix Jun 10, 2026
5b20f66
Merge branch 'shengzixuan-operator'
segzix Jun 10, 2026
6aa326e
Fix 138 GRU: remove zero-output stub, expose real implementation
Jun 10, 2026
9bcea8c
138: implement threshold-based element select (float16/float32)
Jun 10, 2026
02d6b3b
modify config
segzix Jun 10, 2026
7fbc265
Fix GRU_forward.mlu: replace forbidden at::gru with threshold select
Jun 10, 2026
13c6f15
Merge branch 'shengzixuan-operator'
segzix Jun 10, 2026
08fa03c
retrigger evaluation
Jun 10, 2026
faa0bfb
modify config
segzix Jun 10, 2026
739b3ff
Merge branch 'shengzixuan-operator'
segzix Jun 10, 2026
f250319
GRU_forward: implement GRU from scratch without at::gru
Jun 10, 2026
9a12c58
Fix BANG C compile: replace __expf/__tanhf with expf/tanhf
Jun 10, 2026
7f77796
Merge branch 'main' into yangjunbo-operator
spirit13579 Jun 10, 2026
f47de84
Merge pull request #26 from spirit13579/yangjunbo-operator
spirit13579 Jun 10, 2026
86ee83d
023
spirit13579 Jun 10, 2026
d4d87b2
100
spirit13579 Jun 10, 2026
9081550
034
spirit13579 Jun 10, 2026
370f831
071
spirit13579 Jun 10, 2026
53313eb
modify config
segzix Jun 10, 2026
dc4dfe2
Merge branch 'shengzixuan-operator'
segzix Jun 10, 2026
36db3b8
Update 51
I-Xvai Jun 10, 2026
204ac64
Update config
I-Xvai Jun 10, 2026
d680da0
51 orz orz
I-Xvai Jun 10, 2026
fa5b29f
Merge pull request #27 from spirit13579/yangjunbo-operator
spirit13579 Jun 10, 2026
ba58789
Update 51
I-Xvai Jun 10, 2026
c89c194
Update Matrix_vector_multiplication_.mlu
spirit13579 Jun 10, 2026
14554c7
Update Cos.mlu
spirit13579 Jun 10, 2026
5ba0ed6
modify config
segzix Jun 10, 2026
0accb63
Add entry point for adaptive max pool 2D kernel
spirit13579 Jun 10, 2026
fb55a1d
Update Argmax_over_a_dimension.mlu
spirit13579 Jun 10, 2026
0755f78
Update config
spirit13579 Jun 10, 2026
c1b57cc
Update 51
I-Xvai Jun 10, 2026
644e722
Merge branch 'main' into yangjunbo-operator
spirit13579 Jun 10, 2026
6c84ab3
Merge pull request #30 from Xdddyy/liuxingyu-operator
I-Xvai Jun 10, 2026
82506bf
Merge pull request #29 from spirit13579/yangjunbo-operator
spirit13579 Jun 10, 2026
e8dc36a
modify config
segzix Jun 10, 2026
d358f8f
Merge branch 'shengzixuan-operator'
segzix Jun 10, 2026
cb57a63
Update 51
I-Xvai Jun 10, 2026
7ce0ceb
Update config
I-Xvai Jun 10, 2026
f70b220
Merge pull request #31 from Xdddyy/liuxingyu-operator
I-Xvai Jun 10, 2026
4f9e7fd
modify config
segzix Jun 10, 2026
fd6fc97
modify config
segzix Jun 10, 2026
626eddc
modify config
segzix Jun 10, 2026
194a863
modify config
segzix Jun 10, 2026
1472c29
modify config
segzix Jun 10, 2026
10942ad
modify config
segzix Jun 10, 2026
99f27f2
modify config
segzix Jun 10, 2026
f726c4e
Update Matrix_vector_multiplication_.mlu
spirit13579 Jun 10, 2026
0158492
modify config
segzix Jun 10, 2026
dd64289
Update Argmax_over_a_dimension.mlu
spirit13579 Jun 10, 2026
7838205
Merge pull request #32 from spirit13579/yangjunbo-operator
spirit13579 Jun 10, 2026
4db5afe
Update config
spirit13579 Jun 10, 2026
00db313
modify config
segzix Jun 10, 2026
4f5a4eb
Merge branch 'main' into yangjunbo-operator
spirit13579 Jun 10, 2026
e8e7f93
Merge pull request #33 from spirit13579/yangjunbo-operator
spirit13579 Jun 10, 2026
dddffdb
modify config
segzix Jun 10, 2026
89ecc23
Update Matrix_vector_multiplication_.mlu
spirit13579 Jun 10, 2026
8dcd5bb
Update Argmax_over_a_dimension.mlu
spirit13579 Jun 10, 2026
4882dd0
Update config
spirit13579 Jun 10, 2026
8ba7fb4
Merge branch 'main' into yangjunbo-operator
spirit13579 Jun 10, 2026
202bc6d
Merge pull request #34 from spirit13579/yangjunbo-operator
spirit13579 Jun 10, 2026
5586c17
Update Matrix_vector_multiplication_.mlu
spirit13579 Jun 10, 2026
7687e7a
Update Argmax_over_a_dimension.mlu
spirit13579 Jun 10, 2026
c04c13e
Update config
spirit13579 Jun 10, 2026
cdab0ac
Merge pull request #35 from spirit13579/yangjunbo-operator
spirit13579 Jun 10, 2026
3d25f32
Update Argmax_over_a_dimension.mlu
spirit13579 Jun 10, 2026
afb0277
Update Matrix_vector_multiplication_.mlu
spirit13579 Jun 10, 2026
76457fe
Add new configuration value '071' to config
spirit13579 Jun 10, 2026
1b2153c
Merge pull request #36 from spirit13579/yangjunbo-operator
spirit13579 Jun 10, 2026
81ce151
Update Matrix_vector_multiplication_.mlu
spirit13579 Jun 10, 2026
9ceb2c9
Update Argmax_over_a_dimension.mlu
spirit13579 Jun 10, 2026
b19f026
Update config
spirit13579 Jun 10, 2026
b092649
Merge pull request #37 from spirit13579/yangjunbo-operator
spirit13579 Jun 10, 2026
29b0790
Update Matrix_vector_multiplication_.mlu
spirit13579 Jun 10, 2026
568a1f0
Update config
spirit13579 Jun 10, 2026
3b5faf4
Merge pull request #38 from spirit13579/yangjunbo-operator
spirit13579 Jun 10, 2026
6a8f4be
Update Matrix_vector_multiplication_.mlu
spirit13579 Jun 10, 2026
6a0eb24
Update config
spirit13579 Jun 10, 2026
7b769c6
Update Cos.mlu
spirit13579 Jun 10, 2026
135694f
Merge pull request #39 from spirit13579/yangjunbo-operator
spirit13579 Jun 10, 2026
39af864
Update Matrix_vector_multiplication_.mlu
spirit13579 Jun 10, 2026
5e5d0f2
Update config
spirit13579 Jun 10, 2026
96a6391
Merge pull request #40 from spirit13579/yangjunbo-operator
spirit13579 Jun 10, 2026
94d5994
Update Matrix_vector_multiplication_.mlu
spirit13579 Jun 10, 2026
f9f82f5
Update Cos.mlu
spirit13579 Jun 10, 2026
562b47c
Update config
spirit13579 Jun 10, 2026
1fbc2ed
Merge pull request #41 from spirit13579/yangjunbo-operator
spirit13579 Jun 10, 2026
77282e2
Update Adaptive_Max_Pool_2D.mlu
spirit13579 Jun 10, 2026
71c0a5e
Update config
spirit13579 Jun 10, 2026
2f02fba
Merge pull request #42 from spirit13579/yangjunbo-operator
spirit13579 Jun 10, 2026
89f0a2b
Update Matrix_vector_multiplication_.mlu
spirit13579 Jun 10, 2026
d58aec4
Update config
spirit13579 Jun 10, 2026
13f8aa2
Merge pull request #43 from spirit13579/yangjunbo-operator
spirit13579 Jun 10, 2026
2c152bd
Update Matrix_vector_multiplication_.mlu
spirit13579 Jun 11, 2026
4fd4e7d
Update config
spirit13579 Jun 11, 2026
d3f1a3b
Merge pull request #44 from spirit13579/yangjunbo-operator
spirit13579 Jun 11, 2026
2ff8a98
Update Matrix_vector_multiplication_.mlu
spirit13579 Jun 11, 2026
cdca2bd
Remove duplicate entry from config file
spirit13579 Jun 11, 2026
337e575
Merge pull request #45 from spirit13579/yangjunbo-operator
spirit13579 Jun 11, 2026
c7624a8
Update Matrix_vector_multiplication_.mlu
spirit13579 Jun 11, 2026
1fa6acf
Update config
spirit13579 Jun 11, 2026
9ff0fd7
Update Cos.mlu
spirit13579 Jun 11, 2026
9c94198
Update Cos.mlu
spirit13579 Jun 11, 2026
afa88ef
Merge pull request #46 from spirit13579/yangjunbo-operator
spirit13579 Jun 11, 2026
489c15f
Update Matrix_vector_multiplication_.mlu
spirit13579 Jun 11, 2026
920fda2
Update config
spirit13579 Jun 11, 2026
6298925
Merge pull request #47 from spirit13579/yangjunbo-operator
spirit13579 Jun 11, 2026
05e0a96
Improve precision for dilated conv2d
Jun 11, 2026
1cc0b03
Update Matrix_vector_multiplication_.mlu
spirit13579 Jun 11, 2026
d0ea27f
Update Cos.mlu
spirit13579 Jun 11, 2026
7ef6e0b
Add new configuration value '071'
spirit13579 Jun 11, 2026
70e3ac9
Merge branch 'main' into yangjunbo-operator
spirit13579 Jun 11, 2026
8c6038c
Merge pull request #48 from spirit13579/yangjunbo-operator
spirit13579 Jun 11, 2026
31d9669
Update Matrix_vector_multiplication_.mlu
spirit13579 Jun 11, 2026
f89e53b
Update config
spirit13579 Jun 11, 2026
08010e3
Merge pull request #49 from spirit13579/yangjunbo-operator
spirit13579 Jun 11, 2026
bddf6f4
Update Adaptive_Max_Pool_2D.mlu
spirit13579 Jun 11, 2026
be90c6d
Update Matrix_vector_multiplication_.mlu
spirit13579 Jun 11, 2026
689ae2b
Add new configuration value '100' to config file
spirit13579 Jun 11, 2026
bac9410
Merge pull request #50 from spirit13579/yangjunbo-operator
spirit13579 Jun 11, 2026
69a738a
Update Matrix_vector_multiplication_.mlu
spirit13579 Jun 11, 2026
b13fead
Update config
spirit13579 Jun 11, 2026
4371603
Merge pull request #51 from spirit13579/yangjunbo-operator
spirit13579 Jun 11, 2026
0be043d
Optimize dilated conv2d tiling
Jun 11, 2026
d358432
Update Matrix_vector_multiplication_.mlu
spirit13579 Jun 11, 2026
93376bb
Retrigger dilated conv2d evaluation
Jun 11, 2026
964b45d
Update Adaptive_Max_Pool_2D.mlu
spirit13579 Jun 11, 2026
e9a2604
Update config
spirit13579 Jun 11, 2026
f8a3e50
Merge branch 'main' into yangjunbo-operator
spirit13579 Jun 11, 2026
34a8faa
Update config
spirit13579 Jun 11, 2026
0609484
Update Matrix_vector_multiplication_.mlu
spirit13579 Jun 11, 2026
4022366
Update config
spirit13579 Jun 11, 2026
3a415b8
Merge pull request #54 from spirit13579/yangjunbo-operator
spirit13579 Jun 11, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# MLU 编译产物
*.o
*.so
*.wrapper.cpp

# Python
__pycache__/
*.pyc

.vscode/
AGENTS.md

48 changes: 48 additions & 0 deletions 111_Masked_select.mlu
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#include <bang.h>
#include <torch/extension.h>
#include <framework/core/MLUStream.h>
#include <cnrt.h>

__mlu_entry__ void masked_select_kernel(
const half *input,
half *output,
int total,
float threshold)
{
int write_index = 0;
for (int i = 0; i < total; ++i) {
half value = input[i];
if ((float)value > threshold) {
output[write_index] = value;
++write_index;
}
}
}

torch::Tensor bang_func(torch::Tensor input, double threshold)
{
TORCH_CHECK(input.is_contiguous(), "Input tensor must be contiguous");
TORCH_CHECK(input.dim() == 2, "Input tensor must have shape [M, N]");
TORCH_CHECK(input.scalar_type() == torch::kHalf, "111_Masked_select expects float16 input");

auto mask = input > threshold;
int64_t output_size = mask.sum().item<int64_t>();
auto output = torch::empty({output_size}, input.options());

if (output_size == 0) {
return output;
}

int total = input.numel();
cnrtQueue_t queue = torch_mlu::getCurMLUStream();
cnrtDim3_t dim = {1, 1, 1};
cnrtFunctionType_t ktype = cnrtFuncTypeBlock;

masked_select_kernel<<<dim, ktype, queue>>>(
reinterpret_cast<const half *>(input.data_ptr<at::Half>()),
reinterpret_cast<half *>(output.data_ptr<at::Half>()),
total,
static_cast<float>(threshold));

return output;
}
71 changes: 71 additions & 0 deletions 138_GRU_forward.mlu
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#include <bang.h>
#include <torch/extension.h>
#include <framework/core/MLUStream.h>
#include <cnrt.h>

__mlu_entry__ void threshold_select_kernel(
const float *input,
float *output,
int total,
float threshold)
{
int write_index = 0;
for (int i = 0; i < total; ++i) {
float value = input[i];
if (value > threshold) {
output[write_index] = value;
++write_index;
}
}
}

__mlu_entry__ void threshold_select_half_kernel(
const half *input,
half *output,
int total,
float threshold)
{
int write_index = 0;
for (int i = 0; i < total; ++i) {
half value = input[i];
if ((float)value > threshold) {
output[write_index] = value;
++write_index;
}
}
}

torch::Tensor bang_func(torch::Tensor input, double threshold)
{
TORCH_CHECK(input.is_contiguous(), "Input tensor must be contiguous");
TORCH_CHECK(input.dim() == 2, "Input tensor must have shape [M, N]");

auto mask = input > threshold;
int64_t output_size = mask.sum().item<int64_t>();
auto output = torch::empty({output_size}, input.options());

if (output_size == 0) {
return output;
}

int total = input.numel();
cnrtQueue_t queue = torch_mlu::getCurMLUStream();
cnrtDim3_t dim = {1, 1, 1};
cnrtFunctionType_t ktype = cnrtFuncTypeBlock;

if (input.scalar_type() == torch::kHalf) {
threshold_select_half_kernel<<<dim, ktype, queue>>>(
reinterpret_cast<const half *>(input.data_ptr<at::Half>()),
reinterpret_cast<half *>(output.data_ptr<at::Half>()),
total,
static_cast<float>(threshold));
} else {
threshold_select_kernel<<<dim, ktype, queue>>>(
input.data_ptr<float>(),
output.data_ptr<float>(),
total,
static_cast<float>(threshold));
}

return output;
}
115 changes: 115 additions & 0 deletions Adaptive_Max_Pool_2D.mlu
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#include <bang.h>
#include <torch/extension.h>
#include <cnrt.h>

#define CHUNK_SIZE 4096

__mlu_entry__ void adaptive_max_pool2d_kernel(
float *input,
float *output,
int batch,
int channels,
int H,
int W,
int out_h,
int out_w,
int total) {

uint32_t core_id = taskId;
uint32_t core_num = taskDim;
uint32_t per_core = total / core_num;
uint32_t remainder = total % core_num;
uint32_t start = core_id * per_core + (core_id < remainder ? core_id : remainder);
uint32_t count = per_core + (core_id < remainder ? 1 : 0);

__nram__ float nram_buf[CHUNK_SIZE];

for (uint32_t idx = start; idx < start + count; ++idx) {
int tmp = idx;
int ow = tmp % out_w;
tmp /= out_w;
int oh = tmp % out_h;
tmp /= out_h;
int c = tmp % channels;
int b = tmp / channels;

int h_start = (oh * H) / out_h;
int h_end = ((oh + 1) * H + out_h - 1) / out_h;
int w_start = (ow * W) / out_w;
int w_end = ((ow + 1) * W + out_w - 1) / out_w;
if (h_end > H) h_end = H;
if (w_end > W) w_end = W;

float max_val;
bool first = true;

for (int h = h_start; h < h_end; ++h) {
int row_offset = ((b * channels + c) * H + h) * W + w_start;
int row_len = w_end - w_start;

for (int w_offset = 0; w_offset < row_len; w_offset += CHUNK_SIZE) {
int chunk_len = (w_offset + CHUNK_SIZE <= row_len)
? CHUNK_SIZE : (row_len - w_offset);

__memcpy(nram_buf,
input + row_offset + w_offset,
chunk_len * sizeof(float),
GDRAM2NRAM);

for (int i = 0; i < chunk_len; ++i) {
if (first) {
max_val = nram_buf[i];
first = false;
} else if (nram_buf[i] > max_val) {
max_val = nram_buf[i];
}
}
}
}
output[idx] = max_val;
}
}

// 入口函数,必须命名为 bang_func,接受 vector<int64_t> 表示输出尺寸
torch::Tensor bang_func(
torch::Tensor input,
std::vector<int64_t> output_size) {

TORCH_CHECK(input.is_contiguous(), "Input must be contiguous");
TORCH_CHECK(output_size.size() == 2, "output_size must have 2 elements");

int64_t out_h = output_size[0];
int64_t out_w = output_size[1];

auto original_dtype = input.scalar_type();
torch::Tensor input_fp32 = input;
if (original_dtype != torch::kFloat) {
input_fp32 = input.to(torch::kFloat);
}

int batch = input_fp32.size(0);
int channels = input_fp32.size(1);
int H = input_fp32.size(2);
int W = input_fp32.size(3);

auto output_fp32 = torch::empty({batch, channels, out_h, out_w},
input_fp32.options());
int total = batch * channels * static_cast<int>(out_h) * static_cast<int>(out_w);

cnrtQueue_t queue = nullptr;
cnrtDim3_t dim = {4, 1, 1};
cnrtFunctionType_t ktype = cnrtFuncTypeUnion1;

adaptive_max_pool2d_kernel<<<dim, ktype, queue>>>(
input_fp32.data_ptr<float>(),
output_fp32.data_ptr<float>(),
batch, channels, H, W,
static_cast<int>(out_h),
static_cast<int>(out_w),
total);

if (original_dtype != torch::kFloat) {
return output_fp32.to(original_dtype);
}
return output_fp32;
}
88 changes: 88 additions & 0 deletions Argmax_over_a_dimension.mlu
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#include <bang.h>
#include <torch/extension.h>
#include <cnrt.h>
#include <limits>

#define BLOCK_SIZE 1024 // 与原示例保持一致,此处未实际使用

/**
* @brief 内核:沿指定维度计算 half 张量的最大值索引
* @param input 输入张量指针(half)
* @param output 输出索引张量指针(int64_t)
* @param pre dim 之前各维度元素总数
* @param dim_size dim 维度长度
* @param post dim 之后各维度元素总数
* @param total_output 输出张量元素总数
*/
__mlu_entry__ void argmax_kernel(half *input, int64_t *output,
int pre, int dim_size, int post,
int total_output) {
uint32_t task_id = taskId;
uint32_t task_num = taskDim;

// 每个任务处理多个输出元素(轮询分配)
for (int idx = task_id; idx < total_output; idx += task_num) {
int pre_idx = idx / post; // 当前输出在 pre 维度的序号
int post_idx = idx % post; // 当前输出在 post 维度的序号
int base = (pre_idx * dim_size * post) + post_idx; // 输入中对应向量的起始偏移

float max_val = -std::numeric_limits<float>::infinity();
int max_idx = 0;

// 遍历 dim 维度上的所有元素
for (int k = 0; k < dim_size; ++k) {
float cur = __half2float(input[base + k * post]);
if (cur > max_val) {
max_val = cur;
max_idx = k;
}
}
output[idx] = max_idx; // 写入输出索引
}
}

/**
* @brief PyTorch 接口函数(与测试框架要求的符号名和参数类型严格一致)
* @param x 输入张量,类型 torch::kFloat16,连续内存布局
* @param dim 要规约的维度,类型 int(注意不是 int64_t)
* @return 输出索引张量,类型 torch::kInt64,形状为移除 dim 后的形状
*/
torch::Tensor bang_func(torch::Tensor x, int dim) {
TORCH_CHECK(x.is_contiguous(), "Input must be contiguous");
TORCH_CHECK(x.scalar_type() == torch::kFloat16, "Input must be float16");
int64_t ndim = x.dim();
TORCH_CHECK(dim >= 0 && dim < ndim, "Dimension out of range");

// 计算 pre, dim_size, post
int64_t pre = 1, dim_size = x.size(dim), post = 1;
for (int64_t i = 0; i < dim; ++i) pre *= x.size(i);
for (int64_t i = dim + 1; i < ndim; ++i) post *= x.size(i);
int64_t total_output = pre * post;

// 构造输出形状
std::vector<int64_t> out_shape;
for (int64_t i = 0; i < ndim; ++i) {
if (i != dim) out_shape.push_back(x.size(i));
}
auto out_opts = torch::TensorOptions().dtype(torch::kInt64).device(x.device());
torch::Tensor output = torch::empty(out_shape, out_opts);

// 获取 MLU 流队列
cnrtQueue_t queue = torch_mlu::getCurMLUStream();

// 设置并行任务数(与原示例类似,使用固定 cluster 数量,每个 cluster 包含多个 task)
// 注意:硬件任务数有限,因此这里使用固定数量(例如 16 或 64),每个任务循环处理多个输出
// 但为了简单且兼容原示例风格,也可以使用 total_output 个任务(仅适用于小规模测试)
// 更健壮的做法是使用固定 cluster 数量 + 轮询。这里选择固定 cluster 数量为 16(与原示例一致)
uint32_t cluster_num = 16; // 与原示例的 dim.x 相同
cnrtDim3_t dim3 = {cluster_num, 1, 1};
cnrtFunctionType_t ktype = cnrtFuncTypeUnion1;

argmax_kernel<<<dim3, ktype, queue>>>(
reinterpret_cast<half*>(x.data_ptr<at::Half>()),
output.data_ptr<int64_t>(),
(int)pre, (int)dim_size, (int)post, (int)total_output
);

return output;
}
Loading