1use fallible_iterator::FallibleIterator;
2use postgres_protocol;
3use postgres_protocol::types;
4use postgres_types::{to_sql_checked, FromSql, IsNull, Kind, ToSql, Type};
5use std::error::Error;
6
7use crate::{Array, Dimension};
8use postgres_types::private::BytesMut;
9
10impl<'de, T> FromSql<'de> for Array<T>
11where
12 T: FromSql<'de>,
13{
14 fn from_sql(ty: &Type, raw: &'de [u8]) -> Result<Array<T>, Box<dyn Error + Sync + Send>> {
15 let element_type = match *ty.kind() {
16 Kind::Array(ref ty) => ty,
17 _ => unreachable!(),
18 };
19
20 let array = types::array_from_sql(raw)?;
21
22 let dimensions = array
23 .dimensions()
24 .map(|d| {
25 Ok(Dimension {
26 len: d.len,
27 lower_bound: d.lower_bound,
28 })
29 })
30 .collect()?;
31
32 let elements = array
33 .values()
34 .map(|v| FromSql::from_sql_nullable(element_type, v))
35 .collect()?;
36
37 Ok(Array::from_parts(elements, dimensions))
38 }
39
40 fn accepts(ty: &Type) -> bool {
41 match ty.kind() {
42 &Kind::Array(ref ty) => <T as FromSql>::accepts(ty),
43 _ => false,
44 }
45 }
46}
47
48impl<T> ToSql for Array<T>
49where
50 T: ToSql,
51{
52 fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
53 let element_type = match ty.kind() {
54 &Kind::Array(ref ty) => ty,
55 _ => unreachable!(),
56 };
57
58 let dimensions = self.dimensions().iter().map(|d| types::ArrayDimension {
59 len: d.len,
60 lower_bound: d.lower_bound,
61 });
62 let elements = self.iter();
63
64 types::array_to_sql(
65 dimensions,
66 element_type.oid(),
67 elements,
68 |v, w| match v.to_sql(element_type, w) {
69 Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes),
70 Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No),
71 Err(e) => Err(e),
72 },
73 w,
74 )?;
75
76 Ok(IsNull::No)
77 }
78
79 fn accepts(ty: &Type) -> bool {
80 match ty.kind() {
81 &Kind::Array(ref ty) => <T as ToSql>::accepts(ty),
82 _ => false,
83 }
84 }
85
86 to_sql_checked!();
87}
88
89#[cfg(test)]
90mod test {
91 use std::fmt;
92
93 use crate::Array;
94 use postgres::types::{FromSqlOwned, ToSql};
95 use postgres::{Client, NoTls};
96
97 fn test_type<T: PartialEq + FromSqlOwned + ToSql + Sync, S: fmt::Display>(
98 sql_type: &str,
99 checks: &[(T, S)],
100 ) {
101 let mut conn = Client::connect("postgres://postgres:password@localhost", NoTls).unwrap();
102 for &(ref val, ref repr) in checks.iter() {
103 let result = conn
104 .query(&*format!("SELECT {}::{}", *repr, sql_type), &[])
105 .unwrap()[0]
106 .get(0);
107 assert!(val == &result);
108
109 let result = conn
110 .query(&*format!("SELECT $1::{}", sql_type), &[val])
111 .unwrap()[0]
112 .get(0);
113 assert!(val == &result);
114 }
115 }
116
117 macro_rules! test_array_params {
118 ($name:expr, $v1:expr, $s1:expr, $v2:expr, $s2:expr, $v3:expr, $s3:expr) => {{
119 let tests = &[
120 (
121 Some(Array::from_vec(vec![Some($v1), Some($v2), None], 1)),
122 format!("'{{{},{},NULL}}'", $s1, $s2),
123 ),
124 (None, "NULL".to_string()),
125 ];
126 test_type(&format!("{}[]", $name), tests);
127 let mut a = Array::from_vec(vec![Some($v1), Some($v2)], 0);
128 a.wrap(-1);
129 a.push(Array::from_vec(vec![None, Some($v3)], 0));
130 let tests = &[(
131 Some(a),
132 format!("'[-1:0][0:1]={{{{{},{}}},{{NULL,{}}}}}'", $s1, $s2, $s3),
133 )];
134 test_type(&format!("{}[][]", $name), tests);
135 }};
136 }
137
138 #[test]
139 fn test_boolarray_params() {
140 test_array_params!("BOOL", false, "f", true, "t", true, "t");
141 }
142
143 #[test]
144 fn test_byteaarray_params() {
145 test_array_params!(
146 "BYTEA",
147 vec![0u8, 1],
148 r#""\\x0001""#,
149 vec![254u8, 255u8],
150 r#""\\xfeff""#,
151 vec![10u8, 11u8],
152 r#""\\x0a0b""#
153 );
154 }
155
156 #[test]
157 fn test_chararray_params() {
158 test_array_params!("\"char\"", 'a' as i8, "a", 'z' as i8, "z", '0' as i8, "0");
159 }
160
161 #[test]
162 fn test_namearray_params() {
163 test_array_params!(
164 "NAME",
165 "hello".to_string(),
166 "hello",
167 "world".to_string(),
168 "world",
169 "!".to_string(),
170 "!"
171 );
172 }
173
174 #[test]
175 fn test_int2array_params() {
176 test_array_params!("INT2", 0i16, "0", 1i16, "1", 2i16, "2");
177 }
178
179 #[test]
180 fn test_int4array_params() {
181 test_array_params!("INT4", 0i32, "0", 1i32, "1", 2i32, "2");
182 }
183
184 #[test]
185 fn test_textarray_params() {
186 test_array_params!(
187 "TEXT",
188 "hello".to_string(),
189 "hello",
190 "world".to_string(),
191 "world",
192 "!".to_string(),
193 "!"
194 );
195 }
196
197 #[test]
198 fn test_charnarray_params() {
199 test_array_params!(
200 "CHAR(5)",
201 "hello".to_string(),
202 "hello",
203 "world".to_string(),
204 "world",
205 "! ".to_string(),
206 "!"
207 );
208 }
209
210 #[test]
211 fn test_varchararray_params() {
212 test_array_params!(
213 "VARCHAR",
214 "hello".to_string(),
215 "hello",
216 "world".to_string(),
217 "world",
218 "!".to_string(),
219 "!"
220 );
221 }
222
223 #[test]
224 fn test_int8array_params() {
225 test_array_params!("INT8", 0i64, "0", 1i64, "1", 2i64, "2");
226 }
227
228 #[test]
229 fn test_float4array_params() {
230 test_array_params!("FLOAT4", 0f32, "0", 1.5f32, "1.5", 0.009f32, ".009");
231 }
232
233 #[test]
234 fn test_float8array_params() {
235 test_array_params!("FLOAT8", 0f64, "0", 1.5f64, "1.5", 0.009f64, ".009");
236 }
237
238 #[test]
239 fn test_empty_array() {
240 let mut conn = Client::connect("postgres://postgres@localhost", NoTls).unwrap();
241 conn.query("SELECT '{}'::INT4[]", &[]).unwrap()[0].get::<_, Array<i32>>(0);
242 }
243}