diff --git a/tablex.lua b/tablex.lua index 38de632..ebaa399 100644 --- a/tablex.lua +++ b/tablex.lua @@ -329,41 +329,40 @@ if not tablex.clear then end -- Copy a table --- See shallow_overlay to shallow copy into another table to avoid garbage. +-- See shallow_overlay to shallow copy into an existing table to avoid garbage. function tablex.shallow_copy(t) - assert:type(t, "table", "tablex.copy - t", 1) - local into = {} - for k, v in pairs(t) do - into[k] = v + if type(t) == "table" then + local into = {} + for k, v in pairs(t) do + into[k] = v + end + return into end - return into + return t end local function deep_copy(t, copied) - -- TODO: consider supporting deep_copy(3) so you can always use deep_copy without type checking - local into = {} - for k, v in pairs(t) do - local clone = v - if type(v) == "table" then - if copied[v] then - clone = copied[v] - elseif type(v.copy) == "function" then - clone = v:copy() - assert:type(clone, "table", "copy() didn't return a copy") - else - clone = deep_copy(v, copied) - setmetatable(clone, getmetatable(v)) + local clone = t + if type(t) == "table" then + if copied[t] then + clone = copied[t] + elseif type(t.copy) == "function" then + clone = t:copy() + assert:type(clone, "table", "copy() didn't return a copy") + else + clone = {} + for k, v in pairs(t) do + clone[k] = deep_copy(v, copied) end - copied[v] = clone + setmetatable(clone, getmetatable(t)) + copied[t] = clone end - into[k] = clone end - return into + return clone end -- Recursively copy values of a table. -- Retains the same keys as original table -- they're not cloned. function tablex.deep_copy(t) - assert:type(t, "table", "tablex.deep_copy - t", 1) return deep_copy(t, {}) end diff --git a/tests.lua b/tests.lua index 109deb4..57bcdb4 100644 --- a/tests.lua +++ b/tests.lua @@ -20,6 +20,10 @@ local function test_shallow_copy() x = { a = { b = { 2 }, c = { 3 }, } } r = tablex.shallow_copy(x) assert:equal(r.a, x.a) + + x = 10 + r = tablex.shallow_copy(x) + assert:equal(r, x) end local function test_deep_copy() @@ -35,6 +39,10 @@ local function test_deep_copy() assert(r.a ~= x.a) assert:equal(r.a.b[1], 2) assert:equal(r.a.c[1], 3) + + x = 10 + r = tablex.deep_copy(x) + assert:equal(r, x) end