Compare commits

...

2 Commits

2 changed files with 21 additions and 43 deletions

View File

@ -25,7 +25,7 @@ fn main() -> anyhow::Result<()> {
let device_id = *get_all_devices(CL_DEVICE_TYPE_GPU)?
.first()
.expect("no device found in platform");
let size = {
let max_size = {
match get_device_info(device_id, CL_DEVICE_MAX_WORK_GROUP_SIZE_AMD) {
Ok(size) => size.to_size(),
Err(err) => {
@ -48,8 +48,8 @@ fn main() -> anyhow::Result<()> {
};
let device = Device::new(device_id);
let worker_count: cl_int = 100 as cl_int;
println!("设备最大队列长度: {} real count: {}", size, worker_count);
let worker_count: cl_int = max_size as cl_int;
println!("设备最大队列长度: {} real count: {}", max_size, worker_count);
// Create a Context on an OpenCL device
let context = Context::from_device(&device).expect("Context::from_device failed");
@ -98,26 +98,20 @@ fn main() -> anyhow::Result<()> {
}
};
let team_raw_vec = vec!["x"; worker_count as usize];
let team_raw = "x";
let team_bytes = team_raw.as_bytes();
let name_raw_vec = vec!["x"; worker_count as usize];
let team_bytes_vec = team_raw_vec
.iter()
.map(|s| s.as_bytes())
.collect::<Vec<&[u8]>>();
let name_bytes_vec = name_raw_vec
.iter()
.map(|s| s.as_bytes())
.collect::<Vec<&[u8]>>();
let t_len_vec = team_bytes_vec
.iter()
.map(|s| s.len() as cl_int + 1)
.collect::<Vec<i32>>();
let t_len = team_bytes.len() as cl_int;
let n_len_vec = name_bytes_vec
.iter()
.map(|s| s.len() as cl_int)
.collect::<Vec<i32>>();
let work_count = team_bytes_vec.len();
let work_count = name_raw_vec.len();
println!("开始准备buffer");
@ -126,7 +120,7 @@ fn main() -> anyhow::Result<()> {
Buffer::<cl_uchar>::create(
&context,
CL_MEM_READ_ONLY,
BLOCK_SIZE * work_count,
BLOCK_SIZE,
ptr::null_mut(),
)?
};
@ -138,24 +132,12 @@ fn main() -> anyhow::Result<()> {
ptr::null_mut(),
)?
};
let mut t_len = unsafe {
Buffer::<cl_int>::create(&context, CL_MEM_READ_ONLY, work_count, ptr::null_mut())?
};
let mut n_len = unsafe {
Buffer::<cl_int>::create(&context, CL_MEM_READ_ONLY, work_count, ptr::null_mut())?
};
let mut output = SvmVec::<cl_uchar>::allocate(&context, BLOCK_SIZE * work_count)?;
// 准备一下数据, 都给拼成一维数组
// 填充成 256 * len
let team_data_vec = {
let mut vec = Vec::new();
for data in team_bytes_vec {
let left_over = BLOCK_SIZE - data.len();
vec.extend_from_slice(data);
vec.extend_from_slice(&vec![0; left_over]);
}
vec
};
let name_data_vec = {
let mut vec = Vec::new();
for data in name_bytes_vec {
@ -169,18 +151,15 @@ fn main() -> anyhow::Result<()> {
println!("开始写入buffer");
// 阻塞写
// let _team_write_event =
// unsafe { queue.enqueue_write_buffer(&mut team, CL_BLOCKING, 0, &team_data_vec, &[]) }?;
// _team_write_event.wait()?;
// let _name_write_event =
// unsafe { queue.enqueue_write_buffer(&mut name, CL_BLOCKING, 0, &name_data_vec, &[]) }?;
// _name_write_event.wait()?;
// let _t_len_write_event =
// unsafe { queue.enqueue_write_buffer(&mut t_len, CL_BLOCKING, 0, &t_len_vec, &[]) }?;
// _t_len_write_event.wait()?;
// let _n_len_write_event =
// unsafe { queue.enqueue_write_buffer(&mut n_len, CL_BLOCKING, 0, &n_len_vec, &[]) }?;
// _n_len_write_event.wait()?;
let _team_write_event =
unsafe { queue.enqueue_write_buffer(&mut team, CL_BLOCKING, 0, team_bytes, &[]) }?;
_team_write_event.wait()?;
let _name_write_event =
unsafe { queue.enqueue_write_buffer(&mut name, CL_BLOCKING, 0, &name_data_vec, &[]) }?;
_name_write_event.wait()?;
let _n_len_write_event =
unsafe { queue.enqueue_write_buffer(&mut n_len, CL_BLOCKING, 0, &n_len_vec, &[]) }?;
_n_len_write_event.wait()?;
println!("开始执行kernel");

View File

@ -7,13 +7,13 @@ uchar median(uchar a, uchar b, uchar c) {
// 输入: 1~255 长度的 u8 数组
// 输出: 255 长度的 u8 数组
kernel void load_team(
global const uchar* all_team_bytes,
global const int* all_t_len,
global const uchar* g_team_bytes,
const int t_len,
global const uchar* all_name_bytes,
global const int* all_n_len,
// 一个 svm [u8; 256] * worker_count
global uchar* all_val,
int worker_count
const int worker_count
) {
int gid = get_global_id(0);
if (gid >= worker_count) {
@ -27,10 +27,9 @@ kernel void load_team(
val[i] = i;
}
for (int i = 0; i < 256; i += 4) {
vstore4(vload4(0, &all_team_bytes[256 * gid + i]), i, team_bytes);
vstore4(vload4(0, &g_team_bytes[i]), i, team_bytes);
vstore4(vload4(0, &all_name_bytes[256 * gid + i]), i, name_bytes);
}
int t_len = all_t_len[gid];
int n_len = all_n_len[gid];
// 外面初始化好了