This patch tries to implement an device IOTLB for vhost. This could be used with for co-operation with userspace IOMMU implementation (qemu) for a secure DMA environment (DMAR) in guest. The idea is simple. When vhost meets an IOTLB miss, it will request the assistance of userspace to do the translation, this is done through: - when there's a IOTLB miss, it will notify userspace through vhost_net fd and then userspace read the fault address, size and access from vhost fd. - userspace write the translation result back to vhost fd, vhost can then update its IOTLB. The codes were optimized for fixed mapping users e.g dpdk in guest. It will be slow if dynamic mappings were used in guest. We could do optimizations on top. The codes were designed to be architecture independent. It should be easily ported to any architecture. Stress tested with l2fwd/vfio in guest with 4K/2M/1G page size. On 1G hugepage case, 100% TLB hit rate were noticed. Have a benchmark on this. Test was done with l2fwd in guest.For 2MB page, no difference in 64B performance and I notice a 4%-5% drop for 1500B performance compare to UIO in guest. We can add some shortcut to bypass the IOTLB for virtqueue accessing, but I think it's better to do this on top. Changes from V1: - Fix i386 build warnings - Drop access paramter for vhost_get_vq_desc() (fix VHOST SCSI build error) Changes from RFC V3: - rebase on latest - minor tweak on commit log - use VHOST_ACCESS_RO instead of VHOST_ACCESS_WO in vhost_copy_from_user() - switch to use atomic userspace access helper in vhost_get/put_user() - remove debug codes in vhost_iotlb_miss() - use FIFO instead of FILO when doing TLB replacement - fix unbalanced lock in vhost_process_iotlb_msg() Changes from RFC V2: - introduce memory accessors for vhost - switch from ioctls to oridinary file read/write for iotlb miss and updating - do not assume virtqueue were virtually mapped contiguously, all virtqueue access were done throug IOTLB - verify memory access during IOTLB update and fail early - introduce a module parameter for the size of IOTLB Changes from RFC V1: - support any size/range of updating and invalidation through introducing the interval tree. - convert from per device iotlb request to per virtqueue iotlb request, this solves the possible deadlock in V1. - read/write permission check support. Please review. Jason Wang (3): vhost: introduce vhost memory accessors vhost: convert pre sorted vhost memory array to interval tree vhost: device IOTLB API drivers/vhost/net.c | 58 +++- drivers/vhost/vhost.c | 806 +++++++++++++++++++++++++++++++++++++++------ drivers/vhost/vhost.h | 57 +++- include/uapi/linux/vhost.h | 28 ++ 4 files changed, 829 insertions(+), 120 deletions(-) -- 1.8.3.1
This patch introduces vhost memory accessors which were just wrappers
for userspace address access helpers. This is a requirement for vhost
device iotlb implementation which will add iotlb translations in those
accessors.
Signed-off-by: Jason Wang <jasowang at redhat.com>
---
 drivers/vhost/vhost.c | 50 +++++++++++++++++++++++++++++++++++---------------
 1 file changed, 35 insertions(+), 15 deletions(-)
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 669fef1..9f2a63a 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -651,6 +651,22 @@ static int memory_access_ok(struct vhost_dev *d, struct
vhost_memory *mem,
 	return 1;
 }
 
+#define vhost_put_user(vq, x, ptr)  __put_user(x, ptr)
+
+static int vhost_copy_to_user(struct vhost_virtqueue *vq, void *to,
+			      const void *from, unsigned size)
+{
+	return __copy_to_user(to, from, size);
+}
+
+#define vhost_get_user(vq, x, ptr) __get_user(x, ptr)
+
+static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to,
+				void *from, unsigned size)
+{
+	return __copy_from_user(to, from, size);
+}
+
 static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
 			struct vring_desc __user *desc,
 			struct vring_avail __user *avail,
@@ -1156,7 +1172,8 @@ EXPORT_SYMBOL_GPL(vhost_log_write);
 static int vhost_update_used_flags(struct vhost_virtqueue *vq)
 {
 	void __user *used;
-	if (__put_user(cpu_to_vhost16(vq, vq->used_flags),
&vq->used->flags) < 0)
+	if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags),
+			   &vq->used->flags) < 0)
 		return -EFAULT;
 	if (unlikely(vq->log_used)) {
 		/* Make sure the flag is seen before log. */
@@ -1174,7 +1191,8 @@ static int vhost_update_used_flags(struct vhost_virtqueue
*vq)
 
 static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16
avail_event)
 {
-	if (__put_user(cpu_to_vhost16(vq, vq->avail_idx), vhost_avail_event(vq)))
+	if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx),
+			   vhost_avail_event(vq)))
 		return -EFAULT;
 	if (unlikely(vq->log_used)) {
 		void __user *used;
@@ -1212,7 +1230,7 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq)
 		r = -EFAULT;
 		goto err;
 	}
-	r = __get_user(last_used_idx, &vq->used->idx);
+	r = vhost_get_user(vq, last_used_idx, &vq->used->idx);
 	if (r)
 		goto err;
 	vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx);
@@ -1392,7 +1410,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
 
 	/* Check it isn't doing very strange things with descriptor numbers. */
 	last_avail_idx = vq->last_avail_idx;
