This corresponds to part of the functionality of the NT syscall NtWaitForMultipleObjects(). Specifically, it implements the behaviour where the third argument (wait_any) is TRUE, and it does not handle alertable waits. Those features have been split out into separate patches to ease review. Signed-off-by: Elizabeth Figura --- drivers/misc/ntsync.c | 229 ++++++++++++++++++++++++++++++++++++ include/uapi/linux/ntsync.h | 13 ++ 2 files changed, 242 insertions(+) diff --git a/drivers/misc/ntsync.c b/drivers/misc/ntsync.c index d1c91c2a4f1a..2e8d3c2d51a4 100644 --- a/drivers/misc/ntsync.c +++ b/drivers/misc/ntsync.c @@ -23,6 +23,8 @@ struct ntsync_obj { struct kref refcount; spinlock_t lock; + struct list_head any_waiters; + enum ntsync_type type; /* The following fields are protected by the object lock. */ @@ -34,6 +36,28 @@ struct ntsync_obj { } u; }; +struct ntsync_q_entry { + struct list_head node; + struct ntsync_q *q; + struct ntsync_obj *obj; + __u32 index; +}; + +struct ntsync_q { + struct task_struct *task; + __u32 owner; + + /* + * Protected via atomic_cmpxchg(). Only the thread that wins the + * compare-and-swap may actually change object states and wake this + * task. + */ + atomic_t signaled; + + __u32 count; + struct ntsync_q_entry entries[]; +}; + struct ntsync_device { struct xarray objects; }; @@ -109,6 +133,26 @@ static void init_obj(struct ntsync_obj *obj) { kref_init(&obj->refcount); spin_lock_init(&obj->lock); + INIT_LIST_HEAD(&obj->any_waiters); +} + +static void try_wake_any_sem(struct ntsync_obj *sem) +{ + struct ntsync_q_entry *entry; + + lockdep_assert_held(&sem->lock); + + list_for_each_entry(entry, &sem->any_waiters, node) { + struct ntsync_q *q = entry->q; + + if (!sem->u.sem.count) + break; + + if (atomic_cmpxchg(&q->signaled, -1, entry->index) == -1) { + sem->u.sem.count--; + wake_up_process(q->task); + } + } } static int ntsync_create_sem(struct ntsync_device *dev, void __user *argp) @@ -194,6 +238,8 @@ static int ntsync_put_sem(struct ntsync_device *dev, void __user *argp) prev_count = sem->u.sem.count; ret = put_sem_state(sem, args.count); + if (!ret) + try_wake_any_sem(sem); spin_unlock(&sem->lock); @@ -205,6 +251,187 @@ static int ntsync_put_sem(struct ntsync_device *dev, void __user *argp) return ret; } +static int ntsync_schedule(const struct ntsync_q *q, ktime_t *timeout) +{ + int ret = 0; + + do { + if (signal_pending(current)) { + ret = -ERESTARTSYS; + break; + } + + set_current_state(TASK_INTERRUPTIBLE); + if (atomic_read(&q->signaled) != -1) { + ret = 0; + break; + } + ret = schedule_hrtimeout(timeout, HRTIMER_MODE_ABS); + } while (ret < 0); + __set_current_state(TASK_RUNNING); + + return ret; +} + +/* + * Allocate and initialize the ntsync_q structure, but do not queue us yet. + * Also, calculate the relative timeout. + */ +static int setup_wait(struct ntsync_device *dev, + const struct ntsync_wait_args *args, + ktime_t *ret_timeout, struct ntsync_q **ret_q) +{ + const __u32 count = args->count; + struct ntsync_q *q; + ktime_t timeout = 0; + __u32 *ids; + __u32 i, j; + + if (!args->owner || args->pad) + return -EINVAL; + + if (args->count > NTSYNC_MAX_WAIT_COUNT) + return -EINVAL; + + if (args->timeout) { + struct timespec64 to; + + if (get_timespec64(&to, u64_to_user_ptr(args->timeout))) + return -EFAULT; + if (!timespec64_valid(&to)) + return -EINVAL; + + timeout = timespec64_to_ns(&to); + } + + ids = kmalloc_array(count, sizeof(*ids), GFP_KERNEL); + if (!ids) + return -ENOMEM; + if (copy_from_user(ids, u64_to_user_ptr(args->objs), + array_size(count, sizeof(*ids)))) { + kfree(ids); + return -EFAULT; + } + + q = kmalloc(struct_size(q, entries, count), GFP_KERNEL); + if (!q) { + kfree(ids); + return -ENOMEM; + } + q->task = current; + q->owner = args->owner; + atomic_set(&q->signaled, -1); + q->count = count; + + for (i = 0; i < count; i++) { + struct ntsync_q_entry *entry = &q->entries[i]; + struct ntsync_obj *obj = get_obj(dev, ids[i]); + + if (!obj) + goto err; + + entry->obj = obj; + entry->q = q; + entry->index = i; + } + + kfree(ids); + + *ret_q = q; + *ret_timeout = timeout; + return 0; + +err: + for (j = 0; j < i; j++) + put_obj(q->entries[j].obj); + kfree(ids); + kfree(q); + return -EINVAL; +} + +static void try_wake_any_obj(struct ntsync_obj *obj) +{ + switch (obj->type) { + case NTSYNC_TYPE_SEM: + try_wake_any_sem(obj); + break; + } +} + +static int ntsync_wait_any(struct ntsync_device *dev, void __user *argp) +{ + struct ntsync_wait_args args; + struct ntsync_q *q; + ktime_t timeout; + int signaled; + __u32 i; + int ret; + + if (copy_from_user(&args, argp, sizeof(args))) + return -EFAULT; + + ret = setup_wait(dev, &args, &timeout, &q); + if (ret < 0) + return ret; + + /* queue ourselves */ + + for (i = 0; i < args.count; i++) { + struct ntsync_q_entry *entry = &q->entries[i]; + struct ntsync_obj *obj = entry->obj; + + spin_lock(&obj->lock); + list_add_tail(&entry->node, &obj->any_waiters); + spin_unlock(&obj->lock); + } + + /* check if we are already signaled */ + + for (i = 0; i < args.count; i++) { + struct ntsync_obj *obj = q->entries[i].obj; + + if (atomic_read(&q->signaled) != -1) + break; + + spin_lock(&obj->lock); + try_wake_any_obj(obj); + spin_unlock(&obj->lock); + } + + /* sleep */ + + ret = ntsync_schedule(q, args.timeout ? &timeout : NULL); + + /* and finally, unqueue */ + + for (i = 0; i < args.count; i++) { + struct ntsync_q_entry *entry = &q->entries[i]; + struct ntsync_obj *obj = entry->obj; + + spin_lock(&obj->lock); + list_del(&entry->node); + spin_unlock(&obj->lock); + + put_obj(obj); + } + + signaled = atomic_read(&q->signaled); + if (signaled != -1) { + struct ntsync_wait_args __user *user_args = argp; + + /* even if we caught a signal, we need to communicate success */ + ret = 0; + + if (put_user(signaled, &user_args->index)) + ret = -EFAULT; + } else if (!ret) { + ret = -ETIMEDOUT; + } + + kfree(q); + return ret; +} + static long ntsync_char_ioctl(struct file *file, unsigned int cmd, unsigned long parm) { @@ -218,6 +445,8 @@ static long ntsync_char_ioctl(struct file *file, unsigned int cmd, return ntsync_delete(dev, argp); case NTSYNC_IOC_PUT_SEM: return ntsync_put_sem(dev, argp); + case NTSYNC_IOC_WAIT_ANY: + return ntsync_wait_any(dev, argp); default: return -ENOIOCTLCMD; } diff --git a/include/uapi/linux/ntsync.h b/include/uapi/linux/ntsync.h index 8c610d65f8ef..10f07da7864e 100644 --- a/include/uapi/linux/ntsync.h +++ b/include/uapi/linux/ntsync.h @@ -16,6 +16,17 @@ struct ntsync_sem_args { __u32 max; }; +struct ntsync_wait_args { + __u64 timeout; + __u64 objs; + __u32 count; + __u32 owner; + __u32 index; + __u32 pad; +}; + +#define NTSYNC_MAX_WAIT_COUNT 64 + #define NTSYNC_IOC_BASE 0xf7 #define NTSYNC_IOC_CREATE_SEM _IOWR(NTSYNC_IOC_BASE, 0, \ @@ -23,5 +34,7 @@ struct ntsync_sem_args { #define NTSYNC_IOC_DELETE _IOW (NTSYNC_IOC_BASE, 1, __u32) #define NTSYNC_IOC_PUT_SEM _IOWR(NTSYNC_IOC_BASE, 2, \ struct ntsync_sem_args) +#define NTSYNC_IOC_WAIT_ANY _IOWR(NTSYNC_IOC_BASE, 3, \ + struct ntsync_wait_args) #endif -- 2.43.0