跑不起来,但我还是这么整了一下(
This commit is contained in:
parent
5673a7eb63
commit
7909a649e6
55
src/main.rs
55
src/main.rs
@ -25,7 +25,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let device_id = *get_all_devices(CL_DEVICE_TYPE_GPU)?
|
||||
.first()
|
||||
.expect("no device found in platform");
|
||||
let size = {
|
||||
let max_size = {
|
||||
match get_device_info(device_id, CL_DEVICE_MAX_WORK_GROUP_SIZE_AMD) {
|
||||
Ok(size) => size.to_size(),
|
||||
Err(err) => {
|
||||
@ -48,8 +48,8 @@ fn main() -> anyhow::Result<()> {
|
||||
};
|
||||
let device = Device::new(device_id);
|
||||
|
||||
let worker_count: cl_int = 100 as cl_int;
|
||||
println!("设备最大队列长度: {} real count: {}", size, worker_count);
|
||||
let worker_count: cl_int = max_size as cl_int;
|
||||
println!("设备最大队列长度: {} real count: {}", max_size, worker_count);
|
||||
|
||||
// Create a Context on an OpenCL device
|
||||
let context = Context::from_device(&device).expect("Context::from_device failed");
|
||||
@ -98,26 +98,20 @@ fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
};
|
||||
|
||||
let team_raw_vec = vec!["x"; worker_count as usize];
|
||||
let team_raw = "x";
|
||||
let team_bytes = team_raw.as_bytes();
|
||||
let name_raw_vec = vec!["x"; worker_count as usize];
|
||||
let team_bytes_vec = team_raw_vec
|
||||
.iter()
|
||||
.map(|s| s.as_bytes())
|
||||
.collect::<Vec<&[u8]>>();
|
||||
let name_bytes_vec = name_raw_vec
|
||||
.iter()
|
||||
.map(|s| s.as_bytes())
|
||||
.collect::<Vec<&[u8]>>();
|
||||
let t_len_vec = team_bytes_vec
|
||||
.iter()
|
||||
.map(|s| s.len() as cl_int + 1)
|
||||
.collect::<Vec<i32>>();
|
||||
let t_len = team_bytes.len() as cl_int;
|
||||
let n_len_vec = name_bytes_vec
|
||||
.iter()
|
||||
.map(|s| s.len() as cl_int)
|
||||
.collect::<Vec<i32>>();
|
||||
|
||||
let work_count = team_bytes_vec.len();
|
||||
let work_count = name_raw_vec.len();
|
||||
|
||||
println!("开始准备buffer");
|
||||
|
||||
@ -126,7 +120,7 @@ fn main() -> anyhow::Result<()> {
|
||||
Buffer::<cl_uchar>::create(
|
||||
&context,
|
||||
CL_MEM_READ_ONLY,
|
||||
BLOCK_SIZE * work_count,
|
||||
BLOCK_SIZE,
|
||||
ptr::null_mut(),
|
||||
)?
|
||||
};
|
||||
@ -138,24 +132,15 @@ fn main() -> anyhow::Result<()> {
|
||||
ptr::null_mut(),
|
||||
)?
|
||||
};
|
||||
let mut t_len = unsafe {
|
||||
Buffer::<cl_int>::create(&context, CL_MEM_READ_ONLY, work_count, ptr::null_mut())?
|
||||
};
|
||||
// let mut t_len = unsafe {
|
||||
// Buffer::<cl_int>::create(&context, CL_MEM_READ_ONLY, work_count, ptr::null_mut())?
|
||||
// };
|
||||
let mut n_len = unsafe {
|
||||
Buffer::<cl_int>::create(&context, CL_MEM_READ_ONLY, work_count, ptr::null_mut())?
|
||||
};
|
||||
let mut output = SvmVec::<cl_uchar>::allocate(&context, BLOCK_SIZE * work_count)?;
|
||||
// 准备一下数据, 都给拼成一维数组
|
||||
// 填充成 256 * len
|
||||
let team_data_vec = {
|
||||
let mut vec = Vec::new();
|
||||
for data in team_bytes_vec {
|
||||
let left_over = BLOCK_SIZE - data.len();
|
||||
vec.extend_from_slice(data);
|
||||
vec.extend_from_slice(&vec![0; left_over]);
|
||||
}
|
||||
vec
|
||||
};
|
||||
let name_data_vec = {
|
||||
let mut vec = Vec::new();
|
||||
for data in name_bytes_vec {
|
||||
@ -169,18 +154,18 @@ fn main() -> anyhow::Result<()> {
|
||||
println!("开始写入buffer");
|
||||
|
||||
// 阻塞写
|
||||
// let _team_write_event =
|
||||
// unsafe { queue.enqueue_write_buffer(&mut team, CL_BLOCKING, 0, &team_data_vec, &[]) }?;
|
||||
// _team_write_event.wait()?;
|
||||
// let _name_write_event =
|
||||
// unsafe { queue.enqueue_write_buffer(&mut name, CL_BLOCKING, 0, &name_data_vec, &[]) }?;
|
||||
// _name_write_event.wait()?;
|
||||
let _team_write_event =
|
||||
unsafe { queue.enqueue_write_buffer(&mut team, CL_BLOCKING, 0, team_bytes, &[]) }?;
|
||||
_team_write_event.wait()?;
|
||||
let _name_write_event =
|
||||
unsafe { queue.enqueue_write_buffer(&mut name, CL_BLOCKING, 0, &name_data_vec, &[]) }?;
|
||||
_name_write_event.wait()?;
|
||||
// let _t_len_write_event =
|
||||
// unsafe { queue.enqueue_write_buffer(&mut t_len, CL_BLOCKING, 0, &t_len_vec, &[]) }?;
|
||||
// _t_len_write_event.wait()?;
|
||||
// let _n_len_write_event =
|
||||
// unsafe { queue.enqueue_write_buffer(&mut n_len, CL_BLOCKING, 0, &n_len_vec, &[]) }?;
|
||||
// _n_len_write_event.wait()?;
|
||||
let _n_len_write_event =
|
||||
unsafe { queue.enqueue_write_buffer(&mut n_len, CL_BLOCKING, 0, &n_len_vec, &[]) }?;
|
||||
_n_len_write_event.wait()?;
|
||||
|
||||
println!("开始执行kernel");
|
||||
|
||||
|
@ -7,13 +7,13 @@ uchar median(uchar a, uchar b, uchar c) {
|
||||
// 输入: 1~255 长度的 u8 数组
|
||||
// 输出: 255 长度的 u8 数组
|
||||
kernel void load_team(
|
||||
global const uchar* all_team_bytes,
|
||||
global const int* all_t_len,
|
||||
global const uchar* g_team_bytes,
|
||||
const int t_len,
|
||||
global const uchar* all_name_bytes,
|
||||
global const int* all_n_len,
|
||||
// 一个 svm 的 [u8; 256] * worker_count
|
||||
global uchar* all_val,
|
||||
int worker_count
|
||||
const int worker_count
|
||||
) {
|
||||
int gid = get_global_id(0);
|
||||
if (gid >= worker_count) {
|
||||
@ -27,10 +27,9 @@ kernel void load_team(
|
||||
val[i] = i;
|
||||
}
|
||||
for (int i = 0; i < 256; i += 4) {
|
||||
vstore4(vload4(0, &all_team_bytes[256 * gid + i]), i, team_bytes);
|
||||
vstore4(vload4(0, &g_team_bytes[i]), i, team_bytes);
|
||||
vstore4(vload4(0, &all_name_bytes[256 * gid + i]), i, name_bytes);
|
||||
}
|
||||
int t_len = all_t_len[gid];
|
||||
int n_len = all_n_len[gid];
|
||||
|
||||
// 外面初始化好了
|
||||
|
Loading…
Reference in New Issue
Block a user