added changes to accept spairs; merged #58

This commit is contained in:
Max Cahill 2022-05-23 14:48:52 +10:00
parent 199b8ded1d
commit 8a57ae458f
3 changed files with 42 additions and 32 deletions

View File

@ -6,7 +6,6 @@ package.path = package.path .. ";../?.lua"
local assert = require("batteries.assert") local assert = require("batteries.assert")
local tablex = require("batteries.tablex") local tablex = require("batteries.tablex")
-- tablex {{{ -- tablex {{{
local function test_shallow_copy() local function test_shallow_copy()
@ -47,8 +46,8 @@ local function test_shallow_overlay()
tablex.deep_equal( tablex.deep_equal(
r, r,
{ a = 1, b = 2, c = 8, d = 9 } { a = 1, b = 2, c = 8, d = 9 }
)
) )
)
x = { b = { 2 }, c = { 3 }, } x = { b = { 2 }, c = { 3 }, }
y = { c = { 8 }, d = { 9 }, } y = { c = { 8 }, d = { 9 }, }
@ -59,7 +58,9 @@ local function test_shallow_overlay()
assert( assert(
tablex.deep_equal( tablex.deep_equal(
r, r,
{ b = { 2 }, c = { 8 }, d = { 9 }, })) { b = { 2 }, c = { 8 }, d = { 9 }, }
)
)
end end
local function test_deep_overlay() local function test_deep_overlay()
@ -70,7 +71,9 @@ local function test_deep_overlay()
assert( assert(
tablex.deep_equal( tablex.deep_equal(
r, r,
{ a = 1, b = 2, c = 8, d = 9 })) { a = 1, b = 2, c = 8, d = 9 }
)
)
x = { a = { b = { 2 }, c = { 3 }, } } x = { a = { b = { 2 }, c = { 3 }, } }
y = { a = { c = { 8 }, d = { 9 }, } } y = { a = { c = { 8 }, d = { 9 }, } }
@ -78,7 +81,9 @@ local function test_deep_overlay()
assert( assert(
tablex.deep_equal( tablex.deep_equal(
r, r,
{ a = { b = { 2 }, c = { 8 }, d = { 9 }, } })) { a = { b = { 2 }, c = { 8 }, d = { 9 }, } }
)
)
end end
@ -135,20 +140,18 @@ local function test_spairs()
local sorted_names = {} local sorted_names = {}
local sorted_score = {} local sorted_score = {}
for k, v in tablex.spairs(t, function(t, a, b) for k, v in tablex.spairs(t, function(a, b)
return t[a].score > t[b].score return t[a].score > t[b].score
end) do end) do
tablex.push(sorted_names, v.name) tablex.push(sorted_names, v.name)
tablex.push(sorted_score, v.score) tablex.push(sorted_score, v.score)
end end
assert(tablex.deep_equal(sorted_names, assert(tablex.deep_equal(sorted_names, {
{
"John", "Joe", "Robert" "John", "Joe", "Robert"
})) }))
assert(tablex.deep_equal(sorted_score, assert(tablex.deep_equal(sorted_score, {
{
10, 8, 7 10, 8, 7
})) }))
end end

View File

@ -97,7 +97,16 @@ function sort._merge_sort_impl(array, workspace, low, high, less)
end end
--default comparison; hoisted for clarity --default comparison; hoisted for clarity
local _sorted_types = {
--a list of types that will be sorted by default_less
--provide a custom sort function to sort other types
["string"] = true,
["number"] = true,
}
local function default_less(a, b) local function default_less(a, b)
if not _sorted_types[type(a)] or not _sorted_types[type(b)] then
return false
end
return a < b return a < b
end end

View File

@ -2,13 +2,17 @@
extra table routines extra table routines
]] ]]
--apply prototype to module if it isn't the global table
--so it works "as if" it was the global table api
--upgraded with these routines
local path = (...):gsub("tablex", "") local path = (...):gsub("tablex", "")
local assert = require(path .. "assert") local assert = require(path .. "assert")
--for spairs
--(can be replaced with eg table.sort to use that instead)
local sort = require(path .. "sort")
local spairs_sort = sort.stable_sort
--apply prototype to module if it isn't the global table
--so it works "as if" it was the global table api
--upgraded with these routines
local tablex = setmetatable({}, { local tablex = setmetatable({}, {
__index = table, __index = table,
}) })
@ -78,11 +82,8 @@ function tablex.rotate(t, amount)
return t return t
end end
--default comparison; hoisted for clarity --default comparison from sort.lua
--(shared with sort.lua and suggests the sorted functions below should maybe be refactored there) local default_less = sort.default_less
local function default_less(a, b)
return a < b
end
--check if a function is sorted based on a "less" or "comes before" ordering comparison --check if a function is sorted based on a "less" or "comes before" ordering comparison
--if any item is "less" than the item before it, we are not sorted --if any item is "less" than the item before it, we are not sorted
@ -537,19 +538,16 @@ function tablex.ripairs(t)
return _ripairs_iter, t, #t + 1 return _ripairs_iter, t, #t + 1
end end
-- works like pairs, but returns sorted table --works like pairs, but returns sorted table
function tablex.spairs(t, fn) -- generates a fair bit of garbage but very nice for more stable output
local keys = {} -- less function gets keys the of the table as its argument; if you want to sort on the values they map to then
for k in pairs(t) do -- you'll likely need a closure
tablex.push(keys, k) function tablex.spairs(t, less)
end less = less or default_less
--gather the keys
local keys = tablex.keys(t)
if fn then spairs_sort(keys, less)
table.sort(keys, function(a,b) return fn(t, a, b) end)
else
-- sort by keys if no function passed
table.sort(keys)
end
local i = 0 local i = 0
return function() return function()