diff --git a/library/std/src/io/buffered/linewritershim.rs b/library/std/src/io/buffered/linewritershim.rs index 3d04ccd1c7d81..5ebeada59bb53 100644 --- a/library/std/src/io/buffered/linewritershim.rs +++ b/library/std/src/io/buffered/linewritershim.rs @@ -119,7 +119,14 @@ impl<'a, W: ?Sized + Write> Write for LineWriterShim<'a, W> { // the buffer? // - If not, scan for the last newline that *does* fit in the buffer let tail = if flushed >= newline_idx { - &buf[flushed..] + let tail = &buf[flushed..]; + // Avoid unnecessary short writes by not splitting the remaining + // bytes if they're larger than the buffer. + // They can be written in full by the next call to write. + if tail.len() >= self.buffer.capacity() { + return Ok(flushed); + } + tail } else if newline_idx - flushed <= self.buffer.capacity() { &buf[flushed..newline_idx] } else { diff --git a/library/std/src/io/buffered/tests.rs b/library/std/src/io/buffered/tests.rs index bff0f823c4b5a..17f6107aa030c 100644 --- a/library/std/src/io/buffered/tests.rs +++ b/library/std/src/io/buffered/tests.rs @@ -847,8 +847,7 @@ fn long_line_flushed() { } /// Test that, given a very long partial line *after* successfully -/// flushing a complete line, the very long partial line is buffered -/// unconditionally, and no additional writes take place. This assures +/// flushing a complete line, no additional writes take place. This assures /// the property that `write` should make at-most-one attempt to write /// new data. #[test] @@ -856,13 +855,22 @@ fn line_long_tail_not_flushed() { let writer = ProgrammableSink::default(); let mut writer = LineWriter::with_capacity(5, writer); - // Assert that Line 1\n is flushed, and 01234 is buffered - assert_eq!(writer.write(b"Line 1\n0123456789").unwrap(), 12); + // Assert that Line 1\n is flushed and the long tail isn't. + let bytes = b"Line 1\n0123456789"; + writer.write(bytes).unwrap(); assert_eq!(&writer.get_ref().buffer, b"Line 1\n"); +} + +// Test that appending to a full buffer emits a single write, flushing the buffer. +#[test] +fn line_full_buffer_flushed() { + let writer = ProgrammableSink::default(); + let mut writer = LineWriter::with_capacity(5, writer); + assert_eq!(writer.write(b"01234").unwrap(), 5); // Because the buffer is full, this subsequent write will flush it assert_eq!(writer.write(b"5").unwrap(), 1); - assert_eq!(&writer.get_ref().buffer, b"Line 1\n01234"); + assert_eq!(&writer.get_ref().buffer, b"01234"); } /// Test that, if an attempt to pre-flush buffered data returns Ok(0), diff --git a/library/std/src/sys/pal/windows/c/bindings.txt b/library/std/src/sys/pal/windows/c/bindings.txt index 248ce3c9ff624..06d192587832a 100644 --- a/library/std/src/sys/pal/windows/c/bindings.txt +++ b/library/std/src/sys/pal/windows/c/bindings.txt @@ -2425,6 +2425,7 @@ Windows.Win32.System.Console.ENABLE_VIRTUAL_TERMINAL_PROCESSING Windows.Win32.System.Console.ENABLE_WINDOW_INPUT Windows.Win32.System.Console.ENABLE_WRAP_AT_EOL_OUTPUT Windows.Win32.System.Console.GetConsoleMode +Windows.Win32.System.Console.GetConsoleOutputCP Windows.Win32.System.Console.GetStdHandle Windows.Win32.System.Console.ReadConsoleW Windows.Win32.System.Console.STD_ERROR_HANDLE diff --git a/library/std/src/sys/pal/windows/c/windows_sys.rs b/library/std/src/sys/pal/windows/c/windows_sys.rs index 19925e59dfe9c..a3c4cf6b4220e 100644 --- a/library/std/src/sys/pal/windows/c/windows_sys.rs +++ b/library/std/src/sys/pal/windows/c/windows_sys.rs @@ -34,6 +34,7 @@ windows_targets::link!("kernel32.dll" "system" fn FreeEnvironmentStringsW(penv : windows_targets::link!("kernel32.dll" "system" fn GetActiveProcessorCount(groupnumber : u16) -> u32); windows_targets::link!("kernel32.dll" "system" fn GetCommandLineW() -> PCWSTR); windows_targets::link!("kernel32.dll" "system" fn GetConsoleMode(hconsolehandle : HANDLE, lpmode : *mut CONSOLE_MODE) -> BOOL); +windows_targets::link!("kernel32.dll" "system" fn GetConsoleOutputCP() -> u32); windows_targets::link!("kernel32.dll" "system" fn GetCurrentDirectoryW(nbufferlength : u32, lpbuffer : PWSTR) -> u32); windows_targets::link!("kernel32.dll" "system" fn GetCurrentProcess() -> HANDLE); windows_targets::link!("kernel32.dll" "system" fn GetCurrentProcessId() -> u32); @@ -3317,6 +3318,7 @@ pub struct XSAVE_FORMAT { pub XmmRegisters: [M128A; 8], pub Reserved4: [u8; 224], } + #[cfg(target_arch = "arm")] #[repr(C)] pub struct WSADATA { diff --git a/library/std/src/sys/pal/windows/stdio.rs b/library/std/src/sys/pal/windows/stdio.rs index 642c8bc4df7d1..bb3dee7c8b7c7 100644 --- a/library/std/src/sys/pal/windows/stdio.rs +++ b/library/std/src/sys/pal/windows/stdio.rs @@ -1,13 +1,11 @@ #![unstable(issue = "none", feature = "windows_stdio")] -use core::str::utf8_char_width; - use super::api::{self, WinError}; use crate::mem::MaybeUninit; use crate::os::windows::io::{FromRawHandle, IntoRawHandle}; use crate::sys::handle::Handle; use crate::sys::{c, cvt}; -use crate::{cmp, io, ptr, str}; +use crate::{cmp, io, ptr}; #[cfg(test)] mod tests; @@ -19,13 +17,9 @@ pub struct Stdin { incomplete_utf8: IncompleteUtf8, } -pub struct Stdout { - incomplete_utf8: IncompleteUtf8, -} +pub struct Stdout {} -pub struct Stderr { - incomplete_utf8: IncompleteUtf8, -} +pub struct Stderr {} struct IncompleteUtf8 { bytes: [u8; 4], @@ -84,140 +78,69 @@ fn is_console(handle: c::HANDLE) -> bool { unsafe { c::GetConsoleMode(handle, &mut mode) != 0 } } -fn write(handle_id: u32, data: &[u8], incomplete_utf8: &mut IncompleteUtf8) -> io::Result { +/// Returns true if the attached console's code page is currently UTF-8. +#[cfg(not(target_vendor = "win7"))] +fn is_utf8_console() -> bool { + unsafe { c::GetConsoleOutputCP() == c::CP_UTF8 } +} + +#[cfg(target_vendor = "win7")] +fn is_utf8_console() -> bool { + // Windows 7 has a fun "feature" where WriteFile on a console handle will return + // the number of UTF-16 code units written and not the number of bytes from the input string. + // So we always claim the console isn't UTF-8 to trigger the WriteConsole fallback code. + false +} + +fn write(handle_id: u32, data: &[u8]) -> io::Result { if data.is_empty() { return Ok(0); } let handle = get_handle(handle_id)?; - if !is_console(handle) { + if !is_console(handle) || is_utf8_console() { unsafe { let handle = Handle::from_raw_handle(handle); let ret = handle.write(data); let _ = handle.into_raw_handle(); // Don't close the handle return ret; } + } else { + write_console_utf16(data, handle) } - - if incomplete_utf8.len > 0 { - assert!( - incomplete_utf8.len < 4, - "Unexpected number of bytes for incomplete UTF-8 codepoint." - ); - if data[0] >> 6 != 0b10 { - // not a continuation byte - reject - incomplete_utf8.len = 0; - return Err(io::const_error!( - io::ErrorKind::InvalidData, - "Windows stdio in console mode does not support writing non-UTF-8 byte sequences", - )); - } - incomplete_utf8.bytes[incomplete_utf8.len as usize] = data[0]; - incomplete_utf8.len += 1; - let char_width = utf8_char_width(incomplete_utf8.bytes[0]); - if (incomplete_utf8.len as usize) < char_width { - // more bytes needed - return Ok(1); - } - let s = str::from_utf8(&incomplete_utf8.bytes[0..incomplete_utf8.len as usize]); - incomplete_utf8.len = 0; - match s { - Ok(s) => { - assert_eq!(char_width, s.len()); - let written = write_valid_utf8_to_console(handle, s)?; - assert_eq!(written, s.len()); // guaranteed by write_valid_utf8_to_console() for single codepoint writes - return Ok(1); - } - Err(_) => { - return Err(io::const_error!( - io::ErrorKind::InvalidData, - "Windows stdio in console mode does not support writing non-UTF-8 byte sequences", - )); - } - } - } - - // As the console is meant for presenting text, we assume bytes of `data` are encoded as UTF-8, - // which needs to be encoded as UTF-16. - // - // If the data is not valid UTF-8 we write out as many bytes as are valid. - // If the first byte is invalid it is either first byte of a multi-byte sequence but the - // provided byte slice is too short or it is the first byte of an invalid multi-byte sequence. - let len = cmp::min(data.len(), MAX_BUFFER_SIZE / 2); - let utf8 = match str::from_utf8(&data[..len]) { - Ok(s) => s, - Err(ref e) if e.valid_up_to() == 0 => { - let first_byte_char_width = utf8_char_width(data[0]); - if first_byte_char_width > 1 && data.len() < first_byte_char_width { - incomplete_utf8.bytes[0] = data[0]; - incomplete_utf8.len = 1; - return Ok(1); - } else { - return Err(io::const_error!( - io::ErrorKind::InvalidData, - "Windows stdio in console mode does not support writing non-UTF-8 byte sequences", - )); - } - } - Err(e) => str::from_utf8(&data[..e.valid_up_to()]).unwrap(), - }; - - write_valid_utf8_to_console(handle, utf8) } -fn write_valid_utf8_to_console(handle: c::HANDLE, utf8: &str) -> io::Result { - debug_assert!(!utf8.is_empty()); - - let mut utf16 = [MaybeUninit::::uninit(); MAX_BUFFER_SIZE / 2]; - let utf8 = &utf8[..utf8.floor_char_boundary(utf16.len())]; +fn write_console_utf16(data: &[u8], handle: c::HANDLE) -> io::Result { + let mut buffer = [MaybeUninit::::uninit(); MAX_BUFFER_SIZE / 2]; + let data = &data[..data.len().min(buffer.len())]; + + // Split off any trailing incomplete UTF-8 from the end of the input. + let utf8 = trim_last_char_boundary(data); + let utf16 = utf8_to_utf16_lossy(utf8, &mut buffer); + debug_assert!(!utf16.is_empty()); + + // Write the UTF-16 chars to the console. + // This will succeed in one write so long as our [u16] slice is smaller than the console's buffer, + // which we've ensured by truncating the input (see `MAX_BUFFER_SIZE`). + let written = write_u16s(handle, &utf16)?; + debug_assert_eq!(written, utf16.len()); + Ok(utf8.len()) +} - let utf16: &[u16] = unsafe { - // Note that this theoretically checks validity twice in the (most common) case - // where the underlying byte sequence is valid utf-8 (given the check in `write()`). +fn utf8_to_utf16_lossy<'a>(utf8: &[u8], utf16: &'a mut [MaybeUninit]) -> &'a [u16] { + unsafe { let result = c::MultiByteToWideChar( c::CP_UTF8, // CodePage - c::MB_ERR_INVALID_CHARS, // dwFlags + 0, // dwFlags utf8.as_ptr(), // lpMultiByteStr utf8.len() as i32, // cbMultiByte utf16.as_mut_ptr() as *mut c::WCHAR, // lpWideCharStr utf16.len() as i32, // cchWideChar ); - assert!(result != 0, "Unexpected error in MultiByteToWideChar"); - + // The only way an error can happen here is if we've messed up. + debug_assert!(result != 0, "Unexpected error in MultiByteToWideChar"); // Safety: MultiByteToWideChar initializes `result` values. MaybeUninit::slice_assume_init_ref(&utf16[..result as usize]) - }; - - let mut written = write_u16s(handle, utf16)?; - - // Figure out how many bytes of as UTF-8 were written away as UTF-16. - if written == utf16.len() { - Ok(utf8.len()) - } else { - // Make sure we didn't end up writing only half of a surrogate pair (even though the chance - // is tiny). Because it is not possible for user code to re-slice `data` in such a way that - // a missing surrogate can be produced (and also because of the UTF-8 validation above), - // write the missing surrogate out now. - // Buffering it would mean we have to lie about the number of bytes written. - let first_code_unit_remaining = utf16[written]; - if matches!(first_code_unit_remaining, 0xDCEE..=0xDFFF) { - // low surrogate - // We just hope this works, and give up otherwise - let _ = write_u16s(handle, &utf16[written..written + 1]); - written += 1; - } - // Calculate the number of bytes of `utf8` that were actually written. - let mut count = 0; - for ch in utf16[..written].iter() { - count += match ch { - 0x0000..=0x007F => 1, - 0x0080..=0x07FF => 2, - 0xDCEE..=0xDFFF => 1, // Low surrogate. We already counted 3 bytes for the other. - _ => 3, - }; - } - debug_assert!(String::from_utf16(&utf16[..written]).unwrap() == utf8[..count]); - Ok(count) } } @@ -410,13 +333,13 @@ impl IncompleteUtf8 { impl Stdout { pub const fn new() -> Stdout { - Stdout { incomplete_utf8: IncompleteUtf8::new() } + Stdout {} } } impl io::Write for Stdout { fn write(&mut self, buf: &[u8]) -> io::Result { - write(c::STD_OUTPUT_HANDLE, buf, &mut self.incomplete_utf8) + write(c::STD_OUTPUT_HANDLE, buf) } fn flush(&mut self) -> io::Result<()> { @@ -426,13 +349,13 @@ impl io::Write for Stdout { impl Stderr { pub const fn new() -> Stderr { - Stderr { incomplete_utf8: IncompleteUtf8::new() } + Stderr {} } } impl io::Write for Stderr { fn write(&mut self, buf: &[u8]) -> io::Result { - write(c::STD_ERROR_HANDLE, buf, &mut self.incomplete_utf8) + write(c::STD_ERROR_HANDLE, buf) } fn flush(&mut self) -> io::Result<()> { @@ -447,3 +370,50 @@ pub fn is_ebadf(err: &io::Error) -> bool { pub fn panic_output() -> Option { Some(Stderr::new()) } + +/// Trim one incomplete UTF-8 char from the end of a byte slice. +/// +/// If trimming would lead to an empty slice then it returns `bytes` instead. +/// +/// Note: This function is optimized for size rather than speed. +pub fn trim_last_char_boundary(bytes: &[u8]) -> &[u8] { + // UTF-8's multiple-byte encoding uses the leading bits to encode the length of a code point. + // The bits of a multi-byte sequence are (where `n` is a placeholder for any bit): + // + // 11110nnn 10nnnnnn 10nnnnnn 10nnnnnn + // 1110nnnn 10nnnnnn 10nnnnnn + // 110nnnnn 10nnnnnn + // + // So if follows that an incomplete sequence is one of these: + // 11110nnn 10nnnnnn 10nnnnnn + // 11110nnn 10nnnnnn + // 1110nnnn 10nnnnnn + // 11110nnn + // 1110nnnn + // 110nnnnn + + // Get up to three bytes from the end of the slice and encode them as a u32 + // because it turns out the compiler is very good at optimizing numbers. + let u = match bytes { + [.., b1, b2, b3] => (*b1 as u32) << 16 | (*b2 as u32) << 8 | *b3 as u32, + [.., b1, b2] => (*b1 as u32) << 8 | *b2 as u32, + // If it's just a single byte or empty then we return the full slice + _ => return bytes, + }; + if (u & 0b_11111000_11000000_11000000 == 0b_11110000_10000000_10000000) && bytes.len() >= 4 { + &bytes[..bytes.len() - 3] + } else if (u & 0b_11111000_11000000 == 0b_11110000_10000000 + || u & 0b_11110000_11000000 == 0b_11100000_10000000) + && bytes.len() >= 3 + { + &bytes[..bytes.len() - 2] + } else if (u & 0b_1111_1000 == 0b_1111_0000 + || u & 0b_11110000 == 0b_11100000 + || u & 0b_11100000 == 0b_11000000) + && bytes.len() >= 2 + { + &bytes[..bytes.len() - 1] + } else { + bytes + } +}