-	if (unlikely(__get_user(avail_idx, &vq->avail->idx))) {
+	if (unlikely(vhost_get_user(vq, avail_idx, &vq->avail->idx))) {
 		vq_err(vq, "Failed to access avail idx at %p\n",
 		       &vq->avail->idx);
 		return -EFAULT;
@@ -1414,8 +1432,8 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
 
 	/* Grab the next descriptor number they're advertising, and increment
 	 * the index we've seen. */
-	if (unlikely(__get_user(ring_head,
-				&vq->avail->ring[last_avail_idx & (vq->num - 1)]))) {
+	if (unlikely(vhost_get_user(vq, ring_head,
+		     &vq->avail->ring[last_avail_idx & (vq->num - 1)]))) {
 		vq_err(vq, "Failed to read head: idx %d address %p\n",
 		       last_avail_idx,
 		       &vq->avail->ring[last_avail_idx % vq->num]);
@@ -1450,7 +1468,8 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
 			       i, vq->num, head);
 			return -EINVAL;
 		}
-		ret = __copy_from_user(&desc, vq->desc + i, sizeof desc);
+		ret = vhost_copy_from_user(vq, &desc, vq->desc + i,
+					   sizeof desc);
 		if (unlikely(ret)) {
 			vq_err(vq, "Failed to get descriptor: idx %d addr %p\n",
 			       i, vq->desc + i);
@@ -1538,15 +1557,15 @@ static int __vhost_add_used_n(struct vhost_virtqueue
*vq,
 	start = vq->last_used_idx & (vq->num - 1);
 	used = vq->used->ring + start;
 	if (count == 1) {
-		if (__put_user(heads[0].id, &used->id)) {
+		if (vhost_put_user(vq, heads[0].id, &used->id)) {
 			vq_err(vq, "Failed to write used id");
 			return -EFAULT;
 		}
-		if (__put_user(heads[0].len, &used->len)) {
+		if (vhost_put_user(vq, heads[0].len, &used->len)) {
 			vq_err(vq, "Failed to write used len");
 			return -EFAULT;
 		}
-	} else if (__copy_to_user(used, heads, count * sizeof *used)) {
+	} else if (vhost_copy_to_user(vq, used, heads, count * sizeof *used)) {
 		vq_err(vq, "Failed to write used");
 		return -EFAULT;
 	}
@@ -1590,7 +1609,8 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct
vring_used_elem *heads,
 
 	/* Make sure buffer is written before we update index. */
 	smp_wmb();
-	if (__put_user(cpu_to_vhost16(vq, vq->last_used_idx),
&vq->used->idx)) {
+	if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx),
+			   &vq->used->idx)) {
 		vq_err(vq, "Failed to increment used idx");
 		return -EFAULT;
 	}
@@ -1622,7 +1642,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct
vhost_virtqueue *vq)
 
 	if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
 		__virtio16 flags;
-		if (__get_user(flags, &vq->avail->flags)) {
+		if (vhost_get_user(vq, flags, &vq->avail->flags)) {
 			vq_err(vq, "Failed to get flags");
 			return true;
 		}
@@ -1636,7 +1656,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct
vhost_virtqueue *vq)
 	if (unlikely(!v))
 		return true;
 
-	if (__get_user(event, vhost_used_event(vq))) {
+	if (vhost_get_user(vq, event, vhost_used_event(vq))) {
 		vq_err(vq, "Failed to get used event idx");
 		return true;
 	}
@@ -1678,7 +1698,7 @@ bool vhost_vq_avail_empty(struct vhost_dev *dev, struct
vhost_virtqueue *vq)
 	__virtio16 avail_idx;
 	int r;
 
-	r = __get_user(avail_idx, &vq->avail->idx);
+	r = vhost_get_user(vq, avail_idx, &vq->avail->idx);
 	if (r)
 		return false;
 
@@ -1713,7 +1733,7 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct
vhost_virtqueue *vq)
 	/* They could have slipped one in as we were doing that: make
 	 * sure it's written, then check again. */
 	smp_mb();
-	r = __get_user(avail_idx, &vq->avail->idx);
+	r = vhost_get_user(vq, avail_idx, &vq->avail->idx);
 	if (r) {
 		vq_err(vq, "Failed to check avail idx at %p: %d\n",
 		       &vq->avail->idx, r);
-- 
1.8.3.1
Jason Wang
2016-Jun-23  06:04 UTC
[PATCH V2 2/3] vhost: convert pre sorted vhost memory array to interval tree
Current pre-sorted memory region array has some limitations for future
device IOTLB conversion:
1) need extra work for adding and removing a single region, and it's
   expected to be slow because of sorting or memory re-allocation.
2) need extra work of removing a large range which may intersect
   several regions with different size.
3) need trick for a replacement policy like LRU
To overcome the above shortcomings, this patch convert it to interval
tree which can easily address the above issue with almost no extra
work.
The patch could be used for:
- Extend the current API and only let the userspace to send diffs of
  memory table.
- Simplify Device IOTLB implementation.
Signed-off-by: Jason Wang <jasowang at redhat.com>
---
 drivers/vhost/net.c   |   8 +--
 drivers/vhost/vhost.c | 182 ++++++++++++++++++++++++++++----------------------
 drivers/vhost/vhost.h |  27 ++++++--
 3 files changed, 128 insertions(+), 89 deletions(-)
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index 1d3e45f..fc66956 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -1038,20 +1038,20 @@ static long vhost_net_reset_owner(struct vhost_net *n)
 	struct socket *tx_sock = NULL;
 	struct socket *rx_sock = NULL;
 	long err;
-	struct vhost_memory *memory;
+	struct vhost_umem *umem;
 
 	mutex_lock(&n->dev.mutex);
 	err = vhost_dev_check_owner(&n->dev);
 	if (err)
 		goto done;
-	memory = vhost_dev_reset_owner_prepare();
-	if (!memory) {
+	umem = vhost_dev_reset_owner_prepare();
+	if (!umem) {
 		err = -ENOMEM;
 		goto done;
 	}
 	vhost_net_stop(n, &tx_sock, &rx_sock);
 	vhost_net_flush(n);
-	vhost_dev_reset_owner(&n->dev, memory);
+	vhost_dev_reset_owner(&n->dev, umem);
 	vhost_net_vq_reset(n);
 done:
 	mutex_unlock(&n->dev.mutex);
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 9f2a63a..166e779 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -27,6 +27,7 @@
 #include <linux/cgroup.h>
 #include <linux/module.h>
 #include <linux/sort.h>
+#include <linux/interval_tree_generic.h>
 
 #include "vhost.h"
 
@@ -42,6 +43,10 @@ enum {
 #define vhost_used_event(vq) ((__virtio16 __user
*)&vq->avail->ring[vq->num])
 #define vhost_avail_event(vq) ((__virtio16 __user
*)&vq->used->ring[vq->num])
 
+INTERVAL_TREE_DEFINE(struct vhost_umem_node,
+		     rb, __u64, __subtree_last,
+		     START, LAST, , vhost_umem_interval_tree);
+
 #ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY
 static void vhost_disable_cross_endian(struct vhost_virtqueue *vq)
 {
@@ -300,10 +305,10 @@ static void vhost_vq_reset(struct vhost_dev *dev,
 	vq->call_ctx = NULL;
 	vq->call = NULL;
 	vq->log_ctx = NULL;
-	vq->memory = NULL;
 	vhost_reset_is_le(vq);
 	vhost_disable_cross_endian(vq);
 	vq->busyloop_timeout = 0;
+	vq->umem = NULL;
 }
 
 static int vhost_worker(void *data)
@@ -407,7 +412,7 @@ void vhost_dev_init(struct vhost_dev *dev,
 	mutex_init(&dev->mutex);
 	dev->log_ctx = NULL;
 	dev->log_file = NULL;
-	dev->memory = NULL;
+	dev->umem = NULL;
 	dev->mm = NULL;
 	spin_lock_init(&dev->work_lock);
 	INIT_LIST_HEAD(&dev->work_list);
@@ -512,27 +517,36 @@ err_mm:
 }
 EXPORT_SYMBOL_GPL(vhost_dev_set_owner);
 
-struct vhost_memory *vhost_dev_reset_owner_prepare(void)
+static void *vhost_kvzalloc(unsigned long size)
 {
-	return kmalloc(offsetof(struct vhost_memory, regions), GFP_KERNEL);
+	void *n = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT);
+
+	if (!n)
+		n = vzalloc(size);
+	return n;
+}
+
+struct vhost_umem *vhost_dev_reset_owner_prepare(void)
+{
+	return vhost_kvzalloc(sizeof(struct vhost_umem));
 }
 EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
 
 /* Caller should have device mutex */
-void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory)
+void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_umem *umem)
 {
 	int i;
 
 	vhost_dev_cleanup(dev, true);
 
 	/* Restore memory to default empty mapping. */
-	memory->nregions = 0;
-	dev->memory = memory;
+	INIT_LIST_HEAD(&umem->umem_list);
+	dev->umem = umem;
 	/* We don't need VQ locks below since vhost_dev_cleanup makes sure
 	 * VQs aren't running.
 	 */
 	for (i = 0; i < dev->nvqs; ++i)
-		dev->vqs[i]->memory = memory;
+		dev->vqs[i]->umem = umem;
 }
 EXPORT_SYMBOL_GPL(vhost_dev_reset_owner);
 
@@ -549,6 +563,21 @@ void vhost_dev_stop(struct vhost_dev *dev)
 }
 EXPORT_SYMBOL_GPL(vhost_dev_stop);
 
+static void vhost_umem_clean(struct vhost_umem *umem)
+{
+	struct vhost_umem_node *node, *tmp;
+
+	if (!umem)
+		return;
+
+	list_for_each_entry_safe(node, tmp, &umem->umem_list, link) {
+		vhost_umem_interval_tree_remove(node, &umem->umem_tree);
+		list_del(&node->link);
+		kvfree(node);
+	}
+	kvfree(umem);
+}
+
 /* Caller should have device mutex if and only if locked is set */
 void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
 {
@@ -575,8 +604,8 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
 		fput(dev->log_file);
 	dev->log_file = NULL;
 	/* No one will access memory at this point */
-	kvfree(dev->memory);
-	dev->memory = NULL;
+	vhost_umem_clean(dev->umem);
+	dev->umem = NULL;
 	WARN_ON(!list_empty(&dev->work_list));
 	if (dev->worker) {
 		kthread_stop(dev->worker);
@@ -602,25 +631,25 @@ static int log_access_ok(void __user *log_base, u64 addr,
unsigned long sz)
 }
 
 /* Caller should have vq mutex and device mutex. */
-static int vq_memory_access_ok(void __user *log_base, struct vhost_memory *mem,
+static int vq_memory_access_ok(void __user *log_base, struct vhost_umem *umem,
 			       int log_all)
 {
-	int i;
+	struct vhost_umem_node *node;
 
-	if (!mem)
+	if (!umem)
 		return 0;
 
-	for (i = 0; i < mem->nregions; ++i) {
-		struct vhost_memory_region *m = mem->regions + i;
-		unsigned long a = m->userspace_addr;
-		if (m->memory_size > ULONG_MAX)
+	list_for_each_entry(node, &umem->umem_list, link) {
+		unsigned long a = node->userspace_addr;
+
+		if (node->size > ULONG_MAX)
 			return 0;
 		else if (!access_ok(VERIFY_WRITE, (void __user *)a,
-				    m->memory_size))
+				    node->size))
 			return 0;
 		else if (log_all && !log_access_ok(log_base,
-						   m->guest_phys_addr,
-						   m->memory_size))
+						   node->start,
+						   node->size))
 			return 0;
 	}
 	return 1;
