use std::marker::PhantomData;
use std::mem;
use crate::buffer::Buffer;
use crate::datatypes::*;
use crate::error::{ArrowError, Result};
fn compute_row_major_strides<T: ArrowPrimitiveType>(shape: &[usize]) -> Result<Vec<usize>> {
let mut remaining_bytes = mem::size_of::<T::Native>();
for i in shape {
if let Some(val) = remaining_bytes.checked_mul(*i) {
remaining_bytes = val;
} else {
return Err(ArrowError::ComputeError(
"overflow occurred when computing row major strides.".to_string(),
));
}
}
let mut strides = Vec::<usize>::new();
for i in shape {
remaining_bytes /= *i;
strides.push(remaining_bytes);
}
Ok(strides)
}
fn compute_column_major_strides<T: ArrowPrimitiveType>(shape: &[usize]) -> Result<Vec<usize>> {
let mut remaining_bytes = mem::size_of::<T::Native>();
let mut strides = Vec::<usize>::new();
for i in shape {
strides.push(remaining_bytes);
if let Some(val) = remaining_bytes.checked_mul(*i) {
remaining_bytes = val;
} else {
return Err(ArrowError::ComputeError(
"overflow occurred when computing column major strides.".to_string(),
));
}
}
Ok(strides)
}
#[derive(Debug)]
pub struct Tensor<'a, T: ArrowPrimitiveType> {
data_type: DataType,
buffer: Buffer,
shape: Option<Vec<usize>>,
strides: Option<Vec<usize>>,
names: Option<Vec<&'a str>>,
_marker: PhantomData<T>,
}
pub type BooleanTensor<'a> = Tensor<'a, BooleanType>;
pub type Date32Tensor<'a> = Tensor<'a, Date32Type>;
pub type Date64Tensor<'a> = Tensor<'a, Date64Type>;
pub type Decimal128Tensor<'a> = Tensor<'a, Decimal128Type>;
pub type Decimal256Tensor<'a> = Tensor<'a, Decimal256Type>;
pub type DurationMicrosecondTensor<'a> = Tensor<'a, DurationMicrosecondType>;
pub type DurationMillisecondTensor<'a> = Tensor<'a, DurationMillisecondType>;
pub type DurationNanosecondTensor<'a> = Tensor<'a, DurationNanosecondType>;
pub type DurationSecondTensor<'a> = Tensor<'a, DurationSecondType>;
pub type Float16Tensor<'a> = Tensor<'a, Float16Type>;
pub type Float32Tensor<'a> = Tensor<'a, Float32Type>;
pub type Float64Tensor<'a> = Tensor<'a, Float64Type>;
pub type Int8Tensor<'a> = Tensor<'a, Int8Type>;
pub type Int16Tensor<'a> = Tensor<'a, Int16Type>;
pub type Int32Tensor<'a> = Tensor<'a, Int32Type>;
pub type Int64Tensor<'a> = Tensor<'a, Int64Type>;
pub type IntervalDayTimeTensor<'a> = Tensor<'a, IntervalDayTimeType>;
pub type IntervalMonthDayNanoTensor<'a> = Tensor<'a, IntervalMonthDayNanoType>;
pub type IntervalYearMonthTensor<'a> = Tensor<'a, IntervalYearMonthType>;
pub type Time32MillisecondTensor<'a> = Tensor<'a, Time32MillisecondType>;
pub type Time32SecondTensor<'a> = Tensor<'a, Time32SecondType>;
pub type Time64MicrosecondTensor<'a> = Tensor<'a, Time64MicrosecondType>;
pub type Time64NanosecondTensor<'a> = Tensor<'a, Time64NanosecondType>;
pub type TimestampMicrosecondTensor<'a> = Tensor<'a, TimestampMicrosecondType>;
pub type TimestampMillisecondTensor<'a> = Tensor<'a, TimestampMillisecondType>;
pub type TimestampNanosecondTensor<'a> = Tensor<'a, TimestampNanosecondType>;
pub type TimestampSecondTensor<'a> = Tensor<'a, TimestampSecondType>;
pub type UInt8Tensor<'a> = Tensor<'a, UInt8Type>;
pub type UInt16Tensor<'a> = Tensor<'a, UInt16Type>;
pub type UInt32Tensor<'a> = Tensor<'a, UInt32Type>;
pub type UInt64Tensor<'a> = Tensor<'a, UInt64Type>;
impl<'a, T: ArrowPrimitiveType> Tensor<'a, T> {
pub fn try_new(
buffer: Buffer,
shape: Option<Vec<usize>>,
strides: Option<Vec<usize>>,
names: Option<Vec<&'a str>>,
) -> Result<Self> {
match shape {
None => {
if buffer.len() != mem::size_of::<T::Native>() {
return Err(ArrowError::InvalidArgumentError(
"underlying buffer should only contain a single tensor element".to_string(),
));
}
if strides.is_some() {
return Err(ArrowError::InvalidArgumentError(
"expected None strides for tensor with no shape".to_string(),
));
}
if names.is_some() {
return Err(ArrowError::InvalidArgumentError(
"expected None names for tensor with no shape".to_string(),
));
}
}
Some(ref s) => {
if let Some(ref st) = strides {
if st.len() != s.len() {
return Err(ArrowError::InvalidArgumentError(
"shape and stride dimensions differ".to_string(),
));
}
}
if let Some(ref n) = names {
if n.len() != s.len() {
return Err(ArrowError::InvalidArgumentError(
"number of dimensions and number of dimension names differ".to_string(),
));
}
}
let total_elements: usize = s.iter().product();
if total_elements != (buffer.len() / mem::size_of::<T::Native>()) {
return Err(ArrowError::InvalidArgumentError(
"number of elements in buffer does not match dimensions".to_string(),
));
}
}
};
let tensor_strides = {
if let Some(st) = strides {
if let Some(ref s) = shape {
if compute_row_major_strides::<T>(s)? == st
|| compute_column_major_strides::<T>(s)? == st
{
Some(st)
} else {
return Err(ArrowError::InvalidArgumentError(
"the input stride does not match the selected shape".to_string(),
));
}
} else {
Some(st)
}
} else if let Some(ref s) = shape {
Some(compute_row_major_strides::<T>(s)?)
} else {
None
}
};
Ok(Self {
data_type: T::DATA_TYPE,
buffer,
shape,
strides: tensor_strides,
names,
_marker: PhantomData,
})
}
pub fn new_row_major(
buffer: Buffer,
shape: Option<Vec<usize>>,
names: Option<Vec<&'a str>>,
) -> Result<Self> {
if let Some(ref s) = shape {
let strides = Some(compute_row_major_strides::<T>(s)?);
Self::try_new(buffer, shape, strides, names)
} else {
Err(ArrowError::InvalidArgumentError(
"shape required to create row major tensor".to_string(),
))
}
}
pub fn new_column_major(
buffer: Buffer,
shape: Option<Vec<usize>>,
names: Option<Vec<&'a str>>,
) -> Result<Self> {
if let Some(ref s) = shape {
let strides = Some(compute_column_major_strides::<T>(s)?);
Self::try_new(buffer, shape, strides, names)
} else {
Err(ArrowError::InvalidArgumentError(
"shape required to create column major tensor".to_string(),
))
}
}
pub fn data_type(&self) -> &DataType {
&self.data_type
}
pub fn shape(&self) -> Option<&Vec<usize>> {
self.shape.as_ref()
}
pub fn data(&self) -> &Buffer {
&self.buffer
}
pub fn strides(&self) -> Option<&Vec<usize>> {
self.strides.as_ref()
}
pub fn names(&self) -> Option<&Vec<&'a str>> {
self.names.as_ref()
}
pub fn ndim(&self) -> usize {
match &self.shape {
None => 0,
Some(v) => v.len(),
}
}
pub fn dim_name(&self, i: usize) -> Option<&'a str> {
self.names.as_ref().map(|names| names[i])
}
pub fn size(&self) -> usize {
match self.shape {
None => 0,
Some(ref s) => s.iter().product(),
}
}
pub fn is_contiguous(&self) -> Result<bool> {
Ok(self.is_row_major()? || self.is_column_major()?)
}
pub fn is_row_major(&self) -> Result<bool> {
match self.shape {
None => Ok(false),
Some(ref s) => Ok(Some(compute_row_major_strides::<T>(s)?) == self.strides),
}
}
pub fn is_column_major(&self) -> Result<bool> {
match self.shape {
None => Ok(false),
Some(ref s) => Ok(Some(compute_column_major_strides::<T>(s)?) == self.strides),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::array::*;
#[test]
fn test_compute_row_major_strides() {
assert_eq!(
vec![48_usize, 8],
compute_row_major_strides::<Int64Type>(&[4_usize, 6]).unwrap()
);
assert_eq!(
vec![24_usize, 4],
compute_row_major_strides::<Int32Type>(&[4_usize, 6]).unwrap()
);
assert_eq!(
vec![6_usize, 1],
compute_row_major_strides::<Int8Type>(&[4_usize, 6]).unwrap()
);
}
#[test]
fn test_compute_column_major_strides() {
assert_eq!(
vec![8_usize, 32],
compute_column_major_strides::<Int64Type>(&[4_usize, 6]).unwrap()
);
assert_eq!(
vec![4_usize, 16],
compute_column_major_strides::<Int32Type>(&[4_usize, 6]).unwrap()
);
assert_eq!(
vec![1_usize, 4],
compute_column_major_strides::<Int8Type>(&[4_usize, 6]).unwrap()
);
}
#[test]
fn test_zero_dim() {
let buf = Buffer::from(&[1]);
let tensor = UInt8Tensor::try_new(buf, None, None, None).unwrap();
assert_eq!(0, tensor.size());
assert_eq!(None, tensor.shape());
assert_eq!(None, tensor.names());
assert_eq!(0, tensor.ndim());
assert!(!tensor.is_row_major().unwrap());
assert!(!tensor.is_column_major().unwrap());
assert!(!tensor.is_contiguous().unwrap());
let buf = Buffer::from(&[1, 2, 2, 2]);
let tensor = Int32Tensor::try_new(buf, None, None, None).unwrap();
assert_eq!(0, tensor.size());
assert_eq!(None, tensor.shape());
assert_eq!(None, tensor.names());
assert_eq!(0, tensor.ndim());
assert!(!tensor.is_row_major().unwrap());
assert!(!tensor.is_column_major().unwrap());
assert!(!tensor.is_contiguous().unwrap());
}
#[test]
fn test_tensor() {
let mut builder = Int32BufferBuilder::new(16);
for i in 0..16 {
builder.append(i);
}
let buf = builder.finish();
let tensor = Int32Tensor::try_new(buf, Some(vec![2, 8]), None, None).unwrap();
assert_eq!(16, tensor.size());
assert_eq!(Some(vec![2_usize, 8]).as_ref(), tensor.shape());
assert_eq!(Some(vec![32_usize, 4]).as_ref(), tensor.strides());
assert_eq!(2, tensor.ndim());
assert_eq!(None, tensor.names());
}
#[test]
fn test_new_row_major() {
let mut builder = Int32BufferBuilder::new(16);
for i in 0..16 {
builder.append(i);
}
let buf = builder.finish();
let tensor = Int32Tensor::new_row_major(buf, Some(vec![2, 8]), None).unwrap();
assert_eq!(16, tensor.size());
assert_eq!(Some(vec![2_usize, 8]).as_ref(), tensor.shape());
assert_eq!(Some(vec![32_usize, 4]).as_ref(), tensor.strides());
assert_eq!(None, tensor.names());
assert_eq!(2, tensor.ndim());
assert!(tensor.is_row_major().unwrap());
assert!(!tensor.is_column_major().unwrap());
assert!(tensor.is_contiguous().unwrap());
}
#[test]
fn test_new_column_major() {
let mut builder = Int32BufferBuilder::new(16);
for i in 0..16 {
builder.append(i);
}
let buf = builder.finish();
let tensor = Int32Tensor::new_column_major(buf, Some(vec![2, 8]), None).unwrap();
assert_eq!(16, tensor.size());
assert_eq!(Some(vec![2_usize, 8]).as_ref(), tensor.shape());
assert_eq!(Some(vec![4_usize, 8]).as_ref(), tensor.strides());
assert_eq!(None, tensor.names());
assert_eq!(2, tensor.ndim());
assert!(!tensor.is_row_major().unwrap());
assert!(tensor.is_column_major().unwrap());
assert!(tensor.is_contiguous().unwrap());
}
#[test]
fn test_with_names() {
let mut builder = Int64BufferBuilder::new(8);
for i in 0..8 {
builder.append(i);
}
let buf = builder.finish();
let names = vec!["Dim 1", "Dim 2"];
let tensor = Int64Tensor::new_column_major(buf, Some(vec![2, 4]), Some(names)).unwrap();
assert_eq!(8, tensor.size());
assert_eq!(Some(vec![2_usize, 4]).as_ref(), tensor.shape());
assert_eq!(Some(vec![8_usize, 16]).as_ref(), tensor.strides());
assert_eq!("Dim 1", tensor.dim_name(0).unwrap());
assert_eq!("Dim 2", tensor.dim_name(1).unwrap());
assert_eq!(2, tensor.ndim());
assert!(!tensor.is_row_major().unwrap());
assert!(tensor.is_column_major().unwrap());
assert!(tensor.is_contiguous().unwrap());
}
#[test]
fn test_inconsistent_strides() {
let mut builder = Int32BufferBuilder::new(16);
for i in 0..16 {
builder.append(i);
}
let buf = builder.finish();
let result = Int32Tensor::try_new(buf, Some(vec![2, 8]), Some(vec![2, 8, 1]), None);
if result.is_ok() {
panic!("shape and stride dimensions are different")
}
}
#[test]
fn test_inconsistent_names() {
let mut builder = Int32BufferBuilder::new(16);
for i in 0..16 {
builder.append(i);
}
let buf = builder.finish();
let result = Int32Tensor::try_new(
buf,
Some(vec![2, 8]),
Some(vec![4, 8]),
Some(vec!["1", "2", "3"]),
);
if result.is_ok() {
panic!("dimensions and names have different shape")
}
}
#[test]
fn test_incorrect_shape() {
let mut builder = Int32BufferBuilder::new(16);
for i in 0..16 {
builder.append(i);
}
let buf = builder.finish();
let result = Int32Tensor::try_new(buf, Some(vec![2, 6]), None, None);
if result.is_ok() {
panic!("number of elements does not match for the shape")
}
}
#[test]
fn test_incorrect_stride() {
let mut builder = Int32BufferBuilder::new(16);
for i in 0..16 {
builder.append(i);
}
let buf = builder.finish();
let result = Int32Tensor::try_new(buf, Some(vec![2, 8]), Some(vec![30, 4]), None);
if result.is_ok() {
panic!("the input stride does not match the selected shape")
}
}
}