plonk_wasm/
rayon.rs

1/*
2 * Copyright 2022 Google Inc. All Rights Reserved.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *     http://www.apache.org/licenses/LICENSE-2.0
7 * Unless required by applicable law or agreed to in writing, software
8 * distributed under the License is distributed on an "AS IS" BASIS,
9 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 * See the License for the specific language governing permissions and
11 * limitations under the License.
12 */
13
14use js_sys::Promise;
15use spmc::{channel, Receiver, Sender};
16use wasm_bindgen::prelude::*;
17
18#[cfg(feature = "nodejs")]
19use js_sys::JsString;
20
21static mut THREAD_POOL: Option<rayon::ThreadPool> = None;
22
23pub fn run_in_pool<OP, R>(op: OP) -> R
24where
25    OP: FnOnce() -> R + Send,
26    R: Send,
27{
28    let pool = unsafe { THREAD_POOL.as_ref().unwrap() };
29    pool.install(op)
30}
31
32#[wasm_bindgen]
33#[doc(hidden)]
34pub struct PoolBuilder {
35    num_threads: usize,
36    sender: Sender<rayon::ThreadBuilder>,
37    receiver: Receiver<rayon::ThreadBuilder>,
38}
39
40#[cfg(not(feature = "nodejs"))]
41#[wasm_bindgen]
42extern "C" {
43    #[wasm_bindgen(js_name = startWorkers)]
44    fn start_workers(module: JsValue, memory: JsValue, builder: PoolBuilder) -> Promise;
45}
46
47#[cfg(feature = "nodejs")]
48#[wasm_bindgen]
49extern "C" {
50    #[wasm_bindgen(js_name = startWorkers)]
51    fn start_workers(module: JsString, memory: JsValue, builder: PoolBuilder) -> Promise;
52}
53
54#[wasm_bindgen]
55extern "C" {
56    #[wasm_bindgen(js_name = terminateWorkers)]
57    fn terminate_workers() -> Promise;
58}
59
60#[wasm_bindgen]
61impl PoolBuilder {
62    fn new(num_threads: usize) -> Self {
63        let (sender, receiver) = channel();
64        Self {
65            num_threads,
66            sender,
67            receiver,
68        }
69    }
70
71    #[wasm_bindgen(js_name = numThreads)]
72    pub fn num_threads(&self) -> usize {
73        self.num_threads
74    }
75
76    pub fn receiver(&self) -> *const Receiver<rayon::ThreadBuilder> {
77        &self.receiver
78    }
79
80    // This should be called by the JS side once all the Workers are spawned.
81    // Important: it must take `self` by reference, otherwise
82    // `start_worker_thread` will try to receive a message on a moved value.
83    pub fn build(&mut self) {
84        unsafe {
85            THREAD_POOL = Some(
86                rayon::ThreadPoolBuilder::new()
87                    .num_threads(self.num_threads)
88                    // We could use postMessage here instead of Rust channels,
89                    // but currently we can't due to a Chrome bug that will cause
90                    // the main thread to lock up before it even sends the message:
91                    // https://bugs.chromium.org/p/chromium/issues/detail?id=1075645
92                    .spawn_handler(move |thread| {
93                        // Note: `send` will return an error if there are no receivers.
94                        // We can use it because all the threads are spawned and ready to accept
95                        // messages by the time we call `build()` to instantiate spawn handler.
96                        self.sender.send(thread).unwrap_throw();
97                        Ok(())
98                    })
99                    .build()
100                    .unwrap_throw(),
101            )
102        }
103    }
104}
105
106#[cfg(feature = "nodejs")]
107#[wasm_bindgen(js_name = initThreadPool)]
108#[doc(hidden)]
109pub fn init_thread_pool(num_threads: usize, worker_source: JsString) -> Promise {
110    start_workers(
111        worker_source,
112        wasm_bindgen::memory(),
113        PoolBuilder::new(num_threads),
114    )
115}
116
117#[cfg(not(feature = "nodejs"))]
118#[wasm_bindgen(js_name = initThreadPool)]
119#[doc(hidden)]
120pub fn init_thread_pool(num_threads: usize) -> Promise {
121    start_workers(
122        wasm_bindgen::module(),
123        wasm_bindgen::memory(),
124        PoolBuilder::new(num_threads),
125    )
126}
127
128#[wasm_bindgen(js_name = exitThreadPool)]
129#[doc(hidden)]
130pub fn exit_thread_pool() -> Promise {
131    unsafe {
132        let promise = terminate_workers();
133        THREAD_POOL = None;
134        promise
135    }
136}
137
138#[wasm_bindgen]
139#[allow(clippy::not_unsafe_ptr_arg_deref)]
140#[doc(hidden)]
141pub fn wbg_rayon_start_worker(receiver: *const Receiver<rayon::ThreadBuilder>)
142where
143    // Statically assert that it's safe to accept `Receiver` from another thread.
144    Receiver<rayon::ThreadBuilder>: Sync,
145{
146    // This is safe, because we know it came from a reference to PoolBuilder,
147    // allocated on the heap by wasm-bindgen and dropped only once all the
148    // threads are running.
149    //
150    // The only way to violate safety is if someone externally calls
151    // `exports.wbg_rayon_start_worker(garbageValue)`, but then no Rust tools
152    // would prevent us from issues anyway.
153    let receiver = unsafe { &*receiver };
154    // Wait for a task (`ThreadBuilder`) on the channel, and, once received,
155    // start executing it.
156    //
157    // On practice this will start running Rayon's internal event loop.
158    receiver.recv().unwrap_throw().run();
159}