#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <unistd.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <sys/types.h>
#include <sys/ioctl.h>
#include <netinet/in.h>
#include <sys/wait.h>
#include <sys/resource.h>
#include <arpa/inet.h>
#include <assert.h>
#include <ctype.h>
#include <fcntl.h>
#include <signal.h>

unsigned char session_id[8];


int s = -1;

int
readn(void *bufx, int n)
{
  char *buf = bufx;
  while(n > 0){
    int cc = read(s, buf, n);
    if(cc <= 0)
      return -1;
    n -= cc;
    buf += cc;
  }
  return 0;
}

int
readmsg(void *bufx)
{
  unsigned char *buf = bufx;
  if(readn(buf, 4) < 0)
    return -1;
  int n = (buf[2] << 8) | buf[3];
  if(readn(buf+4, n) < 0)
    return -1;
  return n + 4;
}

int
header(char *buf, int command)
{
  int ii = 0;

  unsigned int status = 0;
  unsigned int flags = 1; // 1=REDIRECT

  // SMB-over-TCP 4-byte header
  buf[ii++] = 0; // must be zero
  buf[ii++] = 0; // high byte of len
  *(short*)(buf+ii) = htons(sizeof(buf)-4); // non-inclusive len, filled later
  ii += 2;
  
  // SMB2 SYNC Packet Header, MS-SMB2 2.2.1.2
  buf[ii++] = 0xfe;
  buf[ii++] = 'S';
  buf[ii++] = 'M';
  buf[ii++] = 'B';
  *(short*)(buf+ii) = 64; // StructureSize (of SMB2 header)
  ii += 2;
  *(short*)(buf+ii) = 0; // CreditCharge
  ii += 2;
  *(int*)(buf+ii) = status; // Status
  ii += 4;
  *(short*)(buf+ii) = command;
  ii += 2;
  *(short*)(buf+ii) = 6; // CreditRequest
  ii += 2;
  *(int*)(buf+ii) = flags; // Flags 
  ii += 4;
  *(int*)(buf+ii) = 0; // NextCommand
  ii += 4;
  static unsigned long long seq = 0;
  *(long long *)(buf+ii) = seq++; // MessageId
  ii += 8;
  *(int*)(buf+ii) = 0; // Reserved
  ii += 4;
  *(int*)(buf+ii) = 0; // TreeId
  ii += 4;
  memcpy(buf+ii, session_id, 8); // SessionId
  ii += 8;
  memset(buf+ii, 'd', 16); // Signature
  ii += 16;
  
  return ii;
}

void
negotiate()
{
  // SMB2 NEGOTIATE, MS-SMB2 2.2.3
  char buf[128+64];
  memset(buf, 1, sizeof(buf));

  int ii = header(buf, 0x0000);
  int ii0 = ii;
  
  *(short*)(buf+ii) = 65; // StructureSize
  ii += 2;
  *(short*)(buf+ii) = 0; // SecurityMode
  ii += 2;
  *(short*)(buf+ii) = 0x0311; // DialectRevision
  ii += 2;
  *(short*)(buf+ii) = 1; // NegotiateContextCount
  ii += 2;
  ii += 16; // ServerGuid from MS-DTYP
  *(int*)(buf+ii) = 0x3e; // Capabilities
  ii += 4;
  *(int*)(buf+ii) = 8192; // MaxTransactSize
  ii += 4;
  *(int*)(buf+ii) = 8192; // MaxReadSize
  ii += 4;
  *(int*)(buf+ii) = 8192; // MaxWriteSize
  ii += 4;
  *(long long *)(buf+ii) = 1; // SystemTime
  ii += 8;
  *(long long *)(buf+ii) = 1; // ServerStartTime
  ii += 8;
  *(short*)(buf+ii) = 0; // SecurityBufferOffset
  ii += 2;
  *(short*)(buf+ii) = 0; // SecurityBufferLength
  ii += 2;
  int context_offset_i = ii;
  *(int*)(buf+ii) = 0; // NegotiateContextOffset
  ii += 4;

  // SecurityBuffer

  while((ii - 4) % 8)
    ii++;

  // NegotiateContextList
  *(int*)(buf+context_offset_i) = ii - 4; // NegotiateContextOffset
  
  {
    *(short*)(buf+ii) = 0x0001; // SMB2_PREAUTH_INTEGRITY_CAPABILITIES
    ii += 2;
    int data_length_i = ii;
    *(short*)(buf+ii) = 0; // DataLength
    ii += 2;
    *(int*)(buf+ii) = 0; // Reserved
    ii += 4;
    int data_field_i = ii;
    *(short*)(buf+ii) = 1; // HashAlgorithmCount
    ii += 2;
    *(short*)(buf+ii) = 2; // SaltLength
    ii += 2;
    *(short*)(buf+ii) = 1; // SHA-512
    ii += 2;
    *(short*)(buf+ii) = 1; // Salt
    ii += 2;
    *(short*)(buf+data_length_i) = ii - data_field_i;
  }
  
  assert(ii <= sizeof(buf));
  
  ii = sizeof(buf);
  
  *(short*)(buf+2) = htons(ii-4); // non-inclusive len
  
  printf("negotiate writing %d bytes\n", ii); fflush(stdout);
  
  if(write(s, buf, ii) <= 0){
    perror("write");
  }
}

