call functions from dynamic shared object files
git clone http://git.nthia.dev/dlcall
// gcc main.c -ldl -O3 -o dlcall
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <dlfcn.h>
#include <stdint.h>
const char *HEX = "0123456789abcdef";
uint64_t call(void *(*f)(void *), size_t arg_len, uint64_t *args) {
uint64_t ret = 0;
__asm__(
"mov %2,%%r15\n"
"mov %3,%%rax\n"
"mov $0,%%rbx\n"
"cmp %%rbx,%%rax\n"
"jle f\n"
"mov (%%r15,%%rbx,8),%%rdi\n" // arg0
"add $1,%%rbx\n"
"cmp %%rbx,%%rax\n"
"jle f\n"
"mov (%%r15,%%rbx,8),%%rsi\n" // arg1
"add $1,%%rbx\n"
"cmp %%rbx,%%rax\n"
"jle f\n"
"mov (%%r15,%%rbx,8),%%rdx\n" // arg2
"add $1,%%rbx\n"
"cmp %%rbx,%%rax\n"
"jle f\n"
"mov (%%r15,%%rbx,8),%%rcx\n" // arg3
"add $1,%%rbx\n"
"cmp %%rbx,%%rax\n"
"jle f\n"
"mov (%%r15,%%rbx,8),%%r8\n" // arg4
"add $1,%%rbx\n"
"cmp %%rbx,%%rax\n"
"jle f\n"
"mov (%%r15,%%rbx,8),%%r9\n" // arg5
"mov %%rax,%%r14\n"
"imul $8,%%r14\n"
"loop:\n"
"add $1,%%rbx\n"
"cmp %%rbx,%%rax\n"
"jl f\n"
"push (%%r15,%%r14)\n" // argN
"sub $8,%%r14\n"
"jmp loop\n"
"f:\n"
"call *%1\n"
"mov %%rax,%0\n"
:"=m" (ret):"m" (f), "m" (args), "m" (arg_len)
:"rax","rbx","r14","r15"
);
return ret;
}
char *cmd_name(char *file) {
size_t l = strlen(file);
if (file[0] == '.') return file;
for (int i = l-1; i >= 0; i--) {
if (file[i] == '/') return file+i+1;
}
return file;
}
void show(char *srt, uint64_t ret) {
uint64_t intv;
if (strcmp(srt,"u8") == 0) {
printf("%d\n", (uint8_t) ret);
} else if (strcmp(srt,"i8") == 0) {
printf("%d\n", (int8_t) ret);
} else if (strcmp(srt,"u16") == 0) {
printf("%d\n", (uint16_t) ret);
} else if (strcmp(srt,"i16") == 0) {
printf("%d\n", (int16_t) ret);
} else if (strcmp(srt,"u32") == 0) {
printf("%d\n", (uint32_t) ret);
} else if (strcmp(srt,"i32") == 0) {
printf("%d\n", (int32_t) ret);
} else if (strcmp(srt,"str") == 0) {
printf("%s\n", (char *) ret);
} else if (sscanf(srt, "[%lu]", &intv) == 1) {
unsigned char *data = (unsigned char *) ret;
char *str = malloc((size_t) (intv*2+1));
size_t offset = 0;
for (int i = 0; i < intv; i++) {
unsigned char d = data[i];
str[offset++] = HEX[d/16];
str[offset++] = HEX[d%16];
}
str[offset++] = 0;
printf("%s\n", str);
} else if (strcmp(srt,"void") != 0) {
printf("%lu\n", ret);
}
}
uint8_t parse_hex_digit(unsigned char d) {
if (d >= 48 && d <= 57) return d-48;
else if (d >= 97 && d <= 102) return d-87;
else if (d >= 65 && d <= 70) return d-55;
else return 0;
}
char *parse_hex(char *data) {
size_t out_len = (strlen(data)+1)/2 + 1;
char *out = malloc(out_len*sizeof(char));
size_t offset = 0;
for (int i = 0; data[i]; i+=2) {
out[offset++] = parse_hex_digit(data[i])*16 + parse_hex_digit(data[i+1]);
}
return out;
}
int main(int argc, char **argv) {
if (argc < 2) {
dprintf(1, "usage: %s SO_FILE SYMBOL [ARGUMENTS...]\n", cmd_name(argv[0]));
return 1;
}
char *so_file = argv[1];
char *symbol, *srt, *error;
size_t n = sscanf(argv[2], "%m[^:]:%m[^:]", &srt, &symbol);
if (n != 2) {
symbol = argv[2];
srt = "void";
}
void *so = dlopen(so_file, RTLD_LAZY);
if ((error = dlerror()) != NULL) {
dprintf(1, "error in dlopen: %s\n", error);
return 1;
}
void *(*fn)(void *) = dlsym(so, symbol);
if ((error = dlerror()) != NULL) {
dprintf(1, "error in dlsyn for symbol \"%s\": %s\n", symbol, error);
return 1;
}
int br = 0;
uint64_t intv;
uint64_t *args = malloc((argc-3)*sizeof(uint64_t*));
char **arg_types = malloc((argc-3)*sizeof(char*));
size_t *print_args = malloc((argc-3)*sizeof(size_t));
size_t print_arg_len = 0;
size_t arg_len = 0;
for (int i = 3; i < argc; i++) {
if (!br && strcmp(argv[i],"--") == 0) {
br = 1;
continue;
}
if (sscanf(argv[i], "[%lu]", &intv) == 1) {
arg_types[arg_len] = argv[i];
char *data = malloc(intv*sizeof(uint8_t));
args[arg_len++] = (uint64_t) data;
} else if (strncmp(argv[i], "hex:", 4) == 0) {
char *data = parse_hex(argv[i]+4);
size_t type_len = strlen(data)+1;
char *type = malloc(type_len*sizeof(char));
snprintf(type, type_len, "[%lu]", (size_t) strlen(data));
arg_types[arg_len] = type;
args[arg_len++] = (uint64_t) data;
} else if (strncmp(argv[i], "str:", 4) == 0) {
char *data = argv[i]+4;
size_t type_len = strlen(data)+1;
char *type = malloc(type_len*sizeof(char));
snprintf(type, type_len, "[%lu]", (size_t) strlen(data));
arg_types[arg_len] = type;
args[arg_len++] = (uint64_t) data;
} else if (sscanf(argv[i], "-%lu", &intv) == 1) {
arg_types[arg_len] = "i64";
args[arg_len++] = (int64_t) -intv;
} else if (sscanf(argv[i], "%lu", &intv) == 1) {
arg_types[arg_len] = "u64";
args[arg_len++] = intv;
} else if (!br && strncmp(argv[i],"--",2) == 0) {
dprintf(1, "todo: %s\n", argv[i]);
return 1;
} else if (!br && sscanf(argv[i], "-p%lu", &intv) == 1) {
size_t j = ((intv % arg_len) + arg_len) % arg_len;
print_args[print_arg_len++] = j;
} else if (!br && strcmp(argv[i],"-p") == 0 && i+1 < argc) {
size_t j = ((atoi(argv[++i]) % arg_len) + arg_len) % arg_len;
print_args[print_arg_len++] = j;
} else {
dprintf(1, "unexpected argument: %s\n", argv[i]);
return 1;
}
}
uint64_t ret = call(fn, arg_len, args);
show(srt, ret);
for (size_t i = 0; i < print_arg_len; i++) {
size_t j = print_args[i];
show(arg_types[j], args[j]);
}
dlclose(so);
free(args);
return 0;
}