added changes to accept spairs; merged #58

This commit is contained in:
Max Cahill 2022-05-23 14:48:52 +10:00
parent 199b8ded1d
commit 8a57ae458f
3 changed files with 42 additions and 32 deletions

View File

@ -6,7 +6,6 @@ package.path = package.path .. ";../?.lua"
local assert = require("batteries.assert")
local tablex = require("batteries.tablex")
-- tablex {{{
local function test_shallow_copy()
@ -59,7 +58,9 @@ local function test_shallow_overlay()
assert(
tablex.deep_equal(
r,
{ b = { 2 }, c = { 8 }, d = { 9 }, }))
{ b = { 2 }, c = { 8 }, d = { 9 }, }
)
)
end
local function test_deep_overlay()
@ -70,7 +71,9 @@ local function test_deep_overlay()
assert(
tablex.deep_equal(
r,
{ a = 1, b = 2, c = 8, d = 9 }))
{ a = 1, b = 2, c = 8, d = 9 }
)
)
x = { a = { b = { 2 }, c = { 3 }, } }
y = { a = { c = { 8 }, d = { 9 }, } }
@ -78,7 +81,9 @@ local function test_deep_overlay()
assert(
tablex.deep_equal(
r,
{ a = { b = { 2 }, c = { 8 }, d = { 9 }, } }))
{ a = { b = { 2 }, c = { 8 }, d = { 9 }, } }
)
)
end
@ -135,20 +140,18 @@ local function test_spairs()
local sorted_names = {}
local sorted_score = {}
for k, v in tablex.spairs(t, function(t, a, b)
for k, v in tablex.spairs(t, function(a, b)
return t[a].score > t[b].score
end) do
tablex.push(sorted_names, v.name)
tablex.push(sorted_score, v.score)
end
assert(tablex.deep_equal(sorted_names,
{
assert(tablex.deep_equal(sorted_names, {
"John", "Joe", "Robert"
}))
assert(tablex.deep_equal(sorted_score,
{
assert(tablex.deep_equal(sorted_score, {
10, 8, 7
}))
end

View File

@ -97,7 +97,16 @@ function sort._merge_sort_impl(array, workspace, low, high, less)
end
--default comparison; hoisted for clarity
local _sorted_types = {
--a list of types that will be sorted by default_less
--provide a custom sort function to sort other types
["string"] = true,
["number"] = true,
}
local function default_less(a, b)
if not _sorted_types[type(a)] or not _sorted_types[type(b)] then
return false
end
return a < b
end

View File

@ -2,13 +2,17 @@
extra table routines
]]
--apply prototype to module if it isn't the global table
--so it works "as if" it was the global table api
--upgraded with these routines
local path = (...):gsub("tablex", "")
local assert = require(path .. "assert")
--for spairs
--(can be replaced with eg table.sort to use that instead)
local sort = require(path .. "sort")
local spairs_sort = sort.stable_sort
--apply prototype to module if it isn't the global table
--so it works "as if" it was the global table api
--upgraded with these routines
local tablex = setmetatable({}, {
__index = table,
})
@ -78,11 +82,8 @@ function tablex.rotate(t, amount)
return t
end
--default comparison; hoisted for clarity
--(shared with sort.lua and suggests the sorted functions below should maybe be refactored there)
local function default_less(a, b)
return a < b
end
--default comparison from sort.lua
local default_less = sort.default_less
--check if a function is sorted based on a "less" or "comes before" ordering comparison
--if any item is "less" than the item before it, we are not sorted
@ -538,18 +539,15 @@ function tablex.ripairs(t)
end
--works like pairs, but returns sorted table
function tablex.spairs(t, fn)
local keys = {}
for k in pairs(t) do
tablex.push(keys, k)
end
-- generates a fair bit of garbage but very nice for more stable output
-- less function gets keys the of the table as its argument; if you want to sort on the values they map to then
-- you'll likely need a closure
function tablex.spairs(t, less)
less = less or default_less
--gather the keys
local keys = tablex.keys(t)
if fn then
table.sort(keys, function(a,b) return fn(t, a, b) end)
else
-- sort by keys if no function passed
table.sort(keys)
end
spairs_sort(keys, less)
local i = 0
return function()