diff --git a/library/std/src/sys/net/connection/uefi/mod.rs b/library/std/src/sys/net/connection/uefi/mod.rs index da2174396266f..46d67c8e51019 100644 --- a/library/std/src/sys/net/connection/uefi/mod.rs +++ b/library/std/src/sys/net/connection/uefi/mod.rs @@ -4,11 +4,14 @@ use crate::net::{Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr}; use crate::sys::unsupported; use crate::time::Duration; -pub struct TcpStream(!); +mod tcp; +pub(crate) mod tcp4; + +pub struct TcpStream(#[expect(dead_code)] tcp::Tcp); impl TcpStream { - pub fn connect(_: io::Result<&SocketAddr>) -> io::Result { - unsupported() + pub fn connect(addr: io::Result<&SocketAddr>) -> io::Result { + tcp::Tcp::connect(addr?).map(Self) } pub fn connect_timeout(_: &SocketAddr, _: Duration) -> io::Result { @@ -16,105 +19,105 @@ impl TcpStream { } pub fn set_read_timeout(&self, _: Option) -> io::Result<()> { - self.0 + unsupported() } pub fn set_write_timeout(&self, _: Option) -> io::Result<()> { - self.0 + unsupported() } pub fn read_timeout(&self) -> io::Result> { - self.0 + unsupported() } pub fn write_timeout(&self) -> io::Result> { - self.0 + unsupported() } pub fn peek(&self, _: &mut [u8]) -> io::Result { - self.0 + unsupported() } pub fn read(&self, _: &mut [u8]) -> io::Result { - self.0 + unsupported() } pub fn read_buf(&self, _buf: BorrowedCursor<'_>) -> io::Result<()> { - self.0 + unsupported() } pub fn read_vectored(&self, _: &mut [IoSliceMut<'_>]) -> io::Result { - self.0 + unsupported() } pub fn is_read_vectored(&self) -> bool { - self.0 + false } pub fn write(&self, _: &[u8]) -> io::Result { - self.0 + unsupported() } pub fn write_vectored(&self, _: &[IoSlice<'_>]) -> io::Result { - self.0 + unsupported() } pub fn is_write_vectored(&self) -> bool { - self.0 + false } pub fn peer_addr(&self) -> io::Result { - self.0 + unsupported() } pub fn socket_addr(&self) -> io::Result { - self.0 + unsupported() } pub fn shutdown(&self, _: Shutdown) -> io::Result<()> { - self.0 + unsupported() } pub fn duplicate(&self) -> io::Result { - self.0 + unsupported() } pub fn set_linger(&self, _: Option) -> io::Result<()> { - self.0 + unsupported() } pub fn linger(&self) -> io::Result> { - self.0 + unsupported() } pub fn set_nodelay(&self, _: bool) -> io::Result<()> { - self.0 + unsupported() } pub fn nodelay(&self) -> io::Result { - self.0 + unsupported() } pub fn set_ttl(&self, _: u32) -> io::Result<()> { - self.0 + unsupported() } pub fn ttl(&self) -> io::Result { - self.0 + unsupported() } pub fn take_error(&self) -> io::Result> { - self.0 + unsupported() } pub fn set_nonblocking(&self, _: bool) -> io::Result<()> { - self.0 + unsupported() } } impl fmt::Debug for TcpStream { fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0 + todo!() } } diff --git a/library/std/src/sys/net/connection/uefi/tcp.rs b/library/std/src/sys/net/connection/uefi/tcp.rs new file mode 100644 index 0000000000000..f87accdc41de8 --- /dev/null +++ b/library/std/src/sys/net/connection/uefi/tcp.rs @@ -0,0 +1,21 @@ +use super::tcp4; +use crate::io; +use crate::net::SocketAddr; + +pub(crate) enum Tcp { + V4(#[expect(dead_code)] tcp4::Tcp4), +} + +impl Tcp { + pub(crate) fn connect(addr: &SocketAddr) -> io::Result { + match addr { + SocketAddr::V4(x) => { + let temp = tcp4::Tcp4::new()?; + temp.configure(true, Some(x), None)?; + temp.connect()?; + Ok(Tcp::V4(temp)) + } + SocketAddr::V6(_) => todo!(), + } + } +} diff --git a/library/std/src/sys/net/connection/uefi/tcp4.rs b/library/std/src/sys/net/connection/uefi/tcp4.rs new file mode 100644 index 0000000000000..64f4a9c97b0b6 --- /dev/null +++ b/library/std/src/sys/net/connection/uefi/tcp4.rs @@ -0,0 +1,124 @@ +use r_efi::efi::{self, Status}; +use r_efi::protocols::tcp4; + +use crate::io; +use crate::net::SocketAddrV4; +use crate::ptr::NonNull; +use crate::sync::atomic::{AtomicBool, Ordering}; +use crate::sys::pal::helpers; + +const TYPE_OF_SERVICE: u8 = 8; +const TIME_TO_LIVE: u8 = 255; + +pub(crate) struct Tcp4 { + protocol: NonNull, + flag: AtomicBool, + #[expect(dead_code)] + service_binding: helpers::ServiceProtocol, +} + +const DEFAULT_ADDR: efi::Ipv4Address = efi::Ipv4Address { addr: [0u8; 4] }; + +impl Tcp4 { + pub(crate) fn new() -> io::Result { + let service_binding = helpers::ServiceProtocol::open(tcp4::SERVICE_BINDING_PROTOCOL_GUID)?; + let protocol = helpers::open_protocol(service_binding.child_handle(), tcp4::PROTOCOL_GUID)?; + + Ok(Self { service_binding, protocol, flag: AtomicBool::new(false) }) + } + + pub(crate) fn configure( + &self, + active: bool, + remote_address: Option<&SocketAddrV4>, + station_address: Option<&SocketAddrV4>, + ) -> io::Result<()> { + let protocol = self.protocol.as_ptr(); + + let (remote_address, remote_port) = if let Some(x) = remote_address { + (helpers::ipv4_to_r_efi(*x.ip()), x.port()) + } else { + (DEFAULT_ADDR, 0) + }; + + // FIXME: Remove when passive connections with proper subnet handling are added + assert!(station_address.is_none()); + let use_default_address = efi::Boolean::TRUE; + let (station_address, station_port) = (DEFAULT_ADDR, 0); + let subnet_mask = helpers::ipv4_to_r_efi(crate::net::Ipv4Addr::new(0, 0, 0, 0)); + + let mut config_data = tcp4::ConfigData { + type_of_service: TYPE_OF_SERVICE, + time_to_live: TIME_TO_LIVE, + access_point: tcp4::AccessPoint { + use_default_address, + remote_address, + remote_port, + active_flag: active.into(), + station_address, + station_port, + subnet_mask, + }, + control_option: crate::ptr::null_mut(), + }; + + let r = unsafe { ((*protocol).configure)(protocol, &mut config_data) }; + if r.is_error() { Err(crate::io::Error::from_raw_os_error(r.as_usize())) } else { Ok(()) } + } + + pub(crate) fn connect(&self) -> io::Result<()> { + let evt = unsafe { self.create_evt() }?; + let completion_token = + tcp4::CompletionToken { event: evt.as_ptr(), status: Status::SUCCESS }; + + let protocol = self.protocol.as_ptr(); + let mut conn_token = tcp4::ConnectionToken { completion_token }; + + let r = unsafe { ((*protocol).connect)(protocol, &mut conn_token) }; + if r.is_error() { + return Err(io::Error::from_raw_os_error(r.as_usize())); + } + + self.wait_for_flag().unwrap(); + + if completion_token.status.is_error() { + Err(io::Error::from_raw_os_error(completion_token.status.as_usize())) + } else { + Ok(()) + } + } + + unsafe fn create_evt(&self) -> io::Result { + self.flag.store(false, Ordering::Relaxed); + helpers::OwnedEvent::new( + efi::EVT_NOTIFY_SIGNAL, + efi::TPL_CALLBACK, + Some(toggle_atomic_flag), + Some(unsafe { NonNull::new_unchecked(self.flag.as_ptr().cast()) }), + ) + } + + fn wait_for_flag(&self) -> io::Result<()> { + while !self.flag.load(Ordering::Relaxed) { + self.poll()?; + } + + Ok(()) + } + + fn poll(&self) -> io::Result<()> { + let protocol = self.protocol.as_ptr(); + let r = unsafe { ((*protocol).poll)(protocol) }; + + if r.is_error() { + return Err(io::Error::from_raw_os_error(r.as_usize())); + } else { + Ok(()) + } + } +} + +extern "efiapi" fn toggle_atomic_flag(_: r_efi::efi::Event, ctx: *mut crate::ffi::c_void) { + let flag = unsafe { AtomicBool::from_ptr(ctx.cast()) }; + flag.store(true, Ordering::Relaxed); +} diff --git a/library/std/src/sys/pal/uefi/helpers.rs b/library/std/src/sys/pal/uefi/helpers.rs index 309022bcccf27..7bcf37ac09ea0 100644 --- a/library/std/src/sys/pal/uefi/helpers.rs +++ b/library/std/src/sys/pal/uefi/helpers.rs @@ -653,7 +653,6 @@ pub(crate) struct ServiceProtocol { } impl ServiceProtocol { - #[expect(dead_code)] pub(crate) fn open(service_guid: r_efi::efi::Guid) -> io::Result { let handles = locate_handles(service_guid)?; @@ -670,7 +669,6 @@ impl ServiceProtocol { Err(io::const_error!(io::ErrorKind::NotFound, "no service binding protocol found")) } - #[expect(dead_code)] pub(crate) fn child_handle(&self) -> NonNull { self.child_handle } @@ -732,6 +730,10 @@ impl OwnedEvent { } } + pub(crate) fn as_ptr(&self) -> efi::Event { + self.0.as_ptr() + } + pub(crate) fn into_raw(self) -> *mut crate::ffi::c_void { let r = self.0.as_ptr(); crate::mem::forget(self); @@ -755,3 +757,7 @@ impl Drop for OwnedEvent { } } } + +pub(crate) const fn ipv4_to_r_efi(addr: crate::net::Ipv4Addr) -> efi::Ipv4Address { + efi::Ipv4Address { addr: addr.octets() } +}