diff --git a/lib/atomically/query_service.rb b/lib/atomically/query_service.rb index e0cceb4..94d8d4b 100644 --- a/lib/atomically/query_service.rb +++ b/lib/atomically/query_service.rb @@ -88,7 +88,7 @@ def update_all_and_get_ids(*args) @klass.transaction do @relation.connection.execute('SET @ids := NULL') @relation.where("(SELECT @ids := CONCAT_WS(',', #{id_column}, @ids))").update_all(*args) # 撈出有真的被更新的 id,用逗號串在一起 - ids = @klass.from(nil).pluck(Arel.sql('@ids')).first + ids = @klass.from(Arel.sql('DUAL')).pluck(Arel.sql('@ids')).first end return ids.try{|s| s.split(',').map(&:to_i).uniq.sort } || [] # 將 id 從字串取出來 @id 的格式範例: '1,4,12' end diff --git a/test/test_helper.rb b/test/test_helper.rb index a43ca39..7fb3ccd 100644 --- a/test/test_helper.rb +++ b/test/test_helper.rb @@ -51,3 +51,21 @@ def assert_queries(expected_count, event_key = 'sql.active_record') ensure ActiveSupport::Notifications.unsubscribe(subscriber) end + +def assert_sqls(expected_sqls, event_key = 'sql.active_record') + sqls = [] + subscriber = ActiveSupport::Notifications.subscribe(event_key) do |_, _, _, _, payload| + next if payload[:sql].start_with?('PRAGMA table_info') + next if payload[:sql] =~ /\A(?:BEGIN TRANSACTION|COMMIT TRANSACTION|BEGIN|COMMIT)\z/i + + sqls << payload[:sql] + end + yield + + missing_sqls = expected_sqls - sqls + if missing_sqls.any? + assert_equal "expect #{expected_sqls} queried, but query following sqls:\n#{sqls.join("\n").tr('"', "'")}\n", "\nmissing sqls:\n#{missing_sqls.join("\n").tr('"', "'")}\n" + end +ensure + ActiveSupport::Notifications.unsubscribe(subscriber) +end diff --git a/test/update_all_and_get_ids_test.rb b/test/update_all_and_get_ids_test.rb index 8df33a2..3e9b469 100644 --- a/test/update_all_and_get_ids_test.rb +++ b/test/update_all_and_get_ids_test.rb @@ -65,4 +65,15 @@ def test_on_relation_and_with_race_condition assert_equal ['bomb', '', 'flame thrower'], Item.order('id').pluck(:name) end end + + def test_select_from_dual + skip if not Atomically::AdapterCheckService.new(UserItem).mysql? + + in_sandbox do + assert_sqls(['SELECT @ids FROM DUAL']) do + assert_equal [1, 2], Item.joins(:users).atomically.update_all_and_get_ids('items.name = ""') + assert_equal ['', '', 'flame thrower'], Item.order('id').pluck(:name) + end + end + end end