diff options
Diffstat (limited to 'arch')
-rw-r--r-- | arch/arm64/include/asm/patching.h | 2 | ||||
-rw-r--r-- | arch/arm64/kernel/patching.c | 75 | ||||
-rw-r--r-- | arch/arm64/kernel/stacktrace.c | 26 | ||||
-rw-r--r-- | arch/arm64/net/bpf_jit_comp.c | 226 | ||||
-rw-r--r-- | arch/riscv/net/bpf_jit.h | 134 | ||||
-rw-r--r-- | arch/riscv/net/bpf_jit_comp64.c | 215 |
6 files changed, 496 insertions, 182 deletions
diff --git a/arch/arm64/include/asm/patching.h b/arch/arm64/include/asm/patching.h index 68908b82b168..587bdb91ab7a 100644 --- a/arch/arm64/include/asm/patching.h +++ b/arch/arm64/include/asm/patching.h @@ -8,6 +8,8 @@ int aarch64_insn_read(void *addr, u32 *insnp); int aarch64_insn_write(void *addr, u32 insn); int aarch64_insn_write_literal_u64(void *addr, u64 val); +void *aarch64_insn_set(void *dst, u32 insn, size_t len); +void *aarch64_insn_copy(void *dst, void *src, size_t len); int aarch64_insn_patch_text_nosync(void *addr, u32 insn); int aarch64_insn_patch_text(void *addrs[], u32 insns[], int cnt); diff --git a/arch/arm64/kernel/patching.c b/arch/arm64/kernel/patching.c index b4835f6d594b..255534930368 100644 --- a/arch/arm64/kernel/patching.c +++ b/arch/arm64/kernel/patching.c @@ -105,6 +105,81 @@ noinstr int aarch64_insn_write_literal_u64(void *addr, u64 val) return ret; } +typedef void text_poke_f(void *dst, void *src, size_t patched, size_t len); + +static void *__text_poke(text_poke_f func, void *addr, void *src, size_t len) +{ + unsigned long flags; + size_t patched = 0; + size_t size; + void *waddr; + void *ptr; + + raw_spin_lock_irqsave(&patch_lock, flags); + + while (patched < len) { + ptr = addr + patched; + size = min_t(size_t, PAGE_SIZE - offset_in_page(ptr), + len - patched); + + waddr = patch_map(ptr, FIX_TEXT_POKE0); + func(waddr, src, patched, size); + patch_unmap(FIX_TEXT_POKE0); + + patched += size; + } + raw_spin_unlock_irqrestore(&patch_lock, flags); + + flush_icache_range((uintptr_t)addr, (uintptr_t)addr + len); + + return addr; +} + +static void text_poke_memcpy(void *dst, void *src, size_t patched, size_t len) +{ + copy_to_kernel_nofault(dst, src + patched, len); +} + +static void text_poke_memset(void *dst, void *src, size_t patched, size_t len) +{ + u32 c = *(u32 *)src; + + memset32(dst, c, len / 4); +} + +/** + * aarch64_insn_copy - Copy instructions into (an unused part of) RX memory + * @dst: address to modify + * @src: source of the copy + * @len: length to copy + * + * Useful for JITs to dump new code blocks into unused regions of RX memory. + */ +noinstr void *aarch64_insn_copy(void *dst, void *src, size_t len) +{ + /* A64 instructions must be word aligned */ + if ((uintptr_t)dst & 0x3) + return NULL; + + return __text_poke(text_poke_memcpy, dst, src, len); +} + +/** + * aarch64_insn_set - memset for RX memory regions. + * @dst: address to modify + * @insn: value to set + * @len: length of memory region. + * + * Useful for JITs to fill regions of RX memory with illegal instructions. + */ +noinstr void *aarch64_insn_set(void *dst, u32 insn, size_t len) +{ + if ((uintptr_t)dst & 0x3) + return NULL; + + return __text_poke(text_poke_memset, dst, &insn, len); +} + int __kprobes aarch64_insn_patch_text_nosync(void *addr, u32 insn) { u32 *tp = addr; diff --git a/arch/arm64/kernel/stacktrace.c b/arch/arm64/kernel/stacktrace.c index 7f88028a00c0..66cffc5fc0be 100644 --- a/arch/arm64/kernel/stacktrace.c +++ b/arch/arm64/kernel/stacktrace.c @@ -7,6 +7,7 @@ #include <linux/kernel.h> #include <linux/efi.h> #include <linux/export.h> +#include <linux/filter.h> #include <linux/ftrace.h> #include <linux/kprobes.h> #include <linux/sched.h> @@ -266,6 +267,31 @@ noinline noinstr void arch_stack_walk(stack_trace_consume_fn consume_entry, kunwind_stack_walk(arch_kunwind_consume_entry, &data, task, regs); } +struct bpf_unwind_consume_entry_data { + bool (*consume_entry)(void *cookie, u64 ip, u64 sp, u64 fp); + void *cookie; +}; + +static bool +arch_bpf_unwind_consume_entry(const struct kunwind_state *state, void *cookie) +{ + struct bpf_unwind_consume_entry_data *data = cookie; + + return data->consume_entry(data->cookie, state->common.pc, 0, + state->common.fp); +} + +noinline noinstr void arch_bpf_stack_walk(bool (*consume_entry)(void *cookie, u64 ip, u64 sp, + u64 fp), void *cookie) +{ + struct bpf_unwind_consume_entry_data data = { + .consume_entry = consume_entry, + .cookie = cookie, + }; + + kunwind_stack_walk(arch_bpf_unwind_consume_entry, &data, current, NULL); +} + static bool dump_backtrace_entry(void *arg, unsigned long where) { char *loglvl = arg; diff --git a/arch/arm64/net/bpf_jit_comp.c b/arch/arm64/net/bpf_jit_comp.c index cfd5434de483..5afc7a525eca 100644 --- a/arch/arm64/net/bpf_jit_comp.c +++ b/arch/arm64/net/bpf_jit_comp.c @@ -76,6 +76,7 @@ struct jit_ctx { int *offset; int exentry_idx; __le32 *image; + __le32 *ro_image; u32 stack_size; int fpb_offset; }; @@ -205,6 +206,14 @@ static void jit_fill_hole(void *area, unsigned int size) *ptr++ = cpu_to_le32(AARCH64_BREAK_FAULT); } +int bpf_arch_text_invalidate(void *dst, size_t len) +{ + if (!aarch64_insn_set(dst, AARCH64_BREAK_FAULT, len)) + return -EINVAL; + + return 0; +} + static inline int epilogue_offset(const struct jit_ctx *ctx) { int to = ctx->epilogue_offset; @@ -285,7 +294,8 @@ static bool is_lsi_offset(int offset, int scale) /* Tail call offset to jump into */ #define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 8) -static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf) +static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf, + bool is_exception_cb) { const struct bpf_prog *prog = ctx->prog; const bool is_main_prog = !bpf_is_subprog(prog); @@ -333,19 +343,34 @@ static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf) emit(A64_MOV(1, A64_R(9), A64_LR), ctx); emit(A64_NOP, ctx); - /* Sign lr */ - if (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL)) - emit(A64_PACIASP, ctx); - - /* Save FP and LR registers to stay align with ARM64 AAPCS */ - emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx); - emit(A64_MOV(1, A64_FP, A64_SP), ctx); - - /* Save callee-saved registers */ - emit(A64_PUSH(r6, r7, A64_SP), ctx); - emit(A64_PUSH(r8, r9, A64_SP), ctx); - emit(A64_PUSH(fp, tcc, A64_SP), ctx); - emit(A64_PUSH(fpb, A64_R(28), A64_SP), ctx); + if (!is_exception_cb) { + /* Sign lr */ + if (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL)) + emit(A64_PACIASP, ctx); + /* Save FP and LR registers to stay align with ARM64 AAPCS */ + emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx); + emit(A64_MOV(1, A64_FP, A64_SP), ctx); + + /* Save callee-saved registers */ + emit(A64_PUSH(r6, r7, A64_SP), ctx); + emit(A64_PUSH(r8, r9, A64_SP), ctx); + emit(A64_PUSH(fp, tcc, A64_SP), ctx); + emit(A64_PUSH(fpb, A64_R(28), A64_SP), ctx); + } else { + /* + * Exception callback receives FP of Main Program as third + * parameter + */ + emit(A64_MOV(1, A64_FP, A64_R(2)), ctx); + /* + * Main Program already pushed the frame record and the + * callee-saved registers. The exception callback will not push + * anything and re-use the main program's stack. + * + * 10 registers are on the stack + */ + emit(A64_SUB_I(1, A64_SP, A64_FP, 80), ctx); + } /* Set up BPF prog stack base register */ emit(A64_MOV(1, fp, A64_SP), ctx); @@ -365,6 +390,20 @@ static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf) emit_bti(A64_BTI_J, ctx); } + /* + * Program acting as exception boundary should save all ARM64 + * Callee-saved registers as the exception callback needs to recover + * all ARM64 Callee-saved registers in its epilogue. + */ + if (prog->aux->exception_boundary) { + /* + * As we are pushing two more registers, BPF_FP should be moved + * 16 bytes + */ + emit(A64_SUB_I(1, fp, fp, 16), ctx); + emit(A64_PUSH(A64_R(23), A64_R(24), A64_SP), ctx); + } + emit(A64_SUB_I(1, fpb, fp, ctx->fpb_offset), ctx); /* Stack must be multiples of 16B */ @@ -653,7 +692,7 @@ static void build_plt(struct jit_ctx *ctx) plt->target = (u64)&dummy_tramp; } -static void build_epilogue(struct jit_ctx *ctx) +static void build_epilogue(struct jit_ctx *ctx, bool is_exception_cb) { const u8 r0 = bpf2a64[BPF_REG_0]; const u8 r6 = bpf2a64[BPF_REG_6]; @@ -666,6 +705,15 @@ static void build_epilogue(struct jit_ctx *ctx) /* We're done with BPF stack */ emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx); + /* + * Program acting as exception boundary pushes R23 and R24 in addition + * to BPF callee-saved registers. Exception callback uses the boundary + * program's stack frame, so recover these extra registers in the above + * two cases. + */ + if (ctx->prog->aux->exception_boundary || is_exception_cb) + emit(A64_POP(A64_R(23), A64_R(24), A64_SP), ctx); + /* Restore x27 and x28 */ emit(A64_POP(fpb, A64_R(28), A64_SP), ctx); /* Restore fs (x25) and x26 */ @@ -707,7 +755,8 @@ static int add_exception_handler(const struct bpf_insn *insn, struct jit_ctx *ctx, int dst_reg) { - off_t offset; + off_t ins_offset; + off_t fixup_offset; unsigned long pc; struct exception_table_entry *ex; @@ -724,12 +773,17 @@ static int add_exception_handler(const struct bpf_insn *insn, return -EINVAL; ex = &ctx->prog->aux->extable[ctx->exentry_idx]; - pc = (unsigned long)&ctx->image[ctx->idx - 1]; + pc = (unsigned long)&ctx->ro_image[ctx->idx - 1]; - offset = pc - (long)&ex->insn; - if (WARN_ON_ONCE(offset >= 0 || offset < INT_MIN)) + /* + * This is the relative offset of the instruction that may fault from + * the exception table itself. This will be written to the exception + * table and if this instruction faults, the destination register will + * be set to '0' and the execution will jump to the next instruction. + */ + ins_offset = pc - (long)&ex->insn; + if (WARN_ON_ONCE(ins_offset >= 0 || ins_offset < INT_MIN)) return -ERANGE; - ex->insn = offset; /* * Since the extable follows the program, the fixup offset is always @@ -738,12 +792,25 @@ static int add_exception_handler(const struct bpf_insn *insn, * bits. We don't need to worry about buildtime or runtime sort * modifying the upper bits because the table is already sorted, and * isn't part of the main exception table. + * + * The fixup_offset is set to the next instruction from the instruction + * that may fault. The execution will jump to this after handling the + * fault. */ - offset = (long)&ex->fixup - (pc + AARCH64_INSN_SIZE); - if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, offset)) + fixup_offset = (long)&ex->fixup - (pc + AARCH64_INSN_SIZE); + if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, fixup_offset)) return -ERANGE; - ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, offset) | + /* + * The offsets above have been calculated using the RO buffer but we + * need to use the R/W buffer for writes. + * switch ex to rw buffer for writing. + */ + ex = (void *)ctx->image + ((void *)ex - (void *)ctx->ro_image); + + ex->insn = ins_offset; + + ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, fixup_offset) | FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg); ex->type = EX_TYPE_BPF; @@ -1511,7 +1578,8 @@ static inline void bpf_flush_icache(void *start, void *end) struct arm64_jit_data { struct bpf_binary_header *header; - u8 *image; + u8 *ro_image; + struct bpf_binary_header *ro_header; struct jit_ctx ctx; }; @@ -1520,12 +1588,14 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) int image_size, prog_size, extable_size, extable_align, extable_offset; struct bpf_prog *tmp, *orig_prog = prog; struct bpf_binary_header *header; + struct bpf_binary_header *ro_header; struct arm64_jit_data *jit_data; bool was_classic = bpf_prog_was_classic(prog); bool tmp_blinded = false; bool extra_pass = false; struct jit_ctx ctx; u8 *image_ptr; + u8 *ro_image_ptr; if (!prog->jit_requested) return orig_prog; @@ -1552,8 +1622,11 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) } if (jit_data->ctx.offset) { ctx = jit_data->ctx; - image_ptr = jit_data->image; + ro_image_ptr = jit_data->ro_image; + ro_header = jit_data->ro_header; header = jit_data->header; + image_ptr = (void *)header + ((void *)ro_image_ptr + - (void *)ro_header); extra_pass = true; prog_size = sizeof(u32) * ctx.idx; goto skip_init_ctx; @@ -1575,7 +1648,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) * BPF line info needs ctx->offset[i] to be the offset of * instruction[i] in jited image, so build prologue first. */ - if (build_prologue(&ctx, was_classic)) { + if (build_prologue(&ctx, was_classic, prog->aux->exception_cb)) { prog = orig_prog; goto out_off; } @@ -1586,7 +1659,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) } ctx.epilogue_offset = ctx.idx; - build_epilogue(&ctx); + build_epilogue(&ctx, prog->aux->exception_cb); build_plt(&ctx); extable_align = __alignof__(struct exception_table_entry); @@ -1598,63 +1671,81 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) /* also allocate space for plt target */ extable_offset = round_up(prog_size + PLT_TARGET_SIZE, extable_align); image_size = extable_offset + extable_size; - header = bpf_jit_binary_alloc(image_size, &image_ptr, - sizeof(u32), jit_fill_hole); - if (header == NULL) { + ro_header = bpf_jit_binary_pack_alloc(image_size, &ro_image_ptr, + sizeof(u32), &header, &image_ptr, + jit_fill_hole); + if (!ro_header) { prog = orig_prog; goto out_off; } /* 2. Now, the actual pass. */ + /* + * Use the image(RW) for writing the JITed instructions. But also save + * the ro_image(RX) for calculating the offsets in the image. The RW + * image will be later copied to the RX image from where the program + * will run. The bpf_jit_binary_pack_finalize() will do this copy in the + * final step. + */ ctx.image = (__le32 *)image_ptr; + ctx.ro_image = (__le32 *)ro_image_ptr; if (extable_size) - prog->aux->extable = (void *)image_ptr + extable_offset; + prog->aux->extable = (void *)ro_image_ptr + extable_offset; skip_init_ctx: ctx.idx = 0; ctx.exentry_idx = 0; - build_prologue(&ctx, was_classic); + build_prologue(&ctx, was_classic, prog->aux->exception_cb); if (build_body(&ctx, extra_pass)) { - bpf_jit_binary_free(header); prog = orig_prog; - goto out_off; + goto out_free_hdr; } - build_epilogue(&ctx); + build_epilogue(&ctx, prog->aux->exception_cb); build_plt(&ctx); /* 3. Extra pass to validate JITed code. */ if (validate_ctx(&ctx)) { - bpf_jit_binary_free(header); prog = orig_prog; - goto out_off; + goto out_free_hdr; } /* And we're done. */ if (bpf_jit_enable > 1) bpf_jit_dump(prog->len, prog_size, 2, ctx.image); - bpf_flush_icache(header, ctx.image + ctx.idx); - if (!prog->is_func || extra_pass) { if (extra_pass && ctx.idx != jit_data->ctx.idx) { pr_err_once("multi-func JIT bug %d != %d\n", ctx.idx, jit_data->ctx.idx); - bpf_jit_binary_free(header); prog->bpf_func = NULL; prog->jited = 0; prog->jited_len = 0; + goto out_free_hdr; + } + if (WARN_ON(bpf_jit_binary_pack_finalize(prog, ro_header, + header))) { + /* ro_header has been freed */ + ro_header = NULL; + prog = orig_prog; goto out_off; } - bpf_jit_binary_lock_ro(header); + /* + * The instructions have now been copied to the ROX region from + * where they will execute. Now the data cache has to be cleaned to + * the PoU and the I-cache has to be invalidated for the VAs. + */ + bpf_flush_icache(ro_header, ctx.ro_image + ctx.idx); } else { jit_data->ctx = ctx; - jit_data->image = image_ptr; + jit_data->ro_image = ro_image_ptr; jit_data->header = header; + jit_data->ro_header = ro_header; } - prog->bpf_func = (void *)ctx.image; + + prog->bpf_func = (void *)ctx.ro_image; prog->jited = 1; prog->jited_len = prog_size; @@ -1675,6 +1766,14 @@ out: bpf_jit_prog_release_other(prog, prog == orig_prog ? tmp : orig_prog); return prog; + +out_free_hdr: + if (header) { + bpf_arch_text_copy(&ro_header->size, &header->size, + sizeof(header->size)); + bpf_jit_binary_pack_free(ro_header, header); + } + goto out_off; } bool bpf_jit_supports_kfunc_call(void) @@ -1682,6 +1781,13 @@ bool bpf_jit_supports_kfunc_call(void) return true; } +void *bpf_arch_text_copy(void *dst, void *src, size_t len) +{ + if (!aarch64_insn_copy(dst, src, len)) + return ERR_PTR(-EINVAL); + return dst; +} + u64 bpf_jit_alloc_exec_limit(void) { return VMALLOC_END - VMALLOC_START; @@ -2310,3 +2416,37 @@ bool bpf_jit_supports_ptr_xchg(void) { return true; } + +bool bpf_jit_supports_exceptions(void) +{ + /* We unwind through both kernel frames starting from within bpf_throw + * call and BPF frames. Therefore we require FP unwinder to be enabled + * to walk kernel frames and reach BPF frames in the stack trace. + * ARM64 kernel is aways compiled with CONFIG_FRAME_POINTER=y + */ + return true; +} + +void bpf_jit_free(struct bpf_prog *prog) +{ + if (prog->jited) { + struct arm64_jit_data *jit_data = prog->aux->jit_data; + struct bpf_binary_header *hdr; + + /* + * If we fail the final pass of JIT (from jit_subprogs), + * the program may not be finalized yet. Call finalize here + * before freeing it. + */ + if (jit_data) { + bpf_arch_text_copy(&jit_data->ro_header->size, &jit_data->header->size, + sizeof(jit_data->header->size)); + kfree(jit_data); + } + hdr = bpf_jit_binary_pack_hdr(prog); + bpf_jit_binary_pack_free(hdr, NULL); + WARN_ON_ONCE(!bpf_prog_kallsyms_verify_off(prog)); + } + + bpf_prog_unlock_free(prog); +} diff --git a/arch/riscv/net/bpf_jit.h b/arch/riscv/net/bpf_jit.h index a5ce1ab76ece..8b35f12a4452 100644 --- a/arch/riscv/net/bpf_jit.h +++ b/arch/riscv/net/bpf_jit.h @@ -18,6 +18,11 @@ static inline bool rvc_enabled(void) return IS_ENABLED(CONFIG_RISCV_ISA_C); } +static inline bool rvzbb_enabled(void) +{ + return IS_ENABLED(CONFIG_RISCV_ISA_ZBB) && riscv_has_extension_likely(RISCV_ISA_EXT_ZBB); +} + enum { RV_REG_ZERO = 0, /* The constant value 0 */ RV_REG_RA = 1, /* Return address */ @@ -730,6 +735,33 @@ static inline u16 rvc_swsp(u32 imm8, u8 rs2) return rv_css_insn(0x6, imm, rs2, 0x2); } +/* RVZBB instrutions. */ +static inline u32 rvzbb_sextb(u8 rd, u8 rs1) +{ + return rv_i_insn(0x604, rs1, 1, rd, 0x13); +} + +static inline u32 rvzbb_sexth(u8 rd, u8 rs1) +{ + return rv_i_insn(0x605, rs1, 1, rd, 0x13); +} + +static inline u32 rvzbb_zexth(u8 rd, u8 rs) +{ + if (IS_ENABLED(CONFIG_64BIT)) + return rv_i_insn(0x80, rs, 4, rd, 0x3b); + + return rv_i_insn(0x80, rs, 4, rd, 0x33); +} + +static inline u32 rvzbb_rev8(u8 rd, u8 rs) +{ + if (IS_ENABLED(CONFIG_64BIT)) + return rv_i_insn(0x6b8, rs, 5, rd, 0x13); + + return rv_i_insn(0x698, rs, 5, rd, 0x13); +} + /* * RV64-only instructions. * @@ -1087,6 +1119,108 @@ static inline void emit_subw(u8 rd, u8 rs1, u8 rs2, struct rv_jit_context *ctx) emit(rv_subw(rd, rs1, rs2), ctx); } +static inline void emit_sextb(u8 rd, u8 rs, struct rv_jit_context *ctx) +{ + if (rvzbb_enabled()) { + emit(rvzbb_sextb(rd, rs), ctx); + return; + } + + emit_slli(rd, rs, 56, ctx); + emit_srai(rd, rd, 56, ctx); +} + +static inline void emit_sexth(u8 rd, u8 rs, struct rv_jit_context *ctx) +{ + if (rvzbb_enabled()) { + emit(rvzbb_sexth(rd, rs), ctx); + return; + } + + emit_slli(rd, rs, 48, ctx); + emit_srai(rd, rd, 48, ctx); +} + +static inline void emit_sextw(u8 rd, u8 rs, struct rv_jit_context *ctx) +{ + emit_addiw(rd, rs, 0, ctx); +} + +static inline void emit_zexth(u8 rd, u8 rs, struct rv_jit_context *ctx) +{ + if (rvzbb_enabled()) { + emit(rvzbb_zexth(rd, rs), ctx); + return; + } + + emit_slli(rd, rs, 48, ctx); + emit_srli(rd, rd, 48, ctx); +} + +static inline void emit_zextw(u8 rd, u8 rs, struct rv_jit_context *ctx) +{ + emit_slli(rd, rs, 32, ctx); + emit_srli(rd, rd, 32, ctx); +} + +static inline void emit_bswap(u8 rd, s32 imm, struct rv_jit_context *ctx) +{ + if (rvzbb_enabled()) { + int bits = 64 - imm; + + emit(rvzbb_rev8(rd, rd), ctx); + if (bits) + emit_srli(rd, rd, bits, ctx); + return; + } + + emit_li(RV_REG_T2, 0, ctx); + + emit_andi(RV_REG_T1, rd, 0xff, ctx); + emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx); + emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx); + emit_srli(rd, rd, 8, ctx); + if (imm == 16) + goto out_be; + + emit_andi(RV_REG_T1, rd, 0xff, ctx); + emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx); + emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx); + emit_srli(rd, rd, 8, ctx); + + emit_andi(RV_REG_T1, rd, 0xff, ctx); + emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx); + emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx); + emit_srli(rd, rd, 8, ctx); + if (imm == 32) + goto out_be; + + emit_andi(RV_REG_T1, rd, 0xff, ctx); + emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx); + emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx); + emit_srli(rd, rd, 8, ctx); + + emit_andi(RV_REG_T1, rd, 0xff, ctx); + emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx); + emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx); + emit_srli(rd, rd, 8, ctx); + + emit_andi(RV_REG_T1, rd, 0xff, ctx); + emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx); + emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx); + emit_srli(rd, rd, 8, ctx); + + emit_andi(RV_REG_T1, rd, 0xff, ctx); + emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx); + emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx); + emit_srli(rd, rd, 8, ctx); +out_be: + emit_andi(RV_REG_T1, rd, 0xff, ctx); + emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx); + + emit_mv(rd, RV_REG_T2, ctx); +} + #endif /* __riscv_xlen == 64 */ void bpf_jit_build_prologue(struct rv_jit_context *ctx); diff --git a/arch/riscv/net/bpf_jit_comp64.c b/arch/riscv/net/bpf_jit_comp64.c index 719a97e7edb2..869e4282a2c4 100644 --- a/arch/riscv/net/bpf_jit_comp64.c +++ b/arch/riscv/net/bpf_jit_comp64.c @@ -141,6 +141,19 @@ static bool in_auipc_jalr_range(s64 val) val < ((1L << 31) - (1L << 11)); } +/* Modify rd pointer to alternate reg to avoid corrupting original reg */ +static void emit_sextw_alt(u8 *rd, u8 ra, struct rv_jit_context *ctx) +{ + emit_sextw(ra, *rd, ctx); + *rd = ra; +} + +static void emit_zextw_alt(u8 *rd, u8 ra, struct rv_jit_context *ctx) +{ + emit_zextw(ra, *rd, ctx); + *rd = ra; +} + /* Emit fixed-length instructions for address */ static int emit_addr(u8 rd, u64 addr, bool extra_pass, struct rv_jit_context *ctx) { @@ -326,12 +339,6 @@ static void emit_branch(u8 cond, u8 rd, u8 rs, int rvoff, emit(rv_jalr(RV_REG_ZERO, RV_REG_T1, lower), ctx); } -static void emit_zext_32(u8 reg, struct rv_jit_context *ctx) -{ - emit_slli(reg, reg, 32, ctx); - emit_srli(reg, reg, 32, ctx); -} - static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx) { int tc_ninsn, off, start_insn = ctx->ninsns; @@ -346,7 +353,7 @@ static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx) */ tc_ninsn = insn ? ctx->offset[insn] - ctx->offset[insn - 1] : ctx->offset[0]; - emit_zext_32(RV_REG_A2, ctx); + emit_zextw(RV_REG_A2, RV_REG_A2, ctx); off = offsetof(struct bpf_array, map.max_entries); if (is_12b_check(off, insn)) @@ -405,38 +412,6 @@ static void init_regs(u8 *rd, u8 *rs, const struct bpf_insn *insn, *rs = bpf_to_rv_reg(insn->src_reg, ctx); } -static void emit_zext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx) -{ - emit_mv(RV_REG_T2, *rd, ctx); - emit_zext_32(RV_REG_T2, ctx); - emit_mv(RV_REG_T1, *rs, ctx); - emit_zext_32(RV_REG_T1, ctx); - *rd = RV_REG_T2; - *rs = RV_REG_T1; -} - -static void emit_sext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx) -{ - emit_addiw(RV_REG_T2, *rd, 0, ctx); - emit_addiw(RV_REG_T1, *rs, 0, ctx); - *rd = RV_REG_T2; - *rs = RV_REG_T1; -} - -static void emit_zext_32_rd_t1(u8 *rd, struct rv_jit_context *ctx) -{ - emit_mv(RV_REG_T2, *rd, ctx); - emit_zext_32(RV_REG_T2, ctx); - emit_zext_32(RV_REG_T1, ctx); - *rd = RV_REG_T2; -} - -static void emit_sext_32_rd(u8 *rd, struct rv_jit_context *ctx) -{ - emit_addiw(RV_REG_T2, *rd, 0, ctx); - *rd = RV_REG_T2; -} - static int emit_jump_and_link(u8 rd, s64 rvoff, bool fixed_addr, struct rv_jit_context *ctx) { @@ -519,32 +494,32 @@ static void emit_atomic(u8 rd, u8 rs, s16 off, s32 imm, bool is64, emit(is64 ? rv_amoadd_d(rs, rs, rd, 0, 0) : rv_amoadd_w(rs, rs, rd, 0, 0), ctx); if (!is64) - emit_zext_32(rs, ctx); + emit_zextw(rs, rs, ctx); break; case BPF_AND | BPF_FETCH: emit(is64 ? rv_amoand_d(rs, rs, rd, 0, 0) : rv_amoand_w(rs, rs, rd, 0, 0), ctx); if (!is64) - emit_zext_32(rs, ctx); + emit_zextw(rs, rs, ctx); break; case BPF_OR | BPF_FETCH: emit(is64 ? rv_amoor_d(rs, rs, rd, 0, 0) : rv_amoor_w(rs, rs, rd, 0, 0), ctx); if (!is64) - emit_zext_32(rs, ctx); + emit_zextw(rs, rs, ctx); break; case BPF_XOR | BPF_FETCH: emit(is64 ? rv_amoxor_d(rs, rs, rd, 0, 0) : rv_amoxor_w(rs, rs, rd, 0, 0), ctx); if (!is64) - emit_zext_32(rs, ctx); + emit_zextw(rs, rs, ctx); break; /* src_reg = atomic_xchg(dst_reg + off16, src_reg); */ case BPF_XCHG: emit(is64 ? rv_amoswap_d(rs, rs, rd, 0, 0) : rv_amoswap_w(rs, rs, rd, 0, 0), ctx); if (!is64) - emit_zext_32(rs, ctx); + emit_zextw(rs, rs, ctx); break; /* r0 = atomic_cmpxchg(dst_reg + off16, r0, src_reg); */ case BPF_CMPXCHG: @@ -1091,7 +1066,7 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx, case BPF_ALU64 | BPF_MOV | BPF_X: if (imm == 1) { /* Special mov32 for zext */ - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; } switch (insn->off) { @@ -1099,16 +1074,17 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx, emit_mv(rd, rs, ctx); break; case 8: + emit_sextb(rd, rs, ctx); + break; case 16: - emit_slli(RV_REG_T1, rs, 64 - insn->off, ctx); - emit_srai(rd, RV_REG_T1, 64 - insn->off, ctx); + emit_sexth(rd, rs, ctx); break; case 32: - emit_addiw(rd, rs, 0, ctx); + emit_sextw(rd, rs, ctx); break; } if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; /* dst = dst OP src */ @@ -1116,7 +1092,7 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx, case BPF_ALU64 | BPF_ADD | BPF_X: emit_add(rd, rd, rs, ctx); if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; case BPF_ALU | BPF_SUB | BPF_X: case BPF_ALU64 | BPF_SUB | BPF_X: @@ -1126,31 +1102,31 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx, emit_subw(rd, rd, rs, ctx); if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; case BPF_ALU | BPF_AND | BPF_X: case BPF_ALU64 | BPF_AND | BPF_X: emit_and(rd, rd, rs, ctx); if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; case BPF_ALU | BPF_OR | BPF_X: case BPF_ALU64 | BPF_OR | BPF_X: emit_or(rd, rd, rs, ctx); if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; case BPF_ALU | BPF_XOR | BPF_X: case BPF_ALU64 | BPF_XOR | BPF_X: emit_xor(rd, rd, rs, ctx); if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; case BPF_ALU | BPF_MUL | BPF_X: case BPF_ALU64 | BPF_MUL | BPF_X: emit(is64 ? rv_mul(rd, rd, rs) : rv_mulw(rd, rd, rs), ctx); if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; case BPF_ALU | BPF_DIV | BPF_X: case BPF_ALU64 | BPF_DIV | BPF_X: @@ -1159,7 +1135,7 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx, else emit(is64 ? rv_divu(rd, rd, rs) : rv_divuw(rd, rd, rs), ctx); if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; case BPF_ALU | BPF_MOD | BPF_X: case BPF_ALU64 | BPF_MOD | BPF_X: @@ -1168,25 +1144,25 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx, else emit(is64 ? rv_remu(rd, rd, rs) : rv_remuw(rd, rd, rs), ctx); if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; case BPF_ALU | BPF_LSH | BPF_X: case BPF_ALU64 | BPF_LSH | BPF_X: emit(is64 ? rv_sll(rd, rd, rs) : rv_sllw(rd, rd, rs), ctx); if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; case BPF_ALU | BPF_RSH | BPF_X: case BPF_ALU64 | BPF_RSH | BPF_X: emit(is64 ? rv_srl(rd, rd, rs) : rv_srlw(rd, rd, rs), ctx); if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; case BPF_ALU | BPF_ARSH | BPF_X: case BPF_ALU64 | BPF_ARSH | BPF_X: emit(is64 ? rv_sra(rd, rd, rs) : rv_sraw(rd, rd, rs), ctx); if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; /* dst = -dst */ @@ -1194,73 +1170,27 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx, case BPF_ALU64 | BPF_NEG: emit_sub(rd, RV_REG_ZERO, rd, ctx); if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; /* dst = BSWAP##imm(dst) */ case BPF_ALU | BPF_END | BPF_FROM_LE: switch (imm) { case 16: - emit_slli(rd, rd, 48, ctx); - emit_srli(rd, rd, 48, ctx); + emit_zexth(rd, rd, ctx); break; case 32: if (!aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; case 64: /* Do nothing */ break; } break; - case BPF_ALU | BPF_END | BPF_FROM_BE: case BPF_ALU64 | BPF_END | BPF_FROM_LE: - emit_li(RV_REG_T2, 0, ctx); - - emit_andi(RV_REG_T1, rd, 0xff, ctx); - emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx); - emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx); - emit_srli(rd, rd, 8, ctx); - if (imm == 16) - goto out_be; - - emit_andi(RV_REG_T1, rd, 0xff, ctx); - emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx); - emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx); - emit_srli(rd, rd, 8, ctx); - - emit_andi(RV_REG_T1, rd, 0xff, ctx); - emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx); - emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx); - emit_srli(rd, rd, 8, ctx); - if (imm == 32) - goto out_be; - - emit_andi(RV_REG_T1, rd, 0xff, ctx); - emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx); - emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx); - emit_srli(rd, rd, 8, ctx); - - emit_andi(RV_REG_T1, rd, 0xff, ctx); - emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx); - emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx); - emit_srli(rd, rd, 8, ctx); - - emit_andi(RV_REG_T1, rd, 0xff, ctx); - emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx); - emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx); - emit_srli(rd, rd, 8, ctx); - - emit_andi(RV_REG_T1, rd, 0xff, ctx); - emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx); - emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx); - emit_srli(rd, rd, 8, ctx); -out_be: - emit_andi(RV_REG_T1, rd, 0xff, ctx); - emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx); - - emit_mv(rd, RV_REG_T2, ctx); + emit_bswap(rd, imm, ctx); break; /* dst = imm */ @@ -1268,7 +1198,7 @@ out_be: case BPF_ALU64 | BPF_MOV | BPF_K: emit_imm(rd, imm, ctx); if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; /* dst = dst OP imm */ @@ -1281,7 +1211,7 @@ out_be: emit_add(rd, rd, RV_REG_T1, ctx); } if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; case BPF_ALU | BPF_SUB | BPF_K: case BPF_ALU64 | BPF_SUB | BPF_K: @@ -1292,7 +1222,7 @@ out_be: emit_sub(rd, rd, RV_REG_T1, ctx); } if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; case BPF_ALU | BPF_AND | BPF_K: case BPF_ALU64 | BPF_AND | BPF_K: @@ -1303,7 +1233,7 @@ out_be: emit_and(rd, rd, RV_REG_T1, ctx); } if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; case BPF_ALU | BPF_OR | BPF_K: case BPF_ALU64 | BPF_OR | BPF_K: @@ -1314,7 +1244,7 @@ out_be: emit_or(rd, rd, RV_REG_T1, ctx); } if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; case BPF_ALU | BPF_XOR | BPF_K: case BPF_ALU64 | BPF_XOR | BPF_K: @@ -1325,7 +1255,7 @@ out_be: emit_xor(rd, rd, RV_REG_T1, ctx); } if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; case BPF_ALU | BPF_MUL | BPF_K: case BPF_ALU64 | BPF_MUL | BPF_K: @@ -1333,7 +1263,7 @@ out_be: emit(is64 ? rv_mul(rd, rd, RV_REG_T1) : rv_mulw(rd, rd, RV_REG_T1), ctx); if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; case BPF_ALU | BPF_DIV | BPF_K: case BPF_ALU64 | BPF_DIV | BPF_K: @@ -1345,7 +1275,7 @@ out_be: emit(is64 ? rv_divu(rd, rd, RV_REG_T1) : rv_divuw(rd, rd, RV_REG_T1), ctx); if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; case BPF_ALU | BPF_MOD | BPF_K: case BPF_ALU64 | BPF_MOD | BPF_K: @@ -1357,14 +1287,14 @@ out_be: emit(is64 ? rv_remu(rd, rd, RV_REG_T1) : rv_remuw(rd, rd, RV_REG_T1), ctx); if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; case BPF_ALU | BPF_LSH | BPF_K: case BPF_ALU64 | BPF_LSH | BPF_K: emit_slli(rd, rd, imm, ctx); if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; case BPF_ALU | BPF_RSH | BPF_K: case BPF_ALU64 | BPF_RSH | BPF_K: @@ -1374,7 +1304,7 @@ out_be: emit(rv_srliw(rd, rd, imm), ctx); if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; case BPF_ALU | BPF_ARSH | BPF_K: case BPF_ALU64 | BPF_ARSH | BPF_K: @@ -1384,7 +1314,7 @@ out_be: emit(rv_sraiw(rd, rd, imm), ctx); if (!is64 && !aux->verifier_zext) - emit_zext_32(rd, ctx); + emit_zextw(rd, rd, ctx); break; /* JUMP off */ @@ -1425,10 +1355,13 @@ out_be: rvoff = rv_offset(i, off, ctx); if (!is64) { s = ctx->ninsns; - if (is_signed_bpf_cond(BPF_OP(code))) - emit_sext_32_rd_rs(&rd, &rs, ctx); - else - emit_zext_32_rd_rs(&rd, &rs, ctx); + if (is_signed_bpf_cond(BPF_OP(code))) { + emit_sextw_alt(&rs, RV_REG_T1, ctx); + emit_sextw_alt(&rd, RV_REG_T2, ctx); + } else { + emit_zextw_alt(&rs, RV_REG_T1, ctx); + emit_zextw_alt(&rd, RV_REG_T2, ctx); + } e = ctx->ninsns; /* Adjust for extra insns */ @@ -1439,8 +1372,7 @@ out_be: /* Adjust for and */ rvoff -= 4; emit_and(RV_REG_T1, rd, rs, ctx); - emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff, - ctx); + emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff, ctx); } else { emit_branch(BPF_OP(code), rd, rs, rvoff, ctx); } @@ -1469,18 +1401,18 @@ out_be: case BPF_JMP32 | BPF_JSLE | BPF_K: rvoff = rv_offset(i, off, ctx); s = ctx->ninsns; - if (imm) { + if (imm) emit_imm(RV_REG_T1, imm, ctx); - rs = RV_REG_T1; - } else { - /* If imm is 0, simply use zero register. */ - rs = RV_REG_ZERO; - } + rs = imm ? RV_REG_T1 : RV_REG_ZERO; if (!is64) { - if (is_signed_bpf_cond(BPF_OP(code))) - emit_sext_32_rd(&rd, ctx); - else - emit_zext_32_rd_t1(&rd, ctx); + if (is_signed_bpf_cond(BPF_OP(code))) { + emit_sextw_alt(&rd, RV_REG_T2, ctx); + /* rs has been sign extended */ + } else { + emit_zextw_alt(&rd, RV_REG_T2, ctx); + if (imm) + emit_zextw(rs, rs, ctx); + } } e = ctx->ninsns; @@ -1504,7 +1436,7 @@ out_be: * as t1 is used only in comparison against zero. */ if (!is64 && imm < 0) - emit_addiw(RV_REG_T1, RV_REG_T1, 0, ctx); + emit_sextw(RV_REG_T1, RV_REG_T1, ctx); e = ctx->ninsns; rvoff -= ninsns_rvoff(e - s); emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff, ctx); @@ -1874,3 +1806,8 @@ bool bpf_jit_supports_kfunc_call(void) { return true; } + +bool bpf_jit_supports_ptr_xchg(void) +{ + return true; +} |