void
setup()
{
  // SMB2 SESSION_SETUP, MS-SMB2 2.2.6
  char buf[128+64];
  memset(buf, 1, sizeof(buf));

  session_id[7] = 1;

  int ii = header(buf, 0x0001);
  int ii0 = ii;
  
  *(short*)(buf+ii) = 9; // StructureSize
  ii += 2;
  *(short*)(buf+ii) = 1; // SessionFlags, 1=guest
  ii += 2;
  *(short*)(buf+ii) = (ii - 4 + 4); // SecurityBufferOffset
  ii += 2;
  *(short*)(buf+ii) = 16; // SecurityBufferLength
  ii += 2;

  memcpy(buf+ii, "NTLMSSP", 8);
  
  assert(ii <= sizeof(buf));
  
  ii = sizeof(buf);
  
  *(short*)(buf+2) = htons(ii-4); // non-inclusive len
  
  printf("setup writing %d bytes\n", ii); fflush(stdout);
  
  if(write(s, buf, ii) <= 0){
    perror("write");
  }
}

void
tree_connect()
{
  // SMB2 TREE_CONNECT, MS-SMB2 2.2.10
  char buf[128+64];
  memset(buf, 1, sizeof(buf));

  int ii = header(buf, 0x0003);
  int ii0 = ii;
  
  *(short*)(buf+ii) = 16; // StructureSize
  ii += 2;
  buf[ii++] = 0x01; // ShareType, 1=disk, 2=pipe, 3=print
  buf[ii++] = 0; // Reserved
  *(int*)(buf+ii) = 0x33; // ShareFlags
  ii += 4;
  *(int*)(buf+ii) = 0x1f8; // Capabilities
  ii += 4;
  *(int*)(buf+ii) = 0xffffffff; // MaximalAccess
  ii += 4;
  
  assert(ii <= sizeof(buf));
  
  ii = sizeof(buf);
  
  *(short*)(buf+2) = htons(ii-4); // non-inclusive len
  
  printf("tree_connect writing %d bytes\n", ii); fflush(stdout);
  
  if(write(s, buf, ii) <= 0){
    perror("write");
  }
}

void
smb_ioctl()
{
  // SMB2 IOCTL, MS-SMB2 2.2.32
  char buf[128+64];
  memset(buf, 1, sizeof(buf));

  int ii = header(buf, 0x000b);
  int ii0 = ii;
  
  *(short*)(buf+ii) = 49; // StructureSize
  ii += 2;
  *(short*)(buf+ii) = 0; // Reserved
  ii += 2;
  *(int*)(buf+ii) = 0; // CtlCode
  ii += 4;
  ii += 16; // FileId
  *(int*)(buf+ii) = 0; // InputOffset
  ii += 4;
  *(int*)(buf+ii) = 0; // InputCount
  ii += 4;
  *(int*)(buf+ii) = (ii-4+16); // OutputOffset
  ii += 4;
  *(int*)(buf+ii) = 4; // OutputCount
  ii += 4;
  *(int*)(buf+ii) = 0; // Flags
  ii += 4;
  *(int*)(buf+ii) = 0; // Reserved
  ii += 4;

  *(short*)(buf+ii) = 2; // consumed_ucs
  ii += 2;
  *(short*)(buf+ii) = 0; // num_referrals
  ii += 2;
  buf[ii++] = 0;
  buf[ii++] = 'x';

  assert(ii <= sizeof(buf));
  
  ii = sizeof(buf);
  
  *(short*)(buf+2) = htons(ii-4); // non-inclusive len
  
  printf("ioctl writing %d bytes\n", ii); fflush(stdout);
  
  if(write(s, buf, ii) <= 0){
    perror("write");
  }
}