@@ -628,7 +657,7 @@ static int vq_memory_access_ok(void __user *log_base, struct
vhost_memory *mem,
 
 /* Can we switch to this memory table? */
 /* Caller should have device mutex but not vq mutex */
-static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem,
+static int memory_access_ok(struct vhost_dev *d, struct vhost_umem *umem,
 			    int log_all)
 {
 	int i;
@@ -641,7 +670,8 @@ static int memory_access_ok(struct vhost_dev *d, struct
vhost_memory *mem,
 		log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL);
 		/* If ring is inactive, will check when it's enabled. */
 		if (d->vqs[i]->private_data)
-			ok = vq_memory_access_ok(d->vqs[i]->log_base, mem, log);
+			ok = vq_memory_access_ok(d->vqs[i]->log_base,
+						 umem, log);
 		else
 			ok = 1;
 		mutex_unlock(&d->vqs[i]->mutex);
@@ -684,7 +714,7 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned
int num,
 /* Caller should have device mutex but not vq mutex */
 int vhost_log_access_ok(struct vhost_dev *dev)
 {
-	return memory_access_ok(dev, dev->memory, 1);
+	return memory_access_ok(dev, dev->umem, 1);
 }
 EXPORT_SYMBOL_GPL(vhost_log_access_ok);
 
@@ -695,7 +725,7 @@ static int vq_log_access_ok(struct vhost_virtqueue *vq,
 {
 	size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
 
-	return vq_memory_access_ok(log_base, vq->memory,
+	return vq_memory_access_ok(log_base, vq->umem,
 				   vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
 		(!vq->log_used || log_access_ok(log_base, vq->log_addr,
 					sizeof *vq->used +
@@ -711,28 +741,12 @@ int vhost_vq_access_ok(struct vhost_virtqueue *vq)
 }
 EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
 
-static int vhost_memory_reg_sort_cmp(const void *p1, const void *p2)
-{
-	const struct vhost_memory_region *r1 = p1, *r2 = p2;
-	if (r1->guest_phys_addr < r2->guest_phys_addr)
-		return 1;
-	if (r1->guest_phys_addr > r2->guest_phys_addr)
-		return -1;
-	return 0;
-}
-
-static void *vhost_kvzalloc(unsigned long size)
-{
-	void *n = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT);
-
-	if (!n)
-		n = vzalloc(size);
-	return n;
-}
-
 static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user
*m)
 {
-	struct vhost_memory mem, *newmem, *oldmem;
+	struct vhost_memory mem, *newmem;
+	struct vhost_memory_region *region;
+	struct vhost_umem_node *node;
+	struct vhost_umem *newumem, *oldumem;
 	unsigned long size = offsetof(struct vhost_memory, regions);
 	int i;
 
@@ -752,24 +766,52 @@ static long vhost_set_memory(struct vhost_dev *d, struct
vhost_memory __user *m)
 		kvfree(newmem);
 		return -EFAULT;
 	}
-	sort(newmem->regions, newmem->nregions, sizeof(*newmem->regions),
-		vhost_memory_reg_sort_cmp, NULL);
 
-	if (!memory_access_ok(d, newmem, 0)) {
+	newumem = vhost_kvzalloc(sizeof(*newumem));
+	if (!newumem) {
 		kvfree(newmem);
-		return -EFAULT;
+		return -ENOMEM;
 	}
-	oldmem = d->memory;
-	d->memory = newmem;
+
+	newumem->umem_tree = RB_ROOT;
+	INIT_LIST_HEAD(&newumem->umem_list);
+
+	for (region = newmem->regions;
+	     region < newmem->regions + mem.nregions;
+	     region++) {
+		node = vhost_kvzalloc(sizeof(*node));
+		if (!node)
+			goto err;
+		node->start = region->guest_phys_addr;
+		node->size = region->memory_size;
+		node->last = node->start + node->size - 1;
+		node->userspace_addr = region->userspace_addr;
+		INIT_LIST_HEAD(&node->link);
+		list_add_tail(&node->link, &newumem->umem_list);
+		vhost_umem_interval_tree_insert(node, &newumem->umem_tree);
+	}
+
+	if (!memory_access_ok(d, newumem, 0))
+		goto err;
+
+	oldumem = d->umem;
+	d->umem = newumem;
 
 	/* All memory accesses are done under some VQ mutex. */
 	for (i = 0; i < d->nvqs; ++i) {
 		mutex_lock(&d->vqs[i]->mutex);
-		d->vqs[i]->memory = newmem;
+		d->vqs[i]->umem = newumem;
 		mutex_unlock(&d->vqs[i]->mutex);
 	}
-	kvfree(oldmem);
+
+	kvfree(newmem);
+	vhost_umem_clean(oldumem);
 	return 0;
+
+err:
+	vhost_umem_clean(newumem);
+	kvfree(newmem);
+	return -EFAULT;
 }
 
 long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp)
@@ -1072,28 +1114,6 @@ done:
 }
 EXPORT_SYMBOL_GPL(vhost_dev_ioctl);
 
-static const struct vhost_memory_region *find_region(struct vhost_memory *mem,
-						     __u64 addr, __u32 len)
-{
-	const struct vhost_memory_region *reg;
-	int start = 0, end = mem->nregions;
-
-	while (start < end) {
-		int slot = start + (end - start) / 2;
-		reg = mem->regions + slot;
-		if (addr >= reg->guest_phys_addr)
-			end = slot;
-		else
-			start = slot + 1;
-	}
-
-	reg = mem->regions + start;
-	if (addr >= reg->guest_phys_addr &&
-		reg->guest_phys_addr + reg->memory_size > addr)
-		return reg;
-	return NULL;
-}
-
 /* TODO: This is really inefficient.  We need something like get_user()
  * (instruction directly accesses the data, with an exception table entry
  * returning -EFAULT). See Documentation/x86/exception-tables.txt.
@@ -1244,29 +1264,29 @@ EXPORT_SYMBOL_GPL(vhost_vq_init_access);
 static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
 			  struct iovec iov[], int iov_size)
 {
-	const struct vhost_memory_region *reg;
-	struct vhost_memory *mem;
+	const struct vhost_umem_node *node;
+	struct vhost_umem *umem = vq->umem;
 	struct iovec *_iov;
 	u64 s = 0;
 	int ret = 0;
 
-	mem = vq->memory;
 	while ((u64)len > s) {
 		u64 size;
 		if (unlikely(ret >= iov_size)) {
 			ret = -ENOBUFS;
 			break;
 		}
-		reg = find_region(mem, addr, len);
-		if (unlikely(!reg)) {
+		node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
+							addr, addr + len - 1);
+		if (node == NULL || node->start > addr) {
 			ret = -EFAULT;
 			break;
 		}
 		_iov = iov + ret;
-		size = reg->memory_size - addr + reg->guest_phys_addr;
+		size = node->size - addr + node->start;
 		_iov->iov_len = min((u64)len - s, size);
 		_iov->iov_base = (void __user *)(unsigned long)
-			(reg->userspace_addr + addr - reg->guest_phys_addr);
+			(node->userspace_addr + addr - node->start);
 		s += size;
 		addr += size;
 		++ret;
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index d36d8be..b93b6a0 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -53,6 +53,25 @@ struct vhost_log {
 	u64 len;
 };
 
+#define START(node) ((node)->start)
+#define LAST(node) ((node)->last)
+
+struct vhost_umem_node {
+	struct rb_node rb;
+	struct list_head link;
+	__u64 start;
+	__u64 last;
+	__u64 size;
+	__u64 userspace_addr;
+	__u64 flags_padding;
+	__u64 __subtree_last;
+};
+
+struct vhost_umem {
+	struct rb_root umem_tree;
+	struct list_head umem_list;
+};
+
 /* The virtqueue structure describes a queue attached to a device. */
 struct vhost_virtqueue {
 	struct vhost_dev *dev;
@@ -101,7 +120,7 @@ struct vhost_virtqueue {
 	struct iovec *indirect;
 	struct vring_used_elem *heads;
 	/* Protected by virtqueue mutex. */
-	struct vhost_memory *memory;
+	struct vhost_umem *umem;
 	void *private_data;
 	u64 acked_features;
 	/* Log write descriptors */
@@ -119,7 +138,6 @@ struct vhost_virtqueue {
 };
 
 struct vhost_dev {
-	struct vhost_memory *memory;
 	struct mm_struct *mm;
 	struct mutex mutex;
 	struct vhost_virtqueue **vqs;
@@ -129,14 +147,15 @@ struct vhost_dev {
 	spinlock_t work_lock;
 	struct list_head work_list;
 	struct task_struct *worker;
+	struct vhost_umem *umem;
 };
 
 void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int
nvqs);
 long vhost_dev_set_owner(struct vhost_dev *dev);
 bool vhost_dev_has_owner(struct vhost_dev *dev);
 long vhost_dev_check_owner(struct vhost_dev *);
-struct vhost_memory *vhost_dev_reset_owner_prepare(void);
-void vhost_dev_reset_owner(struct vhost_dev *, struct vhost_memory *);
+struct vhost_umem *vhost_dev_reset_owner_prepare(void);
+void vhost_dev_reset_owner(struct vhost_dev *, struct vhost_umem *);
 void vhost_dev_cleanup(struct vhost_dev *, bool locked);
 void vhost_dev_stop(struct vhost_dev *);
 long vhost_dev_ioctl(struct vhost_dev *, unsigned int ioctl, void __user
*argp);
-- 
1.8.3.1
This patch tries to implement an device IOTLB for vhost. This could be
used with for co-operation with userspace(qemu) implementation of DMA
remapping.
The idea is simple, cache the translation in a software device IOTLB
(which was implemented as interval tree) in vhost and use vhost_net
file descriptor for reporting IOTLB miss and IOTLB
update/invalidation. When vhost meets an IOTLB miss, the fault
address, size and access could be read from the file. After userspace
finishes the translation, it write the translated address to the
vhost_net file to update the device IOTLB.
When device IOTLB (VHOST_F_DEVICE_IOTLB) is enabled all vq address
set by ioctl were treated as iova instead of virtual address and the
accessing could only be done through IOTLB instead of direct
userspace memory access. Before each rounds or vq processing, all vq
metadata were prefetched in device IOTLB to make sure no translation
fault happens during vq processing.
In most cases, virtqueue were mapped contiguous even in virtual
address. So the IOTLB translation for virtqueue itself maybe a little
bit slower. We can add fast path on top of this patch.
Signed-off-by: Jason Wang <jasowang at redhat.com>
---
 drivers/vhost/net.c        |  50 +++-
 drivers/vhost/vhost.c      | 634 ++++++++++++++++++++++++++++++++++++++++++---
 drivers/vhost/vhost.h      |  32 ++-
 include/uapi/linux/vhost.h |  28 ++
 4 files changed, 697 insertions(+), 47 deletions(-)
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index fc66956..4ccebad 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -61,7 +61,8 @@ MODULE_PARM_DESC(experimental_zcopytx, "Enable Zero Copy
TX;"
 enum {
 	VHOST_NET_FEATURES = VHOST_FEATURES |
 			 (1ULL << VHOST_NET_F_VIRTIO_NET_HDR) |
-			 (1ULL << VIRTIO_NET_F_MRG_RXBUF)
+			 (1ULL << VIRTIO_NET_F_MRG_RXBUF) |
+			 (1ULL << VHOST_F_DEVICE_IOTLB)
 };
 
 enum {
@@ -334,7 +335,7 @@ static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
 {
 	unsigned long uninitialized_var(endtime);
 	int r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
-				    out_num, in_num, NULL, NULL);
+				  out_num, in_num, NULL, NULL);
 
 	if (r == vq->num && vq->busyloop_timeout) {
 		preempt_disable();
@@ -344,7 +345,7 @@ static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
 			cpu_relax_lowlatency();
 		preempt_enable();
 		r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
-					out_num, in_num, NULL, NULL);
+				      out_num, in_num, NULL, NULL);
 	}
 
 	return r;
@@ -377,6 +378,9 @@ static void handle_tx(struct vhost_net *net)
 	if (!sock)
 		goto out;
 
+	if (!vq_iotlb_prefetch(vq))
+		goto out;
+
 	vhost_disable_notify(&net->dev, vq);
 
 	hdr_size = nvq->vhost_hlen;
@@ -638,6 +642,10 @@ static void handle_rx(struct vhost_net *net)
 	sock = vq->private_data;
 	if (!sock)
 		goto out;
+
+	if (!vq_iotlb_prefetch(vq))
+		goto out;
+
 	vhost_disable_notify(&net->dev, vq);
 	vhost_net_disable_vq(net, vq);
 
@@ -1086,6 +1094,11 @@ static int vhost_net_set_features(struct vhost_net *n,
u64 features)
 		mutex_unlock(&n->dev.mutex);
 		return -EFAULT;
 	}
+	if ((features & (1ULL << VHOST_F_DEVICE_IOTLB))) {
+		if (vhost_init_device_iotlb(&n->dev, true))
+			return -EFAULT;
+	}
+
 	for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
 		mutex_lock(&n->vqs[i].vq.mutex);
 		n->vqs[i].vq.acked_features = features;
@@ -1168,9 +1181,40 @@ static long vhost_net_compat_ioctl(struct file *f,
unsigned int ioctl,
 }
 #endif
 
+static ssize_t vhost_net_chr_read_iter(struct kiocb *iocb, struct iov_iter *to)
+{
+	struct file *file = iocb->ki_filp;
+	struct vhost_net *n = file->private_data;
+	struct vhost_dev *dev = &n->dev;
+	int noblock = file->f_flags & O_NONBLOCK;
+
+	return vhost_chr_read_iter(dev, to, noblock);
+}
+
+static ssize_t vhost_net_chr_write_iter(struct kiocb *iocb,
+					struct iov_iter *from)
+{
+	struct file *file = iocb->ki_filp;
+	struct vhost_net *n = file->private_data;
+	struct vhost_dev *dev = &n->dev;
+
+	return vhost_chr_write_iter(dev, from);
+}
+
+static unsigned int vhost_net_chr_poll(struct file *file, poll_table *wait)
+{
+	struct vhost_net *n = file->private_data;
+	struct vhost_dev *dev = &n->dev;
+
+	return vhost_chr_poll(file, dev, wait);
+}
+
 static const struct file_operations vhost_net_fops = {
 	.owner          = THIS_MODULE,
 	.release        = vhost_net_release,
+	.read_iter      = vhost_net_chr_read_iter,
+	.write_iter     = vhost_net_chr_write_iter,
+	.poll           = vhost_net_chr_poll,
 	.unlocked_ioctl = vhost_net_ioctl,
 #ifdef CONFIG_COMPAT
 	.compat_ioctl   = vhost_net_compat_ioctl,
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 166e779..11d2f55 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -35,6 +35,10 @@ static ushort max_mem_regions = 64;
 module_param(max_mem_regions, ushort, 0444);
 MODULE_PARM_DESC(max_mem_regions,
 	"Maximum number of memory regions in memory map. (default: 64)");
+static int max_iotlb_entries = 2048;
+module_param(max_iotlb_entries, int, 0444);
+MODULE_PARM_DESC(max_iotlb_entries,
+	"Maximum number of iotlb entries. (default: 2048)");
 
 enum {
 	VHOST_MEMORY_F_LOG = 0x1,
@@ -309,6 +313,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
 	vhost_disable_cross_endian(vq);
 	vq->busyloop_timeout = 0;
 	vq->umem = NULL;
+	vq->iotlb = NULL;
 }
 
 static int vhost_worker(void *data)
@@ -413,9 +418,14 @@ void vhost_dev_init(struct vhost_dev *dev,
 	dev->log_ctx = NULL;
 	dev->log_file = NULL;
 	dev->umem = NULL;
+	dev->iotlb = NULL;
 	dev->mm = NULL;
 	spin_lock_init(&dev->work_lock);
 	INIT_LIST_HEAD(&dev->work_list);
+	init_waitqueue_head(&dev->wait);
+	INIT_LIST_HEAD(&dev->read_list);
+	INIT_LIST_HEAD(&dev->pending_list);
+	spin_lock_init(&dev->iotlb_lock);
 	dev->worker = NULL;
 
 	for (i = 0; i < dev->nvqs; ++i) {
@@ -563,6 +573,15 @@ void vhost_dev_stop(struct vhost_dev *dev)
 }
 EXPORT_SYMBOL_GPL(vhost_dev_stop);
 
+static void vhost_umem_free(struct vhost_umem *umem,
+			    struct vhost_umem_node *node)
+{
+	vhost_umem_interval_tree_remove(node, &umem->umem_tree);
+	list_del(&node->link);
+	kfree(node);
+	umem->numem--;
+}
+
 static void vhost_umem_clean(struct vhost_umem *umem)
 {
 	struct vhost_umem_node *node, *tmp;
@@ -570,14 +589,31 @@ static void vhost_umem_clean(struct vhost_umem *umem)
 	if (!umem)
 		return;
 
-	list_for_each_entry_safe(node, tmp, &umem->umem_list, link) {
-		vhost_umem_interval_tree_remove(node, &umem->umem_tree);
-		list_del(&node->link);
-		kvfree(node);
-	}
+	list_for_each_entry_safe(node, tmp, &umem->umem_list, link)
+		vhost_umem_free(umem, node);
+
 	kvfree(umem);
 }
 
+static void vhost_clear_msg(struct vhost_dev *dev)
+{
+	struct vhost_msg_node *node, *n;
+
+	spin_lock(&dev->iotlb_lock);
+
+	list_for_each_entry_safe(node, n, &dev->read_list, node) {
+		list_del(&node->node);
+		kfree(node);
+	}
+
+	list_for_each_entry_safe(node, n, &dev->pending_list, node) {
+		list_del(&node->node);
+		kfree(node);
+	}
+
+	spin_unlock(&dev->iotlb_lock);
+}
+
 /* Caller should have device mutex if and only if locked is set */
 void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
 {
@@ -606,6 +642,10 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
 	/* No one will access memory at this point */
 	vhost_umem_clean(dev->umem);
 	dev->umem = NULL;
+	vhost_umem_clean(dev->iotlb);
+	dev->iotlb = NULL;
+	vhost_clear_msg(dev);
+	wake_up_interruptible_poll(&dev->wait, POLLIN | POLLRDNORM);
 	WARN_ON(!list_empty(&dev->work_list));
 	if (dev->worker) {
 		kthread_stop(dev->worker);
@@ -681,28 +721,379 @@ static int memory_access_ok(struct vhost_dev *d, struct
vhost_umem *umem,
 	return 1;
 }
 
-#define vhost_put_user(vq, x, ptr)  __put_user(x, ptr)
+static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
+			  struct iovec iov[], int iov_size, int access);
 
 static int vhost_copy_to_user(struct vhost_virtqueue *vq, void *to,
 			      const void *from, unsigned size)
 {
-	return __copy_to_user(to, from, size);
-}
+	int ret;
 
-#define vhost_get_user(vq, x, ptr) __get_user(x, ptr)
+	if (!vq->iotlb)
+		return __copy_to_user(to, from, size);
+	else {
+		/* This function should be called after iotlb
+		 * prefetch, which means we're sure that all vq
+		 * could be access through iotlb. So -EAGAIN should
+		 * not happen in this case.
+		 */
+		/* TODO: more fast path */
+		struct iov_iter t;
+		ret = translate_desc(vq, (u64)(uintptr_t)to, size, vq->iotlb_iov,
+				     ARRAY_SIZE(vq->iotlb_iov),
+				     VHOST_ACCESS_WO);
+		if (ret < 0)
+			goto out;
+		iov_iter_init(&t, WRITE, vq->iotlb_iov, ret, size);
+		ret = copy_to_iter(from, size, &t);
+		if (ret == size)
+			ret = 0;
+	}
+out:
+	return ret;
+}
 
 static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to,
 				void *from, unsigned size)
 {
-	return __copy_from_user(to, from, size);
+	int ret;
+
+	if (!vq->iotlb)
+		return __copy_from_user(to, from, size);
+	else {
+		/* This function should be called after iotlb
+		 * prefetch, which means we're sure that vq
+		 * could be access through iotlb. So -EAGAIN should
+		 * not happen in this case.
+		 */
+		/* TODO: more fast path */
+		struct iov_iter f;
+		ret = translate_desc(vq, (u64)(uintptr_t)from, size, vq->iotlb_iov,
+				     ARRAY_SIZE(vq->iotlb_iov),
+				     VHOST_ACCESS_RO);
+		if (ret < 0) {
+			vq_err(vq, "IOTLB translation failure: uaddr "
+			       "%p size 0x%llx\n", from,
+			       (unsigned long long) size);
+			goto out;
+		}
+		iov_iter_init(&f, READ, vq->iotlb_iov, ret, size);
+		ret = copy_from_iter(to, size, &f);
+		if (ret == size)
+			ret = 0;
+	}
+
+out:
+	return ret;
+}
+
+static void __user *__vhost_get_user(struct vhost_virtqueue *vq,
+				     void *addr, unsigned size)
+{
+	int ret;
+
+	/* This function should be called after iotlb
+	 * prefetch, which means we're sure that vq
+	 * could be access through iotlb. So -EAGAIN should
+	 * not happen in this case.
+	 */
+	/* TODO: more fast path */
+	ret = translate_desc(vq, (u64)(uintptr_t)addr, size, vq->iotlb_iov,
+			     ARRAY_SIZE(vq->iotlb_iov),
+			     VHOST_ACCESS_RO);
+	if (ret < 0) {
+		vq_err(vq, "IOTLB translation failure: uaddr "
+			"%p size 0x%llx\n", addr,
+			(unsigned long long) size);
+		return NULL;
+	}
+
+	if (ret != 1 || vq->iotlb_iov[0].iov_len != size) {
+		vq_err(vq, "Non atomic userspace memory access: uaddr "
+			"%p size 0x%llx\n", addr,
+			(unsigned long long) size);
+		return NULL;
+	}
+
+	return vq->iotlb_iov[0].iov_base;
+}
+
+#define vhost_put_user(vq, x, ptr) \
+({ \
+	int ret = -EFAULT; \
+	if (!vq->iotlb) { \
+		ret = __put_user(x, ptr); \
+	} else { \
+		__typeof__(ptr) to = \
+			(__typeof__(ptr)) __vhost_get_user(vq, ptr, sizeof(*ptr)); \
+		if (to != NULL) \
+			ret = __put_user(x, to); \
+		else \
+			ret = -EFAULT;	\
+	} \
+	ret; \
+})
+
+#define vhost_get_user(vq, x, ptr) \
+({ \
+	int ret; \
+	if (!vq->iotlb) { \
+		ret = __get_user(x, ptr); \
+	} else { \
+		__typeof__(ptr) from = \
+			(__typeof__(ptr)) __vhost_get_user(vq, ptr, sizeof(*ptr)); \
+		if (from != NULL) \
+			ret = __get_user(x, from); \
+		else \
+			ret = -EFAULT; \
+	} \
+	ret; \
+})
+
+static void vhost_dev_lock_vqs(struct vhost_dev *d)
+{
+	int i = 0;
+	for (i = 0; i < d->nvqs; ++i)
+		mutex_lock(&d->vqs[i]->mutex);
+}
+
+static void vhost_dev_unlock_vqs(struct vhost_dev *d)
+{
+	int i = 0;
+	for (i = 0; i < d->nvqs; ++i)
+		mutex_unlock(&d->vqs[i]->mutex);
+}
+
+static int vhost_new_umem_range(struct vhost_umem *umem,
+				u64 start, u64 size, u64 end,
+				u64 userspace_addr, int perm)
+{
+	struct vhost_umem_node *tmp, *node = kmalloc(sizeof(*node), GFP_ATOMIC);
+
+	if (!node)
+		return -ENOMEM;
+
+	if (umem->numem == max_iotlb_entries) {
+		tmp = list_first_entry(&umem->umem_list, typeof(*tmp), link);
+		vhost_umem_free(umem, tmp);
+	}
+
+	node->start = start;
+	node->size = size;
+	node->last = end;
+	node->userspace_addr = userspace_addr;
+	node->perm = perm;
+	INIT_LIST_HEAD(&node->link);
+	list_add_tail(&node->link, &umem->umem_list);
+	vhost_umem_interval_tree_insert(node, &umem->umem_tree);
+	umem->numem++;
+
+	return 0;
+}
+
+static void vhost_del_umem_range(struct vhost_umem *umem,
+				 u64 start, u64 end)
+{
+	struct vhost_umem_node *node;
+
+	while ((node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
+							   start, end)))
+		vhost_umem_free(umem, node);
+}
+
+static void vhost_iotlb_notify_vq(struct vhost_dev *d,
+				  struct vhost_iotlb_msg *msg)
+{
+	struct vhost_msg_node *node, *n;
+
+	spin_lock(&d->iotlb_lock);
+
+	list_for_each_entry_safe(node, n, &d->pending_list, node) {
+		struct vhost_iotlb_msg *vq_msg = &node->msg.iotlb;
+		if (msg->iova <= vq_msg->iova &&
+		    msg->iova + msg->size - 1 > vq_msg->iova &&
+		    vq_msg->type == VHOST_IOTLB_MISS) {
+			vhost_poll_queue(&node->vq->poll);
+			list_del(&node->node);
+			kfree(node);
+		}
+	}
+
+	spin_unlock(&d->iotlb_lock);
+}
+
+static int umem_access_ok(u64 uaddr, u64 size, int access)
+{
+	if ((access & VHOST_ACCESS_RO) &&
+	    !access_ok(VERIFY_READ, uaddr, size))
+		return -EFAULT;
+	if ((access & VHOST_ACCESS_WO) &&
+	    !access_ok(VERIFY_WRITE, uaddr, size))
+		return -EFAULT;
+	return 0;
+}
+
+int vhost_process_iotlb_msg(struct vhost_dev *dev,
+			    struct vhost_iotlb_msg *msg)
+{
+	int ret = 0;
+
+	vhost_dev_lock_vqs(dev);
+	switch (msg->type) {
+	case VHOST_IOTLB_UPDATE:
+		if (!dev->iotlb) {
+			ret = -EFAULT;
+			break;
+		}
+		if (umem_access_ok(msg->uaddr, msg->size, msg->perm)) {
+			ret = -EFAULT;
+			break;
+		}
+		if (vhost_new_umem_range(dev->iotlb, msg->iova, msg->size,
+					 msg->iova + msg->size - 1,
+					 msg->uaddr, msg->perm)) {
+			ret = -ENOMEM;
+			break;
+		}
+		vhost_iotlb_notify_vq(dev, msg);
+		break;
+	case VHOST_IOTLB_INVALIDATE:
+		vhost_del_umem_range(dev->iotlb, msg->iova,
+				     msg->iova + msg->size - 1);
+		break;
+	default:
+		ret = -EINVAL;
+		break;
+	}
+
+	vhost_dev_unlock_vqs(dev);
+	return ret;
+}
+ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
+			     struct iov_iter *from)
+{
+	struct vhost_msg_node node;
+	unsigned size = sizeof(struct vhost_msg);
+	size_t ret;
+	int err;
+
+	if (iov_iter_count(from) < size)
+		return 0;
+	ret = copy_from_iter(&node.msg, size, from);
+	if (ret != size)
+		goto done;
+
+	switch (node.msg.type) {
+	case VHOST_IOTLB_MSG:
+		err = vhost_process_iotlb_msg(dev, &node.msg.iotlb);
+		if (err)
+			ret = err;
+		break;
+	default:
+		ret = -EINVAL;
+		break;
+	}
+
+done:
+	return ret;
+}
+EXPORT_SYMBOL(vhost_chr_write_iter);
+
+unsigned int vhost_chr_poll(struct file *file, struct vhost_dev *dev,
+			    poll_table *wait)
+{
+	unsigned int mask = 0;
+
+	poll_wait(file, &dev->wait, wait);
+
+	if (!list_empty(&dev->read_list))
+		mask |= POLLIN | POLLRDNORM;
+
+	return mask;
+}
+EXPORT_SYMBOL(vhost_chr_poll);
+
+ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to,
+			    int noblock)
+{
+	DEFINE_WAIT(wait);
+	struct vhost_msg_node *node;
+	ssize_t ret = 0;
+	unsigned size = sizeof(struct vhost_msg);
+
+	if (iov_iter_count(to) < size)
+		return 0;
+
+	while (1) {
+		if (!noblock)
+			prepare_to_wait(&dev->wait, &wait,
+					TASK_INTERRUPTIBLE);
+
+		node = vhost_dequeue_msg(dev, &dev->read_list);
+		if (node)
+			break;
+		if (noblock) {
+			ret = -EAGAIN;
+			break;
+		}
+		if (signal_pending(current)) {
+			ret = -ERESTARTSYS;
+			break;
+		}
+		if (!dev->iotlb) {
+			ret = -EBADFD;
+			break;
+		}
+
+		schedule();
+	}
+
+	if (!noblock)
+		finish_wait(&dev->wait, &wait);
+
+	if (node) {
+		ret = copy_to_iter(&node->msg, size, to);
+
+		if (ret != size || node->msg.type != VHOST_IOTLB_MISS) {
+			kfree(node);
+			return ret;
+		}
+
+		vhost_enqueue_msg(dev, &dev->pending_list, node);
+	}
+
+	return ret;
+}
+EXPORT_SYMBOL_GPL(vhost_chr_read_iter);
+
+static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access)
+{
+	struct vhost_dev *dev = vq->dev;
+	struct vhost_msg_node *node;
+	struct vhost_iotlb_msg *msg;
+
+	node = vhost_new_msg(vq, VHOST_IOTLB_MISS);
+	if (!node)
+		return -ENOMEM;
+
+	msg = &node->msg.iotlb;
+	msg->type = VHOST_IOTLB_MISS;
+	msg->iova = iova;
+	msg->perm = access;
+
+	vhost_enqueue_msg(dev, &dev->read_list, node);
+
+	return 0;
 }
 
 static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
 			struct vring_desc __user *desc,
 			struct vring_avail __user *avail,
 			struct vring_used __user *used)
