Michael S. Tsirkin
2021-Aug-02 20:20 UTC
[PATCH 1/4] VSOCK DRIVER: Add multi-cid support for guest
On Mon, Aug 02, 2021 at 08:07:17PM +0800, fuguancheng wrote:> This patch allowes the user to specify multiple additional CIDS > for the guest that can be used for communication between host > and guest. > > The guest reads the additional cids from the device config space. > The device config space layout can be found at uapi/linux/virtio_vsock.h > The existing ioctl call for device VHOST_VIRTIO with request code > VHOST_VSOCK_SET_GUEST_CID is modified to notify the host for the > additional guest CIDS. > > Signed-off-by: fuguancheng <fuguancheng at bytedance.com> > --- > drivers/vhost/vhost.h | 5 ++ > drivers/vhost/vsock.c | 173 +++++++++++++++++++++++++++++--------- > include/net/af_vsock.h | 1 + > include/uapi/linux/vhost.h | 7 ++ > include/uapi/linux/virtio_vsock.h | 3 +- > net/vmw_vsock/af_vsock.c | 6 +- > net/vmw_vsock/virtio_transport.c | 72 ++++++++++++++-- > net/vmw_vsock/vsock_loopback.c | 8 ++ > 8 files changed, 222 insertions(+), 53 deletions(-) > > diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h > index 638bb640d6b4..52bd143ccf0c 100644 > --- a/drivers/vhost/vhost.h > +++ b/drivers/vhost/vhost.h > @@ -25,6 +25,11 @@ struct vhost_work { > unsigned long flags; > }; > > +struct multi_cid_message { > + u32 number_cid; > + u64 *cid; > +}; > + > /* Poll a file (eventfd or socket) */ > /* Note: there's nothing vhost specific about this structure. */ > struct vhost_poll { > diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c > index f249622ef11b..f66c87de91b8 100644 > --- a/drivers/vhost/vsock.c > +++ b/drivers/vhost/vsock.c > @@ -43,12 +43,25 @@ enum { > static DEFINE_MUTEX(vhost_vsock_mutex); > static DEFINE_READ_MOSTLY_HASHTABLE(vhost_vsock_hash, 8); > > +struct vhost_vsock_ref { > + struct vhost_vsock *vsock; > + struct hlist_node ref_hash; > + u32 cid; > +}; > + > +static bool vhost_transport_contain_cid(u32 cid) > +{ > + if (cid == VHOST_VSOCK_DEFAULT_HOST_CID) > + return true; > + return false; > +} > + > struct vhost_vsock { > struct vhost_dev dev; > struct vhost_virtqueue vqs[2]; > > /* Link to global vhost_vsock_hash, writes use vhost_vsock_mutex */ > - struct hlist_node hash; > + struct vhost_vsock_ref *ref_list; > > struct vhost_work send_pkt_work; > spinlock_t send_pkt_list_lock; > @@ -56,7 +69,8 @@ struct vhost_vsock { > > atomic_t queued_replies; > > - u32 guest_cid; > + u32 *cids; > + u32 num_cid; > bool seqpacket_allow; > }; > > @@ -70,23 +84,49 @@ static u32 vhost_transport_get_local_cid(void) > */ > static struct vhost_vsock *vhost_vsock_get(u32 guest_cid) > { > - struct vhost_vsock *vsock; > + struct vhost_vsock_ref *ref; > > - hash_for_each_possible_rcu(vhost_vsock_hash, vsock, hash, guest_cid) { > - u32 other_cid = vsock->guest_cid; > + hash_for_each_possible_rcu(vhost_vsock_hash, ref, ref_hash, guest_cid) { > + u32 other_cid = ref->cid; > > /* Skip instances that have no CID yet */ > if (other_cid == 0) > continue; > > if (other_cid == guest_cid) > - return vsock; > + return ref->vsock; > > } > > return NULL; > } > > +static int check_if_cid_valid(u64 guest_cid, struct vhost_vsock *vsock) > +{ > + struct vhost_vsock *other; > + > + if (guest_cid <= VMADDR_CID_HOST || guest_cid == U32_MAX) > + return -EINVAL; > + > + /* 64-bit CIDs are not yet supported */ > + if (guest_cid > U32_MAX) > + return -EINVAL; > + /* Refuse if CID is assigned to the guest->host transport (i.e. nested > + * VM), to make the loopback work. > + */ > + if (vsock_find_cid(guest_cid)) > + return -EADDRINUSE; > + /* Refuse if CID is already in use */ > + mutex_lock(&vhost_vsock_mutex); > + other = vhost_vsock_get(guest_cid); > + if (other) { > + mutex_unlock(&vhost_vsock_mutex); > + return -EADDRINUSE; > + } > + mutex_unlock(&vhost_vsock_mutex); > + return 0; > +} > + > static void > vhost_transport_do_send_pkt(struct vhost_vsock *vsock, > struct vhost_virtqueue *vq) > @@ -427,6 +467,7 @@ static struct virtio_transport vhost_transport = { > .module = THIS_MODULE, > > .get_local_cid = vhost_transport_get_local_cid, > + .contain_cid = vhost_transport_contain_cid, > > .init = virtio_transport_do_socket_init, > .destruct = virtio_transport_destruct, > @@ -542,9 +583,9 @@ static void vhost_vsock_handle_tx_kick(struct vhost_work *work) > virtio_transport_deliver_tap_pkt(pkt); > > /* Only accept correctly addressed packets */ > - if (le64_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid && > - le64_to_cpu(pkt->hdr.dst_cid) => - vhost_transport_get_local_cid()) > + if (vsock->num_cid > 0 && > + (pkt->hdr.src_cid) == vsock->cids[0] && > + le64_to_cpu(pkt->hdr.dst_cid) == vhost_transport_get_local_cid()) > virtio_transport_recv_pkt(&vhost_transport, pkt); > else > virtio_transport_free_pkt(pkt); > @@ -655,6 +696,10 @@ static int vhost_vsock_stop(struct vhost_vsock *vsock) > > static void vhost_vsock_free(struct vhost_vsock *vsock) > { > + if (vsock->ref_list) > + kvfree(vsock->ref_list); > + if (vsock->cids) > + kvfree(vsock->cids); > kvfree(vsock); > } > > @@ -677,7 +722,9 @@ static int vhost_vsock_dev_open(struct inode *inode, struct file *file) > goto out; > } > > - vsock->guest_cid = 0; /* no CID assigned yet */ > + vsock->ref_list = NULL; > + vsock->cids = NULL; > + vsock->num_cid = 0; > > atomic_set(&vsock->queued_replies, 0); > > @@ -739,11 +786,14 @@ static void vhost_vsock_reset_orphans(struct sock *sk) > > static int vhost_vsock_dev_release(struct inode *inode, struct file *file) > { > + int index; > struct vhost_vsock *vsock = file->private_data; > > mutex_lock(&vhost_vsock_mutex); > - if (vsock->guest_cid) > - hash_del_rcu(&vsock->hash); > + if (vsock->num_cid) { > + for (index = 0; index < vsock->num_cid; index++) > + hash_del_rcu(&vsock->ref_list[index].ref_hash); > + } > mutex_unlock(&vhost_vsock_mutex); > > /* Wait for other CPUs to finish using vsock */ > @@ -774,41 +824,80 @@ static int vhost_vsock_dev_release(struct inode *inode, struct file *file) > return 0; > } > > -static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid) > +static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 __user *cids, u32 number_cid) > { > - struct vhost_vsock *other; > + u64 cid; > + int i, ret; > > - /* Refuse reserved CIDs */ > - if (guest_cid <= VMADDR_CID_HOST || > - guest_cid == U32_MAX) > + if (number_cid <= 0) > return -EINVAL; > - > - /* 64-bit CIDs are not yet supported */ > - if (guest_cid > U32_MAX) > - return -EINVAL; > - > - /* Refuse if CID is assigned to the guest->host transport (i.e. nested > - * VM), to make the loopback work. > - */ > - if (vsock_find_cid(guest_cid)) > - return -EADDRINUSE; > - > - /* Refuse if CID is already in use */ > - mutex_lock(&vhost_vsock_mutex); > - other = vhost_vsock_get(guest_cid); > - if (other && other != vsock) { > + /* delete the old CIDs. */ > + if (vsock->num_cid) { > + mutex_lock(&vhost_vsock_mutex); > + for (i = 0; i < vsock->num_cid; i++) > + hash_del_rcu(&vsock->ref_list[i].ref_hash); > mutex_unlock(&vhost_vsock_mutex); > - return -EADDRINUSE; > + kvfree(vsock->ref_list); > + vsock->ref_list = NULL; > + kvfree(vsock->cids); > + vsock->cids = NULL; > + } > + vsock->num_cid = number_cid; > + vsock->cids = kmalloc_array(vsock->num_cid, sizeof(u32), > + GFP_KERNEL | __GFP_RETRY_MAYFAIL); > + if (!vsock->cids) { > + vsock->num_cid = 0; > + ret = -ENOMEM; > + goto out; > + } > + vsock->ref_list = kvmalloc_array(vsock->num_cid, sizeof(*vsock->ref_list), > + GFP_KERNEL | __GFP_RETRY_MAYFAIL); > + if (!vsock->ref_list) { > + vsock->num_cid = 0; > + ret = -ENOMEM; > + goto out; > } > > - if (vsock->guest_cid) > - hash_del_rcu(&vsock->hash); > - > - vsock->guest_cid = guest_cid; > - hash_add_rcu(vhost_vsock_hash, &vsock->hash, vsock->guest_cid); > - mutex_unlock(&vhost_vsock_mutex); > + for (i = 0; i < number_cid; i++) { > + if (copy_from_user(&cid, cids + i, sizeof(cid))) { > + /* record where we failed, to clean up the ref in hash table. */ > + vsock->num_cid = i; > + ret = -EFAULT; > + goto out; > + } > + ret = check_if_cid_valid(cid, vsock); > + if (ret) { > + vsock->num_cid = i; > + goto out; > + } > > + vsock->cids[i] = (u32)cid; > + vsock->ref_list[i].cid = vsock->cids[i]; > + vsock->ref_list[i].vsock = vsock; > + mutex_lock(&vhost_vsock_mutex); > + hash_add_rcu(vhost_vsock_hash, &vsock->ref_list[i].ref_hash, > + vsock->cids[i]); > + mutex_unlock(&vhost_vsock_mutex); > + } > return 0; > + > +out: > + /* Handle the memory release here. */ > + if (vsock->num_cid) { > + mutex_lock(&vhost_vsock_mutex); > + for (i = 0; i < vsock->num_cid; i++) > + hash_del_rcu(&vsock->ref_list[i].ref_hash); > + mutex_unlock(&vhost_vsock_mutex); > + vsock->num_cid = 0; > + } > + if (vsock->ref_list) > + kvfree(vsock->ref_list); > + if (vsock->cids) > + kvfree(vsock->cids); > + /* Set it to null to prevent double release. */ > + vsock->ref_list = NULL; > + vsock->cids = NULL; > + return ret; > } > > static int vhost_vsock_set_features(struct vhost_vsock *vsock, u64 features) > @@ -852,16 +941,16 @@ static long vhost_vsock_dev_ioctl(struct file *f, unsigned int ioctl, > { > struct vhost_vsock *vsock = f->private_data; > void __user *argp = (void __user *)arg; > - u64 guest_cid; > u64 features; > int start; > int r; > + struct multi_cid_message cid_message; > > switch (ioctl) { > case VHOST_VSOCK_SET_GUEST_CID: > - if (copy_from_user(&guest_cid, argp, sizeof(guest_cid))) > + if (copy_from_user(&cid_message, argp, sizeof(cid_message))) > return -EFAULT; > - return vhost_vsock_set_cid(vsock, guest_cid); > + return vhost_vsock_set_cid(vsock, cid_message.cid, cid_message.number_cid); > case VHOST_VSOCK_SET_RUNNING: > if (copy_from_user(&start, argp, sizeof(start))) > return -EFAULT; > diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h > index ab207677e0a8..d0fc08fb9cac 100644 > --- a/include/net/af_vsock.h > +++ b/include/net/af_vsock.h > @@ -170,6 +170,7 @@ struct vsock_transport { > > /* Addressing. */ > u32 (*get_local_cid)(void); > + bool (*contain_cid)(u32 cid); > }; > > /**** CORE ****/ > diff --git a/include/uapi/linux/vhost.h b/include/uapi/linux/vhost.h > index c998860d7bbc..a3ea99f6fc7f 100644 > --- a/include/uapi/linux/vhost.h > +++ b/include/uapi/linux/vhost.h > @@ -17,6 +17,13 @@ > > #define VHOST_FILE_UNBIND -1 > > +/* structs used for hypervisors to send cid info. */ > + > +struct multi_cid_message { > + u32 number_cid; > + u64 *cid; > +}; > + > /* ioctls */ > > #define VHOST_VIRTIO 0xAF > diff --git a/include/uapi/linux/virtio_vsock.h b/include/uapi/linux/virtio_vsock.h > index 3dd3555b2740..0afc14446b01 100644 > --- a/include/uapi/linux/virtio_vsock.h > +++ b/include/uapi/linux/virtio_vsock.h > @@ -42,7 +42,8 @@ > #define VIRTIO_VSOCK_F_SEQPACKET 1 /* SOCK_SEQPACKET supported */ > > struct virtio_vsock_config { > - __le64 guest_cid; > + __le32 number_cid; > + __le64 cids[];Config space should be generally limited to ~256 bytes. That is < 32 cids. Enough? I would implement an interface where you write a number and read back a cid, instead.> } __attribute__((packed)); >You want a feature bit for this.> enum virtio_vsock_event_id { > diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c > index 3e02cc3b24f8..4e1fbe74013f 100644 > --- a/net/vmw_vsock/af_vsock.c > +++ b/net/vmw_vsock/af_vsock.c > @@ -507,13 +507,13 @@ EXPORT_SYMBOL_GPL(vsock_assign_transport); > > bool vsock_find_cid(unsigned int cid) > { > - if (transport_g2h && cid == transport_g2h->get_local_cid()) > + if (transport_g2h && transport_g2h->contain_cid(cid)) > return true; > > - if (transport_h2g && cid == VMADDR_CID_HOST) > + if (transport_h2g && transport_h2g->contain_cid(cid)) > return true; > > - if (transport_local && cid == VMADDR_CID_LOCAL) > + if (transport_local && transport_local->contain_cid(cid)) > return true; > > return false; > diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c > index e0c2c992ad9c..5f256a57d9ae 100644 > --- a/net/vmw_vsock/virtio_transport.c > +++ b/net/vmw_vsock/virtio_transport.c > @@ -61,10 +61,41 @@ struct virtio_vsock { > bool event_run; > struct virtio_vsock_event event_list[8]; > > - u32 guest_cid; > + /* The following fields are used to hold additional cids given by the hypervisor > + * such as qemu. > + */ > + u32 number_cid; > + u32 *cids; > + > bool seqpacket_allow; > }; > > +static bool virtio_transport_contain_cid(u32 cid) > +{ > + struct virtio_vsock *vsock; > + bool ret; > + u32 num_cid; > + > + num_cid = 0; > + rcu_read_lock(); > + vsock = rcu_dereference(the_virtio_vsock); > + if (!vsock || !vsock->number_cid) { > + ret = false; > + goto out_rcu; > + } > + > + for (num_cid = 0; num_cid < vsock->number_cid; num_cid++) { > + if (vsock->cids[num_cid] == cid) { > + ret = true; > + goto out_rcu; > + } > + } > + ret = false; > +out_rcu: > + rcu_read_unlock(); > + return ret; > +} > + > static u32 virtio_transport_get_local_cid(void) > { > struct virtio_vsock *vsock; > @@ -72,12 +103,12 @@ static u32 virtio_transport_get_local_cid(void) > > rcu_read_lock(); > vsock = rcu_dereference(the_virtio_vsock); > - if (!vsock) { > + if (!vsock || !vsock->number_cid) { > ret = VMADDR_CID_ANY; > goto out_rcu; > } > > - ret = vsock->guest_cid; > + ret = vsock->cids[0]; > out_rcu: > rcu_read_unlock(); > return ret; > @@ -176,7 +207,7 @@ virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt) > goto out_rcu; > } > > - if (le64_to_cpu(pkt->hdr.dst_cid) == vsock->guest_cid) { > + if (le64_to_cpu(pkt->hdr.dst_cid) == vsock->cids[0]) { > virtio_transport_free_pkt(pkt); > len = -ENODEV; > goto out_rcu; > @@ -368,10 +399,33 @@ static void virtio_vsock_update_guest_cid(struct virtio_vsock *vsock) > { > struct virtio_device *vdev = vsock->vdev; > __le64 guest_cid; > + __le32 number_cid; > + u32 index; > > - vdev->config->get(vdev, offsetof(struct virtio_vsock_config, guest_cid), > - &guest_cid, sizeof(guest_cid)); > - vsock->guest_cid = le64_to_cpu(guest_cid); > + vdev->config->get(vdev, offsetof(struct virtio_vsock_config, number_cid), > + &number_cid, sizeof(number_cid));need to handle existing devices without the feature.> + vsock->number_cid = le32_to_cpu(number_cid); > + > + /* number_cid must be greater than 0 in the config space > + * to use this feature. > + */ > + if (vsock->number_cid > 0) { > + vsock->cids = kmalloc_array(vsock->number_cid, sizeof(u32), GFP_KERNEL); > + if (!vsock->cids) { > + /* Space allocated failed, reset number_cid to 0. > + * only use the original guest_cid. > + */ > + vsock->number_cid = 0; > + } > + } > + > + for (index = 0; index < vsock->number_cid; index++) { > + vdev->config->get(vdev, > + offsetof(struct virtio_vsock_config, cids) > + + index * sizeof(uint64_t), > + &guest_cid, sizeof(guest_cid)); > + vsock->cids[index] = le64_to_cpu(guest_cid);You just drop high bits here. Unlikely to behave well if they are not 0.> + } > } > > /* event_lock must be held */ > @@ -451,6 +505,7 @@ static struct virtio_transport virtio_transport = { > .module = THIS_MODULE, > > .get_local_cid = virtio_transport_get_local_cid, > + .contain_cid = virtio_transport_contain_cid, > > .init = virtio_transport_do_socket_init, > .destruct = virtio_transport_destruct, > @@ -594,6 +649,8 @@ static int virtio_vsock_probe(struct virtio_device *vdev) > } > > vsock->vdev = vdev; > + vsock->cids = NULL; > + vsock->number_cid = 0; > > ret = virtio_find_vqs(vsock->vdev, VSOCK_VQ_MAX, > vsock->vqs, callbacks, names, > @@ -713,6 +770,7 @@ static void virtio_vsock_remove(struct virtio_device *vdev) > > mutex_unlock(&the_virtio_vsock_mutex); > > + kfree(vsock->cids); > kfree(vsock); > } > > diff --git a/net/vmw_vsock/vsock_loopback.c b/net/vmw_vsock/vsock_loopback.c > index 169a8cf65b39..3abbbaff34eb 100644 > --- a/net/vmw_vsock/vsock_loopback.c > +++ b/net/vmw_vsock/vsock_loopback.c > @@ -63,6 +63,13 @@ static int vsock_loopback_cancel_pkt(struct vsock_sock *vsk) > return 0; > } > > +static bool vsock_loopback_contain_cid(u32 cid) > +{ > + if (cid == VMADDR_CID_LOCAL) > + return true; > + return false; > +} > + > static bool vsock_loopback_seqpacket_allow(u32 remote_cid); > > static struct virtio_transport loopback_transport = { > @@ -70,6 +77,7 @@ static struct virtio_transport loopback_transport = { > .module = THIS_MODULE, > > .get_local_cid = vsock_loopback_get_local_cid, > + .contain_cid = vsock_loopback_contain_cid, > > .init = virtio_transport_do_socket_init, > .destruct = virtio_transport_destruct, > -- > 2.11.0 > >