void
smb_create()
{
  // SMB2 CREATE, MS-SMB2 2.2.14
  char buf[128+64];
  memset(buf, 0xff, sizeof(buf));

  int ii = header(buf, 0x0005);
  int ii0 = ii;
  
  *(short*)(buf+ii) = 89; // StructureSize
  ii += 2;
  buf[ii++] = 0; // OplockLevel
  buf[ii++] = 0; // Flags
  *(int*)(buf+ii) = 0; // CreateAction
  ii += 4;
  *(long long*)(buf+ii) = 1; // CreationTime
  ii += 8;
  *(long long*)(buf+ii) = 1; // LastAccessTime
  ii += 8;
  *(long long*)(buf+ii) = 1; // LastWriteTime
  ii += 8;
  *(long long*)(buf+ii) = 1; // ChangeTime
  ii += 8;
  *(long long*)(buf+ii) = 1; // AllocationSize
  ii += 8;
  *(long long*)(buf+ii) = 1; // EndofFile
  ii += 8;
  // 0x10 directory
  *(int*)(buf+ii) = 0x10; // FileAttributes
  ii += 4;
  *(int*)(buf+ii) = 0; // Reserved2
  ii += 4;
  ii += 16; // FileId
  *(int*)(buf+ii) = (ii - 4 + 8); // CreateContextsOffset
  ii += 4;
  *(int*)(buf+ii) = 32; // CreateContextsLength
  ii += 4;

  // one context
  *(int*)(buf+ii) = 0; // Next
  ii += 4;
  *(short*)(buf+ii) = 16; // NameOffset
  ii += 2;
  *(short*)(buf+ii) = 4; // NameLength
  ii += 2;
  *(short*)(buf+ii) = 0; // Reserved
  ii += 2;
  *(short*)(buf+ii) = 24; // DataOffset
  ii += 2;
  *(int*)(buf+ii) = 8; // DataLength
  ii += 4;
  

  // SMB2_CREATE_QUERY_MAXIMAL_ACCESS_RESPONSE
  *(int*)(buf+ii) = 0x4d784163; // Name
  ii += 4;
  ii += 4; // pad
  *(int*)(buf+ii) = 0; // QueryStatus
  ii += 4;
  *(int*)(buf+ii) = 0xffffffff; // MaximalAccess
  ii += 4;

  assert(ii <= sizeof(buf));
  
  ii = sizeof(buf);
  
  *(short*)(buf+2) = htons(ii-4); // non-inclusive len
  
  printf("create writing %d bytes\n", ii); fflush(stdout);
  
  if(write(s, buf, ii) <= 0){
    perror("write");
  }
}

void
smb_close()
{
  char buf[128+64];
  memset(buf, 0xff, sizeof(buf));

  int ii = header(buf, 0x0006);
  int ii0 = ii;
  
  *(short*)(buf+ii) = 60; // StructureSize
  ii += 2;
  *(short*)(buf+ii) = 0; // Flags
  ii += 2;
  *(int*)(buf+ii) = 0; // Reserved
  ii += 4;
  *(long long*)(buf+ii) = 1; // CreationTime
  ii += 8;
  *(long long*)(buf+ii) = 1; // LastAccessTime
  ii += 8;
  *(long long*)(buf+ii) = 1; // LastWriteTime
  ii += 8;
  *(long long*)(buf+ii) = 1; // ChangeTime
  ii += 8;
  *(long long*)(buf+ii) = 1; // AllocationSize
  ii += 8;
  *(long long*)(buf+ii) = 1; // EndofFile
  ii += 8;
  // 0x10 directory
  *(int*)(buf+ii) = 0x10; // FileAttributes
  ii += 4;

  assert(ii <= sizeof(buf));
  
  ii = sizeof(buf);
  
  *(short*)(buf+2) = htons(ii-4); // non-inclusive len
  
  printf("close writing %d bytes\n", ii); fflush(stdout);
  
  if(write(s, buf, ii) <= 0){
    perror("write");
  }
}
  
