dlcall / main.c
call functions from dynamic shared object files
git clone http://git.nthia.dev/dlcall

// gcc main.c -ldl -O3 -o dlcall
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <dlfcn.h>
#include <stdint.h>

const char *HEX = "0123456789abcdef";

uint64_t call(void *(*f)(void *), size_t arg_len, uint64_t *args) {
  uint64_t ret = 0;
  __asm__(
    "mov %2,%%r15\n"
    "mov %3,%%rax\n"
    "mov $0,%%rbx\n"
    "cmp %%rbx,%%rax\n"
    "jle f\n"
    "mov (%%r15,%%rbx,8),%%rdi\n" // arg0
    "add $1,%%rbx\n"
    "cmp %%rbx,%%rax\n"
    "jle f\n"
    "mov (%%r15,%%rbx,8),%%rsi\n" // arg1
    "add $1,%%rbx\n"
    "cmp %%rbx,%%rax\n"
    "jle f\n"
    "mov (%%r15,%%rbx,8),%%rdx\n" // arg2
    "add $1,%%rbx\n"
    "cmp %%rbx,%%rax\n"
    "jle f\n"
    "mov (%%r15,%%rbx,8),%%rcx\n" // arg3
    "add $1,%%rbx\n"
    "cmp %%rbx,%%rax\n"
    "jle f\n"
    "mov (%%r15,%%rbx,8),%%r8\n" // arg4
    "add $1,%%rbx\n"
    "cmp %%rbx,%%rax\n"
    "jle f\n"
    "mov (%%r15,%%rbx,8),%%r9\n" // arg5
    "mov %%rax,%%r14\n"
    "imul $8,%%r14\n"
    "loop:\n"
    "add $1,%%rbx\n"
    "cmp %%rbx,%%rax\n"
    "jl f\n"
    "push (%%r15,%%r14)\n" // argN
    "sub $8,%%r14\n"
    "jmp loop\n"
    "f:\n"
    "call *%1\n"
    "mov %%rax,%0\n"
    :"=m" (ret):"m" (f), "m" (args), "m" (arg_len)
    :"rax","rbx","r14","r15"
  );
  return ret;
}

char *cmd_name(char *file) {
  size_t l = strlen(file);
  if (file[0] == '.') return file;
  for (int i = l-1; i >= 0; i--) {
    if (file[i] == '/') return file+i+1;
  }
  return file;
}

void show(char *srt, uint64_t ret) {
  uint64_t intv;
  if (strcmp(srt,"u8") == 0) {
    printf("%d\n", (uint8_t) ret);
  } else if (strcmp(srt,"i8") == 0) {
    printf("%d\n", (int8_t) ret);
  } else if (strcmp(srt,"u16") == 0) {
    printf("%d\n", (uint16_t) ret);
  } else if (strcmp(srt,"i16") == 0) {
    printf("%d\n", (int16_t) ret);
  } else if (strcmp(srt,"u32") == 0) {
    printf("%d\n", (uint32_t) ret);
  } else if (strcmp(srt,"i32") == 0) {
    printf("%d\n", (int32_t) ret);
  } else if (strcmp(srt,"str") == 0) {
    printf("%s\n", (char *) ret);
  } else if (sscanf(srt, "[%lu]", &intv) == 1) {
    unsigned char *data = (unsigned char *) ret;
    char *str = malloc((size_t) (intv*2+1));
    size_t offset = 0;
    for (int i = 0; i < intv; i++) {
      unsigned char d = data[i];
      str[offset++] = HEX[d/16];
      str[offset++] = HEX[d%16];
    }
    str[offset++] = 0;
    printf("%s\n", str);
  } else if (strcmp(srt,"void") != 0) {
    printf("%lu\n", ret);
  }
}

