arm64: bpf: optimize JMP_CALL
[cascardo/linux.git] / arch / arm64 / net / bpf_jit_comp.c
index b405bbb..7ae304e 100644 (file)
@@ -18,6 +18,7 @@
 
 #define pr_fmt(fmt) "bpf_jit: " fmt
 
+#include <linux/bpf.h>
 #include <linux/filter.h>
 #include <linux/printk.h>
 #include <linux/skbuff.h>
@@ -31,8 +32,9 @@
 
 int bpf_jit_enable __read_mostly;
 
-#define TMP_REG_1 (MAX_BPF_REG + 0)
-#define TMP_REG_2 (MAX_BPF_REG + 1)
+#define TMP_REG_1 (MAX_BPF_JIT_REG + 0)
+#define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
+#define TCALL_CNT (MAX_BPF_JIT_REG + 2)
 
 /* Map BPF registers to A64 registers */
 static const int bpf2a64[] = {
@@ -51,15 +53,18 @@ static const int bpf2a64[] = {
        [BPF_REG_9] = A64_R(22),
        /* read-only frame pointer to access stack */
        [BPF_REG_FP] = A64_R(25),
-       /* temporary register for internal BPF JIT */
-       [TMP_REG_1] = A64_R(23),
-       [TMP_REG_2] = A64_R(24),
+       /* temporary registers for internal BPF JIT */
+       [TMP_REG_1] = A64_R(10),
+       [TMP_REG_2] = A64_R(11),
+       /* tail_call_cnt */
+       [TCALL_CNT] = A64_R(26),
+       /* temporary register for blinding constants */
+       [BPF_REG_AX] = A64_R(9),
 };
 
 struct jit_ctx {
        const struct bpf_prog *prog;
        int idx;
-       int tmp_used;
        int epilogue_offset;
        int *offset;
        u32 *image;
@@ -145,15 +150,18 @@ static inline int epilogue_offset(const struct jit_ctx *ctx)
 
 #define STACK_SIZE STACK_ALIGN(_STACK_SIZE)
 
-static void build_prologue(struct jit_ctx *ctx)
+#define PROLOGUE_OFFSET 8
+
+static int build_prologue(struct jit_ctx *ctx)
 {
        const u8 r6 = bpf2a64[BPF_REG_6];
        const u8 r7 = bpf2a64[BPF_REG_7];
        const u8 r8 = bpf2a64[BPF_REG_8];
        const u8 r9 = bpf2a64[BPF_REG_9];
        const u8 fp = bpf2a64[BPF_REG_FP];
-       const u8 tmp1 = bpf2a64[TMP_REG_1];
-       const u8 tmp2 = bpf2a64[TMP_REG_2];
+       const u8 tcc = bpf2a64[TCALL_CNT];
+       const int idx0 = ctx->idx;
+       int cur_offset;
 
        /*
         * BPF prog stack layout
@@ -163,9 +171,7 @@ static void build_prologue(struct jit_ctx *ctx)
         *                        |FP/LR|
         * current A64_FP =>  -16:+-----+
         *                        | ... | callee saved registers
-        *                        +-----+
-        *                        |     | x25/x26
-        * BPF fp register => -80:+-----+ <= (BPF_FP)
+        * BPF fp register => -64:+-----+ <= (BPF_FP)
         *                        |     |
         *                        | ... | BPF prog stack
         *                        |     |
@@ -184,20 +190,90 @@ static void build_prologue(struct jit_ctx *ctx)
        emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
        emit(A64_MOV(1, A64_FP, A64_SP), ctx);
 
-       /* Save callee-saved register */
+       /* Save callee-saved registers */
        emit(A64_PUSH(r6, r7, A64_SP), ctx);
        emit(A64_PUSH(r8, r9, A64_SP), ctx);
-       if (ctx->tmp_used)
-               emit(A64_PUSH(tmp1, tmp2, A64_SP), ctx);
+       emit(A64_PUSH(fp, tcc, A64_SP), ctx);
 
-       /* Save fp (x25) and x26. SP requires 16 bytes alignment */
-       emit(A64_PUSH(fp, A64_R(26), A64_SP), ctx);
-
-       /* Set up BPF prog stack base register (x25) */
+       /* Set up BPF prog stack base register */
        emit(A64_MOV(1, fp, A64_SP), ctx);
 
+       /* Initialize tail_call_cnt */
+       emit(A64_MOVZ(1, tcc, 0, 0), ctx);
+
        /* Set up function call stack */
        emit(A64_SUB_I(1, A64_SP, A64_SP, STACK_SIZE), ctx);
+
+       cur_offset = ctx->idx - idx0;
+       if (cur_offset != PROLOGUE_OFFSET) {
+               pr_err_once("PROLOGUE_OFFSET = %d, expected %d!\n",
+                           cur_offset, PROLOGUE_OFFSET);
+               return -1;
+       }
+       return 0;
+}
+
+static int out_offset = -1; /* initialized on the first pass of build_body() */
+static int emit_bpf_tail_call(struct jit_ctx *ctx)
+{
+       /* bpf_tail_call(void *prog_ctx, struct bpf_array *array, u64 index) */
+       const u8 r2 = bpf2a64[BPF_REG_2];
+       const u8 r3 = bpf2a64[BPF_REG_3];
+
+       const u8 tmp = bpf2a64[TMP_REG_1];
+       const u8 prg = bpf2a64[TMP_REG_2];
+       const u8 tcc = bpf2a64[TCALL_CNT];
+       const int idx0 = ctx->idx;
+#define cur_offset (ctx->idx - idx0)
+#define jmp_offset (out_offset - (cur_offset))
+       size_t off;
+
+       /* if (index >= array->map.max_entries)
+        *     goto out;
+        */
+       off = offsetof(struct bpf_array, map.max_entries);
+       emit_a64_mov_i64(tmp, off, ctx);
+       emit(A64_LDR32(tmp, r2, tmp), ctx);
+       emit(A64_CMP(0, r3, tmp), ctx);
+       emit(A64_B_(A64_COND_GE, jmp_offset), ctx);
+
+       /* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
+        *     goto out;
+        * tail_call_cnt++;
+        */
+       emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx);
+       emit(A64_CMP(1, tcc, tmp), ctx);
+       emit(A64_B_(A64_COND_GT, jmp_offset), ctx);
+       emit(A64_ADD_I(1, tcc, tcc, 1), ctx);
+
+       /* prog = array->ptrs[index];
+        * if (prog == NULL)
+        *     goto out;
+        */
+       off = offsetof(struct bpf_array, ptrs);
+       emit_a64_mov_i64(tmp, off, ctx);
+       emit(A64_LDR64(tmp, r2, tmp), ctx);
+       emit(A64_LDR64(prg, tmp, r3), ctx);
+       emit(A64_CBZ(1, prg, jmp_offset), ctx);
+
+       /* goto *(prog->bpf_func + prologue_size); */
+       off = offsetof(struct bpf_prog, bpf_func);
+       emit_a64_mov_i64(tmp, off, ctx);
+       emit(A64_LDR64(tmp, prg, tmp), ctx);
+       emit(A64_ADD_I(1, tmp, tmp, sizeof(u32) * PROLOGUE_OFFSET), ctx);
+       emit(A64_BR(tmp), ctx);
+
+       /* out: */
+       if (out_offset == -1)
+               out_offset = cur_offset;
+       if (cur_offset != out_offset) {
+               pr_err_once("tail_call out_offset = %d, expected %d!\n",
+                           cur_offset, out_offset);
+               return -1;
+       }
+       return 0;
+#undef cur_offset
+#undef jmp_offset
 }
 
 static void build_epilogue(struct jit_ctx *ctx)
@@ -208,8 +284,6 @@ static void build_epilogue(struct jit_ctx *ctx)
        const u8 r8 = bpf2a64[BPF_REG_8];
        const u8 r9 = bpf2a64[BPF_REG_9];
        const u8 fp = bpf2a64[BPF_REG_FP];
-       const u8 tmp1 = bpf2a64[TMP_REG_1];
-       const u8 tmp2 = bpf2a64[TMP_REG_2];
 
        /* We're done with BPF stack */
        emit(A64_ADD_I(1, A64_SP, A64_SP, STACK_SIZE), ctx);
@@ -218,8 +292,6 @@ static void build_epilogue(struct jit_ctx *ctx)
        emit(A64_POP(fp, A64_R(26), A64_SP), ctx);
 
        /* Restore callee-saved register */
-       if (ctx->tmp_used)
-               emit(A64_POP(tmp1, tmp2, A64_SP), ctx);
        emit(A64_POP(r8, r9, A64_SP), ctx);
        emit(A64_POP(r6, r7, A64_SP), ctx);
 
@@ -315,7 +387,6 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx)
                        emit(A64_UDIV(is64, dst, dst, src), ctx);
                        break;
                case BPF_MOD:
-                       ctx->tmp_used = 1;
                        emit(A64_UDIV(is64, tmp, dst, src), ctx);
                        emit(A64_MUL(is64, tmp, tmp, src), ctx);
                        emit(A64_SUB(is64, dst, dst, tmp), ctx);
@@ -388,49 +459,41 @@ emit_bswap_uxt:
        /* dst = dst OP imm */
        case BPF_ALU | BPF_ADD | BPF_K:
        case BPF_ALU64 | BPF_ADD | BPF_K:
-               ctx->tmp_used = 1;
                emit_a64_mov_i(is64, tmp, imm, ctx);
                emit(A64_ADD(is64, dst, dst, tmp), ctx);
                break;
        case BPF_ALU | BPF_SUB | BPF_K:
        case BPF_ALU64 | BPF_SUB | BPF_K:
-               ctx->tmp_used = 1;
                emit_a64_mov_i(is64, tmp, imm, ctx);
                emit(A64_SUB(is64, dst, dst, tmp), ctx);
                break;
        case BPF_ALU | BPF_AND | BPF_K:
        case BPF_ALU64 | BPF_AND | BPF_K:
-               ctx->tmp_used = 1;
                emit_a64_mov_i(is64, tmp, imm, ctx);
                emit(A64_AND(is64, dst, dst, tmp), ctx);
                break;
        case BPF_ALU | BPF_OR | BPF_K:
        case BPF_ALU64 | BPF_OR | BPF_K:
