Skip to content
27 changes: 27 additions & 0 deletions crates/providers/src/data_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,33 @@ impl<T> DataTree<T> {
}
}

/// Iterate over direct children, yielding `(optional_key, child)` pairs in index order.
///
/// # Example
/// ```rust
/// use qiskit_providers::DataTree;
/// let mut tree = DataTree::new();
/// tree.push_leaf(10); // unnamed
/// tree.insert_leaf("b", 20); // named
/// tree.push_leaf(30); // unnamed
/// let children: Vec<_> = tree.iter_children().collect();
/// assert_eq!(children[0], (None, &DataTree::Leaf(10)));
/// assert_eq!(children[1], (Some("b"), &DataTree::Leaf(20)));
/// assert_eq!(children[2], (None, &DataTree::Leaf(30)));
/// ```
pub fn iter_children(&self) -> impl Iterator<Item = (Option<&str>, &DataTree<T>)> + '_ {
let branch = match self {
Self::Branch(branch) => branch,
Self::Leaf(_) => panic!("called iter_children() on a leaf node"),
};
let rev: HashMap<usize, &str> = branch.keys.iter().map(|(k, &v)| (v, k.as_str())).collect();
branch
.data
.iter()
.enumerate()
.map(move |(i, child)| (rev.get(&i).copied(), child))
}

/// Insert a new leaf node with an associated string key
///
/// If a key is provided that is already in the tree the new value will be associated with
Expand Down
8 changes: 8 additions & 0 deletions crates/providers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,13 @@
// that they have been altered from the originals.

mod data_tree;
pub mod math_nodes;
mod program_node;
mod quantum_program;
mod store;
pub mod tensor;

pub use data_tree::{DataTree, PathEntry};
pub use program_node::{ProgramNode, ProgramNodeError, require_leaf_arg, require_typed_leaf_arg};
pub use quantum_program::{OwnedPathEntry, Port, QuantumProgram, QuantumProgramError};
pub use store::Store;
287 changes: 287 additions & 0 deletions crates/providers/src/math_nodes/binary.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
// This code is part of Qiskit.
//
// (C) Copyright IBM 2026
//
// This code is licensed under the Apache License, Version 2.0. You may
// obtain a copy of this license in the LICENSE.txt file in the root directory
// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0.
//
// Any modifications or derivative works of this code must retain this
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

use crate::data_tree::DataTree;
use crate::program_node::{ProgramNode, require_leaf_arg};
use crate::tensor::{DTypeLike, Tensor, TensorType, promotion};
use std::sync::LazyLock;

/// Shared input type spec for all elementwise binary nodes: two broadcastable tensors `x` and `y`.
static INPUT_TYPES: LazyLock<DataTree<TensorType>> = LazyLock::new(|| {
let mut types = DataTree::with_capacity(2);
types.insert_leaf(
"x",
TensorType {
dtype: DTypeLike::Var("x".into()),
shape: vec![],
broadcastable: true,
},
);
types.insert_leaf(
"y",
TensorType {
dtype: DTypeLike::Var("y".into()),
shape: vec![],
broadcastable: true,
},
);
types
});

/// Shared output type spec for all elementwise binary nodes: a single tensor of the promoted dtype.
static OUTPUT_TYPES: LazyLock<DataTree<TensorType>> = LazyLock::new(|| {
DataTree::new_leaf(TensorType {
dtype: DTypeLike::Promotion(
vec![DTypeLike::Var("x".into()), DTypeLike::Var("y".into())].into(),
),
shape: vec![],
broadcastable: true,
})
});

/// Generate a [`ProgramNode`] struct for an elementwise binary operation.
macro_rules! elementwise_binary_node {
($name:ident, $node_name:literal, $call_fn:expr) => {
#[doc = concat!("Elementwise `", $node_name, "` of two broadcastable tensors.")]
pub struct $name;

impl ProgramNode for $name {
type CallError = super::MathNodeError;

fn name(&self) -> &'static str {
$node_name
}
fn namespace(&self) -> &'static str {
"math"
}
fn input_types(&self) -> &DataTree<TensorType> {
&INPUT_TYPES
}
fn output_types(&self) -> &DataTree<TensorType> {
&OUTPUT_TYPES
}
fn implements_call(&self) -> bool {
true
}
fn call(&self, args: &DataTree<Tensor>) -> Result<DataTree<Tensor>, Self::CallError> {
let x = require_leaf_arg(args, "x")?;
let y = require_leaf_arg(args, "y")?;
let out_dtype = promotion(x.dtype(), y.dtype());
Ok(DataTree::new_leaf($call_fn(
&x.cast_ref(out_dtype),
&y.cast_ref(out_dtype),
)?))
}
}
};
}

elementwise_binary_node!(Add, "add", Tensor::add_tensor);
elementwise_binary_node!(Subtract, "subtract", Tensor::sub_tensor);
elementwise_binary_node!(Multiply, "multiply", Tensor::mul_tensor);
elementwise_binary_node!(Divide, "divide", Tensor::div_tensor);
elementwise_binary_node!(Remainder, "remainder", Tensor::rem_tensor);
elementwise_binary_node!(Power, "power", Tensor::pow);