uint8_t parse_hex_digit(unsigned char d) {
  if (d >= 48 && d <= 57) return d-48;
  else if (d >= 97 && d <= 102) return d-87;
  else if (d >= 65 && d <= 70) return d-55;
  else return 0;
}
char *parse_hex(char *data) {
  size_t out_len = (strlen(data)+1)/2 + 1;
  char *out = malloc(out_len*sizeof(char));
  size_t offset = 0;
  for (int i = 0; data[i]; i+=2) {
    out[offset++] = parse_hex_digit(data[i])*16 + parse_hex_digit(data[i+1]);
  }
  return out;
}

int main(int argc, char **argv) {
  if (argc < 2) {
    dprintf(1, "usage: %s SO_FILE SYMBOL [ARGUMENTS...]\n", cmd_name(argv[0]));
    return 1;
  }
  char *so_file = argv[1];
  char *symbol, *srt, *error;
  size_t n = sscanf(argv[2], "%m[^:]:%m[^:]", &srt, &symbol);
  if (n != 2) {
    symbol = argv[2];
    srt = "void";
  }
  void *so = dlopen(so_file, RTLD_LAZY);
  if ((error = dlerror()) != NULL)  {
    dprintf(1, "error in dlopen: %s\n", error);
    return 1;
  }
  void *(*fn)(void *) = dlsym(so, symbol);
  if ((error = dlerror()) != NULL)  {
    dprintf(1, "error in dlsyn for symbol \"%s\": %s\n", symbol, error);
    return 1;
  }
  int br = 0;
  uint64_t intv;
  uint64_t *args = malloc((argc-3)*sizeof(uint64_t*));
  char **arg_types = malloc((argc-3)*sizeof(char*));
  size_t *print_args = malloc((argc-3)*sizeof(size_t));
  size_t print_arg_len = 0;
  size_t arg_len = 0;
  for (int i = 3; i < argc; i++) {
    if (!br && strcmp(argv[i],"--") == 0) {
      br = 1;
      continue;
    }
    if (sscanf(argv[i], "[%lu]", &intv) == 1) {
      arg_types[arg_len] = argv[i];
      char *data = malloc(intv*sizeof(uint8_t));
      args[arg_len++] = (uint64_t) data;
    } else if (strncmp(argv[i], "hex:", 4) == 0) {
      char *data = parse_hex(argv[i]+4);
      size_t type_len = strlen(data)+1;
      char *type = malloc(type_len*sizeof(char));
      snprintf(type, type_len, "[%lu]", (size_t) strlen(data));
      arg_types[arg_len] = type;
      args[arg_len++] = (uint64_t) data;
    } else if (strncmp(argv[i], "str:", 4) == 0) {
      char *data = argv[i]+4;
      size_t type_len = strlen(data)+1;
      char *type = malloc(type_len*sizeof(char));
      snprintf(type, type_len, "[%lu]", (size_t) strlen(data));
      arg_types[arg_len] = type;
      args[arg_len++] = (uint64_t) data;
    } else if (sscanf(argv[i], "-%lu", &intv) == 1) {
      arg_types[arg_len] = "i64";
      args[arg_len++] = (int64_t) -intv;
    } else if (sscanf(argv[i], "%lu", &intv) == 1) {
      arg_types[arg_len] = "u64";
      args[arg_len++] = intv;
    } else if (!br && strncmp(argv[i],"--",2) == 0) {
      dprintf(1, "todo: %s\n", argv[i]);
      return 1;
    } else if (!br && sscanf(argv[i], "-p%lu", &intv) == 1) {
      size_t j = ((intv % arg_len) + arg_len) % arg_len;
      print_args[print_arg_len++] = j;
    } else if (!br && strcmp(argv[i],"-p") == 0 && i+1 < argc) {
      size_t j = ((atoi(argv[++i]) % arg_len) + arg_len) % arg_len;
      print_args[print_arg_len++] = j;
    } else {
      dprintf(1, "unexpected argument: %s\n", argv[i]);
      return 1;
    }
  }
  uint64_t ret = call(fn, arg_len, args);
  show(srt, ret);
  for (size_t i = 0; i < print_arg_len; i++) {
    size_t j = print_args[i];
    show(arg_types[j], args[j]);
  }
  dlclose(so);
  free(args);
  return 0;
}