void
query_directory(int any)
{
  char buf[256+128];
  memset(buf, 0xff, sizeof(buf));

  int ii = header(buf, 0x000e);
  int ii0 = ii;

  if(any == 0){
    *(int*)(buf+12) = 0x80000006; // STATUS_NO_MORE_FILES
  }
  
  *(short*)(buf+ii) = 9; // StructureSize
  ii += 2;
  *(short*)(buf+ii) = (ii - 4 + 6); // OutputBufferOffset
  ii += 2;
  *(int*)(buf+ii) = any ? 106 : 0; // OutputBufferLength
  ii += 4;

  // probably FileBothDirectoryInformation, MS-FSCC 2.4.8
  *(int*)(buf+ii) = 0; // NextEntryOffset
  ii += 4;
  *(int*)(buf+ii) = 1; // FileIndex
  ii += 4;
  *(long long*)(buf+ii) = 1; // CreationTime
  ii += 8;
  *(long long*)(buf+ii) = 1; // LastAccessTime
  ii += 8;
  *(long long*)(buf+ii) = 1; // LastWriteTime
  ii += 8;
  *(long long*)(buf+ii) = 1; // ChangeTime
  ii += 8;
  *(long long*)(buf+ii) = 1; // EndofFile
  ii += 8;
  *(long long*)(buf+ii) = 1; // AllocationSize
  ii += 8;
  // 0x10 directory
  *(int*)(buf+ii) = 0x10; // FileAttributes
  ii += 4;
  *(int*)(buf+ii) = 2; // FileNameLength
  ii += 4;
  *(int*)(buf+ii) = 0; // EaSize
  ii += 4;
  buf[ii++] = 2; // ShortNameLength
  buf[ii++] = 0; // Reserved
  buf[ii++] = 'f'; // ShortName
  buf[ii++] = 0;
  ii += 22;
  buf[ii++] = 'f'; // FileName
  buf[ii++] = 0;

  assert(ii <= sizeof(buf));
  
  ii = sizeof(buf);
  
  *(short*)(buf+2) = htons(ii-4); // non-inclusive len
  
  printf("query_directory writing %d bytes\n", ii); fflush(stdout);
  
  if(write(s, buf, ii) <= 0){
    perror("write");
  }
}
  
void
query_info()
{
  char buf[256+128];
  memset(buf, 0xff, sizeof(buf));

  int ii = header(buf, 0x0010);
  int ii0 = ii;

  *(short*)(buf+ii) = 9; // StructureSize
  ii += 2;
  *(short*)(buf+ii) = (ii - 4 + 6); // OutputBufferOffset
  ii += 2;
  *(int*)(buf+ii) = 106; // OutputBufferLength
  ii += 4;

  // probably FileBothDirectoryInformation, MS-FSCC 2.4.8
  *(int*)(buf+ii) = 0; // NextEntryOffset
  ii += 4;
  *(int*)(buf+ii) = 1; // FileIndex
  ii += 4;
  *(long long*)(buf+ii) = 1; // CreationTime
  ii += 8;
  *(long long*)(buf+ii) = 1; // LastAccessTime
  ii += 8;
  *(long long*)(buf+ii) = 1; // LastWriteTime
  ii += 8;
  *(long long*)(buf+ii) = 1; // ChangeTime
  ii += 8;
  *(long long*)(buf+ii) = 1; // EndofFile
  ii += 8;
  *(long long*)(buf+ii) = 1; // AllocationSize
  ii += 8;
  // 0x10 directory
  *(int*)(buf+ii) = 0x10; // FileAttributes
  ii += 4;
  *(int*)(buf+ii) = 2; // FileNameLength
  ii += 4;
  *(int*)(buf+ii) = 0; // EaSize
  ii += 4;
  buf[ii++] = 2; // ShortNameLength
  buf[ii++] = 0; // Reserved
  buf[ii++] = 'f'; // ShortName
  buf[ii++] = 0;
  ii += 22;
  buf[ii++] = 'f'; // FileName
  buf[ii++] = 0;

  assert(ii <= sizeof(buf));
  
  ii = sizeof(buf);
  
  *(short*)(buf+2) = htons(ii-4); // non-inclusive len
  
  printf("query_info writing %d bytes\n", ii); fflush(stdout);
  
  if(write(s, buf, ii) <= 0){
    perror("write");
  }
}
  
