vfio/pci: Fix racy vfio_device_get_from_dev() call
[cascardo/linux.git] / drivers / vfio / pci / vfio_pci.c
index e9851ad..964ad57 100644 (file)
@@ -1056,19 +1056,21 @@ struct vfio_devices {
 static int vfio_pci_get_devs(struct pci_dev *pdev, void *data)
 {
        struct vfio_devices *devs = data;
-       struct pci_driver *pci_drv = ACCESS_ONCE(pdev->driver);
-
-       if (pci_drv != &vfio_pci_driver)
-               return -EBUSY;
+       struct vfio_device *device;
 
        if (devs->cur_index == devs->max_index)
                return -ENOSPC;
 
-       devs->devices[devs->cur_index] = vfio_device_get_from_dev(&pdev->dev);
-       if (!devs->devices[devs->cur_index])
+       device = vfio_device_get_from_dev(&pdev->dev);
+       if (!device)
                return -EINVAL;
 
-       devs->cur_index++;
+       if (pci_dev_driver(pdev) != &vfio_pci_driver) {
+               vfio_device_put(device);
+               return -EBUSY;
+       }
+
+       devs->devices[devs->cur_index++] = device;
        return 0;
 }