+
 {
 	size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
+
 	return access_ok(VERIFY_READ, desc, num * sizeof *desc) &&
 	       access_ok(VERIFY_READ, avail,
 			 sizeof *avail + num * sizeof *avail->ring + s) &&
@@ -710,6 +1101,54 @@ static int vq_access_ok(struct vhost_virtqueue *vq,
unsigned int num,
 			sizeof *used + num * sizeof *used->ring + s);
 }
 
+static int iotlb_access_ok(struct vhost_virtqueue *vq,
+			   int access, u64 addr, u64 len)
+{
+	const struct vhost_umem_node *node;
+	struct vhost_umem *umem = vq->iotlb;
+	u64 s = 0, size;
+
+	while (len > s) {
+		node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
+							   addr,
+							   addr + len - 1);
+		if (node == NULL || node->start > addr) {
+			vhost_iotlb_miss(vq, addr, access);
+			return false;
+		} else if (!(node->perm & access)) {
+			/* Report the possible access violation by
+			 * request another translation from userspace.
+			 */
+			return false;
+		}
+
+		size = node->size - addr + node->start;
+		s += size;
+		addr += size;
+	}
+
+	return true;
+}
+
+int vq_iotlb_prefetch(struct vhost_virtqueue *vq)
+{
+	size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
+	unsigned int num = vq->num;
+
+	if (!vq->iotlb)
+		return 1;
+
+	return iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->desc,
+			       num * sizeof *vq->desc) &&
+	       iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->avail,
+			       sizeof *vq->avail +
+			       num * sizeof *vq->avail->ring + s) &&
+	       iotlb_access_ok(vq, VHOST_ACCESS_WO, (u64)(uintptr_t)vq->used,
+			       sizeof *vq->used +
+			       num * sizeof *vq->used->ring + s);
+}
+EXPORT_SYMBOL_GPL(vq_iotlb_prefetch);
+
 /* Can we log writes? */
 /* Caller should have device mutex but not vq mutex */
 int vhost_log_access_ok(struct vhost_dev *dev)