void
smb_read()
{
  char buf[256];
  memset(buf, 0, sizeof(buf));

  int ii = header(buf, 0x0008);
  int ii0 = ii;

  *(short*)(buf+ii) = 17; // StructureSize
  ii += 2;
  buf[ii] = ii - 4 + 14; // DataOffset=80
  ii++;
  buf[ii++] = 0; // Reserved
  *(int*)(buf+ii) = 64; // DataLength
  ii += 4;
  *(int*)(buf+ii) = 0; // DataRemaining
  ii += 4;
  *(int*)(buf+ii) = 0; // Flags
  ii += 4;
  
  //
  // broken RPC reply
  //
  buf[ii++] = 0x10;
  buf[ii++] = 0x00;
  buf[ii++] = 0x00;
  buf[ii++] = 0x00;
  buf[ii++] = 0x10;

  *(short*)(buf+2) = htons(sizeof(buf)-4); // non-inclusive len
  
  printf("read writing %lu bytes\n", sizeof(buf)); fflush(stdout);

  if(write(s, buf, sizeof(buf)) <= 0){
    perror("write");
  }
}

char *command_names[] = {
  "NEGOTIATE",
  "SESSION_SETUP",
  "LOGOFF",
  "TREE_CONNECT",
  "TREE_DISCONNECT",
  "CREATE",
  "CLOSE",
  "FLUSH",
  "READ",
  "WRITE",
  "LOCK",
  "IOCTL",
  "CANCEL",
  "ECHO",
  "QUERY_DIRECTORY",
  "CHANGE_NOTIFY",
  "QUERY_INFO",
  "SET_INFO",
  "OPLOCK_BREAK",
  "0x0012",
};

int
main()
{
  int pid = -1;

  signal(SIGPIPE, SIG_IGN);

  struct sockaddr_in sin;
  memset(&sin, 0, sizeof(sin));
  sin.sin_family = AF_INET;
  sin.sin_port = htons(445); // SMB over TCP

  int ss = socket(AF_INET, SOCK_STREAM, 0);
  int yes = 1;
  setsockopt(ss, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(yes));
  if(bind(ss, (struct sockaddr *)&sin, sizeof(sin)) < 0){
    perror("bind");
    exit(1);
  }
  listen(ss, 10);

  socklen_t sinlen = sizeof(sin);
  s = accept(ss, (struct sockaddr *)&sin, &sinlen);
  if(s < 0){
    perror("accept");
    exit(1);
  }
  close(ss);

  sleep(1);

  unsigned int command;
  int n;
  char ibuf[2048];

  memset(ibuf, 0, sizeof(ibuf));
  n = readmsg(ibuf);
  command = *(short*)(ibuf+16);
  printf("read %d, command %d %s\n", n, command, command_names[command]);
  negotiate();

  memset(ibuf, 0, sizeof(ibuf));
  n = readmsg(ibuf);
  command = *(short*)(ibuf+16);
  printf("read %d, command %d %s\n", n, command, command_names[command]);
  if(command == 1){
    setup();
  }

  memset(ibuf, 0, sizeof(ibuf));
  n = readmsg(ibuf);
  command = *(short*)(ibuf+16);
  printf("read %d, command %d %s\n", n, command, command_names[command]);
  if(command == 3){
    tree_connect();
  }

  while(1){
    memset(ibuf, 0, sizeof(ibuf));
    n = readmsg(ibuf);
    if(n < 0)
      break;
    command = *(short*)(ibuf+16);
    printf("command 0x%02x %s\n", command, command_names[command]);
    if(command == 0xb){
      smb_ioctl();
    } else if(command == 5){
      smb_create();
    } else if(command == 6){
      smb_close();
    } else if(command == 0xe){
      static int nn = 0;
      query_directory(nn == 0 ? 1 : 0);
      nn++;
    } else if(command == 0x10){
      query_info();
    } else if(command == 0x08){
      smb_read();
      break;
    } else {
      break;
    }
  }
  
  sleep(1);
  close(s);

  int st = 0;
  int xpid = wait(&st);
  if(!WIFEXITED(st)){ printf("child %d crashed, wanted %d, st %d\n", xpid, pid, st); while(1){} }
}
