thrift/transport/
socket.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::convert::From;
19use std::io;
20use std::io::{ErrorKind, Read, Write};
21use std::net::{Shutdown, TcpStream, ToSocketAddrs};
22
23#[cfg(unix)]
24use std::os::unix::net::UnixStream;
25
26use super::{ReadHalf, TIoChannel, WriteHalf};
27use crate::{new_transport_error, TransportErrorKind};
28
29/// Bidirectional TCP/IP channel.
30///
31/// # Examples
32///
33/// Create a `TTcpChannel`.
34///
35/// ```no_run
36/// use std::io::{Read, Write};
37/// use thrift::transport::TTcpChannel;
38///
39/// let mut c = TTcpChannel::new();
40/// c.open("localhost:9090").unwrap();
41///
42/// let mut buf = vec![0u8; 4];
43/// c.read(&mut buf).unwrap();
44/// c.write(&vec![0, 1, 2]).unwrap();
45/// ```
46///
47/// Create a `TTcpChannel` by wrapping an existing `TcpStream`.
48///
49/// ```no_run
50/// use std::io::{Read, Write};
51/// use std::net::TcpStream;
52/// use thrift::transport::TTcpChannel;
53///
54/// let stream = TcpStream::connect("127.0.0.1:9189").unwrap();
55///
56/// // no need to call c.open() since we've already connected above
57/// let mut c = TTcpChannel::with_stream(stream);
58///
59/// let mut buf = vec![0u8; 4];
60/// c.read(&mut buf).unwrap();
61/// c.write(&vec![0, 1, 2]).unwrap();
62/// ```
63#[derive(Debug, Default)]
64pub struct TTcpChannel {
65    stream: Option<TcpStream>,
66}
67
68impl TTcpChannel {
69    /// Create an uninitialized `TTcpChannel`.
70    ///
71    /// The returned instance must be opened using `TTcpChannel::open(...)`
72    /// before it can be used.
73    pub fn new() -> TTcpChannel {
74        TTcpChannel { stream: None }
75    }
76
77    /// Create a `TTcpChannel` that wraps an existing `TcpStream`.
78    ///
79    /// The passed-in stream is assumed to have been opened before being wrapped
80    /// by the created `TTcpChannel` instance.
81    pub fn with_stream(stream: TcpStream) -> TTcpChannel {
82        TTcpChannel {
83            stream: Some(stream),
84        }
85    }
86
87    /// Connect to `remote_address`, which should implement `ToSocketAddrs` trait.
88    pub fn open<A: ToSocketAddrs>(&mut self, remote_address: A) -> crate::Result<()> {
89        if self.stream.is_some() {
90            Err(new_transport_error(
91                TransportErrorKind::AlreadyOpen,
92                "tcp connection previously opened",
93            ))
94        } else {
95            match TcpStream::connect(&remote_address) {
96                Ok(s) => {
97                    self.stream = Some(s);
98                    Ok(())
99                }
100                Err(e) => Err(From::from(e)),
101            }
102        }
103    }
104
105    /// Shut down this channel.
106    ///
107    /// Both send and receive halves are closed, and this instance can no
108    /// longer be used to communicate with another endpoint.
109    pub fn close(&mut self) -> crate::Result<()> {
110        self.if_set(|s| s.shutdown(Shutdown::Both))
111            .map_err(From::from)
112    }
113
114    fn if_set<F, T>(&mut self, mut stream_operation: F) -> io::Result<T>
115    where
116        F: FnMut(&mut TcpStream) -> io::Result<T>,
117    {
118        if let Some(ref mut s) = self.stream {
119            stream_operation(s)
120        } else {
121            Err(io::Error::new(
122                ErrorKind::NotConnected,
123                "tcp endpoint not connected",
124            ))
125        }
126    }
127}
128
129impl TIoChannel for TTcpChannel {
130    fn split(self) -> crate::Result<(ReadHalf<Self>, WriteHalf<Self>)>
131    where
132        Self: Sized,
133    {
134        let mut s = self;
135
136        s.stream
137            .as_mut()
138            .and_then(|s| s.try_clone().ok())
139            .map(|cloned| {
140                let read_half = ReadHalf::new(TTcpChannel {
141                    stream: s.stream.take(),
142                });
143                let write_half = WriteHalf::new(TTcpChannel {
144                    stream: Some(cloned),
145                });
146                (read_half, write_half)
147            })
148            .ok_or_else(|| {
149                new_transport_error(
150                    TransportErrorKind::Unknown,
151                    "cannot clone underlying tcp stream",
152                )
153            })
154    }
155}
156
157impl Read for TTcpChannel {
158    fn read(&mut self, b: &mut [u8]) -> io::Result<usize> {
159        self.if_set(|s| s.read(b))
160    }
161}
162
163impl Write for TTcpChannel {
164    fn write(&mut self, b: &[u8]) -> io::Result<usize> {
165        self.if_set(|s| s.write(b))
166    }
167
168    fn flush(&mut self) -> io::Result<()> {
169        self.if_set(|s| s.flush())
170    }
171}
172
173#[cfg(unix)]
174impl TIoChannel for UnixStream {
175    fn split(self) -> crate::Result<(ReadHalf<Self>, WriteHalf<Self>)>
176    where
177        Self: Sized,
178    {
179        let socket_rx = self.try_clone().unwrap();
180
181        Ok((ReadHalf::new(self), WriteHalf::new(socket_rx)))
182    }
183}