Sebastian Andrzej Siewior
2023-Feb-24 20:53 UTC
[PATCH 0/1] ZSTD compression support for OpenSSH
I added ZSTD support to OpenSSH roughly three years ago and I've been playing with it ever since. The nice part is that ZSTD achieves reasonable compression (like zlib) but consumes little CPU so it is unlikely that compression becomes the bottle neck of a transfer. The compression overhead (CPU) is negligible even when uncompressed data is tunneled over the SSH connection (SOCKS proxy, port forward). With this I can enably it by default on the client side by setting "Compression zstd" in ssh_config and it will fallback to none if compression isn't available. Sebastian
Sebastian Andrzej Siewior
2023-Feb-24 20:53 UTC
[PATCH 1/1] Add support for ZSTD compression
From: Sebastian Andrzej Siewior <sebastian at breakpoint.cc> The "zstd at breakpoint.cc" compression algorithm enables ZSTD based compression as defined in RFC8478. The compression is delayed until the server sends the SSH_MSG_USERAUTH_SUCCESS which is the same time as with the "zlib at openssh.com" method. Signed-off-by: Sebastian Andrzej Siewior <sebastian at breakpoint.cc> --- cipher.c | 30 +++++- configure.ac | 8 ++ kex.c | 5 + kex.h | 3 + myproposal.h | 2 +- packet.c | 272 +++++++++++++++++++++++++++++++++++++++++++++------ readconf.c | 8 +- servconf.c | 14 +-- ssh.c | 4 +- 9 files changed, 300 insertions(+), 46 deletions(-) diff --git a/cipher.c b/cipher.c index 02aea4089ff91..1634bb4019c86 100644 --- a/cipher.c +++ b/cipher.c @@ -48,6 +48,7 @@ #include "sshbuf.h" #include "ssherr.h" #include "digest.h" +#include "kex.h" #include "openbsd-compat/openssl-compat.h" @@ -142,12 +143,33 @@ cipher_alg_list(char sep, int auth_only) const char * compression_alg_list(int compression) { -#ifdef WITH_ZLIB - return compression ? "zlib at openssh.com,zlib,none" : - "none,zlib at openssh.com,zlib"; +#ifdef HAVE_LIBZSTD +#define COMP_ZSTD_WITH "zstd at breakpoint.cc," +#define COMP_ZSTD_NONE ",zstd at breakpoint.cc" #else - return "none"; +#define COMP_ZSTD_WITH "" +#define COMP_ZSTD_NONE "" #endif + +#ifdef WITH_ZLIB +#define COMP_ZLIB_C_WITH "zlib at openssh.com,zlib," +#define COMP_ZLIB_S_WITH "zlib at openssh.com," + +#define COMP_ZLIB_C_NONE ",zlib at openssh.com,zlib" +#else +#define COMP_ZLIB_C_WITH "" +#define COMP_ZLIB_S_WITH "" +#define COMP_ZLIB_C_NONE "" +#endif + switch (compression) { + case COMP_ZLIB: return COMP_ZLIB_C_WITH "none"; + case COMP_DELAYED: return COMP_ZLIB_S_WITH "none"; + case COMP_ZSTD: return COMP_ZSTD_WITH "none"; + case COMP_ALL_C: return COMP_ZSTD_WITH COMP_ZLIB_C_WITH "none"; + case COMP_ALL_S: return COMP_ZSTD_WITH COMP_ZLIB_S_WITH "none"; + default: + case 0: return "none" COMP_ZSTD_NONE COMP_ZLIB_C_NONE; + } } u_int diff --git a/configure.ac b/configure.ac index 22fee70f604a2..91ef386788be3 100644 --- a/configure.ac +++ b/configure.ac @@ -1498,6 +1498,14 @@ See http://www.gzip.org/zlib/ for details.]) LIBS="$saved_LIBS" fi +AC_ARG_WITH([libzstd], AS_HELP_STRING([--with-libzstd], [Build with libzstd.])) +AS_IF([test "x$with_libzstd" = "xyes"], + [ + PKG_CHECK_MODULES([LIBZSTD], [libzstd >= 1.4.0], [AC_DEFINE([HAVE_LIBZSTD], [1], [Use LIBZSTD])]) + LIBS="$LIBS ${LIBZSTD_LIBS}" + CFLAGS="$CFLAGS ${LIBZSTD_CFLAGS}" + ]) + dnl UnixWare 2.x AC_CHECK_FUNC([strcasecmp], [], [ AC_CHECK_LIB([resolv], [strcasecmp], [LIBS="$LIBS -lresolv"]) ] diff --git a/kex.c b/kex.c index 7731ca9004fc8..d71fd777a3123 100644 --- a/kex.c +++ b/kex.c @@ -826,6 +826,11 @@ choose_comp(struct sshcomp *comp, char *client, char *server) comp->type = COMP_ZLIB; } else #endif /* WITH_ZLIB */ +#ifdef HAVE_LIBZSTD + if (strcmp(name, "zstd at breakpoint.cc") == 0) { + comp->type = COMP_ZSTD; + } else +#endif /* HAVE_LIBZSTD */ if (strcmp(name, "none") == 0) { comp->type = COMP_NONE; } else { diff --git a/kex.h b/kex.h index c35329501871a..159cfc794bd67 100644 --- a/kex.h +++ b/kex.h @@ -68,6 +68,9 @@ /* pre-auth compression (COMP_ZLIB) is only supported in the client */ #define COMP_ZLIB 1 #define COMP_DELAYED 2 +#define COMP_ZSTD 3 +#define COMP_ALL_C 4 +#define COMP_ALL_S 5 #define CURVE25519_SIZE 32 diff --git a/myproposal.h b/myproposal.h index ee6e9f7415261..a015190b35d9f 100644 --- a/myproposal.h +++ b/myproposal.h @@ -88,7 +88,7 @@ "rsa-sha2-512," \ "rsa-sha2-256" -#define KEX_DEFAULT_COMP "none,zlib at openssh.com" +#define KEX_DEFAULT_COMP "none,zstd at breakpoint.cc,zlib at openssh.com" #define KEX_DEFAULT_LANG "" #define KEX_CLIENT \ diff --git a/packet.c b/packet.c index 3f64d2d32854a..a39b8d7fbd963 100644 --- a/packet.c +++ b/packet.c @@ -79,6 +79,9 @@ #ifdef WITH_ZLIB #include <zlib.h> #endif +#ifdef HAVE_LIBZSTD +#include <zstd.h> +#endif #include "xmalloc.h" #include "compat.h" @@ -156,6 +159,14 @@ struct session_state { /* Incoming/outgoing compression dictionaries */ z_stream compression_in_stream; z_stream compression_out_stream; +#endif +#ifdef HAVE_LIBZSTD + ZSTD_DCtx *compression_zstd_in_stream; + ZSTD_CCtx *compression_zstd_out_stream; + u_int64_t compress_zstd_in_raw; + u_int64_t compress_zstd_in_comp; + u_int64_t compress_zstd_out_raw; + u_int64_t compress_zstd_out_comp; #endif int compression_in_started; int compression_out_started; @@ -604,11 +615,11 @@ ssh_packet_close_internal(struct ssh *ssh, int do_close) state->newkeys[mode] = NULL; ssh_clear_newkeys(ssh, mode); /* next keys */ } -#ifdef WITH_ZLIB /* compression state is in shared mem, so we can only release it once */ if (do_close && state->compression_buffer) { sshbuf_free(state->compression_buffer); - if (state->compression_out_started) { +#ifdef WITH_ZLIB + if (state->compression_out_started == COMP_ZLIB) { z_streamp stream = &state->compression_out_stream; debug("compress outgoing: " "raw data %llu, compressed %llu, factor %.2f", @@ -619,7 +630,7 @@ ssh_packet_close_internal(struct ssh *ssh, int do_close) if (state->compression_out_failures == 0) deflateEnd(stream); } - if (state->compression_in_started) { + if (state->compression_in_started == COMP_ZLIB) { z_streamp stream = &state->compression_in_stream; debug("compress incoming: " "raw data %llu, compressed %llu, factor %.2f", @@ -630,8 +641,28 @@ ssh_packet_close_internal(struct ssh *ssh, int do_close) if (state->compression_in_failures == 0) inflateEnd(stream); } +#endif /* WITH_ZLIB */ +#ifdef HAVE_LIBZSTD + if (state->compression_out_started == COMP_ZSTD) { + debug("compress outgoing: " + "raw data %llu, compressed %llu, factor %.2f", + (unsigned long long)state->compress_zstd_out_raw, + (unsigned long long)state->compress_zstd_out_comp, + state->compress_zstd_out_raw == 0 ? 0.0 : + (double) state->compress_zstd_out_comp / + state->compress_zstd_out_raw); + } + if (state->compression_in_started == COMP_ZSTD) { + debug("compress incoming: " + "raw data %llu, compressed %llu, factor %.2f", + (unsigned long long)state->compress_zstd_in_raw, + (unsigned long long)state->compress_zstd_in_comp, + state->compress_zstd_in_raw == 0 ? 0.0 : + (double) state->compress_zstd_in_comp / + state->compress_zstd_in_raw); + } +#endif /* HAVE_LIBZSTD */ } -#endif /* WITH_ZLIB */ cipher_free(state->send_context); cipher_free(state->receive_context); state->send_context = state->receive_context = NULL; @@ -696,11 +727,11 @@ start_compression_out(struct ssh *ssh, int level) if (level < 1 || level > 9) return SSH_ERR_INVALID_ARGUMENT; debug("Enabling compression at level %d.", level); - if (ssh->state->compression_out_started == 1) + if (ssh->state->compression_out_started == COMP_ZLIB) deflateEnd(&ssh->state->compression_out_stream); switch (deflateInit(&ssh->state->compression_out_stream, level)) { case Z_OK: - ssh->state->compression_out_started = 1; + ssh->state->compression_out_started = COMP_ZLIB; break; case Z_MEM_ERROR: return SSH_ERR_ALLOC_FAIL; @@ -713,11 +744,11 @@ start_compression_out(struct ssh *ssh, int level) static int start_compression_in(struct ssh *ssh) { - if (ssh->state->compression_in_started == 1) + if (ssh->state->compression_in_started == COMP_ZLIB) inflateEnd(&ssh->state->compression_in_stream); switch (inflateInit(&ssh->state->compression_in_stream)) { case Z_OK: - ssh->state->compression_in_started = 1; + ssh->state->compression_in_started = COMP_ZLIB; break; case Z_MEM_ERROR: return SSH_ERR_ALLOC_FAIL; @@ -734,7 +765,7 @@ compress_buffer(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out) u_char buf[4096]; int r, status; - if (ssh->state->compression_out_started != 1) + if (ssh->state->compression_out_started != COMP_ZLIB) return SSH_ERR_INTERNAL_ERROR; /* This case is not handled below. */ @@ -780,7 +811,7 @@ uncompress_buffer(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out) u_char buf[4096]; int r, status; - if (ssh->state->compression_in_started != 1) + if (ssh->state->compression_in_started != COMP_ZLIB) return SSH_ERR_INTERNAL_ERROR; if ((ssh->state->compression_in_stream.next_in @@ -848,6 +879,143 @@ uncompress_buffer(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out) } #endif /* WITH_ZLIB */ +#ifdef HAVE_LIBZSTD +static int +start_compression_zstd_out(struct ssh *ssh) +{ + debug("Enabling ZSTD compression."); + if (ssh->state->compression_out_started == COMP_ZSTD) + ZSTD_CCtx_reset(ssh->state->compression_zstd_out_stream, ZSTD_reset_session_only); + if (!ssh->state->compression_zstd_out_stream) + ssh->state->compression_zstd_out_stream = ZSTD_createCCtx(); + if (!ssh->state->compression_zstd_out_stream) + return SSH_ERR_ALLOC_FAIL; + ssh->state->compression_out_started = COMP_ZSTD; + return 0; +} + +static int +start_compression_zstd_in(struct ssh *ssh) +{ + if (ssh->state->compression_in_started == COMP_ZSTD) + ZSTD_DCtx_reset(ssh->state->compression_zstd_in_stream, ZSTD_reset_session_only); + if (!ssh->state->compression_zstd_in_stream) + ssh->state->compression_zstd_in_stream = ZSTD_createDCtx(); + if (!ssh->state->compression_zstd_in_stream) + return SSH_ERR_ALLOC_FAIL; + + ssh->state->compression_in_started = COMP_ZSTD; + return 0; +} + +static int +compress_buffer_zstd(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out) +{ + u_char buf[4096]; + ZSTD_inBuffer in_buff; + ZSTD_outBuffer out_buff; + int r, comp; + + if (ssh->state->compression_out_started != COMP_ZSTD) + return SSH_ERR_INTERNAL_ERROR; + + if (sshbuf_len(in) == 0) + return 0; + + in_buff.src = sshbuf_mutable_ptr(in); + if (!in_buff.src) + return SSH_ERR_INTERNAL_ERROR; + in_buff.size = sshbuf_len(in); + in_buff.pos = 0; + + ssh->state->compress_zstd_out_raw += in_buff.size; + out_buff.dst = buf; + out_buff.size = sizeof(buf); + + /* + * Consume input and immediatelly flush compressed data. It will loop + * multiple times if the output does not fit into the buffer + */ + do { + out_buff.pos = 0; + + comp = ZSTD_compressStream2(ssh->state->compression_zstd_out_stream, + &out_buff, &in_buff, ZSTD_e_flush); + if (ZSTD_isError(comp)) + return SSH_ERR_ALLOC_FAIL; + /* Append compressed data to output_buffer. */ + r = sshbuf_put(out, buf, out_buff.pos); + if (r != 0) + return r; + ssh->state->compress_zstd_out_comp += out_buff.pos; + } while (comp > 0); + return 0; +} + +static int uncompress_buffer_zstd(struct ssh *ssh, struct sshbuf *in, + struct sshbuf *out) +{ + u_char buf[4096]; + ZSTD_inBuffer in_buff; + ZSTD_outBuffer out_buff; + int r, decomp; + + if (ssh->state->compression_in_started != COMP_ZSTD) + return SSH_ERR_INTERNAL_ERROR; + + in_buff.src = sshbuf_mutable_ptr(in); + if (in_buff.src == NULL) + return SSH_ERR_INTERNAL_ERROR; + in_buff.size = sshbuf_len(in); + in_buff.pos = 0; + ssh->state->compress_zstd_in_comp += in_buff.size; + for (;;) { + /* Set up fixed-size output buffer. */ + out_buff.dst = buf; + out_buff.size = sizeof(buf); + out_buff.pos = 0; + + decomp = ZSTD_decompressStream(ssh->state->compression_zstd_in_stream, + &out_buff, &in_buff); + if (ZSTD_isError(decomp)) + return SSH_ERR_INVALID_FORMAT; + + r = sshbuf_put(out, buf, out_buff.pos); + if (r != 0) + return r; + ssh->state->compress_zstd_in_raw += out_buff.pos; + if (in_buff.size == in_buff.pos && + out_buff.pos < sizeof(buf)) + return 0; + } +} +#else /* HAVE_LIBZSTD */ + +static int +start_compression_zstd_out(struct ssh *ssh) +{ + return SSH_ERR_INTERNAL_ERROR; +} + +static int +start_compression_zstd_in(struct ssh *ssh) +{ + return SSH_ERR_INTERNAL_ERROR; +} + +static int +compress_buffer_zstd(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out) +{ + return SSH_ERR_INTERNAL_ERROR; +} + +static int +uncompress_buffer_zstd(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out) +{ + return SSH_ERR_INTERNAL_ERROR; +} +#endif /* HAVE_LIBZSTD */ + void ssh_clear_newkeys(struct ssh *ssh, int mode) { @@ -924,18 +1092,29 @@ ssh_set_newkeys(struct ssh *ssh, int mode) explicit_bzero(enc->key, enc->key_len); explicit_bzero(mac->key, mac->key_len); */ if ((comp->type == COMP_ZLIB || - (comp->type == COMP_DELAYED && + ((comp->type == COMP_DELAYED || comp->type == COMP_ZSTD) && state->after_authentication)) && comp->enabled == 0) { if ((r = ssh_packet_init_compression(ssh)) < 0) return r; - if (mode == MODE_OUT) { - if ((r = start_compression_out(ssh, 6)) != 0) - return r; + if (comp->type == COMP_ZSTD) { + if (mode == MODE_OUT) { + if ((r = start_compression_zstd_out(ssh)) != 0) + return r; + } else { + if ((r = start_compression_zstd_in(ssh)) != 0) + return r; + } + comp->enabled = COMP_ZSTD; } else { - if ((r = start_compression_in(ssh)) != 0) - return r; + if (mode == MODE_OUT) { + if ((r = start_compression_out(ssh, 6)) != 0) + return r; + } else { + if ((r = start_compression_in(ssh)) != 0) + return r; + } + comp->enabled = COMP_ZLIB; } - comp->enabled = 1; } /* * The 2^(blocksize*2) limit is too expensive for 3DES, @@ -1022,6 +1201,7 @@ ssh_packet_enable_delayed_compress(struct ssh *ssh) struct session_state *state = ssh->state; struct sshcomp *comp = NULL; int r, mode; + int type = 0; /* * Remember that we are past the authentication step, so rekeying @@ -1033,17 +1213,33 @@ ssh_packet_enable_delayed_compress(struct ssh *ssh) if (state->newkeys[mode] == NULL) continue; comp = &state->newkeys[mode]->comp; - if (comp && !comp->enabled && comp->type == COMP_DELAYED) { - if ((r = ssh_packet_init_compression(ssh)) != 0) + if (comp && !comp->enabled && comp->type) + type = comp->type; + if (type == COMP_DELAYED || type == COMP_ZSTD) { + if ((r = ssh_packet_init_compression(ssh)) != 0) { return r; - if (mode == MODE_OUT) { - if ((r = start_compression_out(ssh, 6)) != 0) - return r; - } else { - if ((r = start_compression_in(ssh)) != 0) - return r; } - comp->enabled = 1; + if (type == COMP_DELAYED) { + if (mode == MODE_OUT) { + if ((r = start_compression_out(ssh, 6)) != 0) + return r; + } else { + if ((r = start_compression_in(ssh)) != 0) + return r; + } + comp->enabled = COMP_ZLIB; + } else if (type == COMP_ZSTD) { + if (mode == MODE_OUT) { + if ((r = start_compression_zstd_out(ssh)) != 0) + return r; + } else { + if ((r = start_compression_zstd_in(ssh)) != 0) + return r; + } + comp->enabled = COMP_ZSTD; + } else { + return SSH_ERR_INTERNAL_ERROR; + } } } return 0; @@ -1104,9 +1300,15 @@ ssh_packet_send2_wrapped(struct ssh *ssh) if ((r = sshbuf_consume(state->outgoing_packet, 5)) != 0) goto out; sshbuf_reset(state->compression_buffer); - if ((r = compress_buffer(ssh, state->outgoing_packet, - state->compression_buffer)) != 0) - goto out; + if (comp->enabled == COMP_ZSTD) { + if ((r = compress_buffer_zstd(ssh, state->outgoing_packet, + state->compression_buffer)) != 0) + goto out; + } else { + if ((r = compress_buffer(ssh, state->outgoing_packet, + state->compression_buffer)) != 0) + goto out; + } sshbuf_reset(state->outgoing_packet); if ((r = sshbuf_put(state->outgoing_packet, "\0\0\0\0\0", 5)) != 0 || @@ -1657,9 +1859,15 @@ ssh_packet_read_poll2(struct ssh *ssh, u_char *typep, u_int32_t *seqnr_p) sshbuf_len(state->incoming_packet))); if (comp && comp->enabled) { sshbuf_reset(state->compression_buffer); - if ((r = uncompress_buffer(ssh, state->incoming_packet, - state->compression_buffer)) != 0) - goto out; + if (comp->enabled == COMP_ZSTD) { + if ((r = uncompress_buffer_zstd(ssh, state->incoming_packet, + state->compression_buffer)) != 0) + goto out; + } else { + if ((r = uncompress_buffer(ssh, state->incoming_packet, + state->compression_buffer)) != 0) + goto out; + } sshbuf_reset(state->incoming_packet); if ((r = sshbuf_putb(state->incoming_packet, state->compression_buffer)) != 0) diff --git a/readconf.c b/readconf.c index cf79498848f6d..f05aab2316c8a 100644 --- a/readconf.c +++ b/readconf.c @@ -899,8 +899,14 @@ static const struct multistate multistate_pubkey_auth[] = { { NULL, -1 } }; static const struct multistate multistate_compression[] = { +#if defined(WITH_ZLIB) || defined(HAVE_LIBZSTD) + { "yes", COMP_ALL_C }, +#endif #ifdef WITH_ZLIB - { "yes", COMP_ZLIB }, + { "zlib", COMP_ZLIB }, +#endif +#ifdef HAVE_LIBZSTD + { "zstd", COMP_ZSTD }, #endif { "no", COMP_NONE }, { NULL, -1 } diff --git a/servconf.c b/servconf.c index 2e039da8b95e8..a82ef128c79f7 100644 --- a/servconf.c +++ b/servconf.c @@ -375,11 +375,7 @@ fill_default_server_options(ServerOptions *options) options->permit_user_env_allowlist = NULL; } if (options->compression == -1) -#ifdef WITH_ZLIB - options->compression = COMP_DELAYED; -#else - options->compression = COMP_NONE; -#endif + options->compression = COMP_ALL_S; if (options->rekey_limit == -1) options->rekey_limit = 0; @@ -1303,9 +1299,15 @@ static const struct multistate multistate_permitrootlogin[] = { { NULL, -1 } }; static const struct multistate multistate_compression[] = { +#if defined(WITH_ZLIB) || defined(HAVE_LIBZSTD) + { "yes", COMP_ALL_S }, +#endif #ifdef WITH_ZLIB - { "yes", COMP_DELAYED }, { "delayed", COMP_DELAYED }, + { "zlib", COMP_DELAYED }, +#endif +#ifdef HAVE_LIBZSTD + { "zstd", COMP_ZSTD }, #endif { "no", COMP_NONE }, { NULL, -1 } diff --git a/ssh.c b/ssh.c index 918389bccba25..ae67808a36215 100644 --- a/ssh.c +++ b/ssh.c @@ -1011,8 +1011,8 @@ main(int ac, char **av) break; case 'C': -#ifdef WITH_ZLIB - options.compression = 1; +#if defined(HAVE_LIBZSTD) || defined(WITH_ZLIB) + options.compression = COMP_ALL_C; #else error("Compression not supported, disabling."); #endif -- 2.39.2