@@ -736,16 +1175,35 @@ static int vq_log_access_ok(struct vhost_virtqueue *vq,
 /* Caller should have vq mutex and device mutex */
 int vhost_vq_access_ok(struct vhost_virtqueue *vq)
 {
+	if (vq->iotlb) {
+		/* When device IOTLB was used, the access validation
+		 * will be validated during prefetching.
+		 */
+		return 1;
+	}
 	return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used)
&&
 		vq_log_access_ok(vq, vq->log_base);
 }
 EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
 
+static struct vhost_umem *vhost_umem_alloc(void)
+{
+	struct vhost_umem *umem = vhost_kvzalloc(sizeof(*umem));
+
+	if (!umem)
+		return NULL;
+
+	umem->umem_tree = RB_ROOT;
+	umem->numem = 0;
+	INIT_LIST_HEAD(&umem->umem_list);
+
+	return umem;
+}
+
 static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user
*m)
 {
 	struct vhost_memory mem, *newmem;
 	struct vhost_memory_region *region;
-	struct vhost_umem_node *node;
 	struct vhost_umem *newumem, *oldumem;
 	unsigned long size = offsetof(struct vhost_memory, regions);
 	int i;
@@ -767,28 +1225,23 @@ static long vhost_set_memory(struct vhost_dev *d, struct
vhost_memory __user *m)
 		return -EFAULT;
 	}
 
