Fix the storage_size/state_getStorageSize RPC call (#13154)

* Have `KeyIterator` clone the `prefix` it receives

* Stream keys in `storage_size` RPC and add a runtime limit

* Update client/rpc/Cargo.toml

Co-authored-by: Bastian Köcher <git@kchr.de>

* Update client/rpc/src/state/utils.rs

Co-authored-by: Bastian Köcher <git@kchr.de>

* Rename the types to signify that the cancellation is due to a timeout

* Move the test into a `mod tests`

* Add a comment regarding `biased` in `tokio::select`

* Make the `clone` explicit when calling `KeyIterator::{new, new_child}`

Co-authored-by: Bastian Köcher <git@kchr.de>
This commit is contained in:
Koute
2023-01-18 15:28:27 +09:00
committed by GitHub
parent 3310f7ae96
commit 338d75d0a3
8 changed files with 232 additions and 51 deletions
+13 -13
View File
@@ -303,17 +303,17 @@ pub trait AuxStore {
}
/// An `Iterator` that iterates keys in a given block under a prefix.
pub struct KeyIterator<'a, State, Block> {
pub struct KeyIterator<State, Block> {
state: State,
child_storage: Option<ChildInfo>,
prefix: Option<&'a StorageKey>,
prefix: Option<StorageKey>,
current_key: Vec<u8>,
_phantom: PhantomData<Block>,
}
impl<'a, State, Block> KeyIterator<'a, State, Block> {
impl<State, Block> KeyIterator<State, Block> {
/// create a KeyIterator instance
pub fn new(state: State, prefix: Option<&'a StorageKey>, current_key: Vec<u8>) -> Self {
pub fn new(state: State, prefix: Option<StorageKey>, current_key: Vec<u8>) -> Self {
Self { state, child_storage: None, prefix, current_key, _phantom: PhantomData }
}
@@ -321,14 +321,14 @@ impl<'a, State, Block> KeyIterator<'a, State, Block> {
pub fn new_child(
state: State,
child_info: ChildInfo,
prefix: Option<&'a StorageKey>,
prefix: Option<StorageKey>,
current_key: Vec<u8>,
) -> Self {
Self { state, child_storage: Some(child_info), prefix, current_key, _phantom: PhantomData }
}
}
impl<'a, State, Block> Iterator for KeyIterator<'a, State, Block>
impl<State, Block> Iterator for KeyIterator<State, Block>
where
Block: BlockT,
State: StateBackend<HashFor<Block>>,
@@ -344,7 +344,7 @@ where
.ok()
.flatten()?;
// this terminates the iterator the first time it fails.
if let Some(prefix) = self.prefix {
if let Some(ref prefix) = self.prefix {
if !next_key.starts_with(&prefix.0[..]) {
return None
}
@@ -387,12 +387,12 @@ pub trait StorageProvider<Block: BlockT, B: Backend<Block>> {
/// Given a block's `Hash` and a key prefix, return a `KeyIterator` iterates matching storage
/// keys in that block.
fn storage_keys_iter<'a>(
fn storage_keys_iter(
&self,
hash: Block::Hash,
prefix: Option<&'a StorageKey>,
prefix: Option<&StorageKey>,
start_key: Option<&StorageKey>,
) -> sp_blockchain::Result<KeyIterator<'a, B::State, Block>>;
) -> sp_blockchain::Result<KeyIterator<B::State, Block>>;
/// Given a block's `Hash`, a key and a child storage key, return the value under the key in
/// that block.
@@ -414,13 +414,13 @@ pub trait StorageProvider<Block: BlockT, B: Backend<Block>> {
/// Given a block's `Hash` and a key `prefix` and a child storage key,
/// return a `KeyIterator` that iterates matching storage keys in that block.
fn child_storage_keys_iter<'a>(
fn child_storage_keys_iter(
&self,
hash: Block::Hash,
child_info: ChildInfo,
prefix: Option<&'a StorageKey>,
prefix: Option<&StorageKey>,
start_key: Option<&StorageKey>,
) -> sp_blockchain::Result<KeyIterator<'a, B::State, Block>>;
) -> sp_blockchain::Result<KeyIterator<B::State, Block>>;
/// Given a block's `Hash`, a key and a child storage key, return the hash under the key in that
/// block.
+2 -2
View File
@@ -71,8 +71,8 @@ pub trait StateApi<Hash> {
fn storage_hash(&self, key: StorageKey, hash: Option<Hash>) -> RpcResult<Option<Hash>>;
/// Returns the size of a storage entry at a block's state.
#[method(name = "state_getStorageSize", aliases = ["state_getStorageSizeAt"], blocking)]
fn storage_size(&self, key: StorageKey, hash: Option<Hash>) -> RpcResult<Option<u64>>;
#[method(name = "state_getStorageSize", aliases = ["state_getStorageSizeAt"])]
async fn storage_size(&self, key: StorageKey, hash: Option<Hash>) -> RpcResult<Option<u64>>;
/// Returns the runtime metadata as an opaque blob.
#[method(name = "state_getMetadata", blocking)]
+2 -2
View File
@@ -36,7 +36,7 @@ sp-runtime = { version = "7.0.0", path = "../../primitives/runtime" }
sp-session = { version = "4.0.0-dev", path = "../../primitives/session" }
sp-version = { version = "5.0.0", path = "../../primitives/version" }
tokio = { version = "1.22.0", optional = true }
tokio = "1.22.0"
[dev-dependencies]
env_logger = "0.9"
@@ -51,4 +51,4 @@ sp-io = { version = "7.0.0", path = "../../primitives/io" }
substrate-test-runtime-client = { version = "2.0.0", path = "../../test-utils/runtime/client" }
[features]
test-helpers = ["tokio"]
test-helpers = []
+15 -4
View File
@@ -19,6 +19,7 @@
//! Substrate state API.
mod state_full;
mod utils;
#[cfg(test)]
mod tests;
@@ -28,7 +29,7 @@ use std::sync::Arc;
use crate::SubscriptionTaskExecutor;
use jsonrpsee::{
core::{server::rpc_module::SubscriptionSink, Error as JsonRpseeError, RpcResult},
core::{async_trait, server::rpc_module::SubscriptionSink, Error as JsonRpseeError, RpcResult},
types::SubscriptionResult,
};
@@ -53,6 +54,7 @@ use sp_blockchain::{HeaderBackend, HeaderMetadata};
const STORAGE_KEYS_PAGED_MAX_COUNT: u32 = 1000;
/// State backend API.
#[async_trait]
pub trait StateBackend<Block: BlockT, Client>: Send + Sync + 'static
where
Block: BlockT + 'static,
@@ -107,10 +109,11 @@ where
///
/// If data is available at `key`, it is returned. Else, the sum of values who's key has `key`
/// prefix is returned, i.e. all the storage (double) maps that have this prefix.
fn storage_size(
async fn storage_size(
&self,
block: Option<Block::Hash>,
key: StorageKey,
deny_unsafe: DenyUnsafe,
) -> Result<Option<u64>, Error>;
/// Returns the runtime metadata as an opaque blob.
@@ -202,6 +205,7 @@ pub struct State<Block, Client> {
deny_unsafe: DenyUnsafe,
}
#[async_trait]
impl<Block, Client> StateApiServer<Block::Hash> for State<Block, Client>
where
Block: BlockT + 'static,
@@ -262,8 +266,15 @@ where
self.backend.storage_hash(block, key).map_err(Into::into)
}
fn storage_size(&self, key: StorageKey, block: Option<Block::Hash>) -> RpcResult<Option<u64>> {
self.backend.storage_size(block, key).map_err(Into::into)
async fn storage_size(
&self,
key: StorageKey,
block: Option<Block::Hash>,
) -> RpcResult<Option<u64>> {
self.backend
.storage_size(block, key, self.deny_unsafe)
.await
.map_err(Into::into)
}
fn metadata(&self, block: Option<Block::Hash>) -> RpcResult<Bytes> {
+47 -20
View File
@@ -18,17 +18,20 @@
//! State API backend for full nodes.
use std::{collections::HashMap, marker::PhantomData, sync::Arc};
use std::{collections::HashMap, marker::PhantomData, sync::Arc, time::Duration};
use super::{
client_err,
error::{Error, Result},
ChildStateBackend, StateBackend,
};
use crate::SubscriptionTaskExecutor;
use crate::{DenyUnsafe, SubscriptionTaskExecutor};
use futures::{future, stream, FutureExt, StreamExt};
use jsonrpsee::{core::Error as JsonRpseeError, SubscriptionSink};
use jsonrpsee::{
core::{async_trait, Error as JsonRpseeError},
SubscriptionSink,
};
use sc_client_api::{
Backend, BlockBackend, BlockchainEvents, CallExecutor, ExecutorProvider, ProofProvider,
StorageProvider,
@@ -48,6 +51,9 @@ use sp_core::{
use sp_runtime::{generic::BlockId, traits::Block as BlockT};
use sp_version::RuntimeVersion;
/// The maximum time allowed for an RPC call when running without unsafe RPC enabled.
const MAXIMUM_SAFE_RPC_CALL_TIMEOUT: Duration = Duration::from_secs(30);
/// Ranges to query in state_queryStorage.
struct QueryStorageRange<Block: BlockT> {
/// Hashes of all the blocks in the range.
@@ -166,6 +172,7 @@ where
}
}
#[async_trait]
impl<BE, Block, Client> StateBackend<Block, Client> for FullState<BE, Block, Client>
where
Block: BlockT + 'static,
@@ -251,33 +258,53 @@ where
.map_err(client_err)
}
fn storage_size(
async fn storage_size(
&self,
block: Option<Block::Hash>,
key: StorageKey,
deny_unsafe: DenyUnsafe,
) -> std::result::Result<Option<u64>, Error> {
let block = match self.block_or_best(block) {
Ok(b) => b,
Err(e) => return Err(client_err(e)),
};
match self.client.storage(block, &key) {
Ok(Some(d)) => return Ok(Some(d.0.len() as u64)),
Err(e) => return Err(client_err(e)),
Ok(None) => {},
}
let client = self.client.clone();
let timeout = match deny_unsafe {
DenyUnsafe::Yes => Some(MAXIMUM_SAFE_RPC_CALL_TIMEOUT),
DenyUnsafe::No => None,
};
self.client
.storage_pairs(block, &key)
.map(|kv| {
let item_sum = kv.iter().map(|(_, v)| v.0.len() as u64).sum::<u64>();
if item_sum > 0 {
Some(item_sum)
} else {
None
}
})
.map_err(client_err)
super::utils::spawn_blocking_with_timeout(timeout, move |is_timed_out| {
// Does the key point to a concrete entry in the database?
match client.storage(block, &key) {
Ok(Some(d)) => return Ok(Ok(Some(d.0.len() as u64))),
Err(e) => return Ok(Err(client_err(e))),
Ok(None) => {},
}
// The key doesn't point to anything, so it's probably a prefix.
let iter = match client.storage_keys_iter(block, Some(&key), None).map_err(client_err) {
Ok(iter) => iter,
Err(e) => return Ok(Err(e)),
};
let mut sum = 0;
for storage_key in iter {
let value = client.storage(block, &storage_key).ok().flatten().unwrap_or_default();
sum += value.0.len() as u64;
is_timed_out.check_if_timed_out()?;
}
if sum > 0 {
Ok(Ok(Some(sum)))
} else {
Ok(Ok(None))
}
})
.await
.map_err(|error| Error::Client(Box::new(error)))?
}
fn storage_hash(
+5 -2
View File
@@ -70,9 +70,12 @@ async fn should_return_storage() {
client.storage_hash(key.clone(), Some(genesis_hash).into()).map(|x| x.is_some()),
Ok(true)
);
assert_eq!(client.storage_size(key.clone(), None).unwrap().unwrap() as usize, VALUE.len(),);
assert_eq!(
client.storage_size(StorageKey(b":map".to_vec()), None).unwrap().unwrap() as usize,
client.storage_size(key.clone(), None).await.unwrap().unwrap() as usize,
VALUE.len(),
);
assert_eq!(
client.storage_size(StorageKey(b":map".to_vec()), None).await.unwrap().unwrap() as usize,
2 + 3,
);
assert_eq!(
+140
View File
@@ -0,0 +1,140 @@
// This file is part of Substrate.
// Copyright (C) 2017-2023 Parity Technologies (UK) Ltd.
// SPDX-License-Identifier: GPL-3.0-or-later WITH Classpath-exception-2.0
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
use std::{
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
/// An error signifying that a task has been cancelled due to a timeout.
#[derive(Debug)]
pub struct Timeout;
impl std::error::Error for Timeout {}
impl std::fmt::Display for Timeout {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.write_str("task has been running too long")
}
}
/// A handle which can be used to check whether the task has been cancelled due to a timeout.
#[repr(transparent)]
pub struct IsTimedOut(Arc<AtomicBool>);
impl IsTimedOut {
#[must_use]
pub fn check_if_timed_out(&self) -> std::result::Result<(), Timeout> {
if self.0.load(Ordering::Relaxed) {
Err(Timeout)
} else {
Ok(())
}
}
}
/// An error for a task which either panicked, or has been cancelled due to a timeout.
#[derive(Debug)]
pub enum SpawnWithTimeoutError {
JoinError(tokio::task::JoinError),
Timeout,
}
impl std::error::Error for SpawnWithTimeoutError {}
impl std::fmt::Display for SpawnWithTimeoutError {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
SpawnWithTimeoutError::JoinError(error) => error.fmt(fmt),
SpawnWithTimeoutError::Timeout => Timeout.fmt(fmt),
}
}
}
struct CancelOnDrop(Arc<AtomicBool>);
impl Drop for CancelOnDrop {
fn drop(&mut self) {
self.0.store(true, Ordering::Relaxed)
}
}
/// Spawns a new blocking task with a given `timeout`.
///
/// The `callback` should continuously call [`IsTimedOut::check_if_timed_out`],
/// which will return an error once the task runs for longer than `timeout`.
///
/// If `timeout` is `None` then this works just as a regular `spawn_blocking`.
pub async fn spawn_blocking_with_timeout<R>(
timeout: Option<Duration>,
callback: impl FnOnce(IsTimedOut) -> std::result::Result<R, Timeout> + Send + 'static,
) -> Result<R, SpawnWithTimeoutError>
where
R: Send + 'static,
{
let is_timed_out_arc = Arc::new(AtomicBool::new(false));
let is_timed_out = IsTimedOut(is_timed_out_arc.clone());
let _cancel_on_drop = CancelOnDrop(is_timed_out_arc);
let task = tokio::task::spawn_blocking(move || callback(is_timed_out));
let result = if let Some(timeout) = timeout {
tokio::select! {
// Shouldn't really matter, but make sure the task is polled before the timeout,
// in case the task finishes after the timeout and the timeout is really short.
biased;
task_result = task => task_result,
_ = tokio::time::sleep(timeout) => Ok(Err(Timeout))
}
} else {
task.await
};
match result {
Ok(Ok(result)) => Ok(result),
Ok(Err(Timeout)) => Err(SpawnWithTimeoutError::Timeout),
Err(error) => Err(SpawnWithTimeoutError::JoinError(error)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn spawn_blocking_with_timeout_works() {
let task: Result<(), SpawnWithTimeoutError> =
spawn_blocking_with_timeout(Some(Duration::from_millis(100)), |is_timed_out| {
std::thread::sleep(Duration::from_millis(200));
is_timed_out.check_if_timed_out()?;
unreachable!();
})
.await;
assert_matches::assert_matches!(task, Err(SpawnWithTimeoutError::Timeout));
let task = spawn_blocking_with_timeout(Some(Duration::from_millis(100)), |is_timed_out| {
std::thread::sleep(Duration::from_millis(20));
is_timed_out.check_if_timed_out()?;
Ok(())
})
.await;
assert_matches::assert_matches!(task, Ok(()));
}
}
@@ -1432,27 +1432,27 @@ where
Ok(keys)
}
fn storage_keys_iter<'a>(
fn storage_keys_iter(
&self,
hash: <Block as BlockT>::Hash,
prefix: Option<&'a StorageKey>,
prefix: Option<&StorageKey>,
start_key: Option<&StorageKey>,
) -> sp_blockchain::Result<KeyIterator<'a, B::State, Block>> {
) -> sp_blockchain::Result<KeyIterator<B::State, Block>> {
let state = self.state_at(hash)?;
let start_key = start_key.or(prefix).map(|key| key.0.clone()).unwrap_or_else(Vec::new);
Ok(KeyIterator::new(state, prefix, start_key))
Ok(KeyIterator::new(state, prefix.cloned(), start_key))
}
fn child_storage_keys_iter<'a>(
fn child_storage_keys_iter(
&self,
hash: <Block as BlockT>::Hash,
child_info: ChildInfo,
prefix: Option<&'a StorageKey>,
prefix: Option<&StorageKey>,
start_key: Option<&StorageKey>,
) -> sp_blockchain::Result<KeyIterator<'a, B::State, Block>> {
) -> sp_blockchain::Result<KeyIterator<B::State, Block>> {
let state = self.state_at(hash)?;
let start_key = start_key.or(prefix).map(|key| key.0.clone()).unwrap_or_else(Vec::new);
Ok(KeyIterator::new_child(state, child_info, prefix, start_key))
Ok(KeyIterator::new_child(state, child_info, prefix.cloned(), start_key))
}
fn storage(