// This file is part of Substrate. // Copyright (C) Parity Technologies (UK) Ltd. // SPDX-License-Identifier: GPL-3.0-or-later WITH Classpath-exception-2.0 // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // You should have received a copy of the GNU General Public License // along with this program. If not, see . use std::{ sync::{ atomic::{AtomicBool, Ordering}, Arc, }, time::Duration, }; /// An error signifying that a task has been cancelled due to a timeout. #[derive(Debug)] pub struct Timeout; impl std::error::Error for Timeout {} impl std::fmt::Display for Timeout { fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { fmt.write_str("task has been running too long") } } /// A handle which can be used to check whether the task has been cancelled due to a timeout. #[repr(transparent)] pub struct IsTimedOut(Arc); impl IsTimedOut { #[must_use] pub fn check_if_timed_out(&self) -> std::result::Result<(), Timeout> { if self.0.load(Ordering::Relaxed) { Err(Timeout) } else { Ok(()) } } } /// An error for a task which either panicked, or has been cancelled due to a timeout. #[derive(Debug)] pub enum SpawnWithTimeoutError { JoinError(tokio::task::JoinError), Timeout, } impl std::error::Error for SpawnWithTimeoutError {} impl std::fmt::Display for SpawnWithTimeoutError { fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { match self { SpawnWithTimeoutError::JoinError(error) => error.fmt(fmt), SpawnWithTimeoutError::Timeout => Timeout.fmt(fmt), } } } struct CancelOnDrop(Arc); impl Drop for CancelOnDrop { fn drop(&mut self) { self.0.store(true, Ordering::Relaxed) } } /// Spawns a new blocking task with a given `timeout`. /// /// The `callback` should continuously call [`IsTimedOut::check_if_timed_out`], /// which will return an error once the task runs for longer than `timeout`. /// /// If `timeout` is `None` then this works just as a regular `spawn_blocking`. pub async fn spawn_blocking_with_timeout( timeout: Option, callback: impl FnOnce(IsTimedOut) -> std::result::Result + Send + 'static, ) -> Result where R: Send + 'static, { let is_timed_out_arc = Arc::new(AtomicBool::new(false)); let is_timed_out = IsTimedOut(is_timed_out_arc.clone()); let _cancel_on_drop = CancelOnDrop(is_timed_out_arc); let task = tokio::task::spawn_blocking(move || callback(is_timed_out)); let result = if let Some(timeout) = timeout { tokio::select! { // Shouldn't really matter, but make sure the task is polled before the timeout, // in case the task finishes after the timeout and the timeout is really short. biased; task_result = task => task_result, _ = tokio::time::sleep(timeout) => Ok(Err(Timeout)) } } else { task.await }; match result { Ok(Ok(result)) => Ok(result), Ok(Err(Timeout)) => Err(SpawnWithTimeoutError::Timeout), Err(error) => Err(SpawnWithTimeoutError::JoinError(error)), } } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn spawn_blocking_with_timeout_works() { let task: Result<(), SpawnWithTimeoutError> = spawn_blocking_with_timeout(Some(Duration::from_millis(100)), |is_timed_out| { std::thread::sleep(Duration::from_millis(200)); is_timed_out.check_if_timed_out()?; unreachable!(); }) .await; assert_matches::assert_matches!(task, Err(SpawnWithTimeoutError::Timeout)); let task = spawn_blocking_with_timeout(Some(Duration::from_millis(100)), |is_timed_out| { std::thread::sleep(Duration::from_millis(20)); is_timed_out.check_if_timed_out()?; Ok(()) }) .await; assert_matches::assert_matches!(task, Ok(())); } }