-	newumem = vhost_kvzalloc(sizeof(*newumem));
+	newumem = vhost_umem_alloc();
 	if (!newumem) {
 		kvfree(newmem);
 		return -ENOMEM;
 	}
 
-	newumem->umem_tree = RB_ROOT;
-	INIT_LIST_HEAD(&newumem->umem_list);
-
 	for (region = newmem->regions;
 	     region < newmem->regions + mem.nregions;
 	     region++) {
-		node = vhost_kvzalloc(sizeof(*node));
-		if (!node)
+		if (vhost_new_umem_range(newumem,
+					 region->guest_phys_addr,
+					 region->memory_size,
+					 region->guest_phys_addr +
+					 region->memory_size - 1,
+					 region->userspace_addr,
+					 VHOST_ACCESS_RW))
 			goto err;
-		node->start = region->guest_phys_addr;
-		node->size = region->memory_size;
-		node->last = node->start + node->size - 1;
-		node->userspace_addr = region->userspace_addr;
-		INIT_LIST_HEAD(&node->link);
-		list_add_tail(&node->link, &newumem->umem_list);
-		vhost_umem_interval_tree_insert(node, &newumem->umem_tree);
 	}
 
 	if (!memory_access_ok(d, newumem, 0))
@@ -1032,6 +1485,30 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl,
void __user *argp)
 }
 EXPORT_SYMBOL_GPL(vhost_vring_ioctl);
 