-               ctx->tmp_used = 1;
                emit_a64_mov_i(is64, tmp, imm, ctx);
                emit(A64_ORR(is64, dst, dst, tmp), ctx);
                break;
        case BPF_ALU | BPF_XOR | BPF_K:
        case BPF_ALU64 | BPF_XOR | BPF_K:
-               ctx->tmp_used = 1;
                emit_a64_mov_i(is64, tmp, imm, ctx);
                emit(A64_EOR(is64, dst, dst, tmp), ctx);
                break;
        case BPF_ALU | BPF_MUL | BPF_K:
        case BPF_ALU64 | BPF_MUL | BPF_K:
-               ctx->tmp_used = 1;
                emit_a64_mov_i(is64, tmp, imm, ctx);
                emit(A64_MUL(is64, dst, dst, tmp), ctx);
                break;
        case BPF_ALU | BPF_DIV | BPF_K:
        case BPF_ALU64 | BPF_DIV | BPF_K:
-               ctx->tmp_used = 1;
                emit_a64_mov_i(is64, tmp, imm, ctx);
                emit(A64_UDIV(is64, dst, dst, tmp), ctx);
                break;
        case BPF_ALU | BPF_MOD | BPF_K:
        case BPF_ALU64 | BPF_MOD | BPF_K:
-               ctx->tmp_used = 1;
                emit_a64_mov_i(is64, tmp2, imm, ctx);
                emit(A64_UDIV(is64, tmp, dst, tmp2), ctx);
                emit(A64_MUL(is64, tmp, tmp, tmp2), ctx);
@@ -501,12 +564,10 @@ emit_cond_jmp:
        case BPF_JMP | BPF_JNE | BPF_K:
        case BPF_JMP | BPF_JSGT | BPF_K:
        case BPF_JMP | BPF_JSGE | BPF_K:
-               ctx->tmp_used = 1;
                emit_a64_mov_i(1, tmp, imm, ctx);
                emit(A64_CMP(1, dst, tmp), ctx);
                goto emit_cond_jmp;
        case BPF_JMP | BPF_JSET | BPF_K:
-               ctx->tmp_used = 1;
                emit_a64_mov_i(1, tmp, imm, ctx);
                emit(A64_TST(1, dst, tmp), ctx);
                goto emit_cond_jmp;
@@ -516,15 +577,16 @@ emit_cond_jmp:
                const u8 r0 = bpf2a64[BPF_REG_0];
                const u64 func = (u64)__bpf_call_base + imm;
 
-               ctx->tmp_used = 1;
                emit_a64_mov_i64(tmp, func, ctx);
-               emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
-               emit(A64_MOV(1, A64_FP, A64_SP), ctx);
                emit(A64_BLR(tmp), ctx);
                emit(A64_MOV(1, r0, A64_R(0)), ctx);
-               emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
                break;
        }
+       /* tail call */
+       case BPF_JMP | BPF_CALL | BPF_X:
+               if (emit_bpf_tail_call(ctx))
+                       return -EFAULT;
+               break;
        /* function return */
        case BPF_JMP | BPF_EXIT:
                /* Optimization: when last instruction is EXIT,
@@ -562,7 +624,6 @@ emit_cond_jmp:
        case BPF_LDX | BPF_MEM | BPF_H:
        case BPF_LDX | BPF_MEM | BPF_B:
        case BPF_LDX | BPF_MEM | BPF_DW:
-               ctx->tmp_used = 1;
                emit_a64_mov_i(1, tmp, off, ctx);
                switch (BPF_SIZE(code)) {
                case BPF_W:
@@ -586,7 +647,6 @@ emit_cond_jmp:
        case BPF_ST | BPF_MEM | BPF_B:
        case BPF_ST | BPF_MEM | BPF_DW:
                /* Load imm to a register then store it */
