Michael S. Tsirkin
2021-May-17 23:39 UTC
[RFC PATCH 17/17] virtio_ring: Add validation for used length
On Mon, May 17, 2021 at 05:08:36PM +0800, Xie Yongji wrote:> This adds validation for used length (might come > from an untrusted device) when it will be used by > virtio device driver. > > Signed-off-by: Xie Yongji <xieyongji at bytedance.com> > --- > drivers/virtio/virtio_ring.c | 22 +++++++++++++++++++--- > 1 file changed, 19 insertions(+), 3 deletions(-) > > diff --git a/drivers/virtio/virtio_ring.c b/drivers/virtio/virtio_ring.c > index d999a1d6d271..7d4845d06f21 100644 > --- a/drivers/virtio/virtio_ring.c > +++ b/drivers/virtio/virtio_ring.c > @@ -68,11 +68,13 @@ > struct vring_desc_state_split { > void *data; /* Data for callback. */ > struct vring_desc *indir_desc; /* Indirect descriptor, if any. */ > + u32 in_len; /* Total length of writable buffer */ > }; > > struct vring_desc_state_packed { > void *data; /* Data for callback. */ > struct vring_packed_desc *indir_desc; /* Indirect descriptor, if any. */ > + u32 in_len; /* Total length of writable buffer */ > u16 num; /* Descriptor list length. */ > u16 last; /* The last desc state in a list. */ > };Hmm for packed it's aligned to 64 bit anyway, so we are not making it any worse. But for split this pushes struct size up by 1/3 increasing cache pressure.> @@ -486,7 +488,7 @@ static inline int virtqueue_add_split(struct virtqueue *_vq, > struct vring_virtqueue *vq = to_vvq(_vq); > struct scatterlist *sg; > struct vring_desc *desc; > - unsigned int i, n, avail, descs_used, prev, err_idx; > + unsigned int i, n, avail, descs_used, prev, err_idx, in_len = 0; > int head; > bool indirect; > > @@ -570,6 +572,7 @@ static inline int virtqueue_add_split(struct virtqueue *_vq, > VRING_DESC_F_NEXT | > VRING_DESC_F_WRITE, > indirect); > + in_len += sg->length; > } > } > /* Last one doesn't continue. */ > @@ -604,6 +607,7 @@ static inline int virtqueue_add_split(struct virtqueue *_vq, > > /* Store token and indirect buffer state. */ > vq->split.desc_state[head].data = data; > + vq->split.desc_state[head].in_len = in_len; > if (indirect) > vq->split.desc_state[head].indir_desc = desc; > else > @@ -784,6 +788,10 @@ static void *virtqueue_get_buf_ctx_split(struct virtqueue *_vq, > BAD_RING(vq, "id %u is not a head!\n", i); > return NULL; > } > + if (unlikely(len && vq->split.desc_state[i].in_len < *len)) { > + BAD_RING(vq, "id %u has invalid length: %u!\n", i, *len); > + return NULL; > + } > > /* detach_buf_split clears data, so grab it now. */ > ret = vq->split.desc_state[i].data; > @@ -1059,7 +1067,7 @@ static int virtqueue_add_indirect_packed(struct vring_virtqueue *vq, > { > struct vring_packed_desc *desc; > struct scatterlist *sg; > - unsigned int i, n, err_idx; > + unsigned int i, n, err_idx, in_len = 0; > u16 head, id; > dma_addr_t addr; > > @@ -1084,6 +1092,7 @@ static int virtqueue_add_indirect_packed(struct vring_virtqueue *vq, > if (vring_mapping_error(vq, addr)) > goto unmap_release; > > + in_len += (n < out_sgs) ? 0 : sg->length; > desc[i].flags = cpu_to_le16(n < out_sgs ? > 0 : VRING_DESC_F_WRITE); > desc[i].addr = cpu_to_le64(addr); > @@ -1141,6 +1150,7 @@ static int virtqueue_add_indirect_packed(struct vring_virtqueue *vq, > vq->packed.desc_state[id].data = data; > vq->packed.desc_state[id].indir_desc = desc; > vq->packed.desc_state[id].last = id; > + vq->packed.desc_state[id].in_len = in_len; > > vq->num_added += 1; > > @@ -1173,7 +1183,7 @@ static inline int virtqueue_add_packed(struct virtqueue *_vq, > struct vring_virtqueue *vq = to_vvq(_vq); > struct vring_packed_desc *desc; > struct scatterlist *sg; > - unsigned int i, n, c, descs_used, err_idx; > + unsigned int i, n, c, descs_used, err_idx, in_len = 0; > __le16 head_flags, flags; > u16 head, id, prev, curr, avail_used_flags; > > @@ -1223,6 +1233,7 @@ static inline int virtqueue_add_packed(struct virtqueue *_vq, > if (vring_mapping_error(vq, addr)) > goto unmap_release; > > + in_len += (n < out_sgs) ? 0 : sg->length; > flags = cpu_to_le16(vq->packed.avail_used_flags | > (++c == total_sg ? 0 : VRING_DESC_F_NEXT) | > (n < out_sgs ? 0 : VRING_DESC_F_WRITE)); > @@ -1268,6 +1279,7 @@ static inline int virtqueue_add_packed(struct virtqueue *_vq, > vq->packed.desc_state[id].data = data; > vq->packed.desc_state[id].indir_desc = ctx; > vq->packed.desc_state[id].last = prev; > + vq->packed.desc_state[id].in_len = in_len; > > /* > * A driver MUST NOT make the first descriptor in the list > @@ -1456,6 +1468,10 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq, > BAD_RING(vq, "id %u is not a head!\n", id); > return NULL; > } > + if (unlikely(len && vq->packed.desc_state[id].in_len < *len)) { > + BAD_RING(vq, "id %u has invalid length: %u!\n", id, *len); > + return NULL; > + } > > /* detach_buf_packed clears data, so grab it now. */ > ret = vq->packed.desc_state[id].data; > -- > 2.11.0
Jason Wang
2021-May-25 01:31 UTC
[RFC PATCH 17/17] virtio_ring: Add validation for used length
? 2021/5/18 ??7:39, Michael S. Tsirkin ??:> On Mon, May 17, 2021 at 05:08:36PM +0800, Xie Yongji wrote: >> This adds validation for used length (might come >> from an untrusted device) when it will be used by >> virtio device driver. >> >> Signed-off-by: Xie Yongji <xieyongji at bytedance.com> >> --- >> drivers/virtio/virtio_ring.c | 22 +++++++++++++++++++--- >> 1 file changed, 19 insertions(+), 3 deletions(-) >> >> diff --git a/drivers/virtio/virtio_ring.c b/drivers/virtio/virtio_ring.c >> index d999a1d6d271..7d4845d06f21 100644 >> --- a/drivers/virtio/virtio_ring.c >> +++ b/drivers/virtio/virtio_ring.c >> @@ -68,11 +68,13 @@ >> struct vring_desc_state_split { >> void *data; /* Data for callback. */ >> struct vring_desc *indir_desc; /* Indirect descriptor, if any. */ >> + u32 in_len; /* Total length of writable buffer */ >> }; >> >> struct vring_desc_state_packed { >> void *data; /* Data for callback. */ >> struct vring_packed_desc *indir_desc; /* Indirect descriptor, if any. */ >> + u32 in_len; /* Total length of writable buffer */ >> u16 num; /* Descriptor list length. */ >> u16 last; /* The last desc state in a list. */ >> }; > > Hmm for packed it's aligned to 64 bit anyway, so we are not making it > any worse. But for split this pushes struct size up by 1/3 increasing > cache pressure.We can eliminate this by validating through virtio device driver instead of virtio core. E.g for virtio-net we know the rx buffer size so there's no need to store in twice in the core. Thanks> > >> @@ -486,7 +488,7 @@ static inline int virtqueue_add_split(struct virtqueue *_vq, >> struct vring_virtqueue *vq = to_vvq(_vq); >> struct scatterlist *sg; >> struct vring_desc *desc; >> - unsigned int i, n, avail, descs_used, prev, err_idx; >> + unsigned int i, n, avail, descs_used, prev, err_idx, in_len = 0; >> int head; >> bool indirect; >> >> @@ -570,6 +572,7 @@ static inline int virtqueue_add_split(struct virtqueue *_vq, >> VRING_DESC_F_NEXT | >> VRING_DESC_F_WRITE, >> indirect); >> + in_len += sg->length; >> } >> } >> /* Last one doesn't continue. */ >> @@ -604,6 +607,7 @@ static inline int virtqueue_add_split(struct virtqueue *_vq, >> >> /* Store token and indirect buffer state. */ >> vq->split.desc_state[head].data = data; >> + vq->split.desc_state[head].in_len = in_len; >> if (indirect) >> vq->split.desc_state[head].indir_desc = desc; >> else >> @@ -784,6 +788,10 @@ static void *virtqueue_get_buf_ctx_split(struct virtqueue *_vq, >> BAD_RING(vq, "id %u is not a head!\n", i); >> return NULL; >> } >> + if (unlikely(len && vq->split.desc_state[i].in_len < *len)) { >> + BAD_RING(vq, "id %u has invalid length: %u!\n", i, *len); >> + return NULL; >> + } >> >> /* detach_buf_split clears data, so grab it now. */ >> ret = vq->split.desc_state[i].data; >> @@ -1059,7 +1067,7 @@ static int virtqueue_add_indirect_packed(struct vring_virtqueue *vq, >> { >> struct vring_packed_desc *desc; >> struct scatterlist *sg; >> - unsigned int i, n, err_idx; >> + unsigned int i, n, err_idx, in_len = 0; >> u16 head, id; >> dma_addr_t addr; >> >> @@ -1084,6 +1092,7 @@ static int virtqueue_add_indirect_packed(struct vring_virtqueue *vq, >> if (vring_mapping_error(vq, addr)) >> goto unmap_release; >> >> + in_len += (n < out_sgs) ? 0 : sg->length; >> desc[i].flags = cpu_to_le16(n < out_sgs ? >> 0 : VRING_DESC_F_WRITE); >> desc[i].addr = cpu_to_le64(addr); >> @@ -1141,6 +1150,7 @@ static int virtqueue_add_indirect_packed(struct vring_virtqueue *vq, >> vq->packed.desc_state[id].data = data; >> vq->packed.desc_state[id].indir_desc = desc; >> vq->packed.desc_state[id].last = id; >> + vq->packed.desc_state[id].in_len = in_len; >> >> vq->num_added += 1; >> >> @@ -1173,7 +1183,7 @@ static inline int virtqueue_add_packed(struct virtqueue *_vq, >> struct vring_virtqueue *vq = to_vvq(_vq); >> struct vring_packed_desc *desc; >> struct scatterlist *sg; >> - unsigned int i, n, c, descs_used, err_idx; >> + unsigned int i, n, c, descs_used, err_idx, in_len = 0; >> __le16 head_flags, flags; >> u16 head, id, prev, curr, avail_used_flags; >> >> @@ -1223,6 +1233,7 @@ static inline int virtqueue_add_packed(struct virtqueue *_vq, >> if (vring_mapping_error(vq, addr)) >> goto unmap_release; >> >> + in_len += (n < out_sgs) ? 0 : sg->length; >> flags = cpu_to_le16(vq->packed.avail_used_flags | >> (++c == total_sg ? 0 : VRING_DESC_F_NEXT) | >> (n < out_sgs ? 0 : VRING_DESC_F_WRITE)); >> @@ -1268,6 +1279,7 @@ static inline int virtqueue_add_packed(struct virtqueue *_vq, >> vq->packed.desc_state[id].data = data; >> vq->packed.desc_state[id].indir_desc = ctx; >> vq->packed.desc_state[id].last = prev; >> + vq->packed.desc_state[id].in_len = in_len; >> >> /* >> * A driver MUST NOT make the first descriptor in the list >> @@ -1456,6 +1468,10 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq, >> BAD_RING(vq, "id %u is not a head!\n", id); >> return NULL; >> } >> + if (unlikely(len && vq->packed.desc_state[id].in_len < *len)) { >> + BAD_RING(vq, "id %u has invalid length: %u!\n", id, *len); >> + return NULL; >> + } >> >> /* detach_buf_packed clears data, so grab it now. */ >> ret = vq->packed.desc_state[id].data; >> -- >> 2.11.0