p2p/service_impl/webrtc/
webrtc_rs.rs

1use std::{future::Future, sync::Arc};
2
3use webrtc::{
4    api::APIBuilder,
5    data_channel::{data_channel_init::RTCDataChannelInit, RTCDataChannel},
6    ice_transport::{
7        ice_credential_type::RTCIceCredentialType, ice_gatherer_state::RTCIceGathererState,
8        ice_gathering_state::RTCIceGatheringState, ice_server::RTCIceServer,
9    },
10    peer_connection::{
11        configuration::RTCConfiguration, peer_connection_state::RTCPeerConnectionState,
12        policy::ice_transport_policy::RTCIceTransportPolicy,
13        sdp::session_description::RTCSessionDescription, RTCPeerConnection,
14    },
15};
16
17use crate::{
18    connection::P2pConnectionResponse,
19    webrtc::{Answer, Offer},
20};
21
22use super::{OnConnectionStateChangeHdlrFn, RTCChannelConfig, RTCConfig};
23
24pub type Result<T> = std::result::Result<T, webrtc::Error>;
25
26pub type RTCConnectionState = RTCPeerConnectionState;
27
28pub type Api = Arc<webrtc::api::API>;
29
30pub type RTCCertificate = webrtc::peer_connection::certificate::RTCCertificate;
31
32pub fn certificate_from_pem_key(pem_str: &str) -> RTCCertificate {
33    let keypair = rcgen::KeyPair::from_pem(pem_str).expect("valid pem");
34    RTCCertificate::from_key_pair(keypair).expect("keypair is compatible")
35}
36
37pub fn build_api() -> Api {
38    APIBuilder::new().build().into()
39}
40
41pub struct RTCConnection(Arc<RTCPeerConnection>, bool);
42
43#[derive(Clone)]
44pub struct RTCChannel(Arc<RTCDataChannel>);
45
46#[derive(thiserror::Error, derive_more::From, Debug)]
47pub enum RTCSignalingError {
48    #[error("serialization failed: {0}")]
49    Serialize(serde_json::Error),
50    #[error("http request failed: {0}")]
51    Http(reqwest::Error),
52}
53
54impl RTCConnection {
55    pub async fn create(api: &Api, config: RTCConfig) -> Result<Self> {
56        let mut configuration = RTCConfiguration::from(config);
57        // try default certificate, TODO(vlad): do it right
58        configuration.certificates.clear();
59        api.new_peer_connection(configuration)
60            .await
61            .map(|v| Self(v.into(), true))
62    }
63
64    pub fn is_main(&self) -> bool {
65        self.1
66    }
67
68    pub async fn channel_create(&self, config: RTCChannelConfig) -> Result<RTCChannel> {
69        self.0
70            .create_data_channel(
71                config.label,
72                Some(RTCDataChannelInit {
73                    ordered: Some(true),
74                    max_packet_life_time: None,
75                    max_retransmits: None,
76                    negotiated: config.negotiated,
77                    ..Default::default()
78                }),
79            )
80            .await
81            .map(RTCChannel)
82    }
83
84    pub async fn offer_create(&self) -> Result<RTCSessionDescription> {
85        self.0.create_offer(None).await
86    }
87
88    pub async fn answer_create(&self) -> Result<RTCSessionDescription> {
89        self.0.create_answer(None).await
90    }
91
92    pub async fn local_desc_set(&self, desc: RTCSessionDescription) -> Result<()> {
93        self.0.set_local_description(desc).await
94    }
95
96    pub async fn remote_desc_set(&self, desc: RTCSessionDescription) -> Result<()> {
97        self.0.set_remote_description(desc).await
98    }
99
100    pub async fn local_sdp(&self) -> Option<String> {
101        self.0.local_description().await.map(|v| v.sdp)
102    }
103
104    pub fn connection_state(&self) -> RTCConnectionState {
105        self.0.connection_state()
106    }
107
108    pub async fn wait_for_ice_gathering_complete(&self) {
109        if !matches!(self.0.ice_gathering_state(), RTCIceGatheringState::Complete) {
110            let (tx, rx) = tokio::sync::oneshot::channel::<()>();
111            let mut tx = Some(tx);
112            self.0.on_ice_gathering_state_change(Box::new(move |state| {
113                if matches!(state, RTCIceGathererState::Complete) {
114                    if let Some(tx) = tx.take() {
115                        let _ = tx.send(());
116                    }
117                }
118                Box::pin(std::future::ready(()))
119            }));
120            let _ = rx.await;
121        }
122    }
123
124    pub fn on_connection_state_change(&self, handler: OnConnectionStateChangeHdlrFn) {
125        self.0.on_peer_connection_state_change(handler)
126    }
127
128    pub async fn close(self) {
129        if let Err(error) = self.0.close().await {
130            openmina_core::warn!(
131                openmina_core::log::system_time();
132                summary = "CONNECTION LEAK: Failure when closing RTCConnection",
133                error = error.to_string(),
134            )
135        }
136    }
137}
138
139impl RTCChannel {
140    pub fn on_open<Fut>(&self, f: impl FnOnce() -> Fut + Send + Sync + 'static)
141    where
142        Fut: Future<Output = ()> + Send + 'static,
143    {
144        self.0.on_open(Box::new(move || Box::pin(f())))
145    }
146
147    pub fn on_message<Fut>(&self, mut f: impl FnMut(&[u8]) -> Fut + Send + Sync + 'static)
148    where
149        Fut: Future<Output = ()> + Send + 'static,
150    {
151        self.0
152            .on_message(Box::new(move |msg| Box::pin(f(msg.data.as_ref()))));
153    }
154
155    pub fn on_error<Fut>(&self, mut f: impl FnMut(webrtc::Error) -> Fut + Send + Sync + 'static)
156    where
157        Fut: Future<Output = ()> + Send + 'static,
158    {
159        self.0.on_error(Box::new(move |err| Box::pin(f(err))))
160    }
161
162    pub fn on_close<Fut>(&self, mut f: impl FnMut() -> Fut + Send + Sync + 'static)
163    where
164        Fut: Future<Output = ()> + Send + 'static,
165    {
166        self.0.on_close(Box::new(move || Box::pin(f())))
167    }
168
169    pub async fn send(&self, data: &bytes::Bytes) -> Result<usize> {
170        self.0.send(data).await
171    }
172
173    pub async fn close(&self) {
174        let _ = self.0.close().await;
175    }
176}
177
178pub async fn webrtc_signal_send(
179    url: &str,
180    offer: Offer,
181) -> std::result::Result<P2pConnectionResponse, RTCSignalingError> {
182    let client = reqwest::Client::new();
183    let res = client
184        .post(url)
185        .body(serde_json::to_string(&offer)?)
186        .send()
187        .await?
188        .json()
189        .await?;
190    Ok(res)
191}
192
193impl Clone for RTCConnection {
194    fn clone(&self) -> Self {
195        Self(self.0.clone(), false)
196    }
197}
198
199impl From<RTCConfig> for RTCConfiguration {
200    fn from(value: RTCConfig) -> Self {
201        RTCConfiguration {
202            ice_servers: value.ice_servers.0.into_iter().map(Into::into).collect(),
203            ice_transport_policy: RTCIceTransportPolicy::All,
204            certificates: vec![value.certificate],
205            seed: Some(value.seed.to_vec()),
206            ..Default::default()
207        }
208    }
209}
210
211impl From<super::RTCConfigIceServer> for RTCIceServer {
212    fn from(value: super::RTCConfigIceServer) -> Self {
213        let credential_type = match value.credential.is_some() {
214            false => RTCIceCredentialType::Unspecified,
215            true => RTCIceCredentialType::Password,
216        };
217        RTCIceServer {
218            urls: value.urls,
219            username: value.username.unwrap_or_default(),
220            credential: value.credential.unwrap_or_default(),
221            credential_type,
222        }
223    }
224}
225
226impl TryFrom<Offer> for RTCSessionDescription {
227    type Error = webrtc::Error;
228
229    fn try_from(value: Offer) -> Result<Self> {
230        Self::offer(value.sdp)
231    }
232}
233
234impl TryFrom<Answer> for RTCSessionDescription {
235    type Error = webrtc::Error;
236
237    fn try_from(value: Answer) -> Result<Self> {
238        Self::answer(value.sdp)
239    }
240}