diff options
Diffstat (limited to 'kernel/drivers/misc/cxl/fault.c')
-rw-r--r-- | kernel/drivers/misc/cxl/fault.c | 129 |
1 files changed, 95 insertions, 34 deletions
diff --git a/kernel/drivers/misc/cxl/fault.c b/kernel/drivers/misc/cxl/fault.c index 25a5418c5..81c3f75b7 100644 --- a/kernel/drivers/misc/cxl/fault.c +++ b/kernel/drivers/misc/cxl/fault.c @@ -166,13 +166,92 @@ static void cxl_handle_page_fault(struct cxl_context *ctx, cxl_ack_irq(ctx, CXL_PSL_TFC_An_R, 0); } +/* + * Returns the mm_struct corresponding to the context ctx via ctx->pid + * In case the task has exited we use the task group leader accessible + * via ctx->glpid to find the next task in the thread group that has a + * valid mm_struct associated with it. If a task with valid mm_struct + * is found the ctx->pid is updated to use the task struct for subsequent + * translations. In case no valid mm_struct is found in the task group to + * service the fault a NULL is returned. + */ +static struct mm_struct *get_mem_context(struct cxl_context *ctx) +{ + struct task_struct *task = NULL; + struct mm_struct *mm = NULL; + struct pid *old_pid = ctx->pid; + + if (old_pid == NULL) { + pr_warn("%s: Invalid context for pe=%d\n", + __func__, ctx->pe); + return NULL; + } + + task = get_pid_task(old_pid, PIDTYPE_PID); + + /* + * pid_alive may look racy but this saves us from costly + * get_task_mm when the task is a zombie. In worst case + * we may think a task is alive, which is about to die + * but get_task_mm will return NULL. + */ + if (task != NULL && pid_alive(task)) + mm = get_task_mm(task); + + /* release the task struct that was taken earlier */ + if (task) + put_task_struct(task); + else + pr_devel("%s: Context owning pid=%i for pe=%i dead\n", + __func__, pid_nr(old_pid), ctx->pe); + + /* + * If we couldn't find the mm context then use the group + * leader to iterate over the task group and find a task + * that gives us mm_struct. + */ + if (unlikely(mm == NULL && ctx->glpid != NULL)) { + + rcu_read_lock(); + task = pid_task(ctx->glpid, PIDTYPE_PID); + if (task) + do { + mm = get_task_mm(task); + if (mm) { + ctx->pid = get_task_pid(task, + PIDTYPE_PID); + break; + } + task = next_thread(task); + } while (task && !thread_group_leader(task)); + rcu_read_unlock(); + + /* check if we switched pid */ + if (ctx->pid != old_pid) { + if (mm) + pr_devel("%s:pe=%i switch pid %i->%i\n", + __func__, ctx->pe, pid_nr(old_pid), + pid_nr(ctx->pid)); + else + pr_devel("%s:Cannot find mm for pid=%i\n", + __func__, pid_nr(old_pid)); + + /* drop the reference to older pid */ + put_pid(old_pid); + } + } + + return mm; +} + + + void cxl_handle_fault(struct work_struct *fault_work) { struct cxl_context *ctx = container_of(fault_work, struct cxl_context, fault_work); u64 dsisr = ctx->dsisr; u64 dar = ctx->dar; - struct task_struct *task = NULL; struct mm_struct *mm = NULL; if (cxl_p2n_read(ctx->afu, CXL_PSL_DSISR_An) != dsisr || @@ -195,17 +274,17 @@ void cxl_handle_fault(struct work_struct *fault_work) "DSISR: %#llx DAR: %#llx\n", ctx->pe, dsisr, dar); if (!ctx->kernel) { - if (!(task = get_pid_task(ctx->pid, PIDTYPE_PID))) { - pr_devel("cxl_handle_fault unable to get task %i\n", - pid_nr(ctx->pid)); + + mm = get_mem_context(ctx); + /* indicates all the thread in task group have exited */ + if (mm == NULL) { + pr_devel("%s: unable to get mm for pe=%d pid=%i\n", + __func__, ctx->pe, pid_nr(ctx->pid)); cxl_ack_ae(ctx); return; - } - if (!(mm = get_task_mm(task))) { - pr_devel("cxl_handle_fault unable to get mm %i\n", - pid_nr(ctx->pid)); - cxl_ack_ae(ctx); - goto out; + } else { + pr_devel("Handling page fault for pe=%d pid=%i\n", + ctx->pe, pid_nr(ctx->pid)); } } @@ -218,33 +297,22 @@ void cxl_handle_fault(struct work_struct *fault_work) if (mm) mmput(mm); -out: - if (task) - put_task_struct(task); } static void cxl_prefault_one(struct cxl_context *ctx, u64 ea) { - int rc; - struct task_struct *task; struct mm_struct *mm; - if (!(task = get_pid_task(ctx->pid, PIDTYPE_PID))) { - pr_devel("cxl_prefault_one unable to get task %i\n", - pid_nr(ctx->pid)); - return; - } - if (!(mm = get_task_mm(task))) { + mm = get_mem_context(ctx); + if (mm == NULL) { pr_devel("cxl_prefault_one unable to get mm %i\n", pid_nr(ctx->pid)); - put_task_struct(task); return; } - rc = cxl_fault_segment(ctx, mm, ea); + cxl_fault_segment(ctx, mm, ea); mmput(mm); - put_task_struct(task); } static u64 next_segment(u64 ea, u64 vsid) @@ -263,18 +331,13 @@ static void cxl_prefault_vma(struct cxl_context *ctx) struct copro_slb slb; struct vm_area_struct *vma; int rc; - struct task_struct *task; struct mm_struct *mm; - if (!(task = get_pid_task(ctx->pid, PIDTYPE_PID))) { - pr_devel("cxl_prefault_vma unable to get task %i\n", - pid_nr(ctx->pid)); - return; - } - if (!(mm = get_task_mm(task))) { + mm = get_mem_context(ctx); + if (mm == NULL) { pr_devel("cxl_prefault_vm unable to get mm %i\n", pid_nr(ctx->pid)); - goto out1; + return; } down_read(&mm->mmap_sem); @@ -295,8 +358,6 @@ static void cxl_prefault_vma(struct cxl_context *ctx) up_read(&mm->mmap_sem); mmput(mm); -out1: - put_task_struct(task); } void cxl_prefault(struct cxl_context *ctx, u64 wed) |