diff --git a/core/commands/pin.go b/core/commands/pin.go
index 9311159ba58683a5cfba616cc25a5ca098fb3740..61956207816c3c98b30a93ba99e1dc5060ddae95 100644
--- a/core/commands/pin.go
+++ b/core/commands/pin.go
@@ -86,18 +86,16 @@ var addPinCmd = &cmds.Command{
 			return
 		}
 
-		if !recursive {
+		switch {
+		case !recursive:
 			maxDepth = 0
-		}
-
-		if recursive && maxDepth == 0 {
+		case recursive && maxDepth == 0:
 			res.SetError(
 				errors.New("invalid --max-depth=0. Use a direct pin instead"),
 				cmdkit.ErrNormal,
 			)
-		}
-
-		if recursive && maxDepth <= 0 {
+			return
+		case !recursive && maxDepth < 0:
 			maxDepth = -1
 		}
 
diff --git a/core/commands/refs.go b/core/commands/refs.go
index dda61153654c47c0108b0a4b793ca154aafcf9b3..03cdc8ddb78967c845f982831aecd5e4f886ebb1 100644
--- a/core/commands/refs.go
+++ b/core/commands/refs.go
@@ -11,6 +11,7 @@ import (
 	"github.com/ipfs/go-ipfs/core"
 	e "github.com/ipfs/go-ipfs/core/commands/e"
 	path "github.com/ipfs/go-ipfs/path"
+	"github.com/ipfs/go-ipfs/thirdparty/recpinset"
 
 	ipld "gx/ipfs/QmWi2BYBL5gJ3CiAiQchg6rn1A8iBsrWy51EYxvHVjFvLb/go-ipld-format"
 	cid "gx/ipfs/QmapdYm1b22Frv3k17fqrBYTFRxwiaVJkB299Mfn33edeB/go-cid"
@@ -64,6 +65,7 @@ NOTE: List all references recursively by using the flag '-r'.
 		cmdkit.BoolOption("edges", "e", "Emit edge format: `<from> -> <to>`."),
 		cmdkit.BoolOption("unique", "u", "Omit duplicate refs from output."),
 		cmdkit.BoolOption("recursive", "r", "Recursively list links of child nodes."),
+		cmdkit.IntOption("max-depth", "Only for recursive depths, list down to a maximum branch depth").WithDefault(-1),
 	},
 	Run: func(req cmds.Request, res cmds.Response) {
 		ctx := req.Context()
@@ -85,6 +87,25 @@ NOTE: List all references recursively by using the flag '-r'.
 			return
 		}
 
+		maxDepth, _, err := req.Option("max-depth").Int()
+		if err != nil {
+			res.SetError(err, cmdkit.ErrNormal)
+			return
+		}
+
+		switch {
+		case !recursive:
+			maxDepth = 0
+		case recursive && maxDepth == 0:
+			res.SetError(
+				errors.New("invalid --max-depth=0 for recursive references"),
+				cmdkit.ErrNormal,
+			)
+			return
+		case !recursive && maxDepth < 0:
+			maxDepth = -1
+		}
+
 		format, _, err := req.Option("format").String()
 		if err != nil {
 			res.SetError(err, cmdkit.ErrNormal)
@@ -125,6 +146,7 @@ NOTE: List all references recursively by using the flag '-r'.
 				Unique:    unique,
 				PrintFmt:  format,
 				Recursive: recursive,
+				MaxDepth:  maxDepth,
 			}
 
 			for _, o := range objs {
@@ -233,31 +255,45 @@ type RefWriter struct {
 
 	Unique    bool
 	Recursive bool
+	MaxDepth  int
 	PrintFmt  string
 
-	seen *cid.Set
+	explored *recpinset.Set
 }
 
 // WriteRefs writes refs of the given object to the underlying writer.
 func (rw *RefWriter) WriteRefs(n ipld.Node) (int, error) {
+	maxDepth := 1 // single
 	if rw.Recursive {
-		return rw.writeRefsRecursive(n)
+		maxDepth = rw.MaxDepth
 	}
-	return rw.writeRefsSingle(n)
+
+	return rw.writeRefsRecursive(n, maxDepth)
 }
 
-func (rw *RefWriter) writeRefsRecursive(n ipld.Node) (int, error) {
+func (rw *RefWriter) writeRefsRecursive(n ipld.Node, maxDepth int) (int, error) {
+	if maxDepth == 0 {
+		return 0, nil
+	}
+
+	if maxDepth > 0 {
+		maxDepth--
+	}
+
 	nc := n.Cid()
 
 	var count int
 	for i, ng := range ipld.GetDAG(rw.Ctx, rw.DAG, n) {
 		lc := n.Links()[i].Cid
-		if rw.skip(lc) {
+		skipWrite, skipExplore := rw.skip(lc, maxDepth)
+		if skipExplore {
 			continue
 		}
 
-		if err := rw.WriteEdge(nc, lc, n.Links()[i].Name); err != nil {
-			return count, err
+		if !skipWrite {
+			if err := rw.WriteEdge(nc, lc, n.Links()[i].Name); err != nil {
+				return count, err
+			}
 		}
 
 		nd, err := ng.Get(rw.Ctx)
@@ -265,7 +301,7 @@ func (rw *RefWriter) writeRefsRecursive(n ipld.Node) (int, error) {
 			return count, err
 		}
 
-		c, err := rw.writeRefsRecursive(nd)
+		c, err := rw.writeRefsRecursive(nd, maxDepth)
 		count += c
 		if err != nil {
 			return count, err
@@ -274,43 +310,20 @@ func (rw *RefWriter) writeRefsRecursive(n ipld.Node) (int, error) {
 	return count, nil
 }
 
-func (rw *RefWriter) writeRefsSingle(n ipld.Node) (int, error) {
-	c := n.Cid()
-
-	if rw.skip(c) {
-		return 0, nil
-	}
-
-	count := 0
-	for _, l := range n.Links() {
-		lc := l.Cid
-		if rw.skip(lc) {
-			continue
-		}
-
-		if err := rw.WriteEdge(c, lc, l.Name); err != nil {
-			return count, err
-		}
-		count++
-	}
-	return count, nil
-}
-
-// skip returns whether to skip a cid
-func (rw *RefWriter) skip(c *cid.Cid) bool {
+// skip returns whether a cid has been seen (skip write) and whether
+// a cid branch has been explored (skip explore)
+func (rw *RefWriter) skip(c *cid.Cid, depth int) (bool, bool) {
 	if !rw.Unique {
-		return false
+		return false, false
 	}
 
-	if rw.seen == nil {
-		rw.seen = cid.NewSet()
+	if rw.explored == nil {
+		rw.explored = recpinset.New()
 	}
 
-	has := rw.seen.Has(c)
-	if !has {
-		rw.seen.Add(c)
-	}
-	return has
+	skipWrite := rw.explored.Has(c)
+
+	return skipWrite, !rw.explored.Visit(c, depth)
 }
 
 // Write one edge