From 3e73b7557e06189faf4a7df5f7cb2429eabd0d20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20K=C3=B6cher?= Date: Thu, 16 Mar 2023 12:24:20 +0100 Subject: [PATCH] Try to fix flaky `temp-base-path-work` test (#13505) * Try to fix flaky `temp-base-path-work` test The test is most of the time failing when checking if the database path was deleted. The assumption is that it takes a little bit more time by the OS to actually clean up the temp path under high load. The pr tries to fix this by checking multiple times if the path was deleted. Besides that it also ensures that the tests that require the benchmark feature don't fail when compiled without the feature. * ".git/.scripts/commands/fmt/fmt.sh" * Capture signals earlier * Rewrite tests to let them having one big timeout * Remove unneeded dep * Update bin/node/cli/tests/common.rs Co-authored-by: Koute * Review feedback * Update bin/node/cli/tests/common.rs Co-authored-by: Anton --------- Co-authored-by: command-bot <> Co-authored-by: Koute Co-authored-by: Anton --- .../node/cli/tests/benchmark_pallet_works.rs | 2 + .../node/cli/tests/benchmark_storage_works.rs | 2 + substrate/bin/node/cli/tests/common.rs | 136 ++++++++---------- .../tests/running_the_node_and_interrupt.rs | 127 ++++++++-------- substrate/bin/node/cli/tests/telemetry.rs | 111 +++++++------- .../node/cli/tests/temp_base_path_works.rs | 53 ++++--- substrate/client/cli/src/lib.rs | 14 +- substrate/client/cli/src/runner.rs | 95 +++++++----- 8 files changed, 274 insertions(+), 266 deletions(-) diff --git a/substrate/bin/node/cli/tests/benchmark_pallet_works.rs b/substrate/bin/node/cli/tests/benchmark_pallet_works.rs index 1a278f985d..053516b952 100644 --- a/substrate/bin/node/cli/tests/benchmark_pallet_works.rs +++ b/substrate/bin/node/cli/tests/benchmark_pallet_works.rs @@ -16,6 +16,8 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . +#![cfg(feature = "runtime-benchmarks")] + use assert_cmd::cargo::cargo_bin; use std::process::Command; diff --git a/substrate/bin/node/cli/tests/benchmark_storage_works.rs b/substrate/bin/node/cli/tests/benchmark_storage_works.rs index 1f1181218d..953c07ca7f 100644 --- a/substrate/bin/node/cli/tests/benchmark_storage_works.rs +++ b/substrate/bin/node/cli/tests/benchmark_storage_works.rs @@ -16,6 +16,8 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . +#![cfg(feature = "runtime-benchmarks")] + use assert_cmd::cargo::cargo_bin; use std::{ path::Path, diff --git a/substrate/bin/node/cli/tests/common.rs b/substrate/bin/node/cli/tests/common.rs index a0de7300fc..4c4824391f 100644 --- a/substrate/bin/node/cli/tests/common.rs +++ b/substrate/bin/node/cli/tests/common.rs @@ -20,54 +20,26 @@ use assert_cmd::cargo::cargo_bin; use nix::{ - sys::signal::{kill, Signal::SIGINT}, + sys::signal::{kill, Signal, Signal::SIGINT}, unistd::Pid, }; use node_primitives::{Hash, Header}; +use regex::Regex; use std::{ io::{BufRead, BufReader, Read}, ops::{Deref, DerefMut}, - path::Path, - process::{self, Child, Command, ExitStatus}, + path::{Path, PathBuf}, + process::{self, Child, Command}, time::Duration, }; -use tokio::time::timeout; -/// Wait for the given `child` the given number of `secs`. -/// -/// Returns the `Some(exit status)` or `None` if the process did not finish in the given time. -pub fn wait_for(child: &mut Child, secs: u64) -> Result { - let result = wait_timeout::ChildExt::wait_timeout(child, Duration::from_secs(5.min(secs))) - .map_err(|_| ())?; - if let Some(exit_status) = result { - Ok(exit_status) - } else { - if secs > 5 { - eprintln!("Child process taking over 5 seconds to exit gracefully"); - let result = wait_timeout::ChildExt::wait_timeout(child, Duration::from_secs(secs - 5)) - .map_err(|_| ())?; - if let Some(exit_status) = result { - return Ok(exit_status) - } - } - eprintln!("Took too long to exit (> {} seconds). Killing...", secs); - let _ = child.kill(); - child.wait().unwrap(); - Err(()) - } -} - -/// Wait for at least n blocks to be finalized within a specified time. -pub async fn wait_n_finalized_blocks( - n: usize, - timeout_secs: u64, - url: &str, -) -> Result<(), tokio::time::error::Elapsed> { - timeout(Duration::from_secs(timeout_secs), wait_n_finalized_blocks_from(n, url)).await +/// Run the given `future` and panic if the `timeout` is hit. +pub async fn run_with_timeout(timeout: Duration, future: impl futures::Future) { + tokio::time::timeout(timeout, future).await.expect("Hit timeout"); } /// Wait for at least n blocks to be finalized from a specified node -pub async fn wait_n_finalized_blocks_from(n: usize, url: &str) { +pub async fn wait_n_finalized_blocks(n: usize, url: &str) { use substrate_rpc_client::{ws_client, ChainApi}; let mut built_blocks = std::collections::HashSet::new(); @@ -87,47 +59,55 @@ pub async fn wait_n_finalized_blocks_from(n: usize, url: &str) { /// Run the node for a while (3 blocks) pub async fn run_node_for_a_while(base_path: &Path, args: &[&str]) { - let mut cmd = Command::new(cargo_bin("substrate")) - .stdout(process::Stdio::piped()) - .stderr(process::Stdio::piped()) - .args(args) - .arg("-d") - .arg(base_path) - .spawn() - .unwrap(); + run_with_timeout(Duration::from_secs(60 * 10), async move { + let mut cmd = Command::new(cargo_bin("substrate")) + .stdout(process::Stdio::piped()) + .stderr(process::Stdio::piped()) + .args(args) + .arg("-d") + .arg(base_path) + .spawn() + .unwrap(); - let stderr = cmd.stderr.take().unwrap(); + let stderr = cmd.stderr.take().unwrap(); - let mut child = KillChildOnDrop(cmd); + let mut child = KillChildOnDrop(cmd); - let (ws_url, _) = find_ws_url_from_output(stderr); + let ws_url = extract_info_from_output(stderr).0.ws_url; - // Let it produce some blocks. - let _ = wait_n_finalized_blocks(3, 30, &ws_url).await; + // Let it produce some blocks. + wait_n_finalized_blocks(3, &ws_url).await; - assert!(child.try_wait().unwrap().is_none(), "the process should still be running"); + child.assert_still_running(); - // Stop the process - kill(Pid::from_raw(child.id().try_into().unwrap()), SIGINT).unwrap(); - assert!(wait_for(&mut child, 40).map(|x| x.success()).unwrap()); -} - -/// Run the node asserting that it fails with an error -pub fn run_node_assert_fail(base_path: &Path, args: &[&str]) { - let mut cmd = Command::new(cargo_bin("substrate")); - - let mut child = KillChildOnDrop(cmd.args(args).arg("-d").arg(base_path).spawn().unwrap()); - - // Let it produce some blocks, but it should die within 10 seconds. - assert_ne!( - wait_timeout::ChildExt::wait_timeout(&mut *child, Duration::from_secs(10)).unwrap(), - None, - "the process should not be running anymore" - ); + // Stop the process + child.stop(); + }) + .await } pub struct KillChildOnDrop(pub Child); +impl KillChildOnDrop { + /// Stop the child and wait until it is finished. + /// + /// Asserts if the exit status isn't success. + pub fn stop(&mut self) { + self.stop_with_signal(SIGINT); + } + + /// Same as [`Self::stop`] but takes the `signal` that is sent to stop the child. + pub fn stop_with_signal(&mut self, signal: Signal) { + kill(Pid::from_raw(self.id().try_into().unwrap()), signal).unwrap(); + assert!(self.wait().unwrap().success()); + } + + /// Asserts that the child is still running. + pub fn assert_still_running(&mut self) { + assert!(self.try_wait().unwrap().is_none(), "the process should still be running"); + } +} + impl Drop for KillChildOnDrop { fn drop(&mut self) { let _ = self.0.kill(); @@ -148,18 +128,22 @@ impl DerefMut for KillChildOnDrop { } } -/// Read the WS address from the output. +/// Information extracted from a running node. +pub struct NodeInfo { + pub ws_url: String, + pub db_path: PathBuf, +} + +/// Extract [`NodeInfo`] from a running node by parsing its output. /// -/// This is hack to get the actual binded sockaddr because -/// substrate assigns a random port if the specified port was already binded. -pub fn find_ws_url_from_output(read: impl Read + Send) -> (String, String) { +/// Returns the [`NodeInfo`] and all the read data. +pub fn extract_info_from_output(read: impl Read + Send) -> (NodeInfo, String) { let mut data = String::new(); let ws_url = BufReader::new(read) .lines() .find_map(|line| { - let line = - line.expect("failed to obtain next line from stdout for WS address discovery"); + let line = line.expect("failed to obtain next line while extracting node info"); data.push_str(&line); data.push_str("\n"); @@ -176,5 +160,9 @@ pub fn find_ws_url_from_output(read: impl Read + Send) -> (String, String) { panic!("We should get a WebSocket address") }); - (ws_url, data) + // Database path is printed before the ws url! + let re = Regex::new(r"Database: .+ at (\S+)").unwrap(); + let db_path = PathBuf::from(re.captures(data.as_str()).unwrap().get(1).unwrap().as_str()); + + (NodeInfo { ws_url, db_path }, data) } diff --git a/substrate/bin/node/cli/tests/running_the_node_and_interrupt.rs b/substrate/bin/node/cli/tests/running_the_node_and_interrupt.rs index fc0bf69a09..3d5598f3fb 100644 --- a/substrate/bin/node/cli/tests/running_the_node_and_interrupt.rs +++ b/substrate/bin/node/cli/tests/running_the_node_and_interrupt.rs @@ -18,94 +18,81 @@ #![cfg(unix)] use assert_cmd::cargo::cargo_bin; -use nix::{ - sys::signal::{ - kill, - Signal::{self, SIGINT, SIGTERM}, - }, - unistd::Pid, +use nix::sys::signal::Signal::{self, SIGINT, SIGTERM}; +use std::{ + process::{self, Child, Command}, + time::Duration, }; -use std::process::{self, Child, Command}; use tempfile::tempdir; pub mod common; #[tokio::test] async fn running_the_node_works_and_can_be_interrupted() { - async fn run_command_and_kill(signal: Signal) { - let base_path = tempdir().expect("could not create a temp dir"); - let mut cmd = common::KillChildOnDrop( - Command::new(cargo_bin("substrate")) - .stdout(process::Stdio::piped()) - .stderr(process::Stdio::piped()) - .args(&["--dev", "-d"]) - .arg(base_path.path()) - .arg("--db=paritydb") - .arg("--no-hardware-benchmarks") - .spawn() - .unwrap(), - ); + common::run_with_timeout(Duration::from_secs(60 * 10), async move { + async fn run_command_and_kill(signal: Signal) { + let base_path = tempdir().expect("could not create a temp dir"); + let mut cmd = common::KillChildOnDrop( + Command::new(cargo_bin("substrate")) + .stdout(process::Stdio::piped()) + .stderr(process::Stdio::piped()) + .args(&["--dev", "-d"]) + .arg(base_path.path()) + .arg("--db=paritydb") + .arg("--no-hardware-benchmarks") + .spawn() + .unwrap(), + ); - let stderr = cmd.stderr.take().unwrap(); + let stderr = cmd.stderr.take().unwrap(); - let (ws_url, _) = common::find_ws_url_from_output(stderr); + let ws_url = common::extract_info_from_output(stderr).0.ws_url; - common::wait_n_finalized_blocks(3, 30, &ws_url) - .await - .expect("Blocks are produced in time"); - assert!(cmd.try_wait().unwrap().is_none(), "the process should still be running"); - kill(Pid::from_raw(cmd.id().try_into().unwrap()), signal).unwrap(); - assert_eq!( - common::wait_for(&mut cmd, 30).map(|x| x.success()), - Ok(true), - "the process must exit gracefully after signal {}", - signal, - ); - // Check if the database was closed gracefully. If it was not, - // there may exist a ref cycle that prevents the Client from being dropped properly. - // - // parity-db only writes the stats file on clean shutdown. - let stats_file = base_path.path().join("chains/dev/paritydb/full/stats.txt"); - assert!(std::path::Path::exists(&stats_file)); - } + common::wait_n_finalized_blocks(3, &ws_url).await; - run_command_and_kill(SIGINT).await; - run_command_and_kill(SIGTERM).await; + cmd.assert_still_running(); + + cmd.stop_with_signal(signal); + + // Check if the database was closed gracefully. If it was not, + // there may exist a ref cycle that prevents the Client from being dropped properly. + // + // parity-db only writes the stats file on clean shutdown. + let stats_file = base_path.path().join("chains/dev/paritydb/full/stats.txt"); + assert!(std::path::Path::exists(&stats_file)); + } + + run_command_and_kill(SIGINT).await; + run_command_and_kill(SIGTERM).await; + }) + .await; } #[tokio::test] async fn running_two_nodes_with_the_same_ws_port_should_work() { - fn start_node() -> Child { - Command::new(cargo_bin("substrate")) - .stdout(process::Stdio::piped()) - .stderr(process::Stdio::piped()) - .args(&["--dev", "--tmp", "--ws-port=45789", "--no-hardware-benchmarks"]) - .spawn() - .unwrap() - } + common::run_with_timeout(Duration::from_secs(60 * 10), async move { + fn start_node() -> Child { + Command::new(cargo_bin("substrate")) + .stdout(process::Stdio::piped()) + .stderr(process::Stdio::piped()) + .args(&["--dev", "--tmp", "--ws-port=45789", "--no-hardware-benchmarks"]) + .spawn() + .unwrap() + } - let mut first_node = common::KillChildOnDrop(start_node()); - let mut second_node = common::KillChildOnDrop(start_node()); + let mut first_node = common::KillChildOnDrop(start_node()); + let mut second_node = common::KillChildOnDrop(start_node()); - let stderr = first_node.stderr.take().unwrap(); - let (ws_url, _) = common::find_ws_url_from_output(stderr); + let stderr = first_node.stderr.take().unwrap(); + let ws_url = common::extract_info_from_output(stderr).0.ws_url; - common::wait_n_finalized_blocks(3, 30, &ws_url).await.unwrap(); + common::wait_n_finalized_blocks(3, &ws_url).await; - assert!(first_node.try_wait().unwrap().is_none(), "The first node should still be running"); - assert!(second_node.try_wait().unwrap().is_none(), "The second node should still be running"); + first_node.assert_still_running(); + second_node.assert_still_running(); - kill(Pid::from_raw(first_node.id().try_into().unwrap()), SIGINT).unwrap(); - kill(Pid::from_raw(second_node.id().try_into().unwrap()), SIGINT).unwrap(); - - assert_eq!( - common::wait_for(&mut first_node, 30).map(|x| x.success()), - Ok(true), - "The first node must exit gracefully", - ); - assert_eq!( - common::wait_for(&mut second_node, 30).map(|x| x.success()), - Ok(true), - "The second node must exit gracefully", - ); + first_node.stop(); + second_node.stop(); + }) + .await; } diff --git a/substrate/bin/node/cli/tests/telemetry.rs b/substrate/bin/node/cli/tests/telemetry.rs index 633cc996ca..a68746a2c0 100644 --- a/substrate/bin/node/cli/tests/telemetry.rs +++ b/substrate/bin/node/cli/tests/telemetry.rs @@ -17,78 +17,75 @@ // along with this program. If not, see . use assert_cmd::cargo::cargo_bin; -use nix::{ - sys::signal::{kill, Signal::SIGINT}, - unistd::Pid, -}; -use std::process; +use std::{process, time::Duration}; + +use crate::common::KillChildOnDrop; pub mod common; pub mod websocket_server; #[tokio::test] async fn telemetry_works() { - let config = websocket_server::Config { - capacity: 1, - max_frame_size: 1048 * 1024, - send_buffer_len: 32, - bind_address: "127.0.0.1:0".parse().unwrap(), - }; - let mut server = websocket_server::WsServer::new(config).await.unwrap(); + common::run_with_timeout(Duration::from_secs(60 * 10), async move { + let config = websocket_server::Config { + capacity: 1, + max_frame_size: 1048 * 1024, + send_buffer_len: 32, + bind_address: "127.0.0.1:0".parse().unwrap(), + }; + let mut server = websocket_server::WsServer::new(config).await.unwrap(); - let addr = server.local_addr().unwrap(); + let addr = server.local_addr().unwrap(); - let server_task = tokio::spawn(async move { - loop { - use websocket_server::Event; - match server.next_event().await { - // New connection on the listener. - Event::ConnectionOpen { address } => { - println!("New connection from {:?}", address); - server.accept(); - }, + let server_task = tokio::spawn(async move { + loop { + use websocket_server::Event; + match server.next_event().await { + // New connection on the listener. + Event::ConnectionOpen { address } => { + println!("New connection from {:?}", address); + server.accept(); + }, - // Received a message from a connection. - Event::BinaryFrame { message, .. } => { - let json: serde_json::Value = serde_json::from_slice(&message).unwrap(); - let object = - json.as_object().unwrap().get("payload").unwrap().as_object().unwrap(); - if matches!(object.get("best"), Some(serde_json::Value::String(_))) { - break - } - }, + // Received a message from a connection. + Event::BinaryFrame { message, .. } => { + let json: serde_json::Value = serde_json::from_slice(&message).unwrap(); + let object = + json.as_object().unwrap().get("payload").unwrap().as_object().unwrap(); + if matches!(object.get("best"), Some(serde_json::Value::String(_))) { + break + } + }, - Event::TextFrame { .. } => panic!("Got a TextFrame over the socket, this is a bug"), + Event::TextFrame { .. } => + panic!("Got a TextFrame over the socket, this is a bug"), - // Connection has been closed. - Event::ConnectionError { .. } => {}, + // Connection has been closed. + Event::ConnectionError { .. } => {}, + } } - } - }); + }); - let mut substrate = process::Command::new(cargo_bin("substrate")); + let mut substrate = process::Command::new(cargo_bin("substrate")); - let mut substrate = substrate - .args(&["--dev", "--tmp", "--telemetry-url"]) - .arg(format!("ws://{} 10", addr)) - .arg("--no-hardware-benchmarks") - .stdout(process::Stdio::piped()) - .stderr(process::Stdio::piped()) - .stdin(process::Stdio::null()) - .spawn() - .unwrap(); + let mut substrate = KillChildOnDrop( + substrate + .args(&["--dev", "--tmp", "--telemetry-url"]) + .arg(format!("ws://{} 10", addr)) + .arg("--no-hardware-benchmarks") + .stdout(process::Stdio::piped()) + .stderr(process::Stdio::piped()) + .stdin(process::Stdio::null()) + .spawn() + .unwrap(), + ); - server_task.await.expect("server task panicked"); + server_task.await.expect("server task panicked"); - assert!(substrate.try_wait().unwrap().is_none(), "the process should still be running"); + substrate.assert_still_running(); - // Stop the process - kill(Pid::from_raw(substrate.id().try_into().unwrap()), SIGINT).unwrap(); - assert!(common::wait_for(&mut substrate, 40).map(|x| x.success()).unwrap_or_default()); - - let output = substrate.wait_with_output().unwrap(); - - println!("{}", String::from_utf8(output.stdout).unwrap()); - eprintln!("{}", String::from_utf8(output.stderr).unwrap()); - assert!(output.status.success()); + // Stop the process + substrate.stop(); + }) + .await; } diff --git a/substrate/bin/node/cli/tests/temp_base_path_works.rs b/substrate/bin/node/cli/tests/temp_base_path_works.rs index 4e743f2d3a..89e901c00e 100644 --- a/substrate/bin/node/cli/tests/temp_base_path_works.rs +++ b/substrate/bin/node/cli/tests/temp_base_path_works.rs @@ -19,45 +19,42 @@ #![cfg(unix)] use assert_cmd::cargo::cargo_bin; -use nix::{ - sys::signal::{kill, Signal::SIGINT}, - unistd::Pid, -}; -use regex::Regex; use std::{ - io::Read, - path::PathBuf, process::{Command, Stdio}, + time::Duration, }; pub mod common; #[tokio::test] async fn temp_base_path_works() { - let mut cmd = Command::new(cargo_bin("substrate")); - let mut child = common::KillChildOnDrop( - cmd.args(&["--dev", "--tmp", "--no-hardware-benchmarks"]) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn() - .unwrap(), - ); + common::run_with_timeout(Duration::from_secs(60 * 10), async move { + let mut cmd = Command::new(cargo_bin("substrate")); + let mut child = common::KillChildOnDrop( + cmd.args(&["--dev", "--tmp", "--no-hardware-benchmarks"]) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .unwrap(), + ); - let mut stderr = child.stderr.take().unwrap(); - let (ws_url, mut data) = common::find_ws_url_from_output(&mut stderr); + let mut stderr = child.stderr.take().unwrap(); + let node_info = common::extract_info_from_output(&mut stderr).0; - // Let it produce some blocks. - common::wait_n_finalized_blocks(3, 30, &ws_url).await.unwrap(); - assert!(child.try_wait().unwrap().is_none(), "the process should still be running"); + // Let it produce some blocks. + common::wait_n_finalized_blocks(3, &node_info.ws_url).await; - // Stop the process - kill(Pid::from_raw(child.id().try_into().unwrap()), SIGINT).unwrap(); - assert!(common::wait_for(&mut child, 40).map(|x| x.success()).unwrap_or_default()); + // Ensure the db path exists while the node is running + assert!(node_info.db_path.exists()); - // Ensure the database has been deleted - stderr.read_to_string(&mut data).unwrap(); - let re = Regex::new(r"Database: .+ at (\S+)").unwrap(); - let db_path = PathBuf::from(re.captures(data.as_str()).unwrap().get(1).unwrap().as_str()); + child.assert_still_running(); - assert!(!db_path.exists()); + // Stop the process + child.stop(); + + if node_info.db_path.exists() { + panic!("Database path `{}` wasn't deleted!", node_info.db_path.display()); + } + }) + .await; } diff --git a/substrate/client/cli/src/lib.rs b/substrate/client/cli/src/lib.rs index adbb299317..e73321ecce 100644 --- a/substrate/client/cli/src/lib.rs +++ b/substrate/client/cli/src/lib.rs @@ -197,10 +197,15 @@ pub trait SubstrateCli: Sized { command: &T, ) -> error::Result> { let tokio_runtime = build_runtime()?; + + // `capture` needs to be called in a tokio context. + // Also capture them as early as possible. + let signals = tokio_runtime.block_on(async { Signals::capture() })?; + let config = command.create_configuration(self, tokio_runtime.handle().clone())?; command.init(&Self::support_url(), &Self::impl_version(), |_, _| {}, &config)?; - Runner::new(config, tokio_runtime) + Runner::new(config, tokio_runtime, signals) } /// Create a runner for the command provided in argument. The `logger_hook` can be used to setup @@ -231,10 +236,15 @@ pub trait SubstrateCli: Sized { F: FnOnce(&mut LoggerBuilder, &Configuration), { let tokio_runtime = build_runtime()?; + + // `capture` needs to be called in a tokio context. + // Also capture them as early as possible. + let signals = tokio_runtime.block_on(async { Signals::capture() })?; + let config = command.create_configuration(self, tokio_runtime.handle().clone())?; command.init(&Self::support_url(), &Self::impl_version(), logger_hook, &config)?; - Runner::new(config, tokio_runtime) + Runner::new(config, tokio_runtime, signals) } /// Native runtime version. fn native_runtime_version(chain_spec: &Box) -> &'static RuntimeVersion; diff --git a/substrate/client/cli/src/runner.rs b/substrate/client/cli/src/runner.rs index 8917a37c49..47adfcf89f 100644 --- a/substrate/client/cli/src/runner.rs +++ b/substrate/client/cli/src/runner.rs @@ -18,54 +18,72 @@ use crate::{error::Error as CliError, Result, SubstrateCli}; use chrono::prelude::*; -use futures::{future, future::FutureExt, pin_mut, select, Future}; +use futures::{ + future::{self, BoxFuture, FutureExt}, + pin_mut, select, Future, +}; use log::info; use sc_service::{Configuration, Error as ServiceError, TaskManager}; use sc_utils::metrics::{TOKIO_THREADS_ALIVE, TOKIO_THREADS_TOTAL}; use std::{marker::PhantomData, time::Duration}; -#[cfg(target_family = "unix")] -async fn main(func: F) -> std::result::Result<(), E> -where - F: Future> + future::FusedFuture, - E: std::error::Error + Send + Sync + 'static + From, -{ - use tokio::signal::unix::{signal, SignalKind}; +/// Abstraction over OS signals to handle the shutdown of the node smoothly. +/// +/// On `unix` this represents `SigInt` and `SigTerm`. +pub struct Signals(BoxFuture<'static, ()>); - let mut stream_int = signal(SignalKind::interrupt()).map_err(ServiceError::Io)?; - let mut stream_term = signal(SignalKind::terminate()).map_err(ServiceError::Io)?; +impl Signals { + /// Capture the relevant signals to handle shutdown of the node smoothly. + /// + /// Needs to be called in a Tokio context to have access to the tokio reactor. + #[cfg(target_family = "unix")] + pub fn capture() -> std::result::Result { + use tokio::signal::unix::{signal, SignalKind}; - let t1 = stream_int.recv().fuse(); - let t2 = stream_term.recv().fuse(); - let t3 = func; + let mut stream_int = signal(SignalKind::interrupt()).map_err(ServiceError::Io)?; + let mut stream_term = signal(SignalKind::terminate()).map_err(ServiceError::Io)?; - pin_mut!(t1, t2, t3); - - select! { - _ = t1 => {}, - _ = t2 => {}, - res = t3 => res?, + Ok(Signals( + async move { + future::select(stream_int.recv().boxed(), stream_term.recv().boxed()).await; + } + .boxed(), + )) } - Ok(()) + /// Capture the relevant signals to handle shutdown of the node smoothly. + /// + /// Needs to be called in a Tokio context to have access to the tokio reactor. + #[cfg(not(unix))] + pub fn capture() -> std::result::Result { + use tokio::signal::ctrl_c; + + Ok(Signals( + async move { + let _ = ctrl_c().await; + } + .boxed(), + )) + } + + /// A dummy signal that never returns. + pub fn dummy() -> Self { + Self(future::pending().boxed()) + } } -#[cfg(not(unix))] -async fn main(func: F) -> std::result::Result<(), E> +async fn main(func: F, signals: impl Future) -> std::result::Result<(), E> where F: Future> + future::FusedFuture, - E: std::error::Error + Send + Sync + 'static + From, + E: std::error::Error + Send + Sync + 'static, { - use tokio::signal::ctrl_c; + let signals = signals.fuse(); - let t1 = ctrl_c().fuse(); - let t2 = func; - - pin_mut!(t1, t2); + pin_mut!(func, signals); select! { - _ = t1 => {}, - res = t2 => res?, + _ = signals => {}, + res = func => res?, } Ok(()) @@ -89,6 +107,7 @@ fn run_until_exit( tokio_runtime: tokio::runtime::Runtime, future: F, task_manager: TaskManager, + signals: impl Future, ) -> std::result::Result<(), E> where F: Future> + future::Future, @@ -97,7 +116,7 @@ where let f = future.fuse(); pin_mut!(f); - tokio_runtime.block_on(main(f))?; + tokio_runtime.block_on(main(f, signals))?; drop(task_manager); Ok(()) @@ -107,13 +126,18 @@ where pub struct Runner { config: Configuration, tokio_runtime: tokio::runtime::Runtime, + signals: Signals, phantom: PhantomData, } impl Runner { /// Create a new runtime with the command provided in argument - pub fn new(config: Configuration, tokio_runtime: tokio::runtime::Runtime) -> Result> { - Ok(Runner { config, tokio_runtime, phantom: PhantomData }) + pub fn new( + config: Configuration, + tokio_runtime: tokio::runtime::Runtime, + signals: Signals, + ) -> Result> { + Ok(Runner { config, tokio_runtime, signals, phantom: PhantomData }) } /// Log information about the node itself. @@ -147,7 +171,7 @@ impl Runner { self.print_node_infos(); let mut task_manager = self.tokio_runtime.block_on(initialize(self.config))?; - let res = self.tokio_runtime.block_on(main(task_manager.future().fuse())); + let res = self.tokio_runtime.block_on(main(task_manager.future().fuse(), self.signals.0)); // We need to drop the task manager here to inform all tasks that they should shut down. // // This is important to be done before we instruct the tokio runtime to shutdown. Otherwise @@ -210,7 +234,7 @@ impl Runner { E: std::error::Error + Send + Sync + 'static + From + From, { let (future, task_manager) = runner(self.config)?; - run_until_exit::<_, E>(self.tokio_runtime, future, task_manager) + run_until_exit::<_, E>(self.tokio_runtime, future, task_manager, self.signals.0) } /// Get an immutable reference to the node Configuration @@ -369,6 +393,7 @@ mod tests { runtime_cache_size: 2, }, runtime, + Signals::dummy(), ) .unwrap();