diff --git a/src/main.rs b/src/main.rs index 9e2765a..e08a7d3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -25,7 +25,7 @@ fn main() -> anyhow::Result<()> { let device_id = *get_all_devices(CL_DEVICE_TYPE_GPU)? .first() .expect("no device found in platform"); - let size = { + let max_size = { match get_device_info(device_id, CL_DEVICE_MAX_WORK_GROUP_SIZE_AMD) { Ok(size) => size.to_size(), Err(err) => { @@ -48,8 +48,8 @@ fn main() -> anyhow::Result<()> { }; let device = Device::new(device_id); - let worker_count: cl_int = 100 as cl_int; - println!("设备最大队列长度: {} real count: {}", size, worker_count); + let worker_count: cl_int = max_size as cl_int; + println!("设备最大队列长度: {} real count: {}", max_size, worker_count); // Create a Context on an OpenCL device let context = Context::from_device(&device).expect("Context::from_device failed"); @@ -98,26 +98,20 @@ fn main() -> anyhow::Result<()> { } }; - let team_raw_vec = vec!["x"; worker_count as usize]; + let team_raw = "x"; + let team_bytes = team_raw.as_bytes(); let name_raw_vec = vec!["x"; worker_count as usize]; - let team_bytes_vec = team_raw_vec - .iter() - .map(|s| s.as_bytes()) - .collect::>(); let name_bytes_vec = name_raw_vec .iter() .map(|s| s.as_bytes()) .collect::>(); - let t_len_vec = team_bytes_vec - .iter() - .map(|s| s.len() as cl_int + 1) - .collect::>(); + let t_len = team_bytes.len() as cl_int; let n_len_vec = name_bytes_vec .iter() .map(|s| s.len() as cl_int) .collect::>(); - let work_count = team_bytes_vec.len(); + let work_count = name_raw_vec.len(); println!("开始准备buffer"); @@ -126,7 +120,7 @@ fn main() -> anyhow::Result<()> { Buffer::::create( &context, CL_MEM_READ_ONLY, - BLOCK_SIZE * work_count, + BLOCK_SIZE, ptr::null_mut(), )? }; @@ -138,24 +132,15 @@ fn main() -> anyhow::Result<()> { ptr::null_mut(), )? }; - let mut t_len = unsafe { - Buffer::::create(&context, CL_MEM_READ_ONLY, work_count, ptr::null_mut())? - }; + // let mut t_len = unsafe { + // Buffer::::create(&context, CL_MEM_READ_ONLY, work_count, ptr::null_mut())? + // }; let mut n_len = unsafe { Buffer::::create(&context, CL_MEM_READ_ONLY, work_count, ptr::null_mut())? }; let mut output = SvmVec::::allocate(&context, BLOCK_SIZE * work_count)?; // 准备一下数据, 都给拼成一维数组 // 填充成 256 * len - let team_data_vec = { - let mut vec = Vec::new(); - for data in team_bytes_vec { - let left_over = BLOCK_SIZE - data.len(); - vec.extend_from_slice(data); - vec.extend_from_slice(&vec![0; left_over]); - } - vec - }; let name_data_vec = { let mut vec = Vec::new(); for data in name_bytes_vec { @@ -169,18 +154,18 @@ fn main() -> anyhow::Result<()> { println!("开始写入buffer"); // 阻塞写 - // let _team_write_event = - // unsafe { queue.enqueue_write_buffer(&mut team, CL_BLOCKING, 0, &team_data_vec, &[]) }?; - // _team_write_event.wait()?; - // let _name_write_event = - // unsafe { queue.enqueue_write_buffer(&mut name, CL_BLOCKING, 0, &name_data_vec, &[]) }?; - // _name_write_event.wait()?; + let _team_write_event = + unsafe { queue.enqueue_write_buffer(&mut team, CL_BLOCKING, 0, team_bytes, &[]) }?; + _team_write_event.wait()?; + let _name_write_event = + unsafe { queue.enqueue_write_buffer(&mut name, CL_BLOCKING, 0, &name_data_vec, &[]) }?; + _name_write_event.wait()?; // let _t_len_write_event = // unsafe { queue.enqueue_write_buffer(&mut t_len, CL_BLOCKING, 0, &t_len_vec, &[]) }?; // _t_len_write_event.wait()?; - // let _n_len_write_event = - // unsafe { queue.enqueue_write_buffer(&mut n_len, CL_BLOCKING, 0, &n_len_vec, &[]) }?; - // _n_len_write_event.wait()?; + let _n_len_write_event = + unsafe { queue.enqueue_write_buffer(&mut n_len, CL_BLOCKING, 0, &n_len_vec, &[]) }?; + _n_len_write_event.wait()?; println!("开始执行kernel"); diff --git a/src/program.cl b/src/program.cl index 48fe614..ae3c3f4 100644 --- a/src/program.cl +++ b/src/program.cl @@ -7,13 +7,13 @@ uchar median(uchar a, uchar b, uchar c) { // 输入: 1~255 长度的 u8 数组 // 输出: 255 长度的 u8 数组 kernel void load_team( - global const uchar* all_team_bytes, - global const int* all_t_len, + global const uchar* g_team_bytes, + const int t_len, global const uchar* all_name_bytes, global const int* all_n_len, // 一个 svm 的 [u8; 256] * worker_count global uchar* all_val, - int worker_count + const int worker_count ) { int gid = get_global_id(0); if (gid >= worker_count) { @@ -27,10 +27,9 @@ kernel void load_team( val[i] = i; } for (int i = 0; i < 256; i += 4) { - vstore4(vload4(0, &all_team_bytes[256 * gid + i]), i, team_bytes); + vstore4(vload4(0, &g_team_bytes[i]), i, team_bytes); vstore4(vload4(0, &all_name_bytes[256 * gid + i]), i, name_bytes); } - int t_len = all_t_len[gid]; int n_len = all_n_len[gid]; // 外面初始化好了