+int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled)
+{
+	struct vhost_umem *niotlb, *oiotlb;
+	int i;
+
+	niotlb = vhost_umem_alloc();
+	if (!niotlb)
+		return -ENOMEM;
+
+	oiotlb = d->iotlb;
+	d->iotlb = niotlb;
+
+	for (i = 0; i < d->nvqs; ++i) {
+		mutex_lock(&d->vqs[i]->mutex);
+		d->vqs[i]->iotlb = niotlb;
+		mutex_unlock(&d->vqs[i]->mutex);
+	}
+
+	vhost_umem_clean(oiotlb);
+
+	return 0;
+}
+EXPORT_SYMBOL_GPL(vhost_init_device_iotlb);
+
 /* Caller must have device mutex */
 long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user
*argp)
 {
@@ -1246,15 +1723,20 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq)
 	if (r)
 		goto err;
 	vq->signalled_used_valid = false;
-	if (!access_ok(VERIFY_READ, &vq->used->idx, sizeof
vq->used->idx)) {
+	if (!vq->iotlb &&
+	    !access_ok(VERIFY_READ, &vq->used->idx, sizeof
vq->used->idx)) {
 		r = -EFAULT;
 		goto err;
 	}
 	r = vhost_get_user(vq, last_used_idx, &vq->used->idx);
-	if (r)
+	if (r) {
+		vq_err(vq, "Can't access used idx at %p\n",
+		       &vq->used->idx);
 		goto err;
+	}
 	vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx);
 	return 0;
+
 err:
 	vq->is_le = is_le;
 	return r;
@@ -1262,10 +1744,11 @@ err:
 EXPORT_SYMBOL_GPL(vhost_vq_init_access);
 
 static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
-			  struct iovec iov[], int iov_size)
+			  struct iovec iov[], int iov_size, int access)
 {
 	const struct vhost_umem_node *node;
-	struct vhost_umem *umem = vq->umem;
+	struct vhost_dev *dev = vq->dev;
+	struct vhost_umem *umem = dev->iotlb ? dev->iotlb : dev->umem;
 	struct iovec *_iov;
 	u64 s = 0;
 	int ret = 0;
@@ -1276,12 +1759,21 @@ static int translate_desc(struct vhost_virtqueue *vq,
u64 addr, u32 len,
 			ret = -ENOBUFS;
 			break;
 		}
+
 		node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
 							addr, addr + len - 1);
 		if (node == NULL || node->start > addr) {
-			ret = -EFAULT;
+			if (umem != dev->iotlb) {
+				ret = -EFAULT;
+				break;
+			}
+			ret = -EAGAIN;
+			break;
+		} else if (!(node->perm & access)) {
+			ret = -EPERM;
 			break;
 		}
+
 		_iov = iov + ret;
 		size = node->size - addr + node->start;
 		_iov->iov_len = min((u64)len - s, size);
@@ -1292,6 +1784,8 @@ static int translate_desc(struct vhost_virtqueue *vq, u64
addr, u32 len,
 		++ret;
 	}
 
+	if (ret == -EAGAIN)
+		vhost_iotlb_miss(vq, addr, access);
 	return ret;
 }
 
@@ -1326,7 +1820,7 @@ static int get_indirect(struct vhost_virtqueue *vq,
 	unsigned int i = 0, count, found = 0;
 	u32 len = vhost32_to_cpu(vq, indirect->len);
 	struct iov_iter from;
-	int ret;
+	int ret, access;
 
 	/* Sanity check */
 	if (unlikely(len % sizeof desc)) {
@@ -1338,9 +1832,10 @@ static int get_indirect(struct vhost_virtqueue *vq,
 	}
 
 	ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len,
vq->indirect,
-			     UIO_MAXIOV);
+			     UIO_MAXIOV, VHOST_ACCESS_RO);
 	if (unlikely(ret < 0)) {
-		vq_err(vq, "Translation failure %d in indirect.\n", ret);
+		if (ret != -EAGAIN)
+			vq_err(vq, "Translation failure %d in indirect.\n", ret);
 		return ret;
 	}
 	iov_iter_init(&from, READ, vq->indirect, ret, len);
@@ -1378,16 +1873,22 @@ static int get_indirect(struct vhost_virtqueue *vq,
 			return -EINVAL;
 		}
 