-               ctx->tmp_used = 1;
                emit_a64_mov_i(1, tmp2, off, ctx);
                emit_a64_mov_i(1, tmp, imm, ctx);
                switch (BPF_SIZE(code)) {
@@ -610,7 +670,6 @@ emit_cond_jmp:
        case BPF_STX | BPF_MEM | BPF_H:
        case BPF_STX | BPF_MEM | BPF_B:
        case BPF_STX | BPF_MEM | BPF_DW:
-               ctx->tmp_used = 1;
                emit_a64_mov_i(1, tmp, off, ctx);
                switch (BPF_SIZE(code)) {
                case BPF_W:
@@ -762,33 +821,50 @@ void bpf_jit_compile(struct bpf_prog *prog)
        /* Nothing to do here. We support Internal BPF. */
 }
 
-void bpf_int_jit_compile(struct bpf_prog *prog)
+struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 {
+       struct bpf_prog *tmp, *orig_prog = prog;
        struct bpf_binary_header *header;
+       bool tmp_blinded = false;
        struct jit_ctx ctx;
        int image_size;
        u8 *image_ptr;
 
        if (!bpf_jit_enable)
-               return;
+               return orig_prog;
 
-       if (!prog || !prog->len)
-               return;
+       tmp = bpf_jit_blind_constants(prog);
+       /* If blinding was requested and we failed during blinding,
+        * we must fall back to the interpreter.
+        */
+       if (IS_ERR(tmp))
+               return orig_prog;
+       if (tmp != prog) {
+               tmp_blinded = true;
+               prog = tmp;
+       }
 
        memset(&ctx, 0, sizeof(ctx));
        ctx.prog = prog;
 
        ctx.offset = kcalloc(prog->len, sizeof(int), GFP_KERNEL);
-       if (ctx.offset == NULL)
-               return;
+       if (ctx.offset == NULL) {
+               prog = orig_prog;
+               goto out;
+       }
 
        /* 1. Initial fake pass to compute ctx->idx. */
 
-       /* Fake pass to fill in ctx->offset and ctx->tmp_used. */
-       if (build_body(&ctx))
-               goto out;
+       /* Fake pass to fill in ctx->offset. */
+       if (build_body(&ctx)) {
+               prog = orig_prog;
+               goto out_off;
+       }
 
-       build_prologue(&ctx);
+       if (build_prologue(&ctx)) {
+               prog = orig_prog;
+               goto out_off;
+       }
 
        ctx.epilogue_offset = ctx.idx;
        build_epilogue(&ctx);
@@ -797,8 +873,10 @@ void bpf_int_jit_compile(struct bpf_prog *prog)
        image_size = sizeof(u32) * ctx.idx;
        header = bpf_jit_binary_alloc(image_size, &image_ptr,
                                      sizeof(u32), jit_fill_hole);
-       if (header == NULL)
-               goto out;
+       if (header == NULL) {
+               prog = orig_prog;
+               goto out_off;
+       }
 
        /* 2. Now, the actual pass. */
 
@@ -809,7 +887,8 @@ void bpf_int_jit_compile(struct bpf_prog *prog)
 
        if (build_body(&ctx)) {
                bpf_jit_binary_free(header);
-               goto out;
+               prog = orig_prog;
+               goto out_off;
        }
 
        build_epilogue(&ctx);
@@ -817,7 +896,8 @@ void bpf_int_jit_compile(struct bpf_prog *prog)
        /* 3. Extra pass to validate JITed code. */
        if (validate_code(&ctx)) {
                bpf_jit_binary_free(header);
-               goto out;
+               prog = orig_prog;
+               goto out_off;
        }
 
        /* And we're done. */
@@ -829,8 +909,14 @@ void bpf_int_jit_compile(struct bpf_prog *prog)
        set_memory_ro((unsigned long)header, header->pages);
        prog->bpf_func = (void *)ctx.image;
        prog->jited = 1;
-out:
+
+out_off:
        kfree(ctx.offset);
+out:
+       if (tmp_blinded)
+               bpf_jit_prog_release_other(prog, prog == orig_prog ?
+                                          tmp : orig_prog);
+       return prog;
 }
 
 void bpf_jit_free(struct bpf_prog *prog)