@@ -24,7 +24,7 @@ import           Data.Generics
2424import            Data.Hashable 
2525import            Data.HashSet                           (HashSet )
2626import  qualified  Data.HashSet                           as  HS 
27- import            Data.List.Extra 
27+ import            Data.List.Extra                         hiding  ( length ) 
2828import  qualified  Data.Map                               as  M 
2929import            Data.Maybe 
3030import            Data.Mod.Word 
@@ -42,7 +42,6 @@ import           Development.IDE.GHC.ExactPrint
4242import            Development.IDE.Spans.AtPoint 
4343import            Development.IDE.Types.Location 
4444import            HieDb.Query 
45- import            Ide.Plugin.Config 
4645import            Ide.Plugin.Properties 
4746import            Ide.PluginUtils 
4847import            Ide.Types 
@@ -65,16 +64,28 @@ descriptor pluginId = (defaultPluginDescriptor pluginId)
6564renameProvider  ::  PluginMethodHandler  IdeState  TextDocumentRename 
6665renameProvider state pluginId (RenameParams  (TextDocumentIdentifier  uri) pos _prog newNameText) = 
6766    pluginResponse $  do 
68-         nfp <-  safeUriToNfp uri
69-         oldName <-  getNameAtPos state nfp pos
70-         refLocs <-  refsAtName state nfp oldName
67+         nfp <-  handleUriToNfp uri
68+         directOldNames <-  getNamesAtPos state nfp pos
69+         directRefs <-  concat  <$>  mapM  (refsAtName state nfp) directOldNames
70+ 
71+         {-  References in HieDB are not necessarily transitive. With `NamedFieldPuns`, we can have
72+            indirect references through punned names. To find the transitive closure, we do a pass of 
73+            the direct references to find the references for any punned names. 
74+            See the `IndirectPuns` test for an example. -}  
75+         indirectOldNames <-  concat  .  filter  ((> 1 ) .  Prelude. length ) <$> 
76+             mapM  (uncurry  (getNamesAtPos state) .  locToFilePos) directRefs
77+         let  oldNames =  indirectOldNames ++  directOldNames
78+         refs <-  HS. fromList .  concat  <$>  mapM  (refsAtName state nfp) oldNames
79+ 
80+         --  Validate rename
7181        crossModuleEnabled <-  lift $  usePropertyLsp # crossModule pluginId properties
72-         unless crossModuleEnabled $  failWhenImportOrExport state nfp refLocs oldName
73-         when (isBuiltInSyntax oldName) $ 
74-             throwE (" Invalid rename of built-in syntax: \" " ++  showName oldName ++  " \" " 
82+         unless crossModuleEnabled $  failWhenImportOrExport state nfp refs oldNames
83+         when (any  isBuiltInSyntax oldNames) $  throwE " Invalid rename of built-in syntax" 
84+ 
85+         --  Perform rename
7586        let  newName =  mkTcOcc $  T. unpack newNameText
76-             filesRefs =  collectWith locToUri refLocs 
77-             getFileEdit =  flip  $  getSrcEdit state .  renameRefs  newName
87+             filesRefs =  collectWith locToUri refs 
88+             getFileEdit =  flip  $  getSrcEdit state .  replaceRefs  newName
7889        fileEdits <-  mapM  (uncurry  getFileEdit) filesRefs
7990        pure  $  foldl' (<>)  mempty  fileEdits
8091
@@ -84,16 +95,16 @@ failWhenImportOrExport ::
8495    IdeState  -> 
8596    NormalizedFilePath  -> 
8697    HashSet  Location  -> 
87-     Name  -> 
98+     [ Name ]  -> 
8899    ExceptT  String m  () 
89- failWhenImportOrExport state nfp refLocs name  =  do 
100+ failWhenImportOrExport state nfp refLocs names  =  do 
90101    pm <-  handleMaybeM (" No parsed module for: " ++  show  nfp) $  liftIO $  runAction
91102        " Rename.GetParsedModule" 
92103        state
93104        (use GetParsedModule  nfp)
94105    let  hsMod =  unLoc $  pm_parsed_source pm
95106    case  (unLoc <$>  hsmodName hsMod, hsmodExports hsMod) of 
96-         (mbModName, _) |  not  $  nameIsLocalOrFrom (replaceModName name  mbModName) name 
107+         (mbModName, _) |  not  $  any  ( \ n  ->   nameIsLocalOrFrom (replaceModName n  mbModName) n) names 
97108            ->  throwE " Renaming of an imported name is unsupported" 
98109        (_, Just  (L  _ exports)) |  any  ((`HS.member`  refLocs) .  unsafeSrcSpanToLoc .  getLoc) exports
99110            ->  throwE " Renaming of an exported name is unsupported" 
@@ -112,7 +123,7 @@ getSrcEdit ::
112123    ExceptT  String m  WorkspaceEdit 
113124getSrcEdit state updatePs uri =  do 
114125    ccs <-  lift getClientCapabilities
115-     nfp <-  safeUriToNfp  uri
126+     nfp <-  handleUriToNfp  uri
116127    annAst <-  handleMaybeM (" No parsed source for: " ++  show  nfp) $  liftIO $  runAction
117128        " Rename.GetAnnotatedParsedSource" 
118129        state
@@ -128,13 +139,13 @@ getSrcEdit state updatePs uri = do
128139    pure  $  diffText ccs (uri, src) res IncludeDeletions 
129140
130141--  |  Replace names at every given `Location` (in a given `ParsedSource`) with a given new name. 
131- renameRefs  :: 
142+ replaceRefs  :: 
132143    OccName  -> 
133144    HashSet  Location  -> 
134145    ParsedSource  -> 
135146    ParsedSource 
136147#if  MIN_VERSION_ghc(9,2,1)
137- renameRefs  newName refs =  everywhere $ 
148+ replaceRefs  newName refs =  everywhere $ 
138149    --  there has to be a better way...
139150    mkT (replaceLoc @ AnnListItem ) `extT` 
140151    --  replaceLoc @AnnList `extT` -- not needed
@@ -149,14 +160,13 @@ renameRefs newName refs = everywhere $
149160            |  isRef (locA srcSpan) =  L  srcSpan $  replace oldRdrName
150161        replaceLoc lOldRdrName =  lOldRdrName
151162#else 
152- renameRefs  newName refs =  everywhere $  mkT replaceLoc
163+ replaceRefs  newName refs =  everywhere $  mkT replaceLoc
153164    where 
154165        replaceLoc  ::  Located  RdrName  ->  Located  RdrName 
155166        replaceLoc (L  srcSpan oldRdrName)
156167            |  isRef srcSpan =  L  srcSpan $  replace oldRdrName
157168        replaceLoc lOldRdrName =  lOldRdrName
158169#endif 
159- 
160170        replace  ::  RdrName  ->  RdrName 
161171        replace (Qual  modName _) =  Qual  modName newName
162172        replace _                =  Unqual  newName
@@ -173,10 +183,10 @@ refsAtName ::
173183    IdeState  -> 
174184    NormalizedFilePath  -> 
175185    Name  -> 
176-     ExceptT  String m  ( HashSet   Location ) 
186+     ExceptT  String m  [ Location ] 
177187refsAtName state nfp name =  do 
178188    ShakeExtras {withHieDb} <-  liftIO $  runAction " Rename.HieDb" 
179-     ast <-  safeGetHieAst  state nfp
189+     ast <-  handleGetHieAst  state nfp
180190    dbRefs <-  case  nameModule_maybe name of 
181191        Nothing  ->  pure  [] 
182192        Just  mod  ->  liftIO $  mapMaybe rowToLoc <$>  withHieDb (\ hieDb -> 
@@ -188,32 +198,32 @@ refsAtName state nfp name = do
188198                (Just  $  moduleUnit mod )
189199                [fromNormalizedFilePath nfp]
190200            )
191-     pure  $  HS. fromList  $  getNameLocs  name ast ++  dbRefs
201+     pure  $  nameLocs  name ast ++  dbRefs
192202
193- getNameLocs  ::  Name  ->  (HieAstResult , PositionMapping ) ->  [Location ]
194- getNameLocs  name (HAR  _ _ rm _ _, pm) = 
203+ nameLocs  ::  Name  ->  (HieAstResult , PositionMapping ) ->  [Location ]
204+ nameLocs  name (HAR  _ _ rm _ _, pm) = 
195205    mapMaybe (toCurrentLocation pm .  realSrcSpanToLocation .  fst )
196206             (concat  $  M. lookup  (Right 
197207
198208--------------------------------------------------------------------------------------------------- 
199209--  Util
200210
201- getNameAtPos  ::  IdeState  ->  NormalizedFilePath  ->  Position  ->  ExceptT  String ( LspT   Config   IO )  Name 
202- getNameAtPos  state nfp pos =  do 
203-     (HAR {hieAst}, pm) <-  safeGetHieAst  state nfp
204-     handleMaybe ( " No name at  "   ++  showPos pos)  $  listToMaybe  $  getNamesAtPoint hieAst pos pm
211+ getNamesAtPos  ::  MonadIO   m   =>   IdeState  ->  NormalizedFilePath  ->  Position  ->  ExceptT  String m  [ Name ] 
212+ getNamesAtPos  state nfp pos =  do 
213+     (HAR {hieAst}, pm) <-  handleGetHieAst  state nfp
214+     pure  $  getNamesAtPoint hieAst pos pm
205215
206- safeGetHieAst  :: 
216+ handleGetHieAst  :: 
207217    MonadIO  m  => 
208218    IdeState  -> 
209219    NormalizedFilePath  -> 
210220    ExceptT  String m  (HieAstResult , PositionMapping )
211- safeGetHieAst  state nfp =  handleMaybeM
221+ handleGetHieAst  state nfp =  handleMaybeM
212222    (" No AST for file: " ++  show  nfp)
213223    (liftIO $  runAction " Rename.GetHieAst" $  useWithStale GetHieAst  nfp)
214224
215- safeUriToNfp  ::  (Monad m ) =>  Uri  ->  ExceptT  String m  NormalizedFilePath 
216- safeUriToNfp  uri =  handleMaybe
225+ handleUriToNfp  ::  (Monad m ) =>  Uri  ->  ExceptT  String m  NormalizedFilePath 
226+ handleUriToNfp  uri =  handleMaybe
217227    (" No filepath for uri: " ++  show  uri)
218228    (toNormalizedFilePath <$>  uriToFilePath uri)
219229
@@ -230,15 +240,17 @@ nfpToUri = filePathToUri . fromNormalizedFilePath
230240showName  ::  Name  ->  String 
231241showName =  occNameString .  getOccName
232242
233- showPos  ::  Position  ->  String 
234- showPos Position {_line, _character} =  " line: " ++  show  _line ++  "  - character: " ++  show  _character
235- 
236243unsafeSrcSpanToLoc  ::  SrcSpan  ->  Location 
237244unsafeSrcSpanToLoc srcSpan = 
238245    case  srcSpanToLocation srcSpan of 
239246        Nothing        ->  error  " Invalid conversion from UnhelpfulSpan to Location" 
240247        Just  location ->  location
241248
249+ locToFilePos  ::  Location  ->  (NormalizedFilePath , Position )
250+ locToFilePos (Location  uri (Range  pos _)) =  (nfp, pos)
251+     where 
252+         Just  nfp =  (uriToNormalizedFilePath .  toNormalizedUri) uri
253+ 
242254replaceModName  ::  Name  ->  Maybe ModuleName  ->  Module 
243255replaceModName name mbModName = 
244256    mkModule (moduleUnit $  nameModule name) (fromMaybe (mkModuleName " Main" 
0 commit comments