diff --git a/.test/tests.lua b/.test/tests.lua index a8cd0c3..6a4f1f6 100644 --- a/.test/tests.lua +++ b/.test/tests.lua @@ -6,7 +6,6 @@ package.path = package.path .. ";../?.lua" local assert = require("batteries.assert") local tablex = require("batteries.tablex") - -- tablex {{{ local function test_shallow_copy() @@ -47,8 +46,8 @@ local function test_shallow_overlay() tablex.deep_equal( r, { a = 1, b = 2, c = 8, d = 9 } - ) ) + ) x = { b = { 2 }, c = { 3 }, } y = { c = { 8 }, d = { 9 }, } @@ -59,7 +58,9 @@ local function test_shallow_overlay() assert( tablex.deep_equal( r, - { b = { 2 }, c = { 8 }, d = { 9 }, })) + { b = { 2 }, c = { 8 }, d = { 9 }, } + ) + ) end local function test_deep_overlay() @@ -70,7 +71,9 @@ local function test_deep_overlay() assert( tablex.deep_equal( r, - { a = 1, b = 2, c = 8, d = 9 })) + { a = 1, b = 2, c = 8, d = 9 } + ) + ) x = { a = { b = { 2 }, c = { 3 }, } } y = { a = { c = { 8 }, d = { 9 }, } } @@ -78,7 +81,9 @@ local function test_deep_overlay() assert( tablex.deep_equal( r, - { a = { b = { 2 }, c = { 8 }, d = { 9 }, } })) + { a = { b = { 2 }, c = { 8 }, d = { 9 }, } } + ) + ) end @@ -135,20 +140,18 @@ local function test_spairs() local sorted_names = {} local sorted_score = {} - for k, v in tablex.spairs(t, function(t, a, b) + for k, v in tablex.spairs(t, function(a, b) return t[a].score > t[b].score end) do tablex.push(sorted_names, v.name) tablex.push(sorted_score, v.score) end - assert(tablex.deep_equal(sorted_names, - { + assert(tablex.deep_equal(sorted_names, { "John", "Joe", "Robert" })) - assert(tablex.deep_equal(sorted_score, - { + assert(tablex.deep_equal(sorted_score, { 10, 8, 7 })) -end \ No newline at end of file +end diff --git a/sort.lua b/sort.lua index 6c421e5..b05fe02 100644 --- a/sort.lua +++ b/sort.lua @@ -97,7 +97,16 @@ function sort._merge_sort_impl(array, workspace, low, high, less) end --default comparison; hoisted for clarity +local _sorted_types = { + --a list of types that will be sorted by default_less + --provide a custom sort function to sort other types + ["string"] = true, + ["number"] = true, +} local function default_less(a, b) + if not _sorted_types[type(a)] or not _sorted_types[type(b)] then + return false + end return a < b end diff --git a/tablex.lua b/tablex.lua index 8624dfb..29ce078 100644 --- a/tablex.lua +++ b/tablex.lua @@ -2,13 +2,17 @@ extra table routines ]] ---apply prototype to module if it isn't the global table ---so it works "as if" it was the global table api ---upgraded with these routines - local path = (...):gsub("tablex", "") local assert = require(path .. "assert") +--for spairs +--(can be replaced with eg table.sort to use that instead) +local sort = require(path .. "sort") +local spairs_sort = sort.stable_sort + +--apply prototype to module if it isn't the global table +--so it works "as if" it was the global table api +--upgraded with these routines local tablex = setmetatable({}, { __index = table, }) @@ -78,11 +82,8 @@ function tablex.rotate(t, amount) return t end ---default comparison; hoisted for clarity ---(shared with sort.lua and suggests the sorted functions below should maybe be refactored there) -local function default_less(a, b) - return a < b -end +--default comparison from sort.lua +local default_less = sort.default_less --check if a function is sorted based on a "less" or "comes before" ordering comparison --if any item is "less" than the item before it, we are not sorted @@ -537,19 +538,16 @@ function tablex.ripairs(t) return _ripairs_iter, t, #t + 1 end --- works like pairs, but returns sorted table -function tablex.spairs(t, fn) - local keys = {} - for k in pairs(t) do - tablex.push(keys, k) - end +--works like pairs, but returns sorted table +-- generates a fair bit of garbage but very nice for more stable output +-- less function gets keys the of the table as its argument; if you want to sort on the values they map to then +-- you'll likely need a closure +function tablex.spairs(t, less) + less = less or default_less + --gather the keys + local keys = tablex.keys(t) - if fn then - table.sort(keys, function(a,b) return fn(t, a, b) end) - else - -- sort by keys if no function passed - table.sort(keys) - end + spairs_sort(keys, less) local i = 0 return function()