跑不起来,但我还是这么整了一下(
This commit is contained in:
parent
5673a7eb63
commit
7909a649e6
55
src/main.rs
55
src/main.rs
@ -25,7 +25,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
let device_id = *get_all_devices(CL_DEVICE_TYPE_GPU)?
|
let device_id = *get_all_devices(CL_DEVICE_TYPE_GPU)?
|
||||||
.first()
|
.first()
|
||||||
.expect("no device found in platform");
|
.expect("no device found in platform");
|
||||||
let size = {
|
let max_size = {
|
||||||
match get_device_info(device_id, CL_DEVICE_MAX_WORK_GROUP_SIZE_AMD) {
|
match get_device_info(device_id, CL_DEVICE_MAX_WORK_GROUP_SIZE_AMD) {
|
||||||
Ok(size) => size.to_size(),
|
Ok(size) => size.to_size(),
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
@ -48,8 +48,8 @@ fn main() -> anyhow::Result<()> {
|
|||||||
};
|
};
|
||||||
let device = Device::new(device_id);
|
let device = Device::new(device_id);
|
||||||
|
|
||||||
let worker_count: cl_int = 100 as cl_int;
|
let worker_count: cl_int = max_size as cl_int;
|
||||||
println!("设备最大队列长度: {} real count: {}", size, worker_count);
|
println!("设备最大队列长度: {} real count: {}", max_size, worker_count);
|
||||||
|
|
||||||
// Create a Context on an OpenCL device
|
// Create a Context on an OpenCL device
|
||||||
let context = Context::from_device(&device).expect("Context::from_device failed");
|
let context = Context::from_device(&device).expect("Context::from_device failed");
|
||||||
@ -98,26 +98,20 @@ fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let team_raw_vec = vec!["x"; worker_count as usize];
|
let team_raw = "x";
|
||||||
|
let team_bytes = team_raw.as_bytes();
|
||||||
let name_raw_vec = vec!["x"; worker_count as usize];
|
let name_raw_vec = vec!["x"; worker_count as usize];
|
||||||
let team_bytes_vec = team_raw_vec
|
|
||||||
.iter()
|
|
||||||
.map(|s| s.as_bytes())
|
|
||||||
.collect::<Vec<&[u8]>>();
|
|
||||||
let name_bytes_vec = name_raw_vec
|
let name_bytes_vec = name_raw_vec
|
||||||
.iter()
|
.iter()
|
||||||
.map(|s| s.as_bytes())
|
.map(|s| s.as_bytes())
|
||||||
.collect::<Vec<&[u8]>>();
|
.collect::<Vec<&[u8]>>();
|
||||||
let t_len_vec = team_bytes_vec
|
let t_len = team_bytes.len() as cl_int;
|
||||||
.iter()
|
|
||||||
.map(|s| s.len() as cl_int + 1)
|
|
||||||
.collect::<Vec<i32>>();
|
|
||||||
let n_len_vec = name_bytes_vec
|
let n_len_vec = name_bytes_vec
|
||||||
.iter()
|
.iter()
|
||||||
.map(|s| s.len() as cl_int)
|
.map(|s| s.len() as cl_int)
|
||||||
.collect::<Vec<i32>>();
|
.collect::<Vec<i32>>();
|
||||||
|
|
||||||
let work_count = team_bytes_vec.len();
|
let work_count = name_raw_vec.len();
|
||||||
|
|
||||||
println!("开始准备buffer");
|
println!("开始准备buffer");
|
||||||
|
|
||||||
@ -126,7 +120,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
Buffer::<cl_uchar>::create(
|
Buffer::<cl_uchar>::create(
|
||||||
&context,
|
&context,
|
||||||
CL_MEM_READ_ONLY,
|
CL_MEM_READ_ONLY,
|
||||||
BLOCK_SIZE * work_count,
|
BLOCK_SIZE,
|
||||||
ptr::null_mut(),
|
ptr::null_mut(),
|
||||||
)?
|
)?
|
||||||
};
|
};
|
||||||
@ -138,24 +132,15 @@ fn main() -> anyhow::Result<()> {
|
|||||||
ptr::null_mut(),
|
ptr::null_mut(),
|
||||||
)?
|
)?
|
||||||
};
|
};
|
||||||
let mut t_len = unsafe {
|
// let mut t_len = unsafe {
|
||||||
Buffer::<cl_int>::create(&context, CL_MEM_READ_ONLY, work_count, ptr::null_mut())?
|
// Buffer::<cl_int>::create(&context, CL_MEM_READ_ONLY, work_count, ptr::null_mut())?
|
||||||
};
|
// };
|
||||||
let mut n_len = unsafe {
|
let mut n_len = unsafe {
|
||||||
Buffer::<cl_int>::create(&context, CL_MEM_READ_ONLY, work_count, ptr::null_mut())?
|
Buffer::<cl_int>::create(&context, CL_MEM_READ_ONLY, work_count, ptr::null_mut())?
|
||||||
};
|
};
|
||||||
let mut output = SvmVec::<cl_uchar>::allocate(&context, BLOCK_SIZE * work_count)?;
|
let mut output = SvmVec::<cl_uchar>::allocate(&context, BLOCK_SIZE * work_count)?;
|
||||||
// 准备一下数据, 都给拼成一维数组
|
// 准备一下数据, 都给拼成一维数组
|
||||||
// 填充成 256 * len
|
// 填充成 256 * len
|
||||||
let team_data_vec = {
|
|
||||||
let mut vec = Vec::new();
|
|
||||||
for data in team_bytes_vec {
|
|
||||||
let left_over = BLOCK_SIZE - data.len();
|
|
||||||
vec.extend_from_slice(data);
|
|
||||||
vec.extend_from_slice(&vec![0; left_over]);
|
|
||||||
}
|
|
||||||
vec
|
|
||||||
};
|
|
||||||
let name_data_vec = {
|
let name_data_vec = {
|
||||||
let mut vec = Vec::new();
|
let mut vec = Vec::new();
|
||||||
for data in name_bytes_vec {
|
for data in name_bytes_vec {
|
||||||
@ -169,18 +154,18 @@ fn main() -> anyhow::Result<()> {
|
|||||||
println!("开始写入buffer");
|
println!("开始写入buffer");
|
||||||
|
|
||||||
// 阻塞写
|
// 阻塞写
|
||||||
// let _team_write_event =
|
let _team_write_event =
|
||||||
// unsafe { queue.enqueue_write_buffer(&mut team, CL_BLOCKING, 0, &team_data_vec, &[]) }?;
|
unsafe { queue.enqueue_write_buffer(&mut team, CL_BLOCKING, 0, team_bytes, &[]) }?;
|
||||||
// _team_write_event.wait()?;
|
_team_write_event.wait()?;
|
||||||
// let _name_write_event =
|
let _name_write_event =
|
||||||
// unsafe { queue.enqueue_write_buffer(&mut name, CL_BLOCKING, 0, &name_data_vec, &[]) }?;
|
unsafe { queue.enqueue_write_buffer(&mut name, CL_BLOCKING, 0, &name_data_vec, &[]) }?;
|
||||||
// _name_write_event.wait()?;
|
_name_write_event.wait()?;
|
||||||
// let _t_len_write_event =
|
// let _t_len_write_event =
|
||||||
// unsafe { queue.enqueue_write_buffer(&mut t_len, CL_BLOCKING, 0, &t_len_vec, &[]) }?;
|
// unsafe { queue.enqueue_write_buffer(&mut t_len, CL_BLOCKING, 0, &t_len_vec, &[]) }?;
|
||||||
// _t_len_write_event.wait()?;
|
// _t_len_write_event.wait()?;
|
||||||
// let _n_len_write_event =
|
let _n_len_write_event =
|
||||||
// unsafe { queue.enqueue_write_buffer(&mut n_len, CL_BLOCKING, 0, &n_len_vec, &[]) }?;
|
unsafe { queue.enqueue_write_buffer(&mut n_len, CL_BLOCKING, 0, &n_len_vec, &[]) }?;
|
||||||
// _n_len_write_event.wait()?;
|
_n_len_write_event.wait()?;
|
||||||
|
|
||||||
println!("开始执行kernel");
|
println!("开始执行kernel");
|
||||||
|
|
||||||
|
@ -7,13 +7,13 @@ uchar median(uchar a, uchar b, uchar c) {
|
|||||||
// 输入: 1~255 长度的 u8 数组
|
// 输入: 1~255 长度的 u8 数组
|
||||||
// 输出: 255 长度的 u8 数组
|
// 输出: 255 长度的 u8 数组
|
||||||
kernel void load_team(
|
kernel void load_team(
|
||||||
global const uchar* all_team_bytes,
|
global const uchar* g_team_bytes,
|
||||||
global const int* all_t_len,
|
const int t_len,
|
||||||
global const uchar* all_name_bytes,
|
global const uchar* all_name_bytes,
|
||||||
global const int* all_n_len,
|
global const int* all_n_len,
|
||||||
// 一个 svm 的 [u8; 256] * worker_count
|
// 一个 svm 的 [u8; 256] * worker_count
|
||||||
global uchar* all_val,
|
global uchar* all_val,
|
||||||
int worker_count
|
const int worker_count
|
||||||
) {
|
) {
|
||||||
int gid = get_global_id(0);
|
int gid = get_global_id(0);
|
||||||
if (gid >= worker_count) {
|
if (gid >= worker_count) {
|
||||||
@ -27,10 +27,9 @@ kernel void load_team(
|
|||||||
val[i] = i;
|
val[i] = i;
|
||||||
}
|
}
|
||||||
for (int i = 0; i < 256; i += 4) {
|
for (int i = 0; i < 256; i += 4) {
|
||||||
vstore4(vload4(0, &all_team_bytes[256 * gid + i]), i, team_bytes);
|
vstore4(vload4(0, &g_team_bytes[i]), i, team_bytes);
|
||||||
vstore4(vload4(0, &all_name_bytes[256 * gid + i]), i, name_bytes);
|
vstore4(vload4(0, &all_name_bytes[256 * gid + i]), i, name_bytes);
|
||||||
}
|
}
|
||||||
int t_len = all_t_len[gid];
|
|
||||||
int n_len = all_n_len[gid];
|
int n_len = all_n_len[gid];
|
||||||
|
|
||||||
// 外面初始化好了
|
// 外面初始化好了
|
||||||
|
Loading…
Reference in New Issue
Block a user