#!/usr/bin/env lua

local ignore = { }
for _, i in ipairs(arg) do ignore[i] = true end

local stringf = string.format
local function fprintf(f, ...) return f:write(string.format(...)) end

----------------------------------------------------------------
-- table utilities

function table.set(t) -- set of list
  local u = { }
  for _, v in ipairs(t) do u[v] = true end
  return u
end

function table.sorted_keys(t, lt)
  local u = { }
  for k in pairs(t) do table.insert(u, k) end
  table.sort(u, lt)
  return u
end

----------------------------------------------------------------

local oldcontexts, context = { }, 'document'
local function pushcontext(c)
  --io.stderr:write(string.rep('  ', #oldcontexts), 'PUSH ', c, '\n')
  table.insert(oldcontexts, context)
  context = c
end

local function popcontext(c)
  --io.stderr:write(string.rep('  ', #oldcontexts-1), 'POP  ', c, '\n')
  assert(context == c, 'Current context ' .. context .. ' is not popped context ' .. c)
  context = assert(table.remove(oldcontexts))
end

----------------------------------------------------------------

local defnrefs = { }  -- maps ident to pageref or lineref of its defn
local defcontext, usecontext = { }, { }
local linelabel = { } -- line number of definition if any
local marker = { } -- how the defn was marked
local uses = { } -- set of identifiers used
local typeof = { } -- type of identifier if given
local omitted = { } -- 
local reserved =
  table.set { 'where', 'let', 'in', '_', '|', '->', '--', '=', '=>',
              'if', 'then', 'else', 'data', 'type', 'newtype', 'module',
              'class', 'instance', 'family',
              'import', 'case', 'of', ';', '{', '}', 'do', '<-', 'forall',
              '...', -- not really a reserved word, but treated as such
              '@', -- ditto
              'goto', -- from low-level code examples
            }

for id in io.lines 'hsprelude' do
  defnrefs[id] = 'Prelude'
  defcontext[id] = 'prelude'
end

local multiples = { }

local specialnames = { [ '\\' ] = ':bs', ['%'] = ':pe', [':'] = ':co', ['_'] = ':un' }

local function labelquote(id)
  return (id:gsub('%W', specialnames))
end

local function labelstring(id)
  return stringf('haskell.def.%s', labelquote(id))
end

local function uselabelstring(id)
  return stringf('haskell.firstuse.%s', labelquote(id))
end

local figcontexts = table.set { 'figure', 'figure*', 'table', 'table*' }
local figwaiting, figusewaiting = { }, { }


local def_written = { }

local function write_defn(id)
  if def_written[id] then
    assert(marker[id] == '^', id .. " multiply defined with marker " .. marker[id])
  elseif figcontexts[context] then
    table.insert(figwaiting, id)
  else
    def_written[id] = true
    -- fprintf(io.stderr, 'DEF %q (%s)\n', id, context)
    io.write(stringf([[\label{%s}]], labelstring(id)), '% automated definition\n')
  end
end

local codecontexts =
  table.set { 'code', 'smallcode', 'smallttcode', 'ttcode',
              'numberedcode', 'fuzzcode', 'smallfuzzcode' }
local codeusewaiting = { }

local function write_use(id)
  if figcontexts[context] then
    table.insert(figusewaiting, id)
  elseif codecontexts[context] then
    table.insert(codeusewaiting, id)
  else
    -- fprintf(io.stderr, 'USE %q (%s)\n', id, context)
    io.write(stringf([[\label{%s}]], uselabelstring(id)), '% automated use\n')
  end
end

local function add_defn(d, linelab, mark)
  if defnrefs[d] and mark == '`' or marker[d] == '`' then
    multiples[d] = true
  else
    defnrefs[d] = labelstring(d)
    defcontext[d] = context
    linelabel[d] = linelab
    marker[d] = mark
  end
end

local function add_use(u)
  if not uses[u] then
    uses[u] = true
    usecontext[u] = context
    write_use(u)
  end
end

local function add_omit(d, ty)
  add_defn(d)
  omitted[d] = true
  typeof[d] = ty
end

----------------------------------------------------------------

local symclass = '[%!%#%$%%%&%*%+%.%/%<%=%>%?%@%\\%^%|%-%~]'
local wordclass = "[%w_'%.]"
local sym = symclass .. '+'
local word = wordclass .. '*' .. wordclass:gsub('%%%.', '')
--word = wordclass .. '+'
local symbacktick = sym:gsub('^%[', '[%`%^')
local wordbacktick = word:gsub('%[', '[%`%^')

--io.stderr:write(word, '\n')
--io.stderr:write(wordbacktick, '\n')

local function is_literal(id)
  return id:find '^%d+$'
end

local function is_comment(id)
  return id:find '^%-%-+$'
end

local function is_kind(id)
  return id == '*' or id and is_kind(id:match('^%*%-%>(.*)$'))
end

local function find_uses(s)
  s = s:gsub([[\_]], '_') -- copes with \_ in at-sign land
  for id in s:gmatch(word) do
    if not (reserved[id] or ignore[id] or is_literal(id)) then add_use(id) end
  end
  for id in s:gmatch(sym) do
    if not (reserved[id] or ignore[id] or is_kind(id) or is_comment(id)) then add_use(id) end
  end
end


local comment_pat = '%s+%-%-[%s%w]'
local comment_line_pat = '^%-%-.*$'

local function strip_stringlit(s)
  local q = s:find '"'
  local c = s:find(comment_pat)
  if q and (not c or q < c) then
    s = s:gsub('".-"', ''):gsub("'.-'", "") -- misses escaped quote
    return strip_stringlit(s)
  else
    return s
  end
end


local function strip_comment(c)
  c = strip_stringlit(c)
  c = c:gsub('^%s*%.%.%.%s+.-%s+%.%.%.%s*$', '')
  return (c:gsub(comment_line_pat, ''):gsub(comment_pat .. '.*$', ''))
end


local function add_file_uses(filename)
  io.stderr:write('============ uses in ', filename, '=============\n')
  local f = assert(io.open(filename) or io.open(filename .. '.tex'),
                   "Cannot open " .. filename)
  for l in f:lines() do
    find_uses(strip_comment(l))
  end
  f:close()
  return
end
----------------------------------------------------------------

local function find_definitions(defs, label)
  local function got_def(d)
    d = d:gsub([[\_]], '_') -- copes with \_ in at-sign land
    -- fprintf(io.stderr, "GOT? %q\n", d)
    if d:find ('^[%`%^]' .. word .. '$') or d:find ('^[%`%^]' .. sym .. '$') then
--      fprintf(io.stderr, "Matched definition %q\n", d)
      local marker, d = d:match '^(.)(.*)$'
      table.insert(defs, d)
      add_defn(d, label, marker)
      return d
    else
      return d
    end
  end
  return function(s)
           if false and s:find '[`%^]' then
             io.stderr:write('DEFN? ', s, '\n')
           end
           s = s:gsub('[`%^]' .. symbacktick, got_def)
           s = s:gsub('[`%^]' .. wordbacktick, got_def)
           return s
         end
end

local function find_definitions_ats(defs)
  return function(s)
           return '@' .. find_definitions(defs)(s) .. '@'
         end
end

local function strip_def_markers(s)
  local function got_def(d)
    if d:find ('^[%`%^]' .. word .. '$') or d:find ('^[%`%^]' .. sym .. '$') then
      local marker, d = d:match '^(.)(.*)$'
      return d
    else
      return d
    end
  end
  s = s:gsub('[`%^]' .. symbacktick, got_def)
  s = s:gsub('[`%^]' .. wordbacktick, got_def)
  return s
end




local function process_at_signs(l)
  local defns = { }
  if l:find '^%%' then
    io.write(l, '\n')
  else
    l = l:gsub('[%s%%]%%.*', ''):gsub('@(.-)@', find_definitions_ats(defns))
    io.write(l, '\n')
    for _, d in ipairs(defns) do
      write_defn(d)
    end
    l:gsub('@(.-)@', find_uses)
  end
end

----------------------------------------------------------------

local in_document = false

local special = { }
local endspecial = { }


function special.document(line)
  in_document = true
  io.write(line, '\n')
end

function special.code(line, env)
  local defns = { }
  local endpat = stringf([[^\end{%s}]], env)
  io.write(line, '\n')
  line = io.read()
  while not line:find(endpat) do
    -- io.stderr:write('CODE ', line, '\n')
    local label, code = line:match '^!(.-)!(.*)$'
    local prefix = ''
    if label then
      prefix = '!' .. label .. '!'
    else
      code = line
    end
    line = find_definitions(defns, label)(strip_comment(code))
    pushcontext(env)
    find_uses(strip_comment(code))
    popcontext(env)
    io.write(prefix, strip_def_markers(code), '\n')
    line = io.read()
  end
  io.write(line, '\n')
  for _, d in ipairs(defns) do
    write_defn(d)
  end
  while #codeusewaiting > 0 do
    write_use(table.remove(codeusewaiting))
  end
end

function special.verbatim(line, env)
  io.write(line, '\n')
  pushcontext 'verbatim'
end
function endspecial.verbatim(line, env)
  io.write(line, '\n')
  popcontext 'verbatim'
end
for _, env in ipairs { 'smallverbatim', 'alltt' } do
  special[env] = special.verbatim
  endspecial[env] = endspecial.verbatim
end

for c in pairs(codecontexts) do
  special[c] = special[c] or special.code
end

function special.figure(line, env)
  io.write(line, '\n')
  pushcontext 'figure'
end

function endspecial.figure(line, env)
  popcontext 'figure'
  while #figwaiting > 0 do
    write_defn(table.remove(figwaiting))
  end
  while #figusewaiting > 0 do
    write_use(table.remove(figusewaiting))
  end
  io.write(line, '\n')
end

for _, env in ipairs { 'figure*', 'table', 'table*' } do
  special   [env] = special.figure
  endspecial[env] = endspecial.figure
end

local line = io.read()
while line do
  if line:find 'input.*dfoptdu' then -- skip me
    io.write('% ====> THIS IS A DERIVED FILE; DO NOT EDIT IT <======\n')
    io.write('% ===========> SIMON, THIS MEANS YOU! <===============\n')
  else
    local env      = line:match [[^%s*\begin(%b{})]]
    env = env and env:sub(2, -2)
    local endenv   = line:match [[^%s*\end{(.*)}]]
    local def      = line:match '^%%%s*defn%s+(%S+)%s*$'
    local localdef = line:match '^%%%s*local%s+(%S+)%s*$'
    local omit, ty = line:match '^%%%s*omit%s+(%S+)%s+::%s+(.-)%s*$'
    local input    = line:match [[^%s*\verbatiminput{(.*)}]] or
                     line:match [[^%s*\smallverbatiminput{(.*)}]]
                     line:match [[^%s*\smallfuzzverbatiminput{.-}{(.*)}]]
if line:match 'fuzzcode' then
--  fprintf(io.stderr, "%s: env = %s, endenv = %s\n", line, env or '<nil>', endenv or '<nil>')
end
    if env and special[env] then
      special[env](line, env)
    elseif endenv and special[endenv] then
      if endspecial[endenv] then
        endspecial[endenv](line, endenv)
      else
        io.write(line, '\n')
      end
    elseif context == 'verbatim' then
      io.write(line, '\n')  -- do not look for defs or uses in verbatim environments
    elseif def then
      add_defn(def, nil, '`')
      write_defn(def)
    elseif localdef then
      add_defn(localdef, nil, '^')
      write_defn(localdef)
    elseif omit then
      add_omit(omit, ty)
      write_defn(omit)
    elseif input then
      add_file_uses(input)
      io.write(line, '\n')
    elseif in_document then
      process_at_signs(line)
    else
      io.write(line, '\n')
    end
  end
  line = io.read()
end

local function caselt(s1, s2)
  local l1, l2 = s1:gsub('_', ''):lower(), s2:gsub('_', ''):lower()
  if l1 == l2 then
    return s1 < s2 -- upper case first
  else
    return l1 < l2
  end
end

local defseq = { figure = 'hsfigdef', table = 'hstabdef', prelude = 'hsprelude',
                 numberedcode = 'hslinedef', document = 'hspagedef' }
for k, v in pairs(defseq) do defseq[k .. '*'] = v end
setmetatable(defseq, { __index = function() return 'hspagedef' end })


local function escape_specials(s)
  return (s:gsub([[\]], [[\char`\\]]):gsub('[%&%_%^%$%%]', '\\%1'))
end

local suffix = { ['`'] = '', ['^'] = 'll' }

local f = io.open('defuse.tex', 'w')
for _, k in ipairs(table.sorted_keys(defnrefs, caselt)) do
  local full_k = escape_specials(k)
  if typeof[k] == '*' then
    full_k = full_k .. [[\textrm{~(a~type)}]]
  elseif typeof[k] and typeof[k]:find '%*' then
    full_k = full_k .. [[\textrm{~(a~type of kind }]] .. typeof[k] .. [[\textrm)]]
  elseif typeof[k] then
    full_k = full_k .. ' :: ' .. typeof[k]
  end
  if omitted[k] then
    fprintf(f, [[\omit%s%s{%s}{%s}%s]],
            defseq[defcontext[k]] .. suffix[marker[k] or '`'],
            linelabel[k] and '[' .. linelabel[k] .. ']' or '',
            full_k,
            defnrefs[k], '% context ' .. (defcontext[k] or '??') .. '\n')
  else
    fprintf(f, [[\%s%s{%s}{%s}%s]], defseq[defcontext[k]] .. suffix[marker[k] or '`'],
            linelabel[k] and '[' .. linelabel[k] .. ']' or '',
            full_k,
            defnrefs[k], '% context ' .. (defcontext[k] or '??') .. '\n')
  end
end

local function nontrivial(id)
  return id:len() > 1 and not id:find('^[%lL]%d$') and not id:find("^[%lL]'$")
end

for _, k in ipairs(table.sorted_keys(uses, caselt)) do
  if not defnrefs[k] and nontrivial(k) then
    fprintf(f, [[\not%s{%s}{%s}%s]], defseq[usecontext[k]],
            escape_specials(k), uselabelstring(k), '% used but not defined\n')
  end
end
f:close()

local had_mult = false
for _, m in ipairs(table.sorted_keys(multiples)) do
  fprintf(io.stderr, 'Label %q multiply defined.\n', m)
  had_mult = true
end

if had_mult then os.exit(1) end
