1use js_sys::Promise;
15use spmc::{channel, Receiver, Sender};
16use wasm_bindgen::prelude::*;
17
18#[cfg(feature = "nodejs")]
19use js_sys::JsString;
20
21static mut THREAD_POOL: Option<rayon::ThreadPool> = None;
22
23pub fn run_in_pool<OP, R>(op: OP) -> R
24where
25 OP: FnOnce() -> R + Send,
26 R: Send,
27{
28 let pool = unsafe { THREAD_POOL.as_ref().unwrap() };
29 pool.install(op)
30}
31
32#[wasm_bindgen]
33#[doc(hidden)]
34pub struct PoolBuilder {
35 num_threads: usize,
36 sender: Sender<rayon::ThreadBuilder>,
37 receiver: Receiver<rayon::ThreadBuilder>,
38}
39
40#[cfg(not(feature = "nodejs"))]
41#[wasm_bindgen]
42extern "C" {
43 #[wasm_bindgen(js_name = startWorkers)]
44 fn start_workers(module: JsValue, memory: JsValue, builder: PoolBuilder) -> Promise;
45}
46
47#[cfg(feature = "nodejs")]
48#[wasm_bindgen]
49extern "C" {
50 #[wasm_bindgen(js_name = startWorkers)]
51 fn start_workers(module: JsString, memory: JsValue, builder: PoolBuilder) -> Promise;
52}
53
54#[wasm_bindgen]
55extern "C" {
56 #[wasm_bindgen(js_name = terminateWorkers)]
57 fn terminate_workers() -> Promise;
58}
59
60#[wasm_bindgen]
61impl PoolBuilder {
62 fn new(num_threads: usize) -> Self {
63 let (sender, receiver) = channel();
64 Self {
65 num_threads,
66 sender,
67 receiver,
68 }
69 }
70
71 #[wasm_bindgen(js_name = numThreads)]
72 pub fn num_threads(&self) -> usize {
73 self.num_threads
74 }
75
76 pub fn receiver(&self) -> *const Receiver<rayon::ThreadBuilder> {
77 &self.receiver
78 }
79
80 pub fn build(&mut self) {
84 unsafe {
85 THREAD_POOL = Some(
86 rayon::ThreadPoolBuilder::new()
87 .num_threads(self.num_threads)
88 .spawn_handler(move |thread| {
93 self.sender.send(thread).unwrap_throw();
97 Ok(())
98 })
99 .build()
100 .unwrap_throw(),
101 )
102 }
103 }
104}
105
106#[cfg(feature = "nodejs")]
107#[wasm_bindgen(js_name = initThreadPool)]
108#[doc(hidden)]
109pub fn init_thread_pool(num_threads: usize, worker_source: JsString) -> Promise {
110 start_workers(
111 worker_source,
112 wasm_bindgen::memory(),
113 PoolBuilder::new(num_threads),
114 )
115}
116
117#[cfg(not(feature = "nodejs"))]
118#[wasm_bindgen(js_name = initThreadPool)]
119#[doc(hidden)]
120pub fn init_thread_pool(num_threads: usize) -> Promise {
121 start_workers(
122 wasm_bindgen::module(),
123 wasm_bindgen::memory(),
124 PoolBuilder::new(num_threads),
125 )
126}
127
128#[wasm_bindgen(js_name = exitThreadPool)]
129#[doc(hidden)]
130pub fn exit_thread_pool() -> Promise {
131 unsafe {
132 let promise = terminate_workers();
133 THREAD_POOL = None;
134 promise
135 }
136}
137
138#[wasm_bindgen]
139#[allow(clippy::not_unsafe_ptr_arg_deref)]
140#[doc(hidden)]
141pub fn wbg_rayon_start_worker(receiver: *const Receiver<rayon::ThreadBuilder>)
142where
143 Receiver<rayon::ThreadBuilder>: Sync,
145{
146 let receiver = unsafe { &*receiver };
154 receiver.recv().unwrap_throw().run();
159}