+		if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
+			access = VHOST_ACCESS_WO;
+		else
+			access = VHOST_ACCESS_RO;
+
 		ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
 				     vhost32_to_cpu(vq, desc.len), iov + iov_count,
-				     iov_size - iov_count);
+				     iov_size - iov_count, access);
 		if (unlikely(ret < 0)) {
-			vq_err(vq, "Translation failure %d indirect idx %d\n",
-			       ret, i);
+			if (ret != -EAGAIN)
+				vq_err(vq, "Translation failure %d indirect idx %d\n",
+					ret, i);
 			return ret;
 		}
 		/* If this is an input descriptor, increment that count. */
-		if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) {
+		if (access == VHOST_ACCESS_WO) {
 			*in_num += ret;
 			if (unlikely(log)) {
 				log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
@@ -1426,7 +1927,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
 	u16 last_avail_idx;
 	__virtio16 avail_idx;
 	__virtio16 ring_head;
-	int ret;
+	int ret, access;
 
 	/* Check it isn't doing very strange things with descriptor numbers. */
 	last_avail_idx = vq->last_avail_idx;
@@ -1500,22 +2001,28 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
 					   out_num, in_num,
 					   log, log_num, &desc);
 			if (unlikely(ret < 0)) {
-				vq_err(vq, "Failure detected "
-				       "in indirect descriptor at idx %d\n", i);
+				if (ret != -EAGAIN)
+					vq_err(vq, "Failure detected "
+						"in indirect descriptor at idx %d\n", i);
 				return ret;
 			}
 			continue;
 		}
 
+		if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
+			access = VHOST_ACCESS_WO;
+		else
+			access = VHOST_ACCESS_RO;
 		ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
 				     vhost32_to_cpu(vq, desc.len), iov + iov_count,
-				     iov_size - iov_count);
+				     iov_size - iov_count, access);
 		if (unlikely(ret < 0)) {
-			vq_err(vq, "Translation failure %d descriptor idx %d\n",
-			       ret, i);
+			if (ret != -EAGAIN)
+				vq_err(vq, "Translation failure %d descriptor idx %d\n",
+					ret, i);
 			return ret;
 		}
-		if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) {
+		if (access == VHOST_ACCESS_WO) {
 			/* If this is an input descriptor,
 			 * increment that count. */
 			*in_num += ret;
@@ -1781,6 +2288,47 @@ void vhost_disable_notify(struct vhost_dev *dev, struct
vhost_virtqueue *vq)
 }
 EXPORT_SYMBOL_GPL(vhost_disable_notify);
 
+/* Create a new message. */
+struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type)
+{
+	struct vhost_msg_node *node = kmalloc(sizeof *node, GFP_KERNEL);
+	if (!node)
+		return NULL;
+	node->vq = vq;
+	node->msg.type = type;
+	return node;
+}
+EXPORT_SYMBOL_GPL(vhost_new_msg);
+
+void vhost_enqueue_msg(struct vhost_dev *dev, struct list_head *head,
+		       struct vhost_msg_node *node)
+{
+	spin_lock(&dev->iotlb_lock);
+	list_add_tail(&node->node, head);
+	spin_unlock(&dev->iotlb_lock);
+
+	wake_up_interruptible_poll(&dev->wait, POLLIN | POLLRDNORM);
+}
+EXPORT_SYMBOL_GPL(vhost_enqueue_msg);
+
+struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev,
+					 struct list_head *head)
+{
+	struct vhost_msg_node *node = NULL;
+
+	spin_lock(&dev->iotlb_lock);
+	if (!list_empty(head)) {
+		node = list_first_entry(head, struct vhost_msg_node,
+					node);
+		list_del(&node->node);
+	}
+	spin_unlock(&dev->iotlb_lock);
+
+	return node;
+}
+EXPORT_SYMBOL_GPL(vhost_dequeue_msg);
+
+
 static int __init vhost_init(void)
 {
 	return 0;
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index b93b6a0..8601fc6 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -63,13 +63,15 @@ struct vhost_umem_node {
 	__u64 last;
 	__u64 size;
 	__u64 userspace_addr;
-	__u64 flags_padding;
+	__u32 perm;
+	__u32 flags_padding;
 	__u64 __subtree_last;
 };
 
 struct vhost_umem {
 	struct rb_root umem_tree;
 	struct list_head umem_list;
+	int numem;
 };
 
 /* The virtqueue structure describes a queue attached to a device. */
@@ -117,10 +119,12 @@ struct vhost_virtqueue {
 	u64 log_addr;
 
 	struct iovec iov[UIO_MAXIOV];
+	struct iovec iotlb_iov[64];
 	struct iovec *indirect;
 	struct vring_used_elem *heads;
 	/* Protected by virtqueue mutex. */
 	struct vhost_umem *umem;
+	struct vhost_umem *iotlb;
 	void *private_data;
 	u64 acked_features;
 	/* Log write descriptors */
@@ -137,6 +141,12 @@ struct vhost_virtqueue {
 	u32 busyloop_timeout;
 };
 
+struct vhost_msg_node {
+  struct vhost_msg msg;
+  struct vhost_virtqueue *vq;
+  struct list_head node;
+};
+
 struct vhost_dev {
 	struct mm_struct *mm;
 	struct mutex mutex;
@@ -148,6 +158,11 @@ struct vhost_dev {
 	struct list_head work_list;
 	struct task_struct *worker;
 	struct vhost_umem *umem;
+	struct vhost_umem *iotlb;
+	spinlock_t iotlb_lock;
+	struct list_head read_list;
+	struct list_head pending_list;
+	wait_queue_head_t wait;
 };
 
 void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int
nvqs);
@@ -184,6 +199,21 @@ bool vhost_enable_notify(struct vhost_dev *, struct
vhost_virtqueue *);
 
 int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
 		    unsigned int log_num, u64 len);
+int vq_iotlb_prefetch(struct vhost_virtqueue *vq);
+
+struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type);
+void vhost_enqueue_msg(struct vhost_dev *dev,
+		       struct list_head *head,
+		       struct vhost_msg_node *node);
+struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev,
+					 struct list_head *head);
+unsigned int vhost_chr_poll(struct file *file, struct vhost_dev *dev,
+			    poll_table *wait);
+ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to,
+			    int noblock);
+ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
+			     struct iov_iter *from);
+int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled);
 
 #define vq_err(vq, fmt, ...) do {                                  \
 		pr_debug(pr_fmt(fmt), ##__VA_ARGS__);       \
diff --git a/include/uapi/linux/vhost.h b/include/uapi/linux/vhost.h
index 61a8777..8cb0a65 100644
--- a/include/uapi/linux/vhost.h
+++ b/include/uapi/linux/vhost.h
@@ -47,6 +47,32 @@ struct vhost_vring_addr {
 	__u64 log_guest_addr;
 };
 
+/* no alignment requirement */
+struct vhost_iotlb_msg {
+	__u64 iova;
+	__u64 size;
+	__u64 uaddr;
+#define VHOST_ACCESS_RO      0x1
+#define VHOST_ACCESS_WO      0x2
+#define VHOST_ACCESS_RW      0x3
+	__u8 perm;
+#define VHOST_IOTLB_MISS           1
+#define VHOST_IOTLB_UPDATE         2
+#define VHOST_IOTLB_INVALIDATE     3
+#define VHOST_IOTLB_ACCESS_FAIL    4
+	__u8 type;
+};
+
+#define VHOST_IOTLB_MSG 0x1
+
+struct vhost_msg {
+	int type;
+	union {
+		struct vhost_iotlb_msg iotlb;
+		__u8 padding[64];
+	};
+};
+
 struct vhost_memory_region {
 	__u64 guest_phys_addr;
 	__u64 memory_size; /* bytes */
@@ -146,6 +172,8 @@ struct vhost_memory {
 #define VHOST_F_LOG_ALL 26
 /* vhost-net should add virtio_net_hdr for RX, and strip for TX packets. */
 #define VHOST_NET_F_VIRTIO_NET_HDR 27
+/* Vhost have device IOTLB */
+#define VHOST_F_DEVICE_IOTLB 63
 
 /* VHOST_SCSI specific definitions */
 
-- 
1.8.3.1