#[cfg(test)]
mod tests {
use super::*;
use crate::math_nodes::MathNodeError;
use crate::program_node::ProgramNodeError;
use crate::tensor::{DType, Tensor};

fn args(x: Tensor, y: Tensor) -> DataTree<Tensor> {
let mut tree = DataTree::new();
tree.insert_leaf("x", x);
tree.insert_leaf("y", y);
tree
}

#[test]
fn test_add_same_dtype() {
let result = Add
.call(&args(
Tensor::from([1.0_f64, 2.0, 3.0]),
Tensor::from([4.0_f64, 5.0, 6.0]),
))
.unwrap();
let DataTree::Leaf(Tensor::F64(arr)) = result else {
panic!("expected f64 leaf")
};
assert_eq!(arr.as_slice().unwrap(), &[5.0, 7.0, 9.0]);
}

#[test]
fn test_add_promotes_dtype() {
let result = Add
.call(&args(
Tensor::from([1.0_f32, 2.0]),
Tensor::from([3.0_f64, 4.0]),
))
.unwrap();
let DataTree::Leaf(tensor) = result else {
panic!("expected leaf")
};
assert_eq!(tensor.dtype(), DType::F64);
let Tensor::F64(arr) = tensor else {
panic!("expected f64")
};
assert_eq!(arr.as_slice().unwrap(), &[4.0, 6.0]);
}

#[test]
fn test_add_broadcasts_1d_scalar() {
// shape [3] + shape [1] -> shape [3]
let result = Add
.call(&args(
Tensor::from([1.0_f64, 2.0, 3.0]),
Tensor::from([10.0_f64]),
))
.unwrap();
let DataTree::Leaf(Tensor::F64(arr)) = result else {
panic!("expected f64 leaf")
};
assert_eq!(arr.as_slice().unwrap(), &[11.0, 12.0, 13.0]);
}

#[test]
fn test_add_broadcasts_2d_with_1d() {
// shape [2, 3] + shape [3] -> shape [2, 3]
use ndarray::arr2;
let x = Tensor::F64(arr2(&[[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]]).into_dyn());
let y = Tensor::from([10.0_f64, 20.0, 30.0]);
let result = Add.call(&args(x, y)).unwrap();
let DataTree::Leaf(Tensor::F64(arr)) = result else {
panic!("expected f64 leaf")
};
let expected = arr2(&[[11.0_f64, 22.0, 33.0], [14.0, 25.0, 36.0]]).into_dyn();
assert_eq!(arr, expected);
}

#[test]
fn test_subtract() {
let result = Subtract
.call(&args(
Tensor::from([5.0_f64, 6.0, 7.0]),
Tensor::from([1.0_f64, 2.0, 3.0]),
))
.unwrap();
let DataTree::Leaf(Tensor::F64(arr)) = result else {
panic!()
};
assert_eq!(arr.as_slice().unwrap(), &[4.0, 4.0, 4.0]);
}

#[test]
fn test_multiply() {
let result = Multiply
.call(&args(
Tensor::from([2.0_f64, 3.0, 4.0]),
Tensor::from([10.0_f64, 10.0, 10.0]),
))
.unwrap();
let DataTree::Leaf(Tensor::F64(arr)) = result else {
panic!()
};
assert_eq!(arr.as_slice().unwrap(), &[20.0, 30.0, 40.0]);
}

#[test]
fn test_divide() {
let result = Divide
.call(&args(
Tensor::from([10.0_f64, 9.0, 8.0]),
Tensor::from([2.0_f64, 3.0, 4.0]),
))
.unwrap();
let DataTree::Leaf(Tensor::F64(arr)) = result else {
panic!()
};
assert_eq!(arr.as_slice().unwrap(), &[5.0, 3.0, 2.0]);
}

#[test]
fn test_remainder() {
let result = Remainder
.call(&args(
Tensor::from([7.0_f64, 8.0, 9.0]),
Tensor::from([3.0_f64, 3.0, 3.0]),
))
.unwrap();
let DataTree::Leaf(Tensor::F64(arr)) = result else {
panic!()
};
assert_eq!(arr.as_slice().unwrap(), &[1.0, 2.0, 0.0]);
}

#[test]
fn test_power() {
let result = Power
.call(&args(
Tensor::from([2.0_f64, 3.0, 4.0]),
Tensor::from([3.0_f64, 2.0, 1.0]),
))
.unwrap();
let DataTree::Leaf(Tensor::F64(arr)) = result else {
panic!()
};
for (a, b) in arr.as_slice().unwrap().iter().zip(&[8.0_f64, 9.0, 4.0]) {
assert!(approx::abs_diff_eq!(a, b, epsilon = 1e-12));
}
}

#[test]
fn test_missing_input_returns_error() {
let mut tree = DataTree::new();
tree.insert_leaf("x", Tensor::from([1.0_f64]));
// No "y" input.
let err = Add.call(&tree).unwrap_err();
assert_eq!(
err,
MathNodeError::Input(ProgramNodeError::MissingInput {
key: "y".to_string(),
})
);
}

#[test]
fn test_branch_where_leaf_expected_returns_error() {
let mut tree = DataTree::new();
tree.insert_leaf("x", Tensor::from([1.0_f64]));
// "y" is a branch, not a leaf.
tree.insert_branch("y", DataTree::new());
let err = Add.call(&tree).unwrap_err();
assert_eq!(
err,
MathNodeError::Input(ProgramNodeError::ExpectedLeaf {
key: "y".to_string(),
})
);
}

#[test]
fn test_power_broadcasts() {
// shape [3] ** shape [1] -> shape [3]
let result = Power
.call(&args(
Tensor::from([2.0_f64, 3.0, 4.0]),
Tensor::from([2.0_f64]),
))
.unwrap();
let DataTree::Leaf(Tensor::F64(arr)) = result else {
panic!()
};
for (a, b) in arr.as_slice().unwrap().iter().zip(&[4.0_f64, 9.0, 16.0]) {
assert!(approx::abs_diff_eq!(a, b, epsilon = 1e-12));
}
}